diff --git a/.asf.yaml b/.asf.yaml index 0588a300a5ca..805bb52456f4 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -62,6 +62,12 @@ github: branch-51: required_pull_request_reviews: required_approving_review_count: 1 + branch-52: + required_pull_request_reviews: + required_approving_review_count: 1 + branch-53: + required_pull_request_reviews: + required_approving_review_count: 1 pull_requests: # enable updating head branches of pull requests allow_update_branch: true diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 9d1d77d44c37..2cd4bdfdd792 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -23,6 +23,7 @@ updates: interval: weekly target-branch: main labels: [auto-dependencies] + open-pull-requests-limit: 15 ignore: # major version bumps of arrow* and parquet are handled manually - dependency-name: "arrow*" @@ -44,10 +45,27 @@ updates: patterns: - "prost*" - "pbjson*" + + # Catch-all: group only minor/patch into a single PR, + # excluding deps we want always separate (and excluding arrow/parquet which have their own group) + all-other-cargo-deps: + applies-to: version-updates + patterns: + - "*" + exclude-patterns: + - "arrow*" + - "parquet" + - "object_store" + - "sqlparser" + - "prost*" + - "pbjson*" + update-types: + - "minor" + - "patch" - package-ecosystem: "github-actions" directory: "/" schedule: - interval: "daily" + interval: "weekly" open-pull-requests-limit: 10 labels: [auto-dependencies] - package-ecosystem: "pip" diff --git a/.github/workflows/audit.yml b/.github/workflows/audit.yml index 066151babc91..691fd4f685e1 100644 --- a/.github/workflows/audit.yml +++ b/.github/workflows/audit.yml @@ -40,10 +40,12 @@ jobs: security_audit: runs-on: ubuntu-latest steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Install cargo-audit - uses: taiki-e/install-action@0e76c5c569f13f7eb21e8e5b26fe710062b57b62 # v2.65.13 + uses: taiki-e/install-action@cfdb446e391c69574ebc316dfb7d7849ec12b940 # v2.68.8 with: tool: cargo-audit - name: Run audit check + # Note: you can ignore specific RUSTSEC issues using the `--ignore` flag ,for example: + # run: cargo audit --ignore RUSTSEC-2026-0001 run: cargo audit diff --git a/.github/workflows/dependencies.yml b/.github/workflows/dependencies.yml index fef65870b697..3b2cc243d496 100644 --- a/.github/workflows/dependencies.yml +++ b/.github/workflows/dependencies.yml @@ -44,7 +44,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -62,8 +62,8 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Install cargo-machete run: cargo install cargo-machete --version ^0.9 --locked - name: Detect unused dependencies - run: cargo machete --with-metadata \ No newline at end of file + run: cargo machete --with-metadata diff --git a/.github/workflows/dev.yml b/.github/workflows/dev.yml index 1ec7c16b488f..2fec34365091 100644 --- a/.github/workflows/dev.yml +++ b/.github/workflows/dev.yml @@ -32,7 +32,7 @@ jobs: runs-on: ubuntu-latest name: Check License Header steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Install HawkEye # This CI job is bound by installation time, use `--profile dev` to speed it up run: cargo install hawkeye --version 6.2.0 --locked --profile dev @@ -43,8 +43,8 @@ jobs: name: Use prettier to check formatting of documents runs-on: ubuntu-latest steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - - uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # v6.1.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + - uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # v6.2.0 with: node-version: "20" - name: Prettier check @@ -55,7 +55,7 @@ jobs: name: Spell Check with Typos runs-on: ubuntu-latest steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false # Version fixed on purpose. It uses heuristics to detect typos, so upgrading diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 3e2c48643c36..529c6099fa31 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -32,32 +32,31 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout docs sources - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Checkout asf-site branch - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: asf-site path: asf-site - - name: Setup Python - uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 - with: - python-version: "3.12" + - name: Setup uv + uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7.3.0 - name: Install dependencies + run: uv sync --package datafusion-docs + - name: Install dependency graph tooling run: | set -x - python3 -m venv venv - source venv/bin/activate - pip install -r docs/requirements.txt + sudo apt-get update + sudo apt-get install -y graphviz + cargo install cargo-depgraph --version ^1.6 --locked - name: Build docs run: | set -x - source venv/bin/activate cd docs - ./build.sh + uv run --package datafusion-docs ./build.sh - name: Copy & push the generated HTML run: | diff --git a/.github/workflows/docs_pr.yaml b/.github/workflows/docs_pr.yaml index 81eeb4039ba9..63b87c2e6dd9 100644 --- a/.github/workflows/docs_pr.yaml +++ b/.github/workflows/docs_pr.yaml @@ -40,24 +40,22 @@ jobs: name: Test doc build runs-on: ubuntu-latest steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 - - name: Setup Python - uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 - with: - python-version: "3.12" + - name: Setup uv + uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7.3.0 - name: Install doc dependencies + run: uv sync --package datafusion-docs + - name: Install dependency graph tooling run: | set -x - python3 -m venv venv - source venv/bin/activate - pip install -r docs/requirements.txt + sudo apt-get update + sudo apt-get install -y graphviz + cargo install cargo-depgraph --version ^1.6 --locked - name: Build docs html and check for warnings run: | set -x - source venv/bin/activate cd docs - ./build.sh # fails on errors - + uv run --package datafusion-docs ./build.sh # fails on errors diff --git a/.github/workflows/extended.yml b/.github/workflows/extended.yml index 01de0d5b77a7..8f8597554b98 100644 --- a/.github/workflows/extended.yml +++ b/.github/workflows/extended.yml @@ -66,10 +66,11 @@ jobs: # Check crate compiles and base cargo check passes linux-build-lib: name: linux build test - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a,cpu=8,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} # note: do not use amd/rust container to preserve disk space steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ github.event.inputs.pr_head_sha }} # will be empty if triggered by push submodules: true @@ -80,7 +81,9 @@ jobs: source $HOME/.cargo/env rustup toolchain install - name: Install Protobuf Compiler - run: sudo apt-get install -y protobuf-compiler + run: | + sudo apt-get update + sudo apt-get install -y protobuf-compiler - name: Prepare cargo build run: | cargo check --profile ci --all-targets @@ -90,10 +93,12 @@ jobs: linux-test-extended: name: cargo test 'extended_tests' (amd64) needs: [linux-build-lib] - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a,cpu=32,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion,spot=false', github.run_id) || 'ubuntu-latest' }} + # spot=false because the tests are long, https://runs-on.com/configuration/spot-instances/#disable-spot-pricing # note: do not use amd/rust container to preserve disk space steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ github.event.inputs.pr_head_sha }} # will be empty if triggered by push submodules: true @@ -106,7 +111,9 @@ jobs: source $HOME/.cargo/env rustup toolchain install - name: Install Protobuf Compiler - run: sudo apt-get install -y protobuf-compiler + run: | + sudo apt-get update + sudo apt-get install -y protobuf-compiler # For debugging, test binaries can be large. - name: Show available disk space run: | @@ -133,11 +140,12 @@ jobs: # Check answers are correct when hash values collide hash-collisions: name: cargo test hash collisions (amd64) - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ github.event.inputs.pr_head_sha }} # will be empty if triggered by push submodules: true @@ -154,19 +162,21 @@ jobs: sqllogictest-sqlite: name: "Run sqllogictests with the sqlite test suite" - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a,cpu=48,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion,spot=false', github.run_id) || 'ubuntu-latest' }} + # spot=false because the tests are long, https://runs-on.com/configuration/spot-instances/#disable-spot-pricing container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ github.event.inputs.pr_head_sha }} # will be empty if triggered by push submodules: true fetch-depth: 1 - - name: Setup Rust toolchain - uses: ./.github/actions/setup-builder - with: - rust-version: stable + # Don't use setup-builder to avoid configuring RUST_BACKTRACE which is expensive + - name: Install protobuf compiler + run: | + apt-get update && apt-get install -y protobuf-compiler - name: Run sqllogictest run: | cargo test --features backtrace,parquet_encryption --profile release-nonlto --test sqllogictests -- --include-sqlite diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 01e21115010f..06c58cd802e5 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -39,7 +39,7 @@ jobs: contents: read pull-requests: write steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Assign GitHub labels if: | diff --git a/.github/workflows/labeler/labeler-config.yml b/.github/workflows/labeler/labeler-config.yml index 38d88059dab7..0e492b6f3f6d 100644 --- a/.github/workflows/labeler/labeler-config.yml +++ b/.github/workflows/labeler/labeler-config.yml @@ -62,7 +62,7 @@ datasource: functions: - changed-files: - - any-glob-to-any-file: ['datafusion/functions/**/*', 'datafusion/functions-aggregate/**/*', 'datafusion/functions-aggregate-common', 'datafusion/functions-nested', 'datafusion/functions-table/**/*', 'datafusion/functions-window/**/*', 'datafusion/functions-window-common/**/*'] + - any-glob-to-any-file: ['datafusion/functions/**/*', 'datafusion/functions-aggregate/**/*', 'datafusion/functions-aggregate-common/**/*', 'datafusion/functions-nested/**/*', 'datafusion/functions-table/**/*', 'datafusion/functions-window/**/*', 'datafusion/functions-window-common/**/*'] optimizer: diff --git a/.github/workflows/large_files.yml b/.github/workflows/large_files.yml index b96b8cd4544e..12b7bae76ab3 100644 --- a/.github/workflows/large_files.yml +++ b/.github/workflows/large_files.yml @@ -29,7 +29,7 @@ jobs: check-files: runs-on: ubuntu-latest steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: fetch-depth: 0 - name: Check size of new Git objects diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index fd00e25c03a5..6194262e40f3 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +# For some actions, we use Runs-On to run them on ASF infrastructure: https://datafusion.apache.org/contributor-guide/#ci-runners + name: Rust concurrency: @@ -45,11 +47,12 @@ jobs: # Check crate compiles and base cargo check passes linux-build-lib: name: linux build test - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a,cpu=8,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -77,7 +80,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -98,11 +101,11 @@ jobs: linux-datafusion-substrait-features: name: cargo check datafusion-substrait features needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -135,11 +138,12 @@ jobs: linux-datafusion-proto-features: name: cargo check datafusion-proto features needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -166,11 +170,12 @@ jobs: linux-cargo-check-datafusion: name: cargo check datafusion features needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -235,7 +240,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -266,13 +271,14 @@ jobs: linux-test: name: cargo test (amd64) needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust volumes: - /usr/local:/host/usr/local steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -316,9 +322,10 @@ jobs: linux-test-datafusion-cli: name: cargo test datafusion-cli (amd64) needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -345,11 +352,12 @@ jobs: linux-test-example: name: cargo examples (amd64) needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -375,11 +383,12 @@ jobs: linux-test-doc: name: cargo test doc (amd64) needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -396,11 +405,12 @@ jobs: linux-rustdoc: name: cargo doc needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -412,7 +422,7 @@ jobs: name: build and run with wasm-pack runs-on: ubuntu-24.04 steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup for wasm32 run: | rustup target add wasm32-unknown-unknown @@ -421,7 +431,7 @@ jobs: sudo apt-get update -qq sudo apt-get install -y -qq clang - name: Setup wasm-pack - uses: taiki-e/install-action@0e76c5c569f13f7eb21e8e5b26fe710062b57b62 # v2.65.13 + uses: taiki-e/install-action@cfdb446e391c69574ebc316dfb7d7849ec12b940 # v2.68.8 with: tool: wasm-pack - name: Run tests with headless mode @@ -436,11 +446,12 @@ jobs: verify-benchmark-results: name: verify benchmark results (amd64) needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -469,7 +480,7 @@ jobs: sqllogictest-postgres: name: "Run sqllogictest with Postgres runner" needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust services: @@ -487,7 +498,8 @@ jobs: --health-timeout 5s --health-retries 5 steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -507,11 +519,12 @@ jobs: sqllogictest-substrait: name: "Run sqllogictest in Substrait round-trip mode" needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -547,9 +560,9 @@ jobs: macos-aarch64: name: cargo test (macos-aarch64) - runs-on: macos-14 + runs-on: macos-15 steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -565,7 +578,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -582,7 +595,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -637,11 +650,12 @@ jobs: clippy: name: clippy needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -666,7 +680,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -683,11 +697,12 @@ jobs: config-docs-check: name: check configs.md and ***_functions.md is up-to-date needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -695,7 +710,7 @@ jobs: uses: ./.github/actions/setup-builder with: rust-version: stable - - uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # v6.1.0 + - uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # v6.2.0 with: node-version: "20" - name: Check if configs.md has been modified @@ -709,6 +724,11 @@ jobs: ./dev/update_function_docs.sh git diff --exit-code +# This job ensures `datafusion-examples/README.md` stays in sync with the source code: +# 1. Generates README automatically using the Rust examples docs generator +# (parsing documentation from `examples//main.rs`) +# 2. Formats the generated Markdown using DataFusion's standard Prettier setup +# 3. Compares the result against the committed README.md and fails if out-of-date examples-docs-check: name: check example README is up-to-date needs: linux-build-lib @@ -717,10 +737,20 @@ jobs: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 + + - name: Mark repository as safe for git + # Required for git commands inside container (avoids "dubious ownership" error) + run: git config --global --add safe.directory "$GITHUB_WORKSPACE" + + - name: Set up Node.js (required for prettier) + # doc_prettier_check.sh uses npx to run prettier for Markdown formatting + uses: actions/setup-node@v6 + with: + node-version: '18' - name: Run examples docs check script run: | @@ -737,11 +767,11 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Install cargo-msrv - uses: taiki-e/install-action@0e76c5c569f13f7eb21e8e5b26fe710062b57b62 # v2.65.13 + uses: taiki-e/install-action@cfdb446e391c69574ebc316dfb7d7849ec12b940 # v2.68.8 with: tool: cargo-msrv @@ -778,4 +808,4 @@ jobs: run: cargo msrv --output-format json --log-target stdout verify - name: Check datafusion-proto working-directory: datafusion/proto - run: cargo msrv --output-format json --log-target stdout verify \ No newline at end of file + run: cargo msrv --output-format json --log-target stdout verify diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 2aba1085b832..ec7f54ec24db 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -27,7 +27,7 @@ jobs: issues: write pull-requests: write steps: - - uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1 + - uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10.2.0 with: stale-pr-message: "Thank you for your contribution. Unfortunately, this pull request is stale because it has been open 60 days with no activity. Please remove the stale label or comment or this will be closed in 7 days." days-before-pr-stale: 60 diff --git a/Cargo.lock b/Cargo.lock index 8dcfbc65c21b..5a85136554ce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -14,7 +14,7 @@ dependencies = [ "core_extensions", "crossbeam-channel", "generational-arena", - "libloading 0.7.4", + "libloading", "lock_api", "parking_lot", "paste", @@ -56,17 +56,6 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" -[[package]] -name = "ahash" -version = "0.7.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "891477e0c6a8957309ee5c45a6368af3ae14bb510732d2684ffa19af310920f9" -dependencies = [ - "getrandom 0.2.16", - "once_cell", - "version_check", -] - [[package]] name = "ahash" version = "0.8.12" @@ -83,9 +72,9 @@ dependencies = [ [[package]] name = "aho-corasick" -version = "1.1.3" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" dependencies = [ "memchr", ] @@ -137,12 +126,27 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstream" -version = "0.6.20" +version = "0.6.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" +dependencies = [ + "anstyle", + "anstyle-parse 0.2.7", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstream" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ae563653d1938f79b1ab1b5e668c87c76a9930414574a6583a7b7e11a8e6192" +checksum = "824a212faf96e9acacdbd09febd34438f8f711fb84e09a8916013cd7815ca28d" dependencies = [ "anstyle", - "anstyle-parse", + "anstyle-parse 1.0.0", "anstyle-query", "anstyle-wincon", "colorchoice", @@ -152,9 +156,9 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.11" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "862ed96ca487e809f1c8e5a8447f6ee2cf102f846893800b20cebdf541fc6bbd" +checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000" [[package]] name = "anstyle-parse" @@ -165,31 +169,40 @@ dependencies = [ "utf8parse", ] +[[package]] +name = "anstyle-parse" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52ce7f38b242319f7cabaa6813055467063ecdc9d355bbb4ce0c68908cd8130e" +dependencies = [ + "utf8parse", +] + [[package]] name = "anstyle-query" -version = "1.1.4" +version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e231f6134f61b71076a3eab506c379d4f36122f2af15a9ff04415ea4c3339e2" +checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" dependencies = [ - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] name = "anstyle-wincon" -version = "3.0.10" +version = "3.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e0633414522a32ffaac8ac6cc8f748e090c5717661fddeea04219e2344f5f2a" +checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" dependencies = [ "anstyle", "once_cell_polyfill", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] name = "anyhow" -version = "1.0.100" +version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" [[package]] name = "apache-avro" @@ -220,6 +233,15 @@ dependencies = [ "zstd", ] +[[package]] +name = "ar_archive_writer" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7eb93bbb63b9c227414f6eb3a0adfddca591a8ce1e9b60661bb08969b87e340b" +dependencies = [ + "object", +] + [[package]] name = "arrayref" version = "0.3.9" @@ -234,9 +256,9 @@ checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "arrow" -version = "57.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb372a7cbcac02a35d3fb7b3fc1f969ec078e871f9bb899bf00a2e1809bec8a3" +checksum = "602268ce9f569f282cedb9a9f6bac569b680af47b9b077d515900c03c5d190da" dependencies = [ "arrow-arith", "arrow-array", @@ -257,9 +279,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "57.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f377dcd19e440174596d83deb49cd724886d91060c07fec4f67014ef9d54049" +checksum = "cd53c6bf277dea91f136ae8e3a5d7041b44b5e489e244e637d00ae302051f56f" dependencies = [ "arrow-array", "arrow-buffer", @@ -271,11 +293,11 @@ dependencies = [ [[package]] name = "arrow-array" -version = "57.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a23eaff85a44e9fa914660fb0d0bb00b79c4a3d888b5334adb3ea4330c84f002" +checksum = "e53796e07a6525edaf7dc28b540d477a934aff14af97967ad1d5550878969b9e" dependencies = [ - "ahash 0.8.12", + "ahash", "arrow-buffer", "arrow-data", "arrow-schema", @@ -290,9 +312,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "57.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2819d893750cb3380ab31ebdc8c68874dd4429f90fd09180f3c93538bd21626" +checksum = "f2c1a85bb2e94ee10b76531d8bc3ce9b7b4c0d508cabfb17d477f63f2617bd20" dependencies = [ "bytes", "half", @@ -302,9 +324,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "57.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3d131abb183f80c450d4591dc784f8d7750c50c6e2bc3fcaad148afc8361271" +checksum = "89fb245db6b0e234ed8e15b644edb8664673fefe630575e94e62cd9d489a8a26" dependencies = [ "arrow-array", "arrow-buffer", @@ -324,9 +346,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "57.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2275877a0e5e7e7c76954669366c2aa1a829e340ab1f612e647507860906fb6b" +checksum = "d374882fb465a194462527c0c15a93aa19a554cf690a6b77a26b2a02539937a7" dependencies = [ "arrow-array", "arrow-cast", @@ -339,9 +361,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "57.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05738f3d42cb922b9096f7786f606fcb8669260c2640df8490533bb2fa38c9d3" +checksum = "189d210bc4244c715fa3ed9e6e22864673cccb73d5da28c2723fb2e527329b33" dependencies = [ "arrow-buffer", "arrow-schema", @@ -352,9 +374,9 @@ dependencies = [ [[package]] name = "arrow-flight" -version = "57.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b5f57c3d39d1b1b7c1376a772ea86a131e7da310aed54ebea9363124bb885e3" +checksum = "b4f5cdf00ee0003ba0768d3575d0afc47d736b29673b14c3c228fdffa9a3fb29" dependencies = [ "arrow-arith", "arrow-array", @@ -380,9 +402,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "57.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d09446e8076c4b3f235603d9ea7c5494e73d441b01cd61fb33d7254c11964b3" +checksum = "7968c2e5210c41f4909b2ef76f6e05e172b99021c2def5edf3cc48fdd39d1d6c" dependencies = [ "arrow-array", "arrow-buffer", @@ -396,9 +418,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "57.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "371ffd66fa77f71d7628c63f209c9ca5341081051aa32f9c8020feb0def787c0" +checksum = "92111dba5bf900f443488e01f00d8c4ddc2f47f5c50039d18120287b580baa22" dependencies = [ "arrow-array", "arrow-buffer", @@ -407,7 +429,7 @@ dependencies = [ "arrow-schema", "chrono", "half", - "indexmap 2.12.1", + "indexmap 2.13.0", "itoa", "lexical-core", "memchr", @@ -420,9 +442,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "57.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cbc94fc7adec5d1ba9e8cd1b1e8d6f72423b33fe978bf1f46d970fafab787521" +checksum = "211136cb253577ee1a6665f741a13136d4e563f64f5093ffd6fb837af90b9495" dependencies = [ "arrow-array", "arrow-buffer", @@ -433,9 +455,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "57.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "169676f317157dc079cc5def6354d16db63d8861d61046d2f3883268ced6f99f" +checksum = "8e0f20145f9f5ea3fe383e2ba7a7487bf19be36aa9dbf5dd6a1f92f657179663" dependencies = [ "arrow-array", "arrow-buffer", @@ -446,9 +468,9 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "57.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d27609cd7dd45f006abae27995c2729ef6f4b9361cde1ddd019dc31a5aa017e0" +checksum = "1b47e0ca91cc438d2c7879fe95e0bca5329fff28649e30a88c6f760b1faeddcb" dependencies = [ "bitflags", "serde", @@ -458,11 +480,11 @@ dependencies = [ [[package]] name = "arrow-select" -version = "57.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae980d021879ea119dd6e2a13912d81e64abed372d53163e804dfe84639d8010" +checksum = "750a7d1dda177735f5e82a314485b6915c7cccdbb278262ac44090f4aba4a325" dependencies = [ - "ahash 0.8.12", + "ahash", "arrow-array", "arrow-buffer", "arrow-data", @@ -472,9 +494,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "57.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf35e8ef49dcf0c5f6d175edee6b8af7b45611805333129c541a8b89a0fc0534" +checksum = "e1eab1208bc4fe55d768cdc9b9f3d9df5a794cdb3ee2586bf89f9b30dc31ad8c" dependencies = [ "arrow-array", "arrow-buffer", @@ -501,9 +523,9 @@ dependencies = [ [[package]] name = "astral-tokio-tar" -version = "0.5.6" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec179a06c1769b1e42e1e2cbe74c7dcdb3d6383c838454d063eaac5bbb7ebbe5" +checksum = "3c23f3af104b40a3430ccb90ed5f7bd877a8dc5c26fc92fde51a22b40890dcf9" dependencies = [ "filetime", "futures-core", @@ -517,13 +539,12 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.35" +version = "0.4.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07a926debf178f2d355197f9caddb08e54a9329d44748034bba349c5848cb519" +checksum = "d0f9ee0f6e02ffd7ad5816e9464499fba7b3effd01123b515c41d1697c43dad1" dependencies = [ "compression-codecs", "compression-core", - "futures-core", "pin-project-lite", "tokio", ] @@ -545,7 +566,7 @@ checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -567,7 +588,7 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -578,7 +599,7 @@ checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -604,9 +625,9 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "aws-config" -version = "1.8.12" +version = "1.8.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96571e6996817bf3d58f6b569e4b9fd2e9d2fcf9f7424eed07b2ce9bb87535e5" +checksum = "11493b0bad143270fb8ad284a096dd529ba91924c5409adeac856cc1bf047dbc" dependencies = [ "aws-credential-types", "aws-runtime", @@ -623,8 +644,8 @@ dependencies = [ "bytes", "fastrand", "hex", - "http 1.3.1", - "ring", + "http 1.4.0", + "sha1", "time", "tokio", "tracing", @@ -634,9 +655,9 @@ dependencies = [ [[package]] name = "aws-credential-types" -version = "1.2.11" +version = "1.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3cd362783681b15d136480ad555a099e82ecd8e2d10a841e14dfd0078d67fee3" +checksum = "8f20799b373a1be121fe3005fba0c2090af9411573878f224df44b42727fcaf7" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", @@ -646,9 +667,9 @@ dependencies = [ [[package]] name = "aws-lc-rs" -version = "1.14.0" +version = "1.16.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94b8ff6c09cd57b16da53641caa860168b88c172a5ee163b0288d3d6eea12786" +checksum = "a054912289d18629dc78375ba2c3726a3afe3ff71b4edba9dedfca0e3446d1fc" dependencies = [ "aws-lc-sys", "zeroize", @@ -656,11 +677,10 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.31.0" +version = "0.39.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e44d16778acaf6a9ec9899b92cebd65580b83f685446bf2e1f5d3d732f99dcd" +checksum = "83a25cf98105baa966497416dbd42565ce3a8cf8dbfd59803ec9ad46f3126399" dependencies = [ - "bindgen", "cc", "cmake", "dunce", @@ -669,9 +689,9 @@ dependencies = [ [[package]] name = "aws-runtime" -version = "1.5.17" +version = "1.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d81b5b2898f6798ad58f484856768bca817e3cd9de0974c24ae0f1113fe88f1b" +checksum = "5fc0651c57e384202e47153c1260b84a9936e19803d747615edf199dc3b98d17" dependencies = [ "aws-credential-types", "aws-sigv4", @@ -682,9 +702,10 @@ dependencies = [ "aws-smithy-types", "aws-types", "bytes", + "bytes-utils", "fastrand", - "http 0.2.12", - "http-body 0.4.6", + "http 1.4.0", + "http-body 1.0.1", "percent-encoding", "pin-project-lite", "tracing", @@ -693,15 +714,16 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.91.0" +version = "1.96.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ee6402a36f27b52fe67661c6732d684b2635152b676aa2babbfb5204f99115d" +checksum = "f64a6eded248c6b453966e915d32aeddb48ea63ad17932682774eb026fbef5b1" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", "aws-smithy-http", "aws-smithy-json", + "aws-smithy-observability", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", @@ -709,21 +731,23 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", + "http 1.4.0", "regex-lite", "tracing", ] [[package]] name = "aws-sdk-ssooidc" -version = "1.93.0" +version = "1.98.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a45a7f750bbd170ee3677671ad782d90b894548f4e4ae168302c57ec9de5cb3e" +checksum = "db96d720d3c622fcbe08bae1c4b04a72ce6257d8b0584cb5418da00ae20a344f" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", "aws-smithy-http", "aws-smithy-json", + "aws-smithy-observability", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", @@ -731,21 +755,23 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", + "http 1.4.0", "regex-lite", "tracing", ] [[package]] name = "aws-sdk-sts" -version = "1.95.0" +version = "1.100.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55542378e419558e6b1f398ca70adb0b2088077e79ad9f14eb09441f2f7b2164" +checksum = "fafbdda43b93f57f699c5dfe8328db590b967b8a820a13ccdd6687355dfcc7ca" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", "aws-smithy-http", "aws-smithy-json", + "aws-smithy-observability", "aws-smithy-query", "aws-smithy-runtime", "aws-smithy-runtime-api", @@ -754,15 +780,16 @@ dependencies = [ "aws-types", "fastrand", "http 0.2.12", + "http 1.4.0", "regex-lite", "tracing", ] [[package]] name = "aws-sigv4" -version = "1.3.7" +version = "1.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69e523e1c4e8e7e8ff219d732988e22bfeae8a1cafdbe6d9eca1546fa080be7c" +checksum = "b0b660013a6683ab23797778e21f1f854744fdf05f68204b4cca4c8c04b5d1f4" dependencies = [ "aws-credential-types", "aws-smithy-http", @@ -773,7 +800,7 @@ dependencies = [ "hex", "hmac", "http 0.2.12", - "http 1.3.1", + "http 1.4.0", "percent-encoding", "sha2", "time", @@ -782,9 +809,9 @@ dependencies = [ [[package]] name = "aws-smithy-async" -version = "1.2.7" +version = "1.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ee19095c7c4dda59f1697d028ce704c24b2d33c6718790c7f1d5a3015b4107c" +checksum = "2ffcaf626bdda484571968400c326a244598634dc75fd451325a54ad1a59acfc" dependencies = [ "futures-util", "pin-project-lite", @@ -793,9 +820,9 @@ dependencies = [ [[package]] name = "aws-smithy-http" -version = "0.62.6" +version = "0.63.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "826141069295752372f8203c17f28e30c464d22899a43a0c9fd9c458d469c88b" +checksum = "ba1ab2dc1c2c3749ead27180d333c42f11be8b0e934058fb4b2258ee8dbe5231" dependencies = [ "aws-smithy-runtime-api", "aws-smithy-types", @@ -803,9 +830,9 @@ dependencies = [ "bytes-utils", "futures-core", "futures-util", - "http 0.2.12", - "http 1.3.1", - "http-body 0.4.6", + "http 1.4.0", + "http-body 1.0.1", + "http-body-util", "percent-encoding", "pin-project-lite", "pin-utils", @@ -814,15 +841,15 @@ dependencies = [ [[package]] name = "aws-smithy-http-client" -version = "1.1.5" +version = "1.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59e62db736db19c488966c8d787f52e6270be565727236fd5579eaa301e7bc4a" +checksum = "6a2f165a7feee6f263028b899d0a181987f4fa7179a6411a32a439fba7c5f769" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", "aws-smithy-types", "h2", - "http 1.3.1", + "http 1.4.0", "hyper", "hyper-rustls", "hyper-util", @@ -838,27 +865,27 @@ dependencies = [ [[package]] name = "aws-smithy-json" -version = "0.61.9" +version = "0.62.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49fa1213db31ac95288d981476f78d05d9cbb0353d22cdf3472cc05bb02f6551" +checksum = "9648b0bb82a2eedd844052c6ad2a1a822d1f8e3adee5fbf668366717e428856a" dependencies = [ "aws-smithy-types", ] [[package]] name = "aws-smithy-observability" -version = "0.1.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17f616c3f2260612fe44cede278bafa18e73e6479c4e393e2c4518cf2a9a228a" +checksum = "a06c2315d173edbf1920da8ba3a7189695827002e4c0fc961973ab1c54abca9c" dependencies = [ "aws-smithy-runtime-api", ] [[package]] name = "aws-smithy-query" -version = "0.60.9" +version = "0.60.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae5d689cf437eae90460e944a58b5668530d433b4ff85789e69d2f2a556e057d" +checksum = "1a56d79744fb3edb5d722ef79d86081e121d3b9422cb209eb03aea6aa4f21ebd" dependencies = [ "aws-smithy-types", "urlencoding", @@ -866,9 +893,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.9.6" +version = "1.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65fda37911905ea4d3141a01364bc5509a0f32ae3f3b22d6e330c0abfb62d247" +checksum = "028999056d2d2fd58a697232f9eec4a643cf73a71cf327690a7edad1d2af2110" dependencies = [ "aws-smithy-async", "aws-smithy-http", @@ -879,9 +906,10 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "http 1.3.1", + "http 1.4.0", "http-body 0.4.6", "http-body 1.0.1", + "http-body-util", "pin-project-lite", "pin-utils", "tokio", @@ -890,15 +918,15 @@ dependencies = [ [[package]] name = "aws-smithy-runtime-api" -version = "1.9.3" +version = "1.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab0d43d899f9e508300e587bf582ba54c27a452dd0a9ea294690669138ae14a2" +checksum = "876ab3c9c29791ba4ba02b780a3049e21ec63dabda09268b175272c3733a79e6" dependencies = [ "aws-smithy-async", "aws-smithy-types", "bytes", "http 0.2.12", - "http 1.3.1", + "http 1.4.0", "pin-project-lite", "tokio", "tracing", @@ -907,15 +935,15 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.3.5" +version = "1.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "905cb13a9895626d49cf2ced759b062d913834c7482c38e49557eac4e6193f01" +checksum = "9d73dbfbaa8e4bc57b9045137680b958d274823509a360abfd8e1d514d40c95c" dependencies = [ "base64-simd", "bytes", "bytes-utils", "http 0.2.12", - "http 1.3.1", + "http 1.4.0", "http-body 0.4.6", "http-body 1.0.1", "http-body-util", @@ -930,18 +958,18 @@ dependencies = [ [[package]] name = "aws-smithy-xml" -version = "0.60.13" +version = "0.60.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11b2f670422ff42bf7065031e72b45bc52a3508bd089f743ea90731ca2b6ea57" +checksum = "0ce02add1aa3677d022f8adf81dcbe3046a95f17a1b1e8979c145cd21d3d22b3" dependencies = [ "xmlparser", ] [[package]] name = "aws-types" -version = "1.3.11" +version = "1.3.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d980627d2dd7bfc32a3c025685a033eeab8d365cc840c631ef59d1b8f428164" +checksum = "47c8323699dd9b3c8d5b3c13051ae9cdef58fd179957c882f8374dd8725962d9" dependencies = [ "aws-credential-types", "aws-smithy-async", @@ -953,14 +981,14 @@ dependencies = [ [[package]] name = "axum" -version = "0.8.4" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5" +checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" dependencies = [ "axum-core", "bytes", "futures-util", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "http-body-util", "itoa", @@ -969,8 +997,7 @@ dependencies = [ "mime", "percent-encoding", "pin-project-lite", - "rustversion", - "serde", + "serde_core", "sync_wrapper", "tower", "tower-layer", @@ -979,18 +1006,17 @@ dependencies = [ [[package]] name = "axum-core" -version = "0.5.2" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68464cd0412f486726fb3373129ef5d2993f90c34bc2bc1c1e9943b2f4fc7ca6" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" dependencies = [ "bytes", "futures-core", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "http-body-util", "mime", "pin-project-lite", - "rustversion", "sync_wrapper", "tower-layer", "tower-service", @@ -1020,9 +1046,9 @@ dependencies = [ [[package]] name = "bigdecimal" -version = "0.4.9" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "560f42649de9fa436b73517378a147ec21f6c997a546581df4b4b31677828934" +checksum = "4d6867f1565b3aad85681f1015055b087fcfd840d6aeee6eee7f2da317603695" dependencies = [ "autocfg", "libm", @@ -1032,43 +1058,11 @@ dependencies = [ "serde", ] -[[package]] -name = "bindgen" -version = "0.72.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895" -dependencies = [ - "bitflags", - "cexpr", - "clang-sys", - "itertools 0.13.0", - "log", - "prettyplease", - "proc-macro2", - "quote", - "regex", - "rustc-hash", - "shlex", - "syn 2.0.113", -] - [[package]] name = "bitflags" -version = "2.9.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2261d10cca569e4643e526d8dc2e62e433cc8aba21ab764233731f8d369bf394" - -[[package]] -name = "bitvec" -version = "1.0.1" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" -dependencies = [ - "funty", - "radium", - "tap", - "wyz", -] +checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" [[package]] name = "blake2" @@ -1081,15 +1075,16 @@ dependencies = [ [[package]] name = "blake3" -version = "1.8.2" +version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3888aaa89e4b2a40fca9848e400f6a658a5a3978de7be858e209cafa8be9a4a0" +checksum = "2468ef7d57b3fb7e16b576e8377cdbde2320c60e1491e961d11da40fc4f02a2d" dependencies = [ "arrayref", "arrayvec", "cc", "cfg-if", "constant_time_eq", + "cpufeatures", ] [[package]] @@ -1103,9 +1098,9 @@ dependencies = [ [[package]] name = "bollard" -version = "0.19.4" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87a52479c9237eb04047ddb94788c41ca0d26eaff8b697ecfbb4c32f7fdc3b1b" +checksum = "ee04c4c84f1f811b017f2fbb7dd8815c976e7ca98593de9c1e2afad0f636bff4" dependencies = [ "async-stream", "base64 0.22.1", @@ -1113,12 +1108,11 @@ dependencies = [ "bollard-buildkit-proto", "bollard-stubs", "bytes", - "chrono", "futures-core", "futures-util", "hex", "home", - "http 1.3.1", + "http 1.4.0", "http-body-util", "hyper", "hyper-named-pipe", @@ -1131,14 +1125,13 @@ dependencies = [ "rand 0.9.2", "rustls", "rustls-native-certs", - "rustls-pemfile", "rustls-pki-types", "serde", "serde_derive", "serde_json", - "serde_repr", "serde_urlencoded", "thiserror", + "time", "tokio", "tokio-stream", "tokio-util", @@ -1163,26 +1156,25 @@ dependencies = [ [[package]] name = "bollard-stubs" -version = "1.49.1-rc.28.4.0" +version = "1.52.1-rc.29.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5731fe885755e92beff1950774068e0cae67ea6ec7587381536fca84f1779623" +checksum = "0f0a8ca8799131c1837d1282c3f81f31e76ceb0ce426e04a7fe1ccee3287c066" dependencies = [ "base64 0.22.1", "bollard-buildkit-proto", "bytes", - "chrono", "prost", "serde", "serde_json", "serde_repr", - "serde_with", + "time", ] [[package]] name = "bon" -version = "3.8.1" +version = "3.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebeb9aaf9329dff6ceb65c689ca3db33dbf15f324909c60e4e5eef5701ce31b1" +checksum = "f47dbe92550676ee653353c310dfb9cf6ba17ee70396e1f7cf0a2020ad49b2fe" dependencies = [ "bon-macros", "rustversion", @@ -1190,9 +1182,9 @@ dependencies = [ [[package]] name = "bon-macros" -version = "3.8.1" +version = "3.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77e9d642a7e3a318e37c2c9427b5a6a48aa1ad55dcd986f3034ab2239045a645" +checksum = "519bd3116aeeb42d5372c29d982d16d0170d3d4a5ed85fc7dd91642ffff3c67c" dependencies = [ "darling", "ident_case", @@ -1200,30 +1192,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.113", -] - -[[package]] -name = "borsh" -version = "1.5.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad8646f98db542e39fc66e68a20b2144f6a732636df7c2354e74645faaa433ce" -dependencies = [ - "borsh-derive", - "cfg_aliases", -] - -[[package]] -name = "borsh-derive" -version = "1.5.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdd1d3c0c2f5833f22386f252fe8ed005c7f59fdcddeef025c01b4c3b9fd9ac3" -dependencies = [ - "once_cell", - "proc-macro-crate", - "proc-macro2", - "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -1249,9 +1218,9 @@ dependencies = [ [[package]] name = "bstr" -version = "1.12.0" +version = "1.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "234113d19d0d7d613b40e86fb654acf958910802bcceab913a4f9e7cda03b1a4" +checksum = "63044e1ae8e69f3b5a92c736ca6269b8d12fa7efe39bf34ddb06d102cf0e2cab" dependencies = [ "memchr", "serde", @@ -1259,31 +1228,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" - -[[package]] -name = "bytecheck" -version = "0.6.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23cdc57ce23ac53c931e88a43d06d070a6fd142f2617be5855eb75efc9beb1c2" -dependencies = [ - "bytecheck_derive", - "ptr_meta", - "simdutf8", -] - -[[package]] -name = "bytecheck_derive" -version = "0.6.12" +version = "3.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3db406d29fbcd95542e92559bed4d8ad92636d1ca8b3b72ede10b4bcc010e659" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] +checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" [[package]] name = "byteorder" @@ -1293,9 +1240,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.11.0" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" [[package]] name = "bytes-utils" @@ -1324,9 +1271,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.38" +version = "1.2.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80f41ae168f955c12fb8960b057d70d0ca153fb83182b57d86380443527be7e9" +checksum = "7a0dd1ca384932ff3641c8718a02769f1698e7563dc6974ffd03346116310423" dependencies = [ "find-msvc-tools", "jobserver", @@ -1334,20 +1281,11 @@ dependencies = [ "shlex", ] -[[package]] -name = "cexpr" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" -dependencies = [ - "nom", -] - [[package]] name = "cfg-if" -version = "1.0.3" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" [[package]] name = "cfg_aliases" @@ -1357,16 +1295,16 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chrono" -version = "0.4.42" +version = "0.4.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "145052bdd345b87320e369255277e3fb5152762ad123a901ef5c262dd38fe8d2" +checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" dependencies = [ "iana-time-zone", "js-sys", "num-traits", "serde", "wasm-bindgen", - "windows-link 0.2.0", + "windows-link", ] [[package]] @@ -1406,22 +1344,11 @@ dependencies = [ "half", ] -[[package]] -name = "clang-sys" -version = "1.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" -dependencies = [ - "glob", - "libc", - "libloading 0.8.9", -] - [[package]] name = "clap" -version = "4.5.53" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9e340e012a1bf4935f5282ed1436d1489548e8f72308207ea5df0e23d2d03f8" +checksum = "b193af5b67834b676abd72466a96c1024e6a6ad978a1f484bd90b85c94041351" dependencies = [ "clap_builder", "clap_derive", @@ -1429,11 +1356,11 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.53" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d76b5d13eaa18c901fd2f7fca939fefe3a0727a953561fefdf3b2922b8569d00" +checksum = "714a53001bf66416adb0e2ef5ac857140e7dc3a0c48fb28b2f10762fc4b5069f" dependencies = [ - "anstream", + "anstream 1.0.0", "anstyle", "clap_lex", "strsim", @@ -1441,21 +1368,21 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.49" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a0b5487afeab2deb2ff4e03a807ad1a03ac532ff5a2cee5d86884440c7f7671" +checksum = "1110bd8a634a1ab8cb04345d8d878267d57c3cf1b38d91b71af6686408bbca6a" dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] name = "clap_lex" -version = "0.7.5" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675" +checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9" [[package]] name = "clipboard-win" @@ -1468,35 +1395,34 @@ dependencies = [ [[package]] name = "cmake" -version = "0.1.54" +version = "0.1.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +checksum = "75443c44cd6b379beb8c5b45d85d0773baf31cce901fe7bb252f4eff3008ef7d" dependencies = [ "cc", ] [[package]] name = "colorchoice" -version = "1.0.4" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" +checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570" [[package]] name = "comfy-table" -version = "7.1.2" +version = "7.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0d05af1e006a2407bedef5af410552494ce5be9090444dbbcb57258c1af3d56" +checksum = "958c5d6ecf1f214b4c2bbbbf6ab9523a864bd136dcf71a7e8904799acfe1ad47" dependencies = [ - "strum 0.26.3", - "strum_macros 0.26.4", - "unicode-width 0.2.1", + "unicode-segmentation", + "unicode-width 0.2.2", ] [[package]] name = "compression-codecs" -version = "0.4.34" +version = "0.4.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34a3cbbb8b6eca96f3a5c4bf6938d5b27ced3675d69f95bb51948722870bc323" +checksum = "eb7b51a7d9c967fc26773061ba86150f19c50c0d65c887cb1fbe295fd16619b7" dependencies = [ "bzip2", "compression-core", @@ -1527,15 +1453,14 @@ dependencies = [ [[package]] name = "console" -version = "0.16.1" +version = "0.16.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b430743a6eb14e9764d4260d4c0d8123087d504eeb9c48f2b2a5e810dd369df4" +checksum = "d64e8af5551369d19cf50138de61f1c42074ab970f74e99be916646777f8fc87" dependencies = [ "encode_unicode", "libc", - "once_cell", - "unicode-width 0.2.1", - "windows-sys 0.61.0", + "unicode-width 0.2.2", + "windows-sys 0.61.2", ] [[package]] @@ -1563,7 +1488,7 @@ version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" dependencies = [ - "getrandom 0.2.16", + "getrandom 0.2.17", "once_cell", "tiny-keccak", ] @@ -1579,9 +1504,9 @@ dependencies = [ [[package]] name = "constant_time_eq" -version = "0.3.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" +checksum = "3d52eff69cd5e647efe296129160853a42795992097e8af39800e1060caeea9b" [[package]] name = "core-foundation" @@ -1634,9 +1559,9 @@ dependencies = [ [[package]] name = "criterion" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d883447757bb0ee46f233e9dc22eb84d93a9508c9b868687b274fc431d886bf" +checksum = "950046b2aa2492f9a536f5f4f9a3de7b9e2476e575e05bd6c333371add4d98f3" dependencies = [ "alloca", "anes", @@ -1661,9 +1586,9 @@ dependencies = [ [[package]] name = "criterion-plot" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed943f81ea2faa8dcecbbfa50164acf95d555afec96a27871663b300e387b2e4" +checksum = "d8d80a2f4f5b554395e47b5d8305bc3d27813bacb73493eb1001e8f76dae29ea" dependencies = [ "cast", "itertools 0.13.0", @@ -1711,9 +1636,9 @@ checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" [[package]] name = "crypto-common" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" dependencies = [ "generic-array", "typenum", @@ -1721,21 +1646,21 @@ dependencies = [ [[package]] name = "csv" -version = "1.3.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf" +checksum = "52cd9d68cf7efc6ddfaaee42e7288d3a99d613d4b50f76ce9827ae0c6e14f938" dependencies = [ "csv-core", "itoa", "ryu", - "serde", + "serde_core", ] [[package]] name = "csv-core" -version = "0.1.12" +version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d02f3b0da4c6504f86e9cd789d8dbafab48c2321be74e9987593de5a894d93d" +checksum = "704a3c26996a80471189265814dbc2c257598b96b8a7feae2d31ace646bb9782" dependencies = [ "memchr", ] @@ -1764,9 +1689,9 @@ checksum = "b365fabc795046672053e29c954733ec3b05e4be654ab130fe8f1f94d7051f35" [[package]] name = "darling" -version = "0.21.3" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0" +checksum = "25ae13da2f202d56bd7f91c25fba009e7717a1e4a1cc98a76d844b65ae912e9d" dependencies = [ "darling_core", "darling_macro", @@ -1774,27 +1699,26 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.21.3" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1247195ecd7e3c85f83c8d2a366e4210d588e802133e1e355180a9870b517ea4" +checksum = "9865a50f7c335f53564bb694ef660825eb8610e0a53d3e11bf1b0d3df31e03b0" dependencies = [ - "fnv", "ident_case", "proc-macro2", "quote", "strsim", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] name = "darling_macro" -version = "0.21.3" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" +checksum = "ac3984ec7bd6cfa798e62b4a642426a5be0e68f9401cfc2a01e3fa9ea2fcdb8d" dependencies = [ "darling_core", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -1813,7 +1737,7 @@ dependencies = [ [[package]] name = "datafusion" -version = "51.0.0" +version = "53.1.0" dependencies = [ "arrow", "arrow-schema", @@ -1862,13 +1786,15 @@ dependencies = [ "itertools 0.14.0", "liblzma", "log", - "nix", + "nix 0.31.2", "object_store", "parking_lot", "parquet", "paste", + "pretty_assertions", "rand 0.9.2", "rand_distr", + "recursive", "regex", "rstest", "serde", @@ -1885,7 +1811,7 @@ dependencies = [ [[package]] name = "datafusion-benchmarks" -version = "51.0.0" +version = "53.1.0" dependencies = [ "arrow", "clap", @@ -1910,7 +1836,7 @@ dependencies = [ [[package]] name = "datafusion-catalog" -version = "51.0.0" +version = "53.1.0" dependencies = [ "arrow", "async-trait", @@ -1933,7 +1859,7 @@ dependencies = [ [[package]] name = "datafusion-catalog-listing" -version = "51.0.0" +version = "53.1.0" dependencies = [ "arrow", "async-trait", @@ -1955,7 +1881,7 @@ dependencies = [ [[package]] name = "datafusion-cli" -version = "51.0.0" +version = "53.1.0" dependencies = [ "arrow", "async-trait", @@ -1986,9 +1912,9 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "51.0.0" +version = "53.1.0" dependencies = [ - "ahash 0.8.12", + "ahash", "apache-avro", "arrow", "arrow-ipc", @@ -1997,8 +1923,9 @@ dependencies = [ "half", "hashbrown 0.16.1", "hex", - "indexmap 2.12.1", + "indexmap 2.13.0", "insta", + "itertools 0.14.0", "libc", "log", "object_store", @@ -2013,7 +1940,7 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" -version = "51.0.0" +version = "53.1.0" dependencies = [ "futures", "log", @@ -2022,7 +1949,7 @@ dependencies = [ [[package]] name = "datafusion-datasource" -version = "51.0.0" +version = "53.1.0" dependencies = [ "arrow", "async-compression", @@ -2057,7 +1984,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-arrow" -version = "51.0.0" +version = "53.1.0" dependencies = [ "arrow", "arrow-ipc", @@ -2080,7 +2007,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-avro" -version = "51.0.0" +version = "53.1.0" dependencies = [ "apache-avro", "arrow", @@ -2099,7 +2026,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-csv" -version = "51.0.0" +version = "53.1.0" dependencies = [ "arrow", "async-trait", @@ -2120,7 +2047,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-json" -version = "51.0.0" +version = "53.1.0" dependencies = [ "arrow", "async-trait", @@ -2135,23 +2062,27 @@ dependencies = [ "datafusion-session", "futures", "object_store", + "serde_json", "tokio", + "tokio-stream", ] [[package]] name = "datafusion-datasource-parquet" -version = "51.0.0" +version = "53.1.0" dependencies = [ "arrow", "async-trait", "bytes", "chrono", + "criterion", "datafusion-common", "datafusion-common-runtime", "datafusion-datasource", "datafusion-execution", "datafusion-expr", "datafusion-functions-aggregate-common", + "datafusion-functions-nested", "datafusion-physical-expr", "datafusion-physical-expr-adapter", "datafusion-physical-expr-common", @@ -2164,16 +2095,17 @@ dependencies = [ "object_store", "parking_lot", "parquet", + "tempfile", "tokio", ] [[package]] name = "datafusion-doc" -version = "51.0.0" +version = "53.1.0" [[package]] name = "datafusion-examples" -version = "51.0.0" +version = "53.1.0" dependencies = [ "arrow", "arrow-flight", @@ -2193,13 +2125,15 @@ dependencies = [ "insta", "log", "mimalloc", - "nix", + "nix 0.31.2", + "nom", "object_store", "prost", "rand 0.9.2", + "serde", "serde_json", - "strum 0.27.2", - "strum_macros 0.27.2", + "strum 0.28.0", + "strum_macros 0.28.0", "tempfile", "test-utils", "tokio", @@ -2212,14 +2146,16 @@ dependencies = [ [[package]] name = "datafusion-execution" -version = "51.0.0" +version = "53.1.0" dependencies = [ "arrow", + "arrow-buffer", "async-trait", "chrono", "dashmap", "datafusion-common", "datafusion-expr", + "datafusion-physical-expr-common", "futures", "insta", "log", @@ -2233,7 +2169,7 @@ dependencies = [ [[package]] name = "datafusion-expr" -version = "51.0.0" +version = "53.1.0" dependencies = [ "arrow", "async-trait", @@ -2246,7 +2182,7 @@ dependencies = [ "datafusion-functions-window-common", "datafusion-physical-expr-common", "env_logger", - "indexmap 2.12.1", + "indexmap 2.13.0", "insta", "itertools 0.14.0", "paste", @@ -2257,18 +2193,18 @@ dependencies = [ [[package]] name = "datafusion-expr-common" -version = "51.0.0" +version = "53.1.0" dependencies = [ "arrow", "datafusion-common", - "indexmap 2.12.1", + "indexmap 2.13.0", "itertools 0.14.0", "paste", ] [[package]] name = "datafusion-ffi" -version = "51.0.0" +version = "53.1.0" dependencies = [ "abi_stable", "arrow", @@ -2302,7 +2238,7 @@ dependencies = [ [[package]] name = "datafusion-functions" -version = "51.0.0" +version = "53.1.0" dependencies = [ "arrow", "arrow-buffer", @@ -2324,6 +2260,7 @@ dependencies = [ "itertools 0.14.0", "log", "md-5", + "memchr", "num-traits", "rand 0.9.2", "regex", @@ -2335,9 +2272,9 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" -version = "51.0.0" +version = "53.1.0" dependencies = [ - "ahash 0.8.12", + "ahash", "arrow", "criterion", "datafusion-common", @@ -2350,15 +2287,16 @@ dependencies = [ "datafusion-physical-expr-common", "half", "log", + "num-traits", "paste", "rand 0.9.2", ] [[package]] name = "datafusion-functions-aggregate-common" -version = "51.0.0" +version = "53.1.0" dependencies = [ - "ahash 0.8.12", + "ahash", "arrow", "criterion", "datafusion-common", @@ -2369,7 +2307,7 @@ dependencies = [ [[package]] name = "datafusion-functions-nested" -version = "51.0.0" +version = "53.1.0" dependencies = [ "arrow", "arrow-ord", @@ -2384,7 +2322,9 @@ dependencies = [ "datafusion-functions-aggregate-common", "datafusion-macros", "datafusion-physical-expr-common", + "hashbrown 0.16.1", "itertools 0.14.0", + "itoa", "log", "paste", "rand 0.9.2", @@ -2392,7 +2332,7 @@ dependencies = [ [[package]] name = "datafusion-functions-table" -version = "51.0.0" +version = "53.1.0" dependencies = [ "arrow", "async-trait", @@ -2406,9 +2346,10 @@ dependencies = [ [[package]] name = "datafusion-functions-window" -version = "51.0.0" +version = "53.1.0" dependencies = [ "arrow", + "criterion", "datafusion-common", "datafusion-doc", "datafusion-expr", @@ -2422,7 +2363,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window-common" -version = "51.0.0" +version = "53.1.0" dependencies = [ "datafusion-common", "datafusion-physical-expr-common", @@ -2430,16 +2371,16 @@ dependencies = [ [[package]] name = "datafusion-macros" -version = "51.0.0" +version = "53.1.0" dependencies = [ "datafusion-doc", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] name = "datafusion-optimizer" -version = "51.0.0" +version = "53.1.0" dependencies = [ "arrow", "async-trait", @@ -2455,7 +2396,7 @@ dependencies = [ "datafusion-physical-expr", "datafusion-sql", "env_logger", - "indexmap 2.12.1", + "indexmap 2.13.0", "insta", "itertools 0.14.0", "log", @@ -2466,9 +2407,9 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" -version = "51.0.0" +version = "53.1.0" dependencies = [ - "ahash 0.8.12", + "ahash", "arrow", "criterion", "datafusion-common", @@ -2479,12 +2420,12 @@ dependencies = [ "datafusion-physical-expr-common", "half", "hashbrown 0.16.1", - "indexmap 2.12.1", + "indexmap 2.13.0", "insta", "itertools 0.14.0", "parking_lot", "paste", - "petgraph 0.8.3", + "petgraph", "rand 0.9.2", "recursive", "rstest", @@ -2493,7 +2434,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr-adapter" -version = "51.0.0" +version = "53.1.0" dependencies = [ "arrow", "datafusion-common", @@ -2506,22 +2447,22 @@ dependencies = [ [[package]] name = "datafusion-physical-expr-common" -version = "51.0.0" +version = "53.1.0" dependencies = [ - "ahash 0.8.12", + "ahash", "arrow", "chrono", "datafusion-common", "datafusion-expr-common", "hashbrown 0.16.1", - "indexmap 2.12.1", + "indexmap 2.13.0", "itertools 0.14.0", "parking_lot", ] [[package]] name = "datafusion-physical-optimizer" -version = "51.0.0" +version = "53.1.0" dependencies = [ "arrow", "datafusion-common", @@ -2541,9 +2482,9 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" -version = "51.0.0" +version = "53.1.0" dependencies = [ - "ahash 0.8.12", + "ahash", "arrow", "arrow-ord", "arrow-schema", @@ -2563,10 +2504,11 @@ dependencies = [ "futures", "half", "hashbrown 0.16.1", - "indexmap 2.12.1", + "indexmap 2.13.0", "insta", "itertools 0.14.0", "log", + "num-traits", "parking_lot", "pin-project-lite", "rand 0.9.2", @@ -2577,7 +2519,7 @@ dependencies = [ [[package]] name = "datafusion-proto" -version = "51.0.0" +version = "53.1.0" dependencies = [ "arrow", "async-trait", @@ -2604,9 +2546,10 @@ dependencies = [ "datafusion-proto-common", "doc-comment", "object_store", - "pbjson", + "pbjson 0.9.0", "pretty_assertions", "prost", + "rand 0.9.2", "serde", "serde_json", "tokio", @@ -2614,19 +2557,19 @@ dependencies = [ [[package]] name = "datafusion-proto-common" -version = "51.0.0" +version = "53.1.0" dependencies = [ "arrow", "datafusion-common", "doc-comment", - "pbjson", + "pbjson 0.9.0", "prost", "serde", ] [[package]] name = "datafusion-pruning" -version = "51.0.0" +version = "53.1.0" dependencies = [ "arrow", "datafusion-common", @@ -2644,7 +2587,7 @@ dependencies = [ [[package]] name = "datafusion-session" -version = "51.0.0" +version = "53.1.0" dependencies = [ "async-trait", "datafusion-common", @@ -2656,29 +2599,33 @@ dependencies = [ [[package]] name = "datafusion-spark" -version = "51.0.0" +version = "53.1.0" dependencies = [ "arrow", "bigdecimal", "chrono", "crc32fast", "criterion", + "datafusion", "datafusion-catalog", "datafusion-common", "datafusion-execution", "datafusion-expr", "datafusion-functions", + "datafusion-functions-aggregate", "datafusion-functions-nested", "log", "percent-encoding", "rand 0.9.2", + "serde_json", "sha1", + "sha2", "url", ] [[package]] name = "datafusion-sql" -version = "51.0.0" +version = "53.1.0" dependencies = [ "arrow", "bigdecimal", @@ -2691,7 +2638,7 @@ dependencies = [ "datafusion-functions-nested", "datafusion-functions-window", "env_logger", - "indexmap 2.12.1", + "indexmap 2.13.0", "insta", "itertools 0.14.0", "log", @@ -2704,7 +2651,7 @@ dependencies = [ [[package]] name = "datafusion-sqllogictest" -version = "51.0.0" +version = "53.1.0" dependencies = [ "arrow", "async-trait", @@ -2722,10 +2669,8 @@ dependencies = [ "itertools 0.14.0", "log", "object_store", - "postgres-protocol", "postgres-types", "regex", - "rust_decimal", "sqllogictest", "sqlparser", "tempfile", @@ -2737,7 +2682,7 @@ dependencies = [ [[package]] name = "datafusion-substrait" -version = "51.0.0" +version = "53.1.0" dependencies = [ "async-recursion", "async-trait", @@ -2754,13 +2699,13 @@ dependencies = [ "substrait", "tokio", "url", - "uuid", ] [[package]] name = "datafusion-wasmtest" -version = "51.0.0" +version = "53.1.0" dependencies = [ + "bytes", "chrono", "console_error_panic_hook", "datafusion", @@ -2770,6 +2715,7 @@ dependencies = [ "datafusion-optimizer", "datafusion-physical-plan", "datafusion-sql", + "futures", "getrandom 0.3.4", "object_store", "tokio", @@ -2780,12 +2726,12 @@ dependencies = [ [[package]] name = "deranged" -version = "0.5.3" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d630bccd429a5bb5a64b5e94f693bfc48c9f8566418fda4c494cc94f911f87cc" +checksum = "7cd812cc2bc1d69d4764bd80df88b4317eaef9e773c75226407d9bc0876b211c" dependencies = [ "powerfmt", - "serde", + "serde_core", ] [[package]] @@ -2823,7 +2769,7 @@ dependencies = [ "libc", "option-ext", "redox_users", - "windows-sys 0.61.0", + "windows-sys 0.59.0", ] [[package]] @@ -2834,14 +2780,14 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] name = "doc-comment" -version = "0.3.3" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" +checksum = "780955b8b195a21ab8e4ac6b60dd1dbdcec1dc6c51c0617964b08c81785e12c9" [[package]] name = "docker_credential" @@ -2890,7 +2836,7 @@ dependencies = [ "enum-ordinalize", "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -2913,29 +2859,29 @@ checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" [[package]] name = "enum-ordinalize" -version = "4.3.0" +version = "4.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fea0dcfa4e54eeb516fe454635a95753ddd39acda650ce703031c6973e315dd5" +checksum = "4a1091a7bb1f8f2c4b28f1fe2cef4980ca2d410a3d727d67ecc3178c9b0800f0" dependencies = [ "enum-ordinalize-derive", ] [[package]] name = "enum-ordinalize-derive" -version = "4.3.1" +version = "4.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" +checksum = "8ca9601fb2d62598ee17836250842873a413586e5d7ed88b356e38ddbb0ec631" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] name = "env_filter" -version = "0.1.3" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0" +checksum = "7a1c3cc8e57274ec99de65301228b537f1e4eedc1b8e0f9411c6caac8ae7308f" dependencies = [ "log", "regex", @@ -2943,11 +2889,11 @@ dependencies = [ [[package]] name = "env_logger" -version = "0.11.8" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13c863f0904021b108aa8b2f55046443e6b1ebde8fd4a15c399893aae4fa069f" +checksum = "b2daee4ea451f429a58296525ddf28b45a3b64f1acf6587e2067437bb11e218d" dependencies = [ - "anstream", + "anstream 0.6.21", "anstyle", "env_filter", "jiff", @@ -2967,7 +2913,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.61.0", + "windows-sys 0.59.0", ] [[package]] @@ -2989,7 +2935,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "de48cc4d1c1d97a20fd819def54b890cadde72ed3ad0c614822a0a433361be96" dependencies = [ "cfg-if", - "windows-sys 0.61.0", + "windows-sys 0.61.2", ] [[package]] @@ -3058,21 +3004,20 @@ dependencies = [ [[package]] name = "filetime" -version = "0.2.26" +version = "0.2.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc0505cd1b6fa6580283f6bdf70a73fcf4aba1184038c90902b92b3dd0df63ed" +checksum = "f98844151eee8917efc50bd9e8318cb963ae8b297431495d3f758616ea5c57db" dependencies = [ "cfg-if", "libc", "libredox", - "windows-sys 0.60.2", ] [[package]] name = "find-msvc-tools" -version = "0.1.2" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ced73b1dacfc750a6db6c0a0c3a3853c8b41997e2e2c563dc90804ae6867959" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" [[package]] name = "fixedbitset" @@ -3082,9 +3027,9 @@ checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" [[package]] name = "flatbuffers" -version = "25.2.10" +version = "25.12.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1045398c1bfd89168b5fd3f1fc11f6e70b34f6f66300c87d44d3de849463abf1" +checksum = "35f6839d7b3b98adde531effaf34f0c2badc6f4735d26fe74709d8e513a96ef3" dependencies = [ "bitflags", "rustc_version", @@ -3092,13 +3037,13 @@ dependencies = [ [[package]] name = "flate2" -version = "1.1.5" +version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb" +checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" dependencies = [ "crc32fast", - "libz-rs-sys", "miniz_oxide", + "zlib-rs", ] [[package]] @@ -3130,9 +3075,9 @@ dependencies = [ [[package]] name = "fs-err" -version = "3.1.2" +version = "3.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44f150ffc8782f35521cec2b23727707cb4045706ba3c854e86bef66b3a8cdbd" +checksum = "73fde052dbfc920003cfd2c8e2c6e6d4cc7c1091538c3a24226cec0665ab08c0" dependencies = [ "autocfg", ] @@ -3143,17 +3088,11 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" -[[package]] -name = "funty" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" - [[package]] name = "futures" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d" dependencies = [ "futures-channel", "futures-core", @@ -3166,9 +3105,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" dependencies = [ "futures-core", "futures-sink", @@ -3176,15 +3115,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" [[package]] name = "futures-executor" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d" dependencies = [ "futures-core", "futures-task", @@ -3193,32 +3132,32 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" [[package]] name = "futures-macro" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] name = "futures-sink" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" +checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" [[package]] name = "futures-task" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" [[package]] name = "futures-timer" @@ -3228,9 +3167,9 @@ checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" [[package]] name = "futures-util" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" dependencies = [ "futures-channel", "futures-core", @@ -3240,7 +3179,6 @@ dependencies = [ "futures-task", "memchr", "pin-project-lite", - "pin-utils", "slab", ] @@ -3248,7 +3186,7 @@ dependencies = [ name = "gen" version = "0.1.0" dependencies = [ - "pbjson-build", + "pbjson-build 0.9.0", "prost-build", ] @@ -3256,7 +3194,7 @@ dependencies = [ name = "gen-common" version = "0.1.0" dependencies = [ - "pbjson-build", + "pbjson-build 0.9.0", "prost-build", ] @@ -3281,14 +3219,14 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" dependencies = [ "cfg-if", "js-sys", "libc", - "wasi", + "wasi 0.11.1+wasi-snapshot-preview1", "wasm-bindgen", ] @@ -3301,11 +3239,24 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "r-efi", + "r-efi 5.3.0", "wasip2", "wasm-bindgen", ] +[[package]] +name = "getrandom" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" +dependencies = [ + "cfg-if", + "libc", + "r-efi 6.0.0", + "wasip2", + "wasip3", +] + [[package]] name = "glob" version = "0.3.3" @@ -3314,9 +3265,9 @@ checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" [[package]] name = "globset" -version = "0.4.16" +version = "0.4.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54a1028dfc5f5df5da8a56a73e6c153c9a9708ec57232470703592a3f18e49f5" +checksum = "52dfc19153a48bde0cbd630453615c8151bce3a5adfac7a0aebfbf0a1e1f57e3" dependencies = [ "aho-corasick", "bstr", @@ -3327,17 +3278,17 @@ dependencies = [ [[package]] name = "h2" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3c0b69cfcb4e1b9f1bf2f53f95f766e4661169728ec61cd3fe5a0166f2d1386" +checksum = "2f44da3a8150a6703ed5d34e164b875fd14c2cdab9af1252a9a1020bde2bdc54" dependencies = [ "atomic-waker", "bytes", "fnv", "futures-core", "futures-sink", - "http 1.3.1", - "indexmap 2.12.1", + "http 1.4.0", + "indexmap 2.13.0", "slab", "tokio", "tokio-util", @@ -3353,6 +3304,8 @@ dependencies = [ "cfg-if", "crunchy", "num-traits", + "rand 0.9.2", + "rand_distr", "zerocopy", ] @@ -3361,9 +3314,6 @@ name = "hashbrown" version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" -dependencies = [ - "ahash 0.7.8", -] [[package]] name = "hashbrown" @@ -3377,8 +3327,6 @@ version = "0.15.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" dependencies = [ - "allocator-api2", - "equivalent", "foldhash 0.1.5", ] @@ -3416,11 +3364,11 @@ dependencies = [ [[package]] name = "home" -version = "0.5.11" +version = "0.5.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" +checksum = "cc627f471c528ff0c4a49e1d5e60450c8f6461dd6d10ba9dcd3a61d3dff7728d" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -3436,12 +3384,11 @@ dependencies = [ [[package]] name = "http" -version = "1.3.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" dependencies = [ "bytes", - "fnv", "itoa", ] @@ -3463,7 +3410,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http 1.3.1", + "http 1.4.0", ] [[package]] @@ -3474,7 +3421,7 @@ checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", "futures-core", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "pin-project-lite", ] @@ -3499,16 +3446,16 @@ checksum = "135b12329e5e3ce057a9f972339ea52bc954fe1e9358ef27f95e89716fbc5424" [[package]] name = "hyper" -version = "1.7.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb3aa54a13a0dfe7fbe3a59e0c76093041720fdc77b110cc0fc260fafb4dc51e" +checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11" dependencies = [ "atomic-waker", "bytes", "futures-channel", "futures-core", "h2", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "httparse", "httpdate", @@ -3541,7 +3488,7 @@ version = "0.27.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" dependencies = [ - "http 1.3.1", + "http 1.4.0", "hyper", "hyper-util", "rustls", @@ -3567,16 +3514,15 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.17" +version = "0.1.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c6995591a8f1380fcb4ba966a252a4b29188d51d2b89e3a252f5305be65aea8" +checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" dependencies = [ "base64 0.22.1", "bytes", "futures-channel", - "futures-core", "futures-util", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "hyper", "ipnet", @@ -3606,9 +3552,9 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.64" +version = "0.1.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33e57f83510bb73707521ebaffa789ec8caf86f9657cad665b092b581d40e9fb" +checksum = "e31bc9ad994ba00e440a8aa5c9ef0ec67d5cb5e5cb0cc7f8b744a35b389cc470" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -3630,9 +3576,9 @@ dependencies = [ [[package]] name = "icu_collections" -version = "2.0.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "200072f5d0e3614556f94a9930d5dc3e0662a652823904c3a75dc3b0af7fee47" +checksum = "4c6b649701667bbe825c3b7e6388cb521c23d88644678e83c0c4d0a621a34b43" dependencies = [ "displaydoc", "potential_utf", @@ -3643,9 +3589,9 @@ dependencies = [ [[package]] name = "icu_locale_core" -version = "2.0.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0cde2700ccaed3872079a65fb1a78f6c0a36c91570f28755dda67bc8f7d9f00a" +checksum = "edba7861004dd3714265b4db54a3c390e880ab658fec5f7db895fae2046b5bb6" dependencies = [ "displaydoc", "litemap", @@ -3656,11 +3602,10 @@ dependencies = [ [[package]] name = "icu_normalizer" -version = "2.0.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "436880e8e18df4d7bbc06d58432329d6458cc84531f7ac5f024e93deadb37979" +checksum = "5f6c8828b67bf8908d82127b2054ea1b4427ff0230ee9141c54251934ab1b599" dependencies = [ - "displaydoc", "icu_collections", "icu_normalizer_data", "icu_properties", @@ -3671,42 +3616,38 @@ dependencies = [ [[package]] name = "icu_normalizer_data" -version = "2.0.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00210d6893afc98edb752b664b8890f0ef174c8adbb8d0be9710fa66fbbf72d3" +checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a" [[package]] name = "icu_properties" -version = "2.0.1" +version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "016c619c1eeb94efb86809b015c58f479963de65bdb6253345c1a1276f22e32b" +checksum = "020bfc02fe870ec3a66d93e677ccca0562506e5872c650f893269e08615d74ec" dependencies = [ - "displaydoc", "icu_collections", "icu_locale_core", "icu_properties_data", "icu_provider", - "potential_utf", "zerotrie", "zerovec", ] [[package]] name = "icu_properties_data" -version = "2.0.1" +version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "298459143998310acd25ffe6810ed544932242d3f07083eee1084d83a71bd632" +checksum = "616c294cf8d725c6afcd8f55abc17c56464ef6211f9ed59cccffe534129c77af" [[package]] name = "icu_provider" -version = "2.0.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03c80da27b5f4187909049ee2d72f276f0d9f99a42c306bd0131ecfe04d8e5af" +checksum = "85962cf0ce02e1e0a629cc34e7ca3e373ce20dda4c4d7294bbd0bf1fdb59e614" dependencies = [ "displaydoc", "icu_locale_core", - "stable_deref_trait", - "tinystr", "writeable", "yoke", "zerofrom", @@ -3714,6 +3655,12 @@ dependencies = [ "zerovec", ] +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + [[package]] name = "ident_case" version = "1.0.1" @@ -3754,9 +3701,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.12.1" +version = "2.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ad4bb2b565bca0645f4d68c5c9af97fba094e9791da685bf83cb5f3ce74acf2" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" dependencies = [ "equivalent", "hashbrown 0.16.1", @@ -3766,22 +3713,22 @@ dependencies = [ [[package]] name = "indicatif" -version = "0.18.3" +version = "0.18.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9375e112e4b463ec1b1c6c011953545c65a30164fbab5b581df32b3abf0dcb88" +checksum = "25470f23803092da7d239834776d653104d551bc4d7eacaf31e6837854b8e9eb" dependencies = [ - "console 0.16.1", + "console 0.16.3", "portable-atomic", - "unicode-width 0.2.1", + "unicode-width 0.2.2", "unit-prefix", "web-time", ] [[package]] name = "insta" -version = "1.46.0" +version = "1.46.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b66886d14d18d420ab5052cbff544fc5d34d0b2cdd35eb5976aaa10a4a472e5" +checksum = "e82db8c87c7f1ccecb34ce0c24399b8a73081427f3c7c50a5d597925356115e4" dependencies = [ "console 0.15.11", "globset", @@ -3812,15 +3759,15 @@ checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" [[package]] name = "ipnet" -version = "2.11.0" +version = "2.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" +checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" [[package]] name = "iri-string" -version = "0.7.8" +version = "0.7.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbc5ebe9c3a1a7a5127f920a418f7585e9e758e911d0466ed004f393b0e380b2" +checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a" dependencies = [ "memchr", "serde", @@ -3828,9 +3775,9 @@ dependencies = [ [[package]] name = "is_terminal_polyfill" -version = "1.70.1" +version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" +checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" [[package]] name = "itertools" @@ -3852,32 +3799,32 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.15" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" [[package]] name = "jiff" -version = "0.2.15" +version = "0.2.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be1f93b8b1eb69c77f24bbb0afdf66f54b632ee39af40ca21c4365a1d7347e49" +checksum = "1a3546dc96b6d42c5f24902af9e2538e82e39ad350b0c766eb3fbf2d8f3d8359" dependencies = [ "jiff-static", "log", "portable-atomic", "portable-atomic-util", - "serde", + "serde_core", ] [[package]] name = "jiff-static" -version = "0.2.15" +version = "0.2.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03343451ff899767262ec32146f6d559dd759fdadf42ff0e227c7c48f72594b4" +checksum = "2a8c8b344124222efd714b73bb41f8b5120b27a7cc1c75593a6ff768d9d05aa4" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -3892,9 +3839,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.83" +version = "0.3.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "464a3709c7f55f1f721e5389aa6ea4e3bc6aba669353300af094b29ffbdde1d8" +checksum = "b49715b7073f385ba4bc528e5747d02e66cb39c6146efb66b781f131f0fb399c" dependencies = [ "once_cell", "wasm-bindgen", @@ -3906,6 +3853,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + [[package]] name = "lexical-core" version = "1.0.6" @@ -3971,9 +3924,9 @@ checksum = "2c4a545a15244c7d945065b5d392b2d2d7f21526fba56ce51467b06ed445e8f7" [[package]] name = "libc" -version = "0.2.177" +version = "0.2.183" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976" +checksum = "b5b646652bf6661599e1da8901b3b9522896f01e736bad5f723fe7a3a27f899d" [[package]] name = "libloading" @@ -3985,30 +3938,20 @@ dependencies = [ "winapi", ] -[[package]] -name = "libloading" -version = "0.8.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55" -dependencies = [ - "cfg-if", - "windows-link 0.2.0", -] - [[package]] name = "liblzma" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73c36d08cad03a3fbe2c4e7bb3a9e84c57e4ee4135ed0b065cade3d98480c648" +checksum = "b6033b77c21d1f56deeae8014eb9fbe7bdf1765185a6c508b5ca82eeaed7f899" dependencies = [ "liblzma-sys", ] [[package]] name = "liblzma-sys" -version = "0.4.4" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01b9596486f6d60c3bbe644c0e1be1aa6ccc472ad630fe8927b456973d7cb736" +checksum = "9f2db66f3268487b5033077f266da6777d057949b8f93c8ad82e441df25e6186" dependencies = [ "cc", "libc", @@ -4017,9 +3960,9 @@ dependencies = [ [[package]] name = "libm" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" +checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" [[package]] name = "libmimalloc-sys" @@ -4034,55 +3977,46 @@ dependencies = [ [[package]] name = "libredox" -version = "0.1.10" +version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "416f7e718bdb06000964960ffa43b4335ad4012ae8b99060261aa4a8088d5ccb" +checksum = "1744e39d1d6a9948f4f388969627434e31128196de472883b39f148769bfe30a" dependencies = [ "bitflags", "libc", - "redox_syscall", + "plain", + "redox_syscall 0.7.3", ] [[package]] name = "libtest-mimic" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5297962ef19edda4ce33aaa484386e0a5b3d7f2f4e037cbeee00503ef6b29d33" +checksum = "14e6ba06f0ade6e504aff834d7c34298e5155c6baca353cc6a4aaff2f9fd7f33" dependencies = [ - "anstream", + "anstream 1.0.0", "anstyle", "clap", "escape8259", ] -[[package]] -name = "libz-rs-sys" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "840db8cf39d9ec4dd794376f38acc40d0fc65eec2a8f484f7fd375b84602becd" -dependencies = [ - "zlib-rs", -] - [[package]] name = "linux-raw-sys" -version = "0.11.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" +checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" [[package]] name = "litemap" -version = "0.8.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" +checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" [[package]] name = "lock_api" -version = "0.4.13" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" dependencies = [ - "autocfg", "scopeguard", ] @@ -4100,9 +4034,9 @@ checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" [[package]] name = "lz4_flex" -version = "0.12.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab6473172471198271ff72e9379150e9dfd70d8e533e0752a27e515b48dd375e" +checksum = "98c23545df7ecf1b16c303910a69b079e8e251d60f7dd2cc9b4177f2afaf1746" dependencies = [ "twox-hash", ] @@ -4125,9 +4059,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.7.5" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" [[package]] name = "mimalloc" @@ -4146,20 +4080,14 @@ checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" [[package]] name = "minicov" -version = "0.3.7" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f27fe9f1cc3c22e1687f9446c2083c4c5fc7f0bcf1c7a86bdbded14985895b4b" +checksum = "4869b6a491569605d66d3952bcdf03df789e5b536e5f0cf7758a7f08a55ae24d" dependencies = [ "cc", "walkdir", ] -[[package]] -name = "minimal-lexical" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" - [[package]] name = "miniz_oxide" version = "0.8.9" @@ -4172,13 +4100,13 @@ dependencies = [ [[package]] name = "mio" -version = "1.0.4" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78bed444cc8a2160f01cbcf811ef18cac863ad68ae8ca62092e8db51d51c761c" +checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" dependencies = [ "libc", - "wasi", - "windows-sys 0.59.0", + "wasi 0.11.1+wasi-snapshot-preview1", + "windows-sys 0.61.2", ] [[package]] @@ -4208,32 +4136,43 @@ dependencies = [ "libc", ] +[[package]] +name = "nix" +version = "0.31.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d6d0705320c1e6ba1d912b5e37cf18071b6c2e9b7fa8215a1e8a7651966f5d3" +dependencies = [ + "bitflags", + "cfg-if", + "cfg_aliases", + "libc", +] + [[package]] name = "nom" -version = "7.1.3" +version = "8.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +checksum = "df9761775871bdef83bee530e60050f7e54b1105350d6884eb0fb4f46c2f9405" dependencies = [ "memchr", - "minimal-lexical", ] [[package]] name = "ntapi" -version = "0.4.1" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8a3895c6391c39d7fe7ebc444a87eb2991b2a0bc718fdabd071eec617fc68e4" +checksum = "c3b335231dfd352ffb0f8017f3b6027a4917f7df785ea2143d8af2adc66980ae" dependencies = [ "winapi", ] [[package]] name = "nu-ansi-term" -version = "0.50.1" +version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4a28e057d01f97e61255210fcff094d74ed0466038633e95017f5beb68e4399" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -4272,9 +4211,9 @@ dependencies = [ [[package]] name = "num-conv" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050" [[package]] name = "num-integer" @@ -4319,28 +4258,46 @@ dependencies = [ [[package]] name = "objc2-core-foundation" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c10c2894a6fed806ade6027bcd50662746363a9589d3ec9d9bef30a4e4bc166" +checksum = "2a180dd8642fa45cdb7dd721cd4c11b1cadd4929ce112ebd8b9f5803cc79d536" dependencies = [ "bitflags", ] [[package]] name = "objc2-io-kit" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71c1c64d6120e51cd86033f67176b1cb66780c2efe34dec55176f77befd93c0a" +checksum = "33fafba39597d6dc1fb709123dfa8289d39406734be322956a69f0931c73bb15" dependencies = [ "libc", "objc2-core-foundation", ] +[[package]] +name = "objc2-system-configuration" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7216bd11cbda54ccabcab84d523dc93b858ec75ecfb3a7d89513fa22464da396" +dependencies = [ + "objc2-core-foundation", +] + +[[package]] +name = "object" +version = "0.37.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff76201f031d8863c38aa7f905eca4f53abbfa15f609db4277d44cd8938f33fe" +dependencies = [ + "memchr", +] + [[package]] name = "object_store" -version = "0.12.4" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c1be0c6c22ec0817cdc77d3842f721a17fd30ab6965001415b5402a74e6b740" +checksum = "c2858065e55c148d294a9f3aae3b0fa9458edadb41a108397094566f4e3c0dfb" dependencies = [ "async-trait", "base64 0.22.1", @@ -4348,7 +4305,7 @@ dependencies = [ "chrono", "form_urlencoded", "futures", - "http 1.3.1", + "http 1.4.0", "http-body-util", "humantime", "hyper", @@ -4360,7 +4317,7 @@ dependencies = [ "rand 0.9.2", "reqwest", "ring", - "rustls-pemfile", + "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", @@ -4375,15 +4332,15 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.21.3" +version = "1.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" [[package]] name = "once_cell_polyfill" -version = "1.70.1" +version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4895175b425cb1f87721b59f0f286c2092bd4af812243672510e1ac53e2e0ad" +checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" [[package]] name = "oorandom" @@ -4393,9 +4350,9 @@ checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" [[package]] name = "openssl-probe" -version = "0.1.6" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" +checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" [[package]] name = "option-ext" @@ -4420,9 +4377,9 @@ checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" [[package]] name = "owo-colors" -version = "4.2.2" +version = "4.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48dd4f4a2c8405440fd0462561f0e5806bd0f77e86f51c761481bdd4018b545e" +checksum = "d211803b9b6b570f68772237e415a029d5a50c65d382910b879fb19d3271f94d" [[package]] name = "page_size" @@ -4436,9 +4393,9 @@ dependencies = [ [[package]] name = "parking_lot" -version = "0.12.4" +version = "0.12.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" dependencies = [ "lock_api", "parking_lot_core", @@ -4446,27 +4403,26 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.11" +version = "0.9.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" dependencies = [ "cfg-if", "libc", - "redox_syscall", + "redox_syscall 0.5.18", "smallvec", - "windows-targets 0.52.6", + "windows-link", ] [[package]] name = "parquet" -version = "57.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be3e4f6d320dd92bfa7d612e265d7d08bba0a240bab86af3425e1d255a511d89" +checksum = "3f491d0ef1b510194426ee67ddc18a9b747ef3c42050c19322a2cd2e1666c29b" dependencies = [ - "ahash 0.8.12", + "ahash", "arrow-array", "arrow-buffer", - "arrow-cast", "arrow-data", "arrow-ipc", "arrow-schema", @@ -4517,7 +4473,7 @@ dependencies = [ "regex", "regex-syntax", "structmeta", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -4536,6 +4492,16 @@ dependencies = [ "serde", ] +[[package]] +name = "pbjson" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8edd1efdd8ab23ba9cb9ace3d9987a72663d5d7c9f74fa00b51d6213645cf6c" +dependencies = [ + "base64 0.22.1", + "serde", +] + [[package]] name = "pbjson-build" version = "0.8.0" @@ -4548,6 +4514,18 @@ dependencies = [ "prost-types", ] +[[package]] +name = "pbjson-build" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ed4d5c6ae95e08ac768883c8401cf0e8deb4e6e1d6a4e1fd3d2ec4f0ec63200" +dependencies = [ + "heck", + "itertools 0.14.0", + "prost", + "prost-types", +] + [[package]] name = "pbjson-types" version = "0.8.0" @@ -4556,8 +4534,8 @@ checksum = "8e748e28374f10a330ee3bb9f29b828c0ac79831a32bab65015ad9b661ead526" dependencies = [ "bytes", "chrono", - "pbjson", - "pbjson-build", + "pbjson 0.8.0", + "pbjson-build 0.8.0", "prost", "prost-build", "serde", @@ -4569,16 +4547,6 @@ version = "2.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" -[[package]] -name = "petgraph" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" -dependencies = [ - "fixedbitset", - "indexmap 2.12.1", -] - [[package]] name = "petgraph" version = "0.8.3" @@ -4587,7 +4555,7 @@ checksum = "8701b58ea97060d5e5b155d383a69952a60943f0e6dfe30b04c287beb0b27455" dependencies = [ "fixedbitset", "hashbrown 0.15.5", - "indexmap 2.12.1", + "indexmap 2.13.0", "serde", ] @@ -4630,29 +4598,29 @@ dependencies = [ [[package]] name = "pin-project" -version = "1.1.10" +version = "1.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a" +checksum = "f1749c7ed4bcaf4c3d0a3efc28538844fb29bcdd7d2b67b2be7e20ba861ff517" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.10" +version = "1.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" +checksum = "d9b20ed30f105399776b9c883e68e536ef602a16ae6f596d2c473591d6ad64c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] name = "pin-project-lite" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" [[package]] name = "pin-utils" @@ -4666,6 +4634,12 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "plain" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" + [[package]] name = "plotters" version = "0.3.7" @@ -4696,15 +4670,15 @@ dependencies = [ [[package]] name = "portable-atomic" -version = "1.13.0" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f89776e4d69bb58bc6993e99ffa1d11f228b839984854c7daeb5d37f87cbe950" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" [[package]] name = "portable-atomic-util" -version = "0.2.4" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +checksum = "091397be61a01d4be58e7841595bd4bfedb15f1cd54977d79b8271e94ed799a3" dependencies = [ "portable-atomic", ] @@ -4718,14 +4692,14 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] name = "postgres-protocol" -version = "0.6.9" +version = "0.6.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbef655056b916eb868048276cfd5d6a7dea4f81560dfd047f97c8c6fe3fcfd4" +checksum = "3ee9dd5fe15055d2b6806f4736aa0c9637217074e224bbec46d4041b91bb9491" dependencies = [ "base64 0.22.1", "byteorder", @@ -4741,9 +4715,9 @@ dependencies = [ [[package]] name = "postgres-types" -version = "0.2.11" +version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef4605b7c057056dd35baeb6ac0c0338e4975b1f2bef0f65da953285eb007095" +checksum = "54b858f82211e84682fecd373f68e1ceae642d8d751a1ebd13f33de6257b3e20" dependencies = [ "bytes", "chrono", @@ -4754,9 +4728,9 @@ dependencies = [ [[package]] name = "potential_utf" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84df19adbe5b5a0782edcab45899906947ab039ccf4573713735ee7de1e6b08a" +checksum = "b73949432f5e2a09657003c25bca5e19a0e9c84f8058ca374f49e0ebe605af77" dependencies = [ "zerovec", ] @@ -4793,32 +4767,32 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" dependencies = [ "proc-macro2", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] name = "proc-macro-crate" -version = "3.4.0" +version = "3.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "219cb19e96be00ab2e37d6e299658a0cfa83e52429179969b0f0121b4ac46983" +checksum = "e67ba7e9b2b56446f1d419b1d807906278ffa1a658a8a5d8a39dcb1f5a78614f" dependencies = [ "toml_edit", ] [[package]] name = "proc-macro2" -version = "1.0.101" +version = "1.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" dependencies = [ "unicode-ident", ] [[package]] name = "prost" -version = "0.14.1" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7231bd9b3d3d33c86b58adbac74b5ec0ad9f496b19d22801d773636feaa95f3d" +checksum = "d2ea70524a2f82d518bce41317d0fae74151505651af45faf1ffbd6fd33f0568" dependencies = [ "bytes", "prost-derive", @@ -4826,42 +4800,41 @@ dependencies = [ [[package]] name = "prost-build" -version = "0.14.1" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac6c3320f9abac597dcbc668774ef006702672474aad53c6d596b62e487b40b1" +checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7" dependencies = [ "heck", "itertools 0.14.0", "log", "multimap", - "once_cell", - "petgraph 0.7.1", + "petgraph", "prettyplease", "prost", "prost-types", "regex", - "syn 2.0.113", + "syn 2.0.117", "tempfile", ] [[package]] name = "prost-derive" -version = "0.14.1" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9120690fafc389a67ba3803df527d0ec9cbbc9cc45e4cc20b332996dfb672425" +checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" dependencies = [ "anyhow", "itertools 0.14.0", "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] name = "prost-types" -version = "0.14.1" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9b4db3d6da204ed77bb26ba83b6122a73aeb2e87e25fbf7ad2e84c4ccbf8f72" +checksum = "8991c4cbdb8bc5b11f0b074ffe286c30e523de90fee5ba8132f1399f23cb3dd7" dependencies = [ "prost", ] @@ -4877,33 +4850,14 @@ dependencies = [ [[package]] name = "psm" -version = "0.1.26" +version = "0.1.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e944464ec8536cd1beb0bbfd96987eb5e3b72f2ecdafdc5c769a37f1fa2ae1f" +checksum = "3852766467df634d74f0b2d7819bf8dc483a0eb2e3b0f50f756f9cfe8b0d18d8" dependencies = [ + "ar_archive_writer", "cc", ] -[[package]] -name = "ptr_meta" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0738ccf7ea06b608c10564b31debd4f5bc5e197fc8bfe088f68ae5ce81e7a4f1" -dependencies = [ - "ptr_meta_derive", -] - -[[package]] -name = "ptr_meta_derive" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16b845dbfca988fa33db069c0e230574d15a3088f147a87b64c7589eb662c9ac" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "quad-rand" version = "0.2.3" @@ -4912,9 +4866,9 @@ checksum = "5a651516ddc9168ebd67b24afd085a718be02f8858fe406591b013d101ce2f40" [[package]] name = "quick-xml" -version = "0.38.3" +version = "0.38.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42a232e7487fc2ef313d96dde7948e7a3c05101870d8985e4fd8d26aedd27b89" +checksum = "b66c2058c55a409d601666cffe35f04333cf1013010882cec174a7467cd4e21c" dependencies = [ "memchr", "serde", @@ -4942,9 +4896,9 @@ dependencies = [ [[package]] name = "quinn-proto" -version = "0.11.13" +version = "0.11.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" +checksum = "434b42fec591c96ef50e21e886936e66d3cc3f737104fdb9b737c40ffb94c098" dependencies = [ "bytes", "getrandom 0.3.4", @@ -4972,14 +4926,14 @@ dependencies = [ "once_cell", "socket2", "tracing", - "windows-sys 0.60.2", + "windows-sys 0.59.0", ] [[package]] name = "quote" -version = "1.0.41" +version = "1.0.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce25767e7b499d1b604768e7cde645d14cc8584231ea6b295e9c9eb22c02e1d1" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" dependencies = [ "proc-macro2", ] @@ -4991,10 +4945,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" [[package]] -name = "radium" -version = "0.7.0" +name = "r-efi" +version = "6.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" [[package]] name = "radix_trie" @@ -5024,7 +4978,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ "rand_chacha 0.9.0", - "rand_core 0.9.3", + "rand_core 0.9.5", ] [[package]] @@ -5044,7 +4998,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", - "rand_core 0.9.3", + "rand_core 0.9.5", ] [[package]] @@ -5053,14 +5007,14 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom 0.2.16", + "getrandom 0.2.17", ] [[package]] name = "rand_core" -version = "0.9.3" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" dependencies = [ "getrandom 0.3.4", ] @@ -5112,14 +5066,23 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" dependencies = [ "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] name = "redox_syscall" -version = "0.5.17" +version = "0.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5407465600fb0548f1442edf71dd20683c6ed326200ace4b1ef0763521bb3b77" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + +[[package]] +name = "redox_syscall" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce70a74e890531977d37e532c34d45e9055d2409ed08ddba14529471ed0be16" dependencies = [ "bitflags", ] @@ -5130,36 +5093,36 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4e608c6638b9c18977b00b475ac1f28d14e84b27d8d42f70e0bf1e3dec127ac" dependencies = [ - "getrandom 0.2.16", + "getrandom 0.2.17", "libredox", "thiserror", ] [[package]] name = "ref-cast" -version = "1.0.24" +version = "1.0.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a0ae411dbe946a674d89546582cea4ba2bb8defac896622d6496f14c23ba5cf" +checksum = "f354300ae66f76f1c85c5f84693f0ce81d747e2c3f21a45fef496d89c960bf7d" dependencies = [ "ref-cast-impl", ] [[package]] name = "ref-cast-impl" -version = "1.0.24" +version = "1.0.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1165225c21bff1f3bbce98f5a1f889949bc902d3575308cc7b0de30b4f6d27c7" +checksum = "b7186006dcb21920990093f30e3dea63b7d6e977bf1256be20c3563a5db070da" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] name = "regex" -version = "1.12.2" +version = "1.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "843bc0191f75f3e22651ae5f1e72939ab2f72a4bc30fa80a066bd66edefc24d4" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" dependencies = [ "aho-corasick", "memchr", @@ -5169,9 +5132,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.13" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5276caf25ac86c8d810222b3dbb938e512c55c6831a10f3e6ed1c93b84041f1c" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" dependencies = [ "aho-corasick", "memchr", @@ -5180,23 +5143,23 @@ dependencies = [ [[package]] name = "regex-lite" -version = "0.1.8" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d942b98df5e658f56f20d592c7f868833fe38115e65c33003d8cd224b0155da" +checksum = "cab834c73d247e67f4fae452806d17d3c7501756d98c8808d7c9c7aa7d18f973" [[package]] name = "regex-syntax" -version = "0.8.6" +version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "caf4aa5b0f434c91fe5c7f1ecb6a5ece2130b02ad2a590589dda5146df959001" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" [[package]] name = "regress" -version = "0.10.4" +version = "0.10.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "145bb27393fe455dd64d6cbc8d059adfa392590a45eadf079c01b11857e7b010" +checksum = "2057b2325e68a893284d1538021ab90279adac1139957ca2a74426c6f118fb48" dependencies = [ - "hashbrown 0.15.5", + "hashbrown 0.16.1", "memchr", ] @@ -5206,15 +5169,6 @@ version = "1.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" -[[package]] -name = "rend" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71fe3824f5629716b1589be05dacd749f6aa084c87e00e016714a8cdfccc997c" -dependencies = [ - "bytecheck", -] - [[package]] name = "repr_offset" version = "0.2.2" @@ -5226,16 +5180,16 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.12.23" +version = "0.12.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d429f34c8092b2d42c7c93cec323bb4adeb7c67698f70839adec842ec10c7ceb" +checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" dependencies = [ "base64 0.22.1", "bytes", "futures-core", "futures-util", "h2", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "http-body-util", "hyper", @@ -5274,41 +5228,12 @@ checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" dependencies = [ "cc", "cfg-if", - "getrandom 0.2.16", + "getrandom 0.2.17", "libc", "untrusted", "windows-sys 0.52.0", ] -[[package]] -name = "rkyv" -version = "0.7.45" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9008cd6385b9e161d8229e1f6549dd23c3d022f132a2ea37ac3a10ac4935779b" -dependencies = [ - "bitvec", - "bytecheck", - "bytes", - "hashbrown 0.12.3", - "ptr_meta", - "rend", - "rkyv_derive", - "seahash", - "tinyvec", - "uuid", -] - -[[package]] -name = "rkyv_derive" -version = "0.7.45" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "503d1d27590a2b0a3a4ca4c94755aa2875657196ecbf401a42eff41d7de532c0" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "rstest" version = "0.26.1" @@ -5334,7 +5259,7 @@ dependencies = [ "regex", "relative-path", "rustc_version", - "syn 2.0.113", + "syn 2.0.117", "unicode-ident", ] @@ -5346,24 +5271,7 @@ checksum = "b3a8fb4672e840a587a66fc577a5491375df51ddb88f2a2c2a792598c326fe14" dependencies = [ "quote", "rand 0.8.5", - "syn 2.0.113", -] - -[[package]] -name = "rust_decimal" -version = "1.38.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8975fc98059f365204d635119cf9c5a60ae67b841ed49b5422a9a7e56cdfac0" -dependencies = [ - "arrayvec", - "borsh", - "bytes", - "num-traits", - "postgres-types", - "rand 0.8.5", - "rkyv", - "serde", - "serde_json", + "syn 2.0.117", ] [[package]] @@ -5383,22 +5291,22 @@ dependencies = [ [[package]] name = "rustix" -version = "1.1.2" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e" +checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" dependencies = [ "bitflags", "errno", "libc", "linux-raw-sys", - "windows-sys 0.61.0", + "windows-sys 0.59.0", ] [[package]] name = "rustls" -version = "0.23.32" +version = "0.23.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd3c25631629d034ce7cd9940adc9d45762d46de2b0f57193c4443b92c6d4d40" +checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" dependencies = [ "aws-lc-rs", "log", @@ -5412,9 +5320,9 @@ dependencies = [ [[package]] name = "rustls-native-certs" -version = "0.8.1" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3" +checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63" dependencies = [ "openssl-probe", "rustls-pki-types", @@ -5422,20 +5330,11 @@ dependencies = [ "security-framework", ] -[[package]] -name = "rustls-pemfile" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" -dependencies = [ - "rustls-pki-types", -] - [[package]] name = "rustls-pki-types" -version = "1.12.0" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79" +checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" dependencies = [ "web-time", "zeroize", @@ -5443,9 +5342,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.103.6" +version = "0.103.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8572f3c2cb9934231157b45499fc41e1f58c589fdfb81a844ba873265e80f8eb" +checksum = "20a6af516fea4b20eccceaf166e8aa666ac996208e8a644ce3ef5aa783bc7cd4" dependencies = [ "aws-lc-rs", "ring", @@ -5473,19 +5372,19 @@ dependencies = [ "libc", "log", "memchr", - "nix", + "nix 0.30.1", "radix_trie", "unicode-segmentation", - "unicode-width 0.2.1", + "unicode-width 0.2.2", "utf8parse", "windows-sys 0.60.2", ] [[package]] name = "ryu" -version = "1.0.20" +version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" [[package]] name = "same-file" @@ -5498,11 +5397,11 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.28" +version = "0.1.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1" +checksum = "91c1b7e4904c873ef0710c1f407dde2e6287de2bebc1bbbf7d430bb7cbffd939" dependencies = [ - "windows-sys 0.61.0", + "windows-sys 0.61.2", ] [[package]] @@ -5531,9 +5430,9 @@ dependencies = [ [[package]] name = "schemars" -version = "1.0.4" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82d20c4491bc164fa2f6c5d44565947a52ad80b9505d8e36f8d54c27c739fcd0" +checksum = "a2b42f36aa1cd011945615b92222f6bf73c599a102a300334cd7f8dbeec726cc" dependencies = [ "dyn-clone", "ref-cast", @@ -5550,7 +5449,7 @@ dependencies = [ "proc-macro2", "quote", "serde_derive_internals", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -5559,17 +5458,11 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" -[[package]] -name = "seahash" -version = "4.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c107b6f4780854c8b126e228ea8869f4d7b71260f962fefb57b996b8959ba6b" - [[package]] name = "security-framework" -version = "3.5.0" +version = "3.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc198e42d9b7510827939c9a15f5062a0c913f3371d765977e586d2fe6c16f4a" +checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d" dependencies = [ "bitflags", "core-foundation", @@ -5580,9 +5473,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.15.0" +version = "2.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0" +checksum = "6ce2691df843ecc5d231c0b14ece2acc3efb62c0a398c7e1d875f3983ce020e3" dependencies = [ "core-foundation-sys", "libc", @@ -5641,7 +5534,7 @@ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -5652,20 +5545,20 @@ checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] name = "serde_json" -version = "1.0.145" +version = "1.0.149" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" dependencies = [ "itoa", "memchr", - "ryu", "serde", "serde_core", + "zmij", ] [[package]] @@ -5676,19 +5569,19 @@ checksum = "175ee3e80ae9982737ca543e96133087cbd9a485eecc3bc4de9c1a37b47ea59c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] name = "serde_tokenstream" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64060d864397305347a78851c51588fd283767e7e7589829e8121d65512340f1" +checksum = "d7c49585c52c01f13c5c2ebb333f14f6885d76daa768d8a037d28017ec538c69" dependencies = [ "proc-macro2", "quote", "serde", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -5705,19 +5598,18 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.14.1" +version = "3.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c522100790450cf78eeac1507263d0a350d4d5b30df0c8e1fe051a10c22b376e" +checksum = "dd5414fad8e6907dbdd5bc441a50ae8d6e26151a03b1de04d89a5576de61d01f" dependencies = [ "base64 0.22.1", "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.12.1", + "indexmap 2.13.0", "schemars 0.9.0", - "schemars 1.0.4", - "serde", - "serde_derive", + "schemars 1.2.1", + "serde_core", "serde_json", "serde_with_macros", "time", @@ -5725,14 +5617,14 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.14.1" +version = "3.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "327ada00f7d64abaac1e55a6911e90cf665aa051b9a561c7006c157f4633135e" +checksum = "d3db8978e608f1fe7357e211969fd9abdcae80bac1ba7a3369bb7eb6b404eb65" dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -5741,7 +5633,7 @@ version = "0.9.34+deprecated" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" dependencies = [ - "indexmap 2.12.1", + "indexmap 2.13.0", "itoa", "ryu", "serde", @@ -5787,18 +5679,19 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "signal-hook-registry" -version = "1.4.6" +version = "1.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2a4719bff48cee6b39d12c020eeb490953ad2443b7055bd0b21fca26bd8c28b" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" dependencies = [ + "errno", "libc", ] [[package]] name = "simd-adler32" -version = "0.3.7" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" +checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" [[package]] name = "simdutf8" @@ -5814,15 +5707,15 @@ checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" [[package]] name = "siphasher" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" +checksum = "b2aa850e253778c88a04c3d7323b043aeda9d3e30d5971937c1855769763678e" [[package]] name = "slab" -version = "0.4.11" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589" +checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" [[package]] name = "smallvec" @@ -5856,19 +5749,19 @@ dependencies = [ [[package]] name = "socket2" -version = "0.6.0" +version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "233504af464074f9d066d7b5416c5f9b894a5862a6506e306f7b816cdd6f1807" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" dependencies = [ "libc", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] name = "sqllogictest" -version = "0.28.4" +version = "0.29.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3566426f72a13e393aa34ca3d542c5b0eb86da4c0db137ee9b5cfccc6179e52d" +checksum = "d03b2262a244037b0b510edbd25a8e6c9fb8d73ee0237fc6cc95a54c16f94a82" dependencies = [ "async-trait", "educe", @@ -5891,9 +5784,9 @@ dependencies = [ [[package]] name = "sqlparser" -version = "0.59.0" +version = "0.61.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4591acadbcf52f0af60eafbb2c003232b2b4cd8de5f0e9437cb8b1b59046cc0f" +checksum = "dbf5ea8d4d7c808e1af1cbabebca9a2abe603bcefc22294c5b95018d53200cb7" dependencies = [ "log", "recursive", @@ -5902,26 +5795,26 @@ dependencies = [ [[package]] name = "sqlparser_derive" -version = "0.3.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da5fc6819faabb412da764b99d3b713bb55083c11e7e0c00144d386cd6a1939c" +checksum = "a6dd45d8fc1c79299bfbb7190e42ccbbdf6a5f52e4a6ad98d92357ea965bd289" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] name = "stable_deref_trait" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" [[package]] name = "stacker" -version = "0.1.21" +version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cddb07e32ddb770749da91081d8d0ac3a16f1a569a18b20348cd371f5dead06b" +checksum = "08d74a23609d509411d10e2176dc2a4346e3b4aea2e7b1869f19fdedbc71c013" dependencies = [ "cc", "cfg-if", @@ -5956,7 +5849,7 @@ dependencies = [ "proc-macro2", "quote", "structmeta-derive", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -5967,44 +5860,43 @@ checksum = "152a0b65a590ff6c3da95cabe2353ee04e6167c896b28e3b14478c2636c922fc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] name = "strum" -version = "0.26.3" +version = "0.27.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" +checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf" [[package]] name = "strum" -version = "0.27.2" +version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf" +checksum = "9628de9b8791db39ceda2b119bbe13134770b56c138ec1d3af810d045c04f9bd" [[package]] name = "strum_macros" -version = "0.26.4" +version = "0.27.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" +checksum = "7695ce3845ea4b33927c055a39dc438a45b059f7c1b3d91d38d10355fb8cbca7" dependencies = [ "heck", "proc-macro2", "quote", - "rustversion", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] name = "strum_macros" -version = "0.27.2" +version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7695ce3845ea4b33927c055a39dc438a45b059f7c1b3d91d38d10355fb8cbca7" +checksum = "ab85eea0270ee17587ed4156089e10b9e6880ee688791d45a905f5b1ca36f664" dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -6024,8 +5916,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62fc4b483a129b9772ccb9c3f7945a472112fdd9140da87f8a4e7f1d44e045d0" dependencies = [ "heck", - "pbjson", - "pbjson-build", + "pbjson 0.8.0", + "pbjson-build 0.8.0", "pbjson-types", "prettyplease", "prost", @@ -6038,7 +5930,7 @@ dependencies = [ "serde", "serde_json", "serde_yaml", - "syn 2.0.113", + "syn 2.0.117", "typify", "walkdir", ] @@ -6062,9 +5954,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.113" +version = "2.0.117" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "678faa00651c9eb72dd2020cbdf275d92eccb2400d568e419efdd64838145cb4" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" dependencies = [ "proc-macro2", "quote", @@ -6088,14 +5980,14 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] name = "sysinfo" -version = "0.37.2" +version = "0.38.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16607d5caffd1c07ce073528f9ed972d88db15dd44023fa57142963be3feb11f" +checksum = "92ab6a2f8bfe508deb3c6406578252e491d299cbbf3bc0529ecc3313aee4a52f" dependencies = [ "libc", "memchr", @@ -6105,23 +5997,17 @@ dependencies = [ "windows", ] -[[package]] -name = "tap" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" - [[package]] name = "tempfile" -version = "3.23.0" +version = "3.27.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16" +checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" dependencies = [ "fastrand", - "getrandom 0.3.4", + "getrandom 0.4.2", "once_cell", "rustix", - "windows-sys 0.61.0", + "windows-sys 0.59.0", ] [[package]] @@ -6137,9 +6023,9 @@ dependencies = [ [[package]] name = "testcontainers" -version = "0.26.3" +version = "0.27.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a81ec0158db5fbb9831e09d1813fe5ea9023a2b5e6e8e0a5fe67e2a820733629" +checksum = "0bd36b06a2a6c0c3c81a83be1ab05fe86460d054d4d51bf513bc56b3e15bdc22" dependencies = [ "astral-tokio-tar", "async-trait", @@ -6150,6 +6036,7 @@ dependencies = [ "etcetera", "ferroid", "futures", + "http 1.4.0", "itertools 0.14.0", "log", "memchr", @@ -6167,31 +6054,31 @@ dependencies = [ [[package]] name = "testcontainers-modules" -version = "0.14.0" +version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e75e78ff453128a2c7da9a5d5a3325ea34ea214d4bf51eab3417de23a4e5147" +checksum = "e5985fde5befe4ffa77a052e035e16c2da86e8bae301baa9f9904ad3c494d357" dependencies = [ "testcontainers", ] [[package]] name = "thiserror" -version = "2.0.17" +version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "2.0.17" +version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -6216,30 +6103,30 @@ dependencies = [ [[package]] name = "time" -version = "0.3.44" +version = "0.3.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91e7d9e3bb61134e77bde20dd4825b97c010155709965fedf0f49bb138e52a9d" +checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c" dependencies = [ "deranged", "itoa", "num-conv", "powerfmt", - "serde", + "serde_core", "time-core", "time-macros", ] [[package]] name = "time-core" -version = "0.1.6" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b" +checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca" [[package]] name = "time-macros" -version = "0.2.24" +version = "0.2.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30cfb0125f12d9c277f35663a0a33f8c30190f4e4574868a330595412d34ebf3" +checksum = "2e70e4c5a0e0a8a4823ad65dfe1a6930e4f4d756dcd9dd7939022b5e8c501215" dependencies = [ "num-conv", "time-core", @@ -6256,9 +6143,9 @@ dependencies = [ [[package]] name = "tinystr" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d4f6d1145dcb577acf783d4e601bc1d76a13337bb54e6233add580b07344c8b" +checksum = "42d3e9c45c09de15d06dd8acf5f4e0e399e85927b7f00711024eb7ae10fa4869" dependencies = [ "displaydoc", "zerovec", @@ -6276,9 +6163,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" +checksum = "3e61e67053d25a4e82c844e8424039d9745781b3fc4f32b8d55ed50f5f667ef3" dependencies = [ "tinyvec_macros", ] @@ -6291,9 +6178,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.48.0" +version = "1.50.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff360e02eab121e0bc37a2d3b4d4dc622e6eda3a8e5253d5435ecf5bd4c68408" +checksum = "27ad5e34374e03cfffefc301becb44e9dc3c17584f414349ebe29ed26661822d" dependencies = [ "bytes", "libc", @@ -6303,25 +6190,25 @@ dependencies = [ "signal-hook-registry", "socket2", "tokio-macros", - "windows-sys 0.61.0", + "windows-sys 0.61.2", ] [[package]] name = "tokio-macros" -version = "2.6.0" +version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5" +checksum = "5c55a2eff8b69ce66c84f85e1da1c233edc36ceb85a2058d11b0d6a3c7e7569c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] name = "tokio-postgres" -version = "0.7.14" +version = "0.7.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a156efe7fff213168257853e1dfde202eed5f487522cbbbf7d219941d753d853" +checksum = "dcea47c8f71744367793f16c2db1f11cb859d28f436bdb4ca9193eb1f787ee42" dependencies = [ "async-trait", "byteorder", @@ -6345,9 +6232,9 @@ dependencies = [ [[package]] name = "tokio-rustls" -version = "0.26.3" +version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f63835928ca123f1bef57abbcd23bb2ba0ac9ae1235f1e65bda0d06e7786bd" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" dependencies = [ "rustls", "tokio", @@ -6355,20 +6242,21 @@ dependencies = [ [[package]] name = "tokio-stream" -version = "0.1.17" +version = "0.1.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" +checksum = "32da49809aab5c3bc678af03902d4ccddea2a87d028d86392a4b1560c6906c70" dependencies = [ "futures-core", "pin-project-lite", "tokio", + "tokio-util", ] [[package]] name = "tokio-util" -version = "0.7.17" +version = "0.7.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2efa149fe76073d6e8fd97ef4f4eca7b67f599660115591483572e406e165594" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" dependencies = [ "bytes", "futures-core", @@ -6379,20 +6267,20 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.7.2" +version = "1.0.0+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32f1085dec27c2b6632b04c80b3bb1b4300d6495d1e129693bdda7d91e72eec1" +checksum = "32c2555c699578a4f59f0cc68e5116c8d7cabbd45e1409b989d4be085b53f13e" dependencies = [ "serde_core", ] [[package]] name = "toml_edit" -version = "0.23.6" +version = "0.25.4+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3effe7c0e86fdff4f69cdd2ccc1b96f933e24811c5441d44904e8683e27184b" +checksum = "7193cbd0ce53dc966037f54351dbbcf0d5a642c7f0038c382ef9e677ce8c13f2" dependencies = [ - "indexmap 2.12.1", + "indexmap 2.13.0", "toml_datetime", "toml_parser", "winnow", @@ -6400,25 +6288,25 @@ dependencies = [ [[package]] name = "toml_parser" -version = "1.0.3" +version = "1.0.9+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cf893c33be71572e0e9aa6dd15e6677937abd686b066eac3f8cd3531688a627" +checksum = "702d4415e08923e7e1ef96cd5727c0dfed80b4d2fa25db9647fe5eb6f7c5a4c4" dependencies = [ "winnow", ] [[package]] name = "tonic" -version = "0.14.2" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb7613188ce9f7df5bfe185db26c5814347d110db17920415cf2fbcad85e7203" +checksum = "fec7c61a0695dc1887c1b53952990f3ad2e3a31453e1f49f10e75424943a93ec" dependencies = [ "async-trait", "axum", "base64 0.22.1", "bytes", "h2", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "http-body-util", "hyper", @@ -6438,9 +6326,9 @@ dependencies = [ [[package]] name = "tonic-prost" -version = "0.14.2" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66bd50ad6ce1252d87ef024b3d64fe4c3cf54a86fb9ef4c631fdd0ded7aeaa67" +checksum = "a55376a0bbaa4975a3f10d009ad763d8f4108f067c7c2e74f3001fb49778d309" dependencies = [ "bytes", "prost", @@ -6449,13 +6337,13 @@ dependencies = [ [[package]] name = "tower" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" +checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" dependencies = [ "futures-core", "futures-util", - "indexmap 2.12.1", + "indexmap 2.13.0", "pin-project-lite", "slab", "sync_wrapper", @@ -6468,14 +6356,14 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.6.6" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adc82fd73de2a9722ac5da747f12383d2bfdb93591ee6c58486e0097890f05f2" +checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" dependencies = [ "bitflags", "bytes", "futures-util", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "iri-string", "pin-project-lite", @@ -6515,7 +6403,7 @@ checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -6541,9 +6429,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.22" +version = "0.3.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" +checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319" dependencies = [ "nu-ansi-term", "sharded-slab", @@ -6588,9 +6476,9 @@ checksum = "6af6ae20167a9ece4bcb41af5b80f8a1f1df981f6391189ce00fd257af04126a" [[package]] name = "typenum" -version = "1.18.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" [[package]] name = "typewit" @@ -6623,7 +6511,7 @@ dependencies = [ "semver", "serde", "serde_json", - "syn 2.0.113", + "syn 2.0.117", "thiserror", "unicode-ident", ] @@ -6641,7 +6529,7 @@ dependencies = [ "serde", "serde_json", "serde_tokenstream", - "syn 2.0.113", + "syn 2.0.117", "typify-impl", ] @@ -6653,24 +6541,24 @@ checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5" [[package]] name = "unicode-ident" -version = "1.0.19" +version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f63a545481291138910575129486daeaf8ac54aee4387fe7906919f7830c7d9d" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" [[package]] name = "unicode-normalization" -version = "0.1.24" +version = "0.1.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956" +checksum = "5fd4f6878c9cb28d874b009da9e8d183b5abc80117c40bbd187a1fde336be6e8" dependencies = [ "tinyvec", ] [[package]] name = "unicode-properties" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0" +checksum = "7df058c713841ad818f1dc5d3fd88063241cc61f49f5fbea4b951e8cf5a8d71d" [[package]] name = "unicode-segmentation" @@ -6686,15 +6574,21 @@ checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" [[package]] name = "unicode-width" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a1a07cc7db3810833284e8d372ccdc6da29741639ecc70c9ec107df0fa6154c" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" + +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" [[package]] name = "unit-prefix" -version = "0.5.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "323402cff2dd658f39ca17c789b502021b3f18707c91cdf22e3838e1b4023817" +checksum = "81e544489bf3d8ef66c953931f56617f423cd4b5494be343d9b9d3dda037b9a3" [[package]] name = "unsafe-libyaml" @@ -6710,43 +6604,42 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "ureq" -version = "3.1.2" +version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99ba1025f18a4a3fc3e9b48c868e9beb4f24f4b4b1a325bada26bd4119f46537" +checksum = "fdc97a28575b85cfedf2a7e7d3cc64b3e11bd8ac766666318003abbacc7a21fc" dependencies = [ "base64 0.22.1", "log", "percent-encoding", "rustls", - "rustls-pemfile", "rustls-pki-types", "ureq-proto", "utf-8", - "webpki-roots", ] [[package]] name = "ureq-proto" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60b4531c118335662134346048ddb0e54cc86bd7e81866757873055f0e38f5d2" +checksum = "d81f9efa9df032be5934a46a068815a10a042b494b6a58cb0a1a97bb5467ed6f" dependencies = [ "base64 0.22.1", - "http 1.3.1", + "http 1.4.0", "httparse", "log", ] [[package]] name = "url" -version = "2.5.7" +version = "2.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08bc136a29a3d1758e07a9cca267be308aeebf5cfd5a10f3f67ab2097683ef5b" +checksum = "ff67a8a4397373c3ef660812acab3268222035010ab8680ec4215f38ba3d0eed" dependencies = [ "form_urlencoded", "idna", "percent-encoding", "serde", + "serde_derive", ] [[package]] @@ -6775,11 +6668,11 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.19.0" +version = "1.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2e054861b4bd027cd373e18e8d8d8e6548085000e41290d95ce0c373a654b4a" +checksum = "a68d3c8f01c0cfa54a75291d83601161799e4a89a39e0929f4b0354d88757a37" dependencies = [ - "getrandom 0.3.4", + "getrandom 0.4.2", "js-sys", "serde_core", "wasm-bindgen", @@ -6828,26 +6721,47 @@ version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" +[[package]] +name = "wasi" +version = "0.14.7+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "883478de20367e224c0090af9cf5f9fa85bed63a95c1abf3afc5c083ebc06e8c" +dependencies = [ + "wasip2", +] + [[package]] name = "wasip2" -version = "1.0.1+wasi-0.2.4" +version = "1.0.2+wasi-0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" dependencies = [ "wit-bindgen", ] [[package]] name = "wasite" -version = "0.1.0" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" +checksum = "66fe902b4a6b8028a753d5424909b764ccf79b7a209eac9bf97e59cda9f71a42" +dependencies = [ + "wasi 0.14.7+wasi-0.2.4", +] [[package]] name = "wasm-bindgen" -version = "0.2.106" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d759f433fa64a2d763d1340820e46e111a7a5ab75f993d1852d70b03dbb80fd" +checksum = "6532f9a5c1ece3798cb1c2cfdba640b9b3ba884f5db45973a6f442510a87d38e" dependencies = [ "cfg-if", "once_cell", @@ -6858,11 +6772,12 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.56" +version = "0.4.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "836d9622d604feee9e5de25ac10e3ea5f2d65b41eac0d9ce72eb5deae707ce7c" +checksum = "e9c5522b3a28661442748e09d40924dfb9ca614b21c00d3fd135720e48b67db8" dependencies = [ "cfg-if", + "futures-util", "js-sys", "once_cell", "wasm-bindgen", @@ -6871,9 +6786,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.106" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48cb0d2638f8baedbc542ed444afc0644a29166f1595371af4fecf8ce1e7eeb3" +checksum = "18a2d50fcf105fb33bb15f00e7a77b772945a2ee45dcf454961fd843e74c18e6" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -6881,31 +6796,31 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.106" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cefb59d5cd5f92d9dcf80e4683949f15ca4b511f4ac0a6e14d4e1ac60c6ecd40" +checksum = "03ce4caeaac547cdf713d280eda22a730824dd11e6b8c3ca9e42247b25c631e3" dependencies = [ "bumpalo", "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.106" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cbc538057e648b67f72a982e708d485b2efa771e1ac05fec311f9f63e5800db4" +checksum = "75a326b8c223ee17883a4251907455a2431acc2791c98c26279376490c378c16" dependencies = [ "unicode-ident", ] [[package]] name = "wasm-bindgen-test" -version = "0.3.56" +version = "0.3.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25e90e66d265d3a1efc0e72a54809ab90b9c0c515915c67cdf658689d2c22c6c" +checksum = "6311c867385cc7d5602463b31825d454d0837a3aba7cdb5e56d5201792a3f7fe" dependencies = [ "async-trait", "cast", @@ -6920,17 +6835,46 @@ dependencies = [ "wasm-bindgen", "wasm-bindgen-futures", "wasm-bindgen-test-macro", + "wasm-bindgen-test-shared", ] [[package]] name = "wasm-bindgen-test-macro" -version = "0.3.56" +version = "0.3.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7150335716dce6028bead2b848e72f47b45e7b9422f64cccdc23bedca89affc1" +checksum = "67008cdde4769831958536b0f11b3bdd0380bde882be17fff9c2f34bb4549abd" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", +] + +[[package]] +name = "wasm-bindgen-test-shared" +version = "0.2.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfe29135b180b72b04c74aa97b2b4a2ef275161eff9a6c7955ea9eaedc7e1d4e" + +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap 2.13.0", + "wasm-encoder", + "wasmparser", ] [[package]] @@ -6946,11 +6890,23 @@ dependencies = [ "web-sys", ] +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap 2.13.0", + "semver", +] + [[package]] name = "web-sys" -version = "0.3.83" +version = "0.3.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b32828d774c412041098d182a8b38b16ea816958e07cf40eec2bc080ae137ac" +checksum = "854ba17bb104abfb26ba36da9729addc7ce7f06f5c0f90f3c391f8461cca21f9" dependencies = [ "js-sys", "wasm-bindgen", @@ -6966,22 +6922,15 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "webpki-roots" -version = "1.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32b130c0d2d49f8b6889abc456e795e82525204f27c42cf767cf0d7734e089b8" -dependencies = [ - "rustls-pki-types", -] - [[package]] name = "whoami" -version = "1.6.1" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d4a4db5077702ca3015d3d02d74974948aba2ad9e12ab7df718ee64ccd7e97d" +checksum = "d6a5b12f9df4f978d2cfdb1bd3bac52433f44393342d7ee9c25f5a1c14c0f45d" dependencies = [ + "libc", "libredox", + "objc2-system-configuration", "wasite", "web-sys", ] @@ -7008,7 +6957,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.61.0", + "windows-sys 0.59.0", ] [[package]] @@ -7019,110 +6968,103 @@ checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "windows" -version = "0.61.3" +version = "0.62.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9babd3a767a4c1aef6900409f85f5d53ce2544ccdfaa86dad48c91782c6d6893" +checksum = "527fadee13e0c05939a6a05d5bd6eec6cd2e3dbd648b9f8e447c6518133d8580" dependencies = [ "windows-collections", "windows-core", "windows-future", - "windows-link 0.1.3", "windows-numerics", ] [[package]] name = "windows-collections" -version = "0.2.0" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3beeceb5e5cfd9eb1d76b381630e82c4241ccd0d27f1a39ed41b2760b255c5e8" +checksum = "23b2d95af1a8a14a3c7367e1ed4fc9c20e0a26e79551b1454d72583c97cc6610" dependencies = [ "windows-core", ] [[package]] name = "windows-core" -version = "0.61.2" +version = "0.62.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" dependencies = [ "windows-implement", "windows-interface", - "windows-link 0.1.3", + "windows-link", "windows-result", "windows-strings", ] [[package]] name = "windows-future" -version = "0.2.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc6a41e98427b19fe4b73c550f060b59fa592d7d686537eebf9385621bfbad8e" +checksum = "e1d6f90251fe18a279739e78025bd6ddc52a7e22f921070ccdc67dde84c605cb" dependencies = [ "windows-core", - "windows-link 0.1.3", + "windows-link", "windows-threading", ] [[package]] name = "windows-implement" -version = "0.60.0" +version = "0.60.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] name = "windows-interface" -version = "0.59.1" +version = "0.59.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] name = "windows-link" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" - -[[package]] -name = "windows-link" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45e46c0661abb7180e7b9c281db115305d49ca1709ab8242adf09666d2173c65" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" [[package]] name = "windows-numerics" -version = "0.2.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9150af68066c4c5c07ddc0ce30421554771e528bde427614c61038bc2c92c2b1" +checksum = "6e2e40844ac143cdb44aead537bbf727de9b044e107a0f1220392177d15b0f26" dependencies = [ "windows-core", - "windows-link 0.1.3", + "windows-link", ] [[package]] name = "windows-result" -version = "0.3.4" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" dependencies = [ - "windows-link 0.1.3", + "windows-link", ] [[package]] name = "windows-strings" -version = "0.4.2" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" dependencies = [ - "windows-link 0.1.3", + "windows-link", ] [[package]] @@ -7149,16 +7091,16 @@ version = "0.60.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" dependencies = [ - "windows-targets 0.53.3", + "windows-targets 0.53.5", ] [[package]] name = "windows-sys" -version = "0.61.0" +version = "0.61.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e201184e40b2ede64bc2ea34968b28e33622acdbbf37104f0e4a33f7abe657aa" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" dependencies = [ - "windows-link 0.2.0", + "windows-link", ] [[package]] @@ -7179,28 +7121,28 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.53.3" +version = "0.53.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5fe6031c4041849d7c496a8ded650796e7b6ecc19df1a431c1a363342e5dc91" +checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" dependencies = [ - "windows-link 0.1.3", - "windows_aarch64_gnullvm 0.53.0", - "windows_aarch64_msvc 0.53.0", - "windows_i686_gnu 0.53.0", - "windows_i686_gnullvm 0.53.0", - "windows_i686_msvc 0.53.0", - "windows_x86_64_gnu 0.53.0", - "windows_x86_64_gnullvm 0.53.0", - "windows_x86_64_msvc 0.53.0", + "windows-link", + "windows_aarch64_gnullvm 0.53.1", + "windows_aarch64_msvc 0.53.1", + "windows_i686_gnu 0.53.1", + "windows_i686_gnullvm 0.53.1", + "windows_i686_msvc 0.53.1", + "windows_x86_64_gnu 0.53.1", + "windows_x86_64_gnullvm 0.53.1", + "windows_x86_64_msvc 0.53.1", ] [[package]] name = "windows-threading" -version = "0.1.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b66463ad2e0ea3bbf808b7f1d371311c80e115c0b71d60efc142cafbcfb057a6" +checksum = "3949bd5b99cafdf1c7ca86b43ca564028dfe27d66958f2470940f73d86d75b37" dependencies = [ - "windows-link 0.1.3", + "windows-link", ] [[package]] @@ -7211,9 +7153,9 @@ checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] name = "windows_aarch64_gnullvm" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" +checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" [[package]] name = "windows_aarch64_msvc" @@ -7223,9 +7165,9 @@ checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] name = "windows_aarch64_msvc" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" +checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" [[package]] name = "windows_i686_gnu" @@ -7235,9 +7177,9 @@ checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] name = "windows_i686_gnu" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" +checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" [[package]] name = "windows_i686_gnullvm" @@ -7247,9 +7189,9 @@ checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_gnullvm" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" +checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" [[package]] name = "windows_i686_msvc" @@ -7259,9 +7201,9 @@ checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_i686_msvc" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" +checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" [[package]] name = "windows_x86_64_gnu" @@ -7271,9 +7213,9 @@ checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnu" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" +checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" [[package]] name = "windows_x86_64_gnullvm" @@ -7283,9 +7225,9 @@ checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] name = "windows_x86_64_gnullvm" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" +checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" [[package]] name = "windows_x86_64_msvc" @@ -7295,40 +7237,113 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "windows_x86_64_msvc" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" +checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" [[package]] name = "winnow" -version = "0.7.13" +version = "0.7.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21a0236b59786fed61e2a80582dd500fe61f18b5dca67a4a067d0bc9039339cf" +checksum = "df79d97927682d2fd8adb29682d1140b343be4ac0f08fd68b7765d9c059d3945" dependencies = [ "memchr", ] [[package]] name = "wit-bindgen" -version = "0.46.0" +version = "0.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] [[package]] -name = "writeable" -version = "0.6.1" +name = "wit-bindgen-core" +version = "0.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] [[package]] -name = "wyz" -version = "0.5.1" +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck", + "indexmap 2.13.0", + "prettyplease", + "syn 2.0.117", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn 2.0.117", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags", + "indexmap 2.13.0", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" dependencies = [ - "tap", + "anyhow", + "id-arena", + "indexmap 2.13.0", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", ] +[[package]] +name = "writeable" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" + [[package]] name = "xattr" version = "1.6.1" @@ -7353,11 +7368,10 @@ checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" [[package]] name = "yoke" -version = "0.8.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f41bb01b8226ef4bfd589436a297c53d118f65921786300e427be8d487695cc" +checksum = "72d6e5c6afb84d73944e5cedb052c4680d5657337201555f9f2a16b7406d4954" dependencies = [ - "serde", "stable_deref_trait", "yoke-derive", "zerofrom", @@ -7365,34 +7379,34 @@ dependencies = [ [[package]] name = "yoke-derive" -version = "0.8.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6" +checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", "synstructure", ] [[package]] name = "zerocopy" -version = "0.8.27" +version = "0.8.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0894878a5fa3edfd6da3f88c4805f4c8558e2b996227a3d864f47fe11e38282c" +checksum = "f2578b716f8a7a858b7f02d5bd870c14bf4ddbbcf3a4c05414ba6503640505e3" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.27" +version = "0.8.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831" +checksum = "7e6cc098ea4d3bd6246687de65af3f920c430e236bee1e3bf2e441463f08a02f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -7412,21 +7426,21 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", "synstructure", ] [[package]] name = "zeroize" -version = "1.8.1" +version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" [[package]] name = "zerotrie" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36f0bbd478583f79edad978b407914f61b2972f5af6fa089686016be8f9af595" +checksum = "2a59c17a5562d507e4b54960e8569ebee33bee890c70aa3fe7b97e85a9fd7851" dependencies = [ "displaydoc", "yoke", @@ -7435,9 +7449,9 @@ dependencies = [ [[package]] name = "zerovec" -version = "0.11.4" +version = "0.11.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7aa2bd55086f1ab526693ecbe444205da57e25f4489879da80635a46d90e73b" +checksum = "6c28719294829477f525be0186d13efa9a3c602f7ec202ca9e353d310fb9a002" dependencies = [ "yoke", "zerofrom", @@ -7446,20 +7460,26 @@ dependencies = [ [[package]] name = "zerovec-derive" -version = "0.11.1" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f" +checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] name = "zlib-rs" -version = "0.5.2" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3be3d40e40a133f9c916ee3f9f4fa2d9d63435b5fbe1bfc6d9dae0aa0ada1513" + +[[package]] +name = "zmij" +version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f06ae92f42f5e5c42443fd094f245eb656abf56dd7cce9b8b263236565e00f2" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" [[package]] name = "zstd" diff --git a/Cargo.toml b/Cargo.toml index b9d8b1a69ef6..d8a7e424873c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,7 +79,7 @@ repository = "https://github.com/apache/datafusion" # Define Minimum Supported Rust Version (MSRV) rust-version = "1.88.0" # Define DataFusion version -version = "51.0.0" +version = "53.1.0" [workspace.dependencies] # We turn off default-features for some dependencies here so the workspaces which inherit them can @@ -91,89 +91,91 @@ ahash = { version = "0.8", default-features = false, features = [ "runtime-rng", ] } apache-avro = { version = "0.21", default-features = false } -arrow = { version = "57.1.0", features = [ +arrow = { version = "58.0.0", features = [ "prettyprint", "chrono-tz", ] } -arrow-buffer = { version = "57.1.0", default-features = false } -arrow-flight = { version = "57.1.0", features = [ +arrow-buffer = { version = "58.0.0", default-features = false } +arrow-flight = { version = "58.0.0", features = [ "flight-sql-experimental", ] } -arrow-ipc = { version = "57.1.0", default-features = false, features = [ +arrow-ipc = { version = "58.0.0", default-features = false, features = [ "lz4", ] } -arrow-ord = { version = "57.1.0", default-features = false } -arrow-schema = { version = "57.1.0", default-features = false } +arrow-ord = { version = "58.0.0", default-features = false } +arrow-schema = { version = "58.0.0", default-features = false } async-trait = "0.1.89" bigdecimal = "0.4.8" bytes = "1.11" bzip2 = "0.6.1" -chrono = { version = "0.4.42", default-features = false } +chrono = { version = "0.4.44", default-features = false } criterion = "0.8" ctor = "0.6.3" dashmap = "6.0.1" -datafusion = { path = "datafusion/core", version = "51.0.0", default-features = false } -datafusion-catalog = { path = "datafusion/catalog", version = "51.0.0" } -datafusion-catalog-listing = { path = "datafusion/catalog-listing", version = "51.0.0" } -datafusion-common = { path = "datafusion/common", version = "51.0.0", default-features = false } -datafusion-common-runtime = { path = "datafusion/common-runtime", version = "51.0.0" } -datafusion-datasource = { path = "datafusion/datasource", version = "51.0.0", default-features = false } -datafusion-datasource-arrow = { path = "datafusion/datasource-arrow", version = "51.0.0", default-features = false } -datafusion-datasource-avro = { path = "datafusion/datasource-avro", version = "51.0.0", default-features = false } -datafusion-datasource-csv = { path = "datafusion/datasource-csv", version = "51.0.0", default-features = false } -datafusion-datasource-json = { path = "datafusion/datasource-json", version = "51.0.0", default-features = false } -datafusion-datasource-parquet = { path = "datafusion/datasource-parquet", version = "51.0.0", default-features = false } -datafusion-doc = { path = "datafusion/doc", version = "51.0.0" } -datafusion-execution = { path = "datafusion/execution", version = "51.0.0", default-features = false } -datafusion-expr = { path = "datafusion/expr", version = "51.0.0", default-features = false } -datafusion-expr-common = { path = "datafusion/expr-common", version = "51.0.0" } -datafusion-ffi = { path = "datafusion/ffi", version = "51.0.0" } -datafusion-functions = { path = "datafusion/functions", version = "51.0.0" } -datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "51.0.0" } -datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "51.0.0" } -datafusion-functions-nested = { path = "datafusion/functions-nested", version = "51.0.0", default-features = false } -datafusion-functions-table = { path = "datafusion/functions-table", version = "51.0.0" } -datafusion-functions-window = { path = "datafusion/functions-window", version = "51.0.0" } -datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "51.0.0" } -datafusion-macros = { path = "datafusion/macros", version = "51.0.0" } -datafusion-optimizer = { path = "datafusion/optimizer", version = "51.0.0", default-features = false } -datafusion-physical-expr = { path = "datafusion/physical-expr", version = "51.0.0", default-features = false } -datafusion-physical-expr-adapter = { path = "datafusion/physical-expr-adapter", version = "51.0.0", default-features = false } -datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "51.0.0", default-features = false } -datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "51.0.0" } -datafusion-physical-plan = { path = "datafusion/physical-plan", version = "51.0.0" } -datafusion-proto = { path = "datafusion/proto", version = "51.0.0" } -datafusion-proto-common = { path = "datafusion/proto-common", version = "51.0.0" } -datafusion-pruning = { path = "datafusion/pruning", version = "51.0.0" } -datafusion-session = { path = "datafusion/session", version = "51.0.0" } -datafusion-spark = { path = "datafusion/spark", version = "51.0.0" } -datafusion-sql = { path = "datafusion/sql", version = "51.0.0" } -datafusion-substrait = { path = "datafusion/substrait", version = "51.0.0" } +datafusion = { path = "datafusion/core", version = "53.1.0", default-features = false } +datafusion-catalog = { path = "datafusion/catalog", version = "53.1.0" } +datafusion-catalog-listing = { path = "datafusion/catalog-listing", version = "53.1.0" } +datafusion-common = { path = "datafusion/common", version = "53.1.0", default-features = false } +datafusion-common-runtime = { path = "datafusion/common-runtime", version = "53.1.0" } +datafusion-datasource = { path = "datafusion/datasource", version = "53.1.0", default-features = false } +datafusion-datasource-arrow = { path = "datafusion/datasource-arrow", version = "53.1.0", default-features = false } +datafusion-datasource-avro = { path = "datafusion/datasource-avro", version = "53.1.0", default-features = false } +datafusion-datasource-csv = { path = "datafusion/datasource-csv", version = "53.1.0", default-features = false } +datafusion-datasource-json = { path = "datafusion/datasource-json", version = "53.1.0", default-features = false } +datafusion-datasource-parquet = { path = "datafusion/datasource-parquet", version = "53.1.0", default-features = false } +datafusion-doc = { path = "datafusion/doc", version = "53.1.0" } +datafusion-execution = { path = "datafusion/execution", version = "53.1.0", default-features = false } +datafusion-expr = { path = "datafusion/expr", version = "53.1.0", default-features = false } +datafusion-expr-common = { path = "datafusion/expr-common", version = "53.1.0" } +datafusion-ffi = { path = "datafusion/ffi", version = "53.1.0" } +datafusion-functions = { path = "datafusion/functions", version = "53.1.0" } +datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "53.1.0" } +datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "53.1.0" } +datafusion-functions-nested = { path = "datafusion/functions-nested", version = "53.1.0", default-features = false } +datafusion-functions-table = { path = "datafusion/functions-table", version = "53.1.0" } +datafusion-functions-window = { path = "datafusion/functions-window", version = "53.1.0" } +datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "53.1.0" } +datafusion-macros = { path = "datafusion/macros", version = "53.1.0" } +datafusion-optimizer = { path = "datafusion/optimizer", version = "53.1.0", default-features = false } +datafusion-physical-expr = { path = "datafusion/physical-expr", version = "53.1.0", default-features = false } +datafusion-physical-expr-adapter = { path = "datafusion/physical-expr-adapter", version = "53.1.0", default-features = false } +datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "53.1.0", default-features = false } +datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "53.1.0" } +datafusion-physical-plan = { path = "datafusion/physical-plan", version = "53.1.0" } +datafusion-proto = { path = "datafusion/proto", version = "53.1.0" } +datafusion-proto-common = { path = "datafusion/proto-common", version = "53.1.0" } +datafusion-pruning = { path = "datafusion/pruning", version = "53.1.0" } +datafusion-session = { path = "datafusion/session", version = "53.1.0" } +datafusion-spark = { path = "datafusion/spark", version = "53.1.0" } +datafusion-sql = { path = "datafusion/sql", version = "53.1.0" } +datafusion-substrait = { path = "datafusion/substrait", version = "53.1.0" } doc-comment = "0.3" env_logger = "0.11" -flate2 = "1.1.5" +flate2 = "1.1.9" futures = "0.3" glob = "0.3.0" half = { version = "2.7.0", default-features = false } hashbrown = { version = "0.16.1" } hex = { version = "0.4.3" } -indexmap = "2.12.1" -insta = { version = "1.46.0", features = ["glob", "filters"] } +indexmap = "2.13.0" +insta = { version = "1.46.3", features = ["glob", "filters"] } itertools = "0.14" -liblzma = { version = "0.4.4", features = ["static"] } +itoa = "1.0" +liblzma = { version = "0.4.6", features = ["static"] } log = "^0.4" +memchr = "2.8.0" num-traits = { version = "0.2" } -object_store = { version = "0.12.4", default-features = false } +object_store = { version = "0.13.1", default-features = false } parking_lot = "0.12" -parquet = { version = "57.1.0", default-features = false, features = [ +parquet = { version = "58.0.0", default-features = false, features = [ "arrow", "async", "object_store", ] } paste = "1.0.15" -pbjson = { version = "0.8.0" } -pbjson-types = "0.8" +pbjson = { version = "0.9.0" } +pbjson-types = "0.9" # Should match arrow-flight's version of prost. prost = "0.14.1" rand = "0.9" @@ -181,13 +183,17 @@ recursive = "0.1.1" regex = "1.12" rstest = "0.26.1" serde_json = "1" -sqlparser = { version = "0.59.0", default-features = false, features = ["std", "visitor"] } -strum = "0.27.2" -strum_macros = "0.27.2" +sha2 = "^0.10.9" +sqlparser = { version = "0.61.0", default-features = false, features = ["std", "visitor"] } +strum = "0.28.0" +strum_macros = "0.28.0" tempfile = "3" -testcontainers-modules = { version = "0.14" } +testcontainers-modules = { version = "0.15" } tokio = { version = "1.48", features = ["macros", "rt", "sync"] } +tokio-stream = "0.1" +tokio-util = "0.7" url = "2.5.7" +uuid = "1.21" zstd = { version = "0.13", default-features = false } [workspace.lints.clippy] @@ -200,6 +206,8 @@ uninlined_format_args = "warn" inefficient_to_string = "warn" # https://github.com/apache/datafusion/issues/18503 needless_pass_by_value = "warn" +# https://github.com/apache/datafusion/issues/18881 +allow_attributes = "warn" [workspace.lints.rust] unexpected_cfgs = { level = "warn", check-cfg = [ diff --git a/NOTICE.txt b/NOTICE.txt index 7f3c80d606c0..0bd2d52368fe 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -1,5 +1,5 @@ Apache DataFusion -Copyright 2019-2025 The Apache Software Foundation +Copyright 2019-2026 The Apache Software Foundation This product includes software developed at The Apache Software Foundation (http://www.apache.org/). diff --git a/README.md b/README.md index 880adfb3ac39..630d4295bd42 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ DataFusion is an extensible query engine written in [Rust] that uses [Apache Arrow] as its in-memory format. This crate provides libraries and binaries for developers building fast and -feature rich database and analytic systems, customized to particular workloads. +feature-rich database and analytic systems, customized for particular workloads. See [use cases] for examples. The following related subprojects target end users: - [DataFusion Python](https://github.com/apache/datafusion-python/) offers a Python interface for SQL and DataFrame @@ -67,7 +67,7 @@ See [use cases] for examples. The following related subprojects target end users DataFusion. "Out of the box," -DataFusion offers [SQL](https://datafusion.apache.org/user-guide/sql/index.html) and [Dataframe](https://datafusion.apache.org/user-guide/dataframe.html) APIs, excellent [performance], +DataFusion offers [SQL](https://datafusion.apache.org/user-guide/sql/index.html) and [DataFrame](https://datafusion.apache.org/user-guide/dataframe.html) APIs, excellent [performance], built-in support for CSV, Parquet, JSON, and Avro, extensive customization, and a great community. @@ -84,7 +84,7 @@ See the [Architecture] section for more details. [performance]: https://benchmark.clickhouse.com/ [architecture]: https://datafusion.apache.org/contributor-guide/architecture.html -Here are links to some important information +Here are links to important resources: - [Project Site](https://datafusion.apache.org/) - [Installation](https://datafusion.apache.org/user-guide/cli/installation.html) @@ -97,8 +97,8 @@ Here are links to some important information ## What can you do with this crate? -DataFusion is great for building projects such as domain specific query engines, new database platforms and data pipelines, query languages and more. -It lets you start quickly from a fully working engine, and then customize those features specific to your use. [Click Here](https://datafusion.apache.org/user-guide/introduction.html#known-users) to see a list known users. +DataFusion is great for building projects such as domain-specific query engines, new database platforms and data pipelines, query languages and more. +It lets you start quickly from a fully working engine, and then customize those features specific to your needs. See the [list of known users](https://datafusion.apache.org/user-guide/introduction.html#known-users). ## Contributing to DataFusion @@ -115,15 +115,15 @@ This crate has several [features] which can be specified in your `Cargo.toml`. Default features: -- `nested_expressions`: functions for working with nested type function such as `array_to_string` +- `nested_expressions`: functions for working with nested types such as `array_to_string` - `compression`: reading files compressed with `xz2`, `bzip2`, `flate2`, and `zstd` - `crypto_expressions`: cryptographic functions such as `md5` and `sha256` - `datetime_expressions`: date and time functions such as `to_timestamp` - `encoding_expressions`: `encode` and `decode` functions - `parquet`: support for reading the [Apache Parquet] format -- `sql`: Support for sql parsing / planning +- `sql`: support for SQL parsing and planning - `regex_expressions`: regular expression functions, such as `regexp_match` -- `unicode_expressions`: Include unicode aware functions such as `character_length` +- `unicode_expressions`: include Unicode-aware functions such as `character_length` - `unparser`: enables support to reverse LogicalPlans back into SQL - `recursive_protection`: uses [recursive](https://docs.rs/recursive/latest/recursive/) for stack overflow protection. diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index df04f56235ec..cb4a308ceb51 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -40,7 +40,7 @@ mimalloc_extended = ["libmimalloc-sys/extended"] [dependencies] arrow = { workspace = true } -clap = { version = "4.5.53", features = ["derive"] } +clap = { version = "4.5.60", features = ["derive"] } datafusion = { workspace = true, default-features = true } datafusion-common = { workspace = true, default-features = true } env_logger = { workspace = true } diff --git a/benchmarks/bench.sh b/benchmarks/bench.sh index 6679405623d0..761efa6d591a 100755 --- a/benchmarks/bench.sh +++ b/benchmarks/bench.sh @@ -42,7 +42,6 @@ DATAFUSION_DIR=${DATAFUSION_DIR:-$SCRIPT_DIR/..} DATA_DIR=${DATA_DIR:-$SCRIPT_DIR/data} CARGO_COMMAND=${CARGO_COMMAND:-"cargo run --release"} PREFER_HASH_JOIN=${PREFER_HASH_JOIN:-true} -VIRTUAL_ENV=${VIRTUAL_ENV:-$SCRIPT_DIR/venv} usage() { echo " @@ -53,7 +52,6 @@ $0 data [benchmark] $0 run [benchmark] [query] $0 compare $0 compare_detail -$0 venv ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ Examples: @@ -71,7 +69,6 @@ data: Generates or downloads data needed for benchmarking run: Runs the named benchmark compare: Compares fastest results from benchmark runs compare_detail: Compares minimum, average (±stddev), and maximum results from benchmark runs -venv: Creates new venv (unless already exists) and installs compare's requirements into it ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ Benchmarks @@ -144,7 +141,6 @@ CARGO_COMMAND command that runs the benchmark binary DATAFUSION_DIR directory to use (default $DATAFUSION_DIR) RESULTS_NAME folder where the benchmark files are stored PREFER_HASH_JOIN Prefer hash join algorithm (default true) -VENV_PATH Python venv to use for compare and venv commands (default ./venv, override by /bin/activate) DATAFUSION_* Set the given datafusion configuration " exit 1 @@ -322,8 +318,7 @@ main() { echo "NLJ benchmark does not require data generation" ;; hj) - # hj uses range() function, no data generation needed - echo "HJ benchmark does not require data generation" + data_tpch "10" "parquet" ;; smj) # smj uses range() function, no data generation needed @@ -543,9 +538,6 @@ main() { compare_detail) compare_benchmarks "$ARG2" "$ARG3" "--detailed" ;; - venv) - setup_venv - ;; "") usage ;; @@ -684,7 +676,7 @@ run_tpch_mem() { # Runs the tpcds benchmark run_tpcds() { - TPCDS_DIR="${DATA_DIR}" + TPCDS_DIR="${DATA_DIR}/tpcds_sf1" # Check if TPCDS data directory and representative file exists if [ ! -f "${TPCDS_DIR}/web_site.parquet" ]; then @@ -709,7 +701,7 @@ run_compile_profile() { local data_path="${DATA_DIR}/tpch_sf1" echo "Running compile profile benchmark..." - local cmd=(python3 "${runner}" --data "${data_path}") + local cmd=(uv run python3 "${runner}" --data "${data_path}") if [ ${#profiles[@]} -gt 0 ]; then cmd+=(--profiles "${profiles[@]}") fi @@ -924,75 +916,13 @@ data_h2o() { SIZE=${1:-"SMALL"} DATA_FORMAT=${2:-"CSV"} - # Function to compare Python versions - version_ge() { - [ "$(printf '%s\n' "$1" "$2" | sort -V | head -n1)" = "$2" ] - } - - export PYO3_USE_ABI3_FORWARD_COMPATIBILITY=1 - - # Find the highest available Python version (3.10 or higher) - REQUIRED_VERSION="3.10" - PYTHON_CMD=$(command -v python3 || true) - - if [ -n "$PYTHON_CMD" ]; then - PYTHON_VERSION=$($PYTHON_CMD -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')") - if version_ge "$PYTHON_VERSION" "$REQUIRED_VERSION"; then - echo "Found Python version $PYTHON_VERSION, which is suitable." - else - echo "Python version $PYTHON_VERSION found, but version $REQUIRED_VERSION or higher is required." - PYTHON_CMD="" - fi - fi - - # Search for suitable Python versions if the default is unsuitable - if [ -z "$PYTHON_CMD" ]; then - # Loop through all available Python3 commands on the system - for CMD in $(compgen -c | grep -E '^python3(\.[0-9]+)?$'); do - if command -v "$CMD" &> /dev/null; then - PYTHON_VERSION=$($CMD -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')") - if version_ge "$PYTHON_VERSION" "$REQUIRED_VERSION"; then - PYTHON_CMD="$CMD" - echo "Found suitable Python version: $PYTHON_VERSION ($CMD)" - break - fi - fi - done - fi - - # If no suitable Python version found, exit with an error - if [ -z "$PYTHON_CMD" ]; then - echo "Python 3.10 or higher is required. Please install it." - return 1 - fi - - echo "Using Python command: $PYTHON_CMD" - - # Install falsa and other dependencies - echo "Installing falsa..." - - # Set virtual environment directory - VIRTUAL_ENV="${PWD}/venv" - - # Create a virtual environment using the detected Python command - $PYTHON_CMD -m venv "$VIRTUAL_ENV" - - # Activate the virtual environment and install dependencies - source "$VIRTUAL_ENV/bin/activate" - - # Ensure 'falsa' is installed (avoid unnecessary reinstall) - pip install --quiet --upgrade falsa - # Create directory if it doesn't exist H2O_DIR="${DATA_DIR}/h2o" mkdir -p "${H2O_DIR}" # Generate h2o test data echo "Generating h2o test data in ${H2O_DIR} with size=${SIZE} and format=${DATA_FORMAT}" - falsa groupby --path-prefix="${H2O_DIR}" --size "${SIZE}" --data-format "${DATA_FORMAT}" - - # Deactivate virtual environment after completion - deactivate + uv run falsa groupby --path-prefix="${H2O_DIR}" --size "${SIZE}" --data-format "${DATA_FORMAT}" } data_h2o_join() { @@ -1000,75 +930,13 @@ data_h2o_join() { SIZE=${1:-"SMALL"} DATA_FORMAT=${2:-"CSV"} - # Function to compare Python versions - version_ge() { - [ "$(printf '%s\n' "$1" "$2" | sort -V | head -n1)" = "$2" ] - } - - export PYO3_USE_ABI3_FORWARD_COMPATIBILITY=1 - - # Find the highest available Python version (3.10 or higher) - REQUIRED_VERSION="3.10" - PYTHON_CMD=$(command -v python3 || true) - - if [ -n "$PYTHON_CMD" ]; then - PYTHON_VERSION=$($PYTHON_CMD -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')") - if version_ge "$PYTHON_VERSION" "$REQUIRED_VERSION"; then - echo "Found Python version $PYTHON_VERSION, which is suitable." - else - echo "Python version $PYTHON_VERSION found, but version $REQUIRED_VERSION or higher is required." - PYTHON_CMD="" - fi - fi - - # Search for suitable Python versions if the default is unsuitable - if [ -z "$PYTHON_CMD" ]; then - # Loop through all available Python3 commands on the system - for CMD in $(compgen -c | grep -E '^python3(\.[0-9]+)?$'); do - if command -v "$CMD" &> /dev/null; then - PYTHON_VERSION=$($CMD -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')") - if version_ge "$PYTHON_VERSION" "$REQUIRED_VERSION"; then - PYTHON_CMD="$CMD" - echo "Found suitable Python version: $PYTHON_VERSION ($CMD)" - break - fi - fi - done - fi - - # If no suitable Python version found, exit with an error - if [ -z "$PYTHON_CMD" ]; then - echo "Python 3.10 or higher is required. Please install it." - return 1 - fi - - echo "Using Python command: $PYTHON_CMD" - - # Install falsa and other dependencies - echo "Installing falsa..." - - # Set virtual environment directory - VIRTUAL_ENV="${PWD}/venv" - - # Create a virtual environment using the detected Python command - $PYTHON_CMD -m venv "$VIRTUAL_ENV" - - # Activate the virtual environment and install dependencies - source "$VIRTUAL_ENV/bin/activate" - - # Ensure 'falsa' is installed (avoid unnecessary reinstall) - pip install --quiet --upgrade falsa - # Create directory if it doesn't exist H2O_DIR="${DATA_DIR}/h2o" mkdir -p "${H2O_DIR}" # Generate h2o test data echo "Generating h2o test data in ${H2O_DIR} with size=${SIZE} and format=${DATA_FORMAT}" - falsa join --path-prefix="${H2O_DIR}" --size "${SIZE}" --data-format "${DATA_FORMAT}" - - # Deactivate virtual environment after completion - deactivate + uv run falsa join --path-prefix="${H2O_DIR}" --size "${SIZE}" --data-format "${DATA_FORMAT}" } # Runner for h2o groupby benchmark @@ -1228,10 +1096,11 @@ run_nlj() { # Runs the hj benchmark run_hj() { + TPCH_DIR="${DATA_DIR}/tpch_sf10" RESULTS_FILE="${RESULTS_DIR}/hj.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running hj benchmark..." - debug_run $CARGO_COMMAND --bin dfbench -- hj --iterations 5 -o "${RESULTS_FILE}" ${QUERY_ARG} + debug_run $CARGO_COMMAND --bin dfbench -- hj --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" ${QUERY_ARG} } # Runs the smj benchmark @@ -1269,7 +1138,7 @@ compare_benchmarks() { echo "--------------------" echo "Benchmark ${BENCH}" echo "--------------------" - PATH=$VIRTUAL_ENV/bin:$PATH python3 "${SCRIPT_DIR}"/compare.py $OPTS "${RESULTS_FILE1}" "${RESULTS_FILE2}" + uv run python3 "${SCRIPT_DIR}"/compare.py $OPTS "${RESULTS_FILE1}" "${RESULTS_FILE2}" else echo "Note: Skipping ${RESULTS_FILE1} as ${RESULTS_FILE2} does not exist" fi @@ -1384,10 +1253,6 @@ run_clickbench_sorted() { ${QUERY_ARG} } -setup_venv() { - python3 -m venv "$VIRTUAL_ENV" - PATH=$VIRTUAL_ENV/bin:$PATH python3 -m pip install -r requirements.txt -} # And start the process up main diff --git a/benchmarks/compare.py b/benchmarks/compare.py index 7e51a38a92c2..9ad1de980abe 100755 --- a/benchmarks/compare.py +++ b/benchmarks/compare.py @@ -154,17 +154,17 @@ def compare( baseline = BenchmarkRun.load_from_file(baseline_path) comparison = BenchmarkRun.load_from_file(comparison_path) - console = Console() + console = Console(width=200) # use basename as the column names - baseline_header = baseline_path.parent.stem - comparison_header = comparison_path.parent.stem + baseline_header = baseline_path.parent.name + comparison_header = comparison_path.parent.name table = Table(show_header=True, header_style="bold magenta") - table.add_column("Query", style="dim", width=12) - table.add_column(baseline_header, justify="right", style="dim") - table.add_column(comparison_header, justify="right", style="dim") - table.add_column("Change", justify="right", style="dim") + table.add_column("Query", style="dim", no_wrap=True) + table.add_column(baseline_header, justify="right", style="dim", no_wrap=True) + table.add_column(comparison_header, justify="right", style="dim", no_wrap=True) + table.add_column("Change", justify="right", style="dim", no_wrap=True) faster_count = 0 slower_count = 0 @@ -175,12 +175,12 @@ def compare( for baseline_result, comparison_result in zip(baseline.queries, comparison.queries): assert baseline_result.query == comparison_result.query - + base_failed = not baseline_result.success - comp_failed = not comparison_result.success + comp_failed = not comparison_result.success # If a query fails, its execution time is excluded from the performance comparison if base_failed or comp_failed: - change_text = "incomparable" + change_text = "incomparable" failure_count += 1 table.add_row( f"Q{baseline_result.query}", diff --git a/benchmarks/pyproject.toml b/benchmarks/pyproject.toml new file mode 100644 index 000000000000..e6a60582148c --- /dev/null +++ b/benchmarks/pyproject.toml @@ -0,0 +1,6 @@ +[project] +name = "datafusion-benchmarks" +version = "0.1.0" +requires-python = ">=3.11" +# typing_extensions is an undeclared dependency of falsa +dependencies = ["rich", "falsa", "typing_extensions"] diff --git a/benchmarks/src/clickbench.rs b/benchmarks/src/clickbench.rs index a9da57b02ae3..c0f911c566f4 100644 --- a/benchmarks/src/clickbench.rs +++ b/benchmarks/src/clickbench.rs @@ -29,6 +29,16 @@ use datafusion::{ use datafusion_common::exec_datafusion_err; use datafusion_common::instant::Instant; +/// SQL to create the hits view with proper EventDate casting. +/// +/// ClickBench stores EventDate as UInt16 (days since 1970-01-01) for +/// storage efficiency (2 bytes vs 4-8 bytes for date types). +/// This view transforms it to SQL DATE type for query compatibility. +const HITS_VIEW_DDL: &str = r#"CREATE VIEW hits AS +SELECT * EXCEPT ("EventDate"), + CAST(CAST("EventDate" AS INTEGER) AS DATE) AS "EventDate" +FROM hits_raw"#; + /// Driver program to run the ClickBench benchmark /// /// The ClickBench[1] benchmarks are widely cited in the industry and @@ -295,7 +305,7 @@ impl RunOpt { // Build CREATE EXTERNAL TABLE DDL with WITH ORDER clause // Schema will be automatically inferred from the Parquet file let create_table_sql = format!( - "CREATE EXTERNAL TABLE hits \ + "CREATE EXTERNAL TABLE hits_raw \ STORED AS PARQUET \ LOCATION '{}' \ WITH ORDER ({} {})", @@ -308,20 +318,34 @@ impl RunOpt { // Execute the CREATE EXTERNAL TABLE statement ctx.sql(&create_table_sql).await?.collect().await?; - - Ok(()) } else { // Original registration without sort order let options = Default::default(); - ctx.register_parquet("hits", path, options) + ctx.register_parquet("hits_raw", path, options) .await .map_err(|e| { DataFusionError::Context( - format!("Registering 'hits' as {path}"), + format!("Registering 'hits_raw' as {path}"), Box::new(e), ) - }) + })?; } + + // Create the hits view with EventDate transformation + Self::create_hits_view(ctx).await + } + + /// Creates the hits view with EventDate transformation from UInt16 to DATE. + /// + /// ClickBench encodes EventDate as UInt16 days since epoch (1970-01-01). + async fn create_hits_view(ctx: &SessionContext) -> Result<()> { + ctx.sql(HITS_VIEW_DDL).await?.collect().await.map_err(|e| { + DataFusionError::Context( + "Creating 'hits' view with EventDate transformation".to_string(), + Box::new(e), + ) + })?; + Ok(()) } fn iterations(&self) -> usize { diff --git a/benchmarks/src/hj.rs b/benchmarks/src/hj.rs index ddb2d268e601..6eb828a3aedf 100644 --- a/benchmarks/src/hj.rs +++ b/benchmarks/src/hj.rs @@ -21,6 +21,7 @@ use datafusion::physical_plan::execute_stream; use datafusion::{error::Result, prelude::SessionContext}; use datafusion_common::instant::Instant; use datafusion_common::{DataFusionError, exec_datafusion_err, exec_err}; +use std::path::PathBuf; use futures::StreamExt; @@ -35,7 +36,7 @@ use futures::StreamExt; #[derive(Debug, Args, Clone)] #[command(verbatim_doc_comment)] pub struct RunOpt { - /// Query number (between 1 and 12). If not specified, runs all queries + /// Query number. If not specified, runs all queries #[arg(short, long)] query: Option, @@ -43,128 +44,265 @@ pub struct RunOpt { #[command(flatten)] common: CommonOpt, + /// Path to TPC-H SF10 data + #[arg(short = 'p', long = "path")] + path: Option, + /// If present, write results json here #[arg(short = 'o', long = "output")] - output_path: Option, + output_path: Option, +} + +struct HashJoinQuery { + sql: &'static str, + density: f64, + prob_hit: f64, + build_size: &'static str, + probe_size: &'static str, } /// Inline SQL queries for Hash Join benchmarks -/// -/// Each query's comment includes: -/// - Left row count × Right row count -/// - Join predicate selectivity (approximate output fraction). -/// - Q11 and Q12 selectivity is relative to cartesian product while the others are -/// relative to probe side. -const HASH_QUERIES: &[&str] = &[ - // Q1: INNER 10 x 10K | LOW ~0.1% - // equality on key + cheap filter to downselect - r#" - SELECT t1.value, t2.value - FROM generate_series(0, 9000, 1000) AS t1(value) - JOIN range(10000) AS t2 - ON t1.value = t2.value; - "#, - // Q2: INNER 10 x 10K | LOW ~0.1% - r#" - SELECT t1.value, t2.value - FROM generate_series(0, 9000, 1000) AS t1 - JOIN range(10000) AS t2 - ON t1.value = t2.value - WHERE t1.value % 5 = 0 - "#, - // Q3: INNER 10K x 10K | HIGH ~90% - r#" - SELECT t1.value, t2.value - FROM range(10000) AS t1 - JOIN range(10000) AS t2 - ON t1.value = t2.value - WHERE t1.value % 10 <> 0 - "#, - // Q4: INNER 30 x 30K | LOW ~0.1% - r#" - SELECT t1.value, t2.value - FROM generate_series(0, 29000, 1000) AS t1 - JOIN range(30000) AS t2 - ON t1.value = t2.value - WHERE t1.value % 5 = 0 - "#, - // Q5: INNER 10 x 200K | VERY LOW ~0.005% (small to large) - r#" - SELECT t1.value, t2.value - FROM generate_series(0, 9000, 1000) AS t1 - JOIN range(200000) AS t2 - ON t1.value = t2.value - WHERE t1.value % 1000 = 0 - "#, - // Q6: INNER 200K x 10 | VERY LOW ~0.005% (large to small) - r#" - SELECT t1.value, t2.value - FROM range(200000) AS t1 - JOIN generate_series(0, 9000, 1000) AS t2 - ON t1.value = t2.value - WHERE t1.value % 1000 = 0 - "#, - // Q7: RIGHT OUTER 10 x 200K | LOW ~0.1% - // Outer join still uses HashJoin for equi-keys; the extra filter reduces matches - r#" - SELECT t1.value AS l, t2.value AS r - FROM generate_series(0, 9000, 1000) AS t1 - RIGHT JOIN range(200000) AS t2 - ON t1.value = t2.value - WHERE t2.value % 1000 = 0 - "#, - // Q8: LEFT OUTER 200K x 10 | LOW ~0.1% - r#" - SELECT t1.value AS l, t2.value AS r - FROM range(200000) AS t1 - LEFT JOIN generate_series(0, 9000, 1000) AS t2 - ON t1.value = t2.value - WHERE t1.value % 1000 = 0 - "#, - // Q9: FULL OUTER 30 x 30K | LOW ~0.1% - r#" - SELECT t1.value AS l, t2.value AS r - FROM generate_series(0, 29000, 1000) AS t1 - FULL JOIN range(30000) AS t2 - ON t1.value = t2.value - WHERE COALESCE(t1.value, t2.value) % 1000 = 0 - "#, - // Q10: FULL OUTER 30 x 30K | HIGH ~90% - r#" - SELECT t1.value AS l, t2.value AS r - FROM generate_series(0, 29000, 1000) AS t1 - FULL JOIN range(30000) AS t2 - ON t1.value = t2.value - WHERE COALESCE(t1.value, t2.value) % 10 <> 0 - "#, - // Q11: INNER 30 x 30K | MEDIUM ~50% | cheap predicate on parity - r#" - SELECT t1.value, t2.value - FROM generate_series(0, 29000, 1000) AS t1 - INNER JOIN range(30000) AS t2 - ON (t1.value % 2) = (t2.value % 2) - "#, - // Q12: FULL OUTER 30 x 30K | MEDIUM ~50% | expression key - r#" - SELECT t1.value AS l, t2.value AS r - FROM generate_series(0, 29000, 1000) AS t1 - FULL JOIN range(30000) AS t2 - ON (t1.value % 2) = (t2.value % 2) - "#, - // Q13: INNER 30 x 30K | LOW 0.1% | modulo with adding values - r#" - SELECT t1.value, t2.value - FROM generate_series(0, 29000, 1000) AS t1 - INNER JOIN range(30000) AS t2 - ON (t1.value = t2.value) AND ((t1.value + t2.value) % 10 < 1) - "#, - // Q14: FULL OUTER 30 x 30K | ALL ~100% | modulo - r#" - SELECT t1.value AS l, t2.value AS r - FROM generate_series(0, 29000, 1000) AS t1 - FULL JOIN range(30000) AS t2 - ON (t1.value = t2.value) AND ((t1.value + t2.value) % 10 = 0) - "#, +const HASH_QUERIES: &[HashJoinQuery] = &[ + // Q1: Very Small Build Side (Dense) + // Build Side: nation (25 rows) | Probe Side: customer (1.5M rows) + HashJoinQuery { + sql: r###"SELECT n_nationkey FROM nation JOIN customer ON c_nationkey = n_nationkey"###, + density: 1.0, + prob_hit: 1.0, + build_size: "25", + probe_size: "1.5M", + }, + // Q2: Very Small Build Side (Sparse, range < 1024) + // Build Side: nation (25 rows, range 961) | Probe Side: customer (1.5M rows) + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT c_nationkey * 40 as k + FROM customer + ) l + JOIN ( + SELECT n_nationkey * 40 as k FROM nation + ) s ON l.k = s.k"###, + density: 0.026, + prob_hit: 1.0, + build_size: "25", + probe_size: "1.5M", + }, + // Q3: 100% Density, 100% Hit rate + HashJoinQuery { + sql: r###"SELECT s_suppkey FROM supplier JOIN lineitem ON s_suppkey = l_suppkey"###, + density: 1.0, + prob_hit: 1.0, + build_size: "100K", + probe_size: "60M", + }, + // Q4: 100% Density, 10% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT CASE WHEN l_suppkey % 10 = 0 THEN l_suppkey ELSE l_suppkey + 1000000 END as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey as k FROM supplier + ) s ON l.k = s.k"###, + density: 1.0, + prob_hit: 0.1, + build_size: "100K", + probe_size: "60M", + }, + // Q5: 75% Density, 100% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT l_suppkey * 4 / 3 as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 4 / 3 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.75, + prob_hit: 1.0, + build_size: "100K", + probe_size: "60M", + }, + // Q6: 75% Density, 10% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT CASE + WHEN l_suppkey % 10 = 0 THEN l_suppkey * 4 / 3 + WHEN l_suppkey % 10 < 9 THEN (l_suppkey * 4 / 3 / 4) * 4 + 3 + ELSE l_suppkey * 4 / 3 + 1000000 + END as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 4 / 3 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.75, + prob_hit: 0.1, + build_size: "100K", + probe_size: "60M", + }, + // Q7: 50% Density, 100% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT l_suppkey * 2 as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 2 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.5, + prob_hit: 1.0, + build_size: "100K", + probe_size: "60M", + }, + // Q8: 50% Density, 10% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT CASE + WHEN l_suppkey % 10 = 0 THEN l_suppkey * 2 + WHEN l_suppkey % 10 < 9 THEN l_suppkey * 2 + 1 + ELSE l_suppkey * 2 + 1000000 + END as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 2 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.5, + prob_hit: 0.1, + build_size: "100K", + probe_size: "60M", + }, + // Q9: 20% Density, 100% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT l_suppkey * 5 as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 5 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.2, + prob_hit: 1.0, + build_size: "100K", + probe_size: "60M", + }, + // Q10: 20% Density, 10% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT CASE + WHEN l_suppkey % 10 = 0 THEN l_suppkey * 5 + WHEN l_suppkey % 10 < 9 THEN l_suppkey * 5 + 1 + ELSE l_suppkey * 5 + 1000000 + END as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 5 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.2, + prob_hit: 0.1, + build_size: "100K", + probe_size: "60M", + }, + // Q11: 10% Density, 100% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT l_suppkey * 10 as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 10 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.1, + prob_hit: 1.0, + build_size: "100K", + probe_size: "60M", + }, + // Q12: 10% Density, 10% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT CASE + WHEN l_suppkey % 10 = 0 THEN l_suppkey * 10 + WHEN l_suppkey % 10 < 9 THEN l_suppkey * 10 + 1 + ELSE l_suppkey * 10 + 1000000 + END as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 10 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.1, + prob_hit: 0.1, + build_size: "100K", + probe_size: "60M", + }, + // Q13: 1% Density, 100% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT l_suppkey * 100 as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 100 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.01, + prob_hit: 1.0, + build_size: "100K", + probe_size: "60M", + }, + // Q14: 1% Density, 10% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT CASE + WHEN l_suppkey % 10 = 0 THEN l_suppkey * 100 + WHEN l_suppkey % 10 < 9 THEN l_suppkey * 100 + 1 + ELSE l_suppkey * 100 + 11000000 -- oob + END as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 100 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.01, + prob_hit: 0.1, + build_size: "100K", + probe_size: "60M", + }, + // Q15: 20% Density, 10% Hit rate, 20% Duplicates in Build Side + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT CASE + WHEN l_suppkey % 10 = 0 THEN ((l_suppkey % 80000) + 1) * 25 / 4 + ELSE ((l_suppkey % 80000) + 1) * 25 / 4 + 1 + END as k + FROM lineitem + ) l + JOIN ( + SELECT CASE + WHEN s_suppkey <= 80000 THEN (s_suppkey * 25) / 4 + ELSE ((s_suppkey - 80000) * 25) / 4 + END as k + FROM supplier + ) s ON l.k = s.k"###, + density: 0.2, + prob_hit: 0.1, + build_size: "100K_(20%_dups)", + probe_size: "60M", + }, ]; impl RunOpt { @@ -189,14 +327,44 @@ impl RunOpt { let rt_builder = self.common.runtime_env_builder()?; let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); + if let Some(path) = &self.path { + for table in &["lineitem", "supplier", "nation", "customer"] { + let table_path = path.join(table); + if !table_path.exists() { + return exec_err!( + "TPC-H table {} not found at {:?}", + table, + table_path + ); + } + ctx.register_parquet( + *table, + table_path.to_str().unwrap(), + Default::default(), + ) + .await?; + } + } + let mut benchmark_run = BenchmarkRun::new(); for query_id in query_range { let query_index = query_id - 1; - let sql = HASH_QUERIES[query_index]; + let query = &HASH_QUERIES[query_index]; + + let case_name = format!( + "Query {}_density={}_prob_hit={}_{}*{}", + query_id, + query.density, + query.prob_hit, + query.build_size, + query.probe_size + ); + benchmark_run.start_new_case(&case_name); - benchmark_run.start_new_case(&format!("Query {query_id}")); - let query_run = self.benchmark_query(sql, &query_id.to_string(), &ctx).await; + let query_run = self + .benchmark_query(query.sql, &query_id.to_string(), &ctx) + .await; match query_run { Ok(query_results) => { for iter in query_results { diff --git a/benchmarks/src/util/options.rs b/benchmarks/src/util/options.rs index 6f7267eabb83..add8ff17fbf8 100644 --- a/benchmarks/src/util/options.rs +++ b/benchmarks/src/util/options.rs @@ -50,12 +50,12 @@ pub struct CommonOpt { /// Memory limit (e.g. '100M', '1.5G'). If not specified, run all pre-defined memory limits for given query /// if there's any, otherwise run with no memory limit. - #[arg(long = "memory-limit", value_parser = parse_memory_limit)] + #[arg(long = "memory-limit", value_parser = parse_capacity_limit)] pub memory_limit: Option, /// The amount of memory to reserve for sort spill operations. DataFusion's default value will be used /// if not specified. - #[arg(long = "sort-spill-reservation-bytes", value_parser = parse_memory_limit)] + #[arg(long = "sort-spill-reservation-bytes", value_parser = parse_capacity_limit)] pub sort_spill_reservation_bytes: Option, /// Activate debug mode to see more details @@ -116,20 +116,26 @@ impl CommonOpt { } } -/// Parse memory limit from string to number of bytes -/// e.g. '1.5G', '100M' -> 1572864 -fn parse_memory_limit(limit: &str) -> Result { +/// Parse capacity limit from string to number of bytes by allowing units: K, M and G. +/// Supports formats like '1.5G' -> 1610612736, '100M' -> 104857600 +fn parse_capacity_limit(limit: &str) -> Result { + if limit.trim().is_empty() { + return Err("Capacity limit cannot be empty".to_string()); + } let (number, unit) = limit.split_at(limit.len() - 1); let number: f64 = number .parse() - .map_err(|_| format!("Failed to parse number from memory limit '{limit}'"))?; + .map_err(|_| format!("Failed to parse number from capacity limit '{limit}'"))?; + if number.is_sign_negative() || number.is_infinite() { + return Err("Limit value should be positive finite number".to_string()); + } match unit { "K" => Ok((number * 1024.0) as usize), "M" => Ok((number * 1024.0 * 1024.0) as usize), "G" => Ok((number * 1024.0 * 1024.0 * 1024.0) as usize), _ => Err(format!( - "Unsupported unit '{unit}' in memory limit '{limit}'" + "Unsupported unit '{unit}' in capacity limit '{limit}'. Unit must be one of: 'K', 'M', 'G'" )), } } @@ -139,16 +145,25 @@ mod tests { use super::*; #[test] - fn test_parse_memory_limit_all() { + fn test_parse_capacity_limit_all() { // Test valid inputs - assert_eq!(parse_memory_limit("100K").unwrap(), 102400); - assert_eq!(parse_memory_limit("1.5M").unwrap(), 1572864); - assert_eq!(parse_memory_limit("2G").unwrap(), 2147483648); + assert_eq!(parse_capacity_limit("100K").unwrap(), 102400); + assert_eq!(parse_capacity_limit("1.5M").unwrap(), 1572864); + assert_eq!(parse_capacity_limit("2G").unwrap(), 2147483648); // Test invalid unit - assert!(parse_memory_limit("500X").is_err()); + assert!(parse_capacity_limit("500X").is_err()); // Test invalid number - assert!(parse_memory_limit("abcM").is_err()); + assert!(parse_capacity_limit("abcM").is_err()); + + // Test negative number + assert!(parse_capacity_limit("-1M").is_err()); + + // Test infinite number + assert!(parse_capacity_limit("infM").is_err()); + + // Test negative infinite number + assert!(parse_capacity_limit("-infM").is_err()); } } diff --git a/ci/scripts/check_examples_docs.sh b/ci/scripts/check_examples_docs.sh index 37b0cc088df4..62308b323b53 100755 --- a/ci/scripts/check_examples_docs.sh +++ b/ci/scripts/check_examples_docs.sh @@ -17,48 +17,61 @@ # specific language governing permissions and limitations # under the License. -set -euo pipefail - -EXAMPLES_DIR="datafusion-examples/examples" -README="datafusion-examples/README.md" +# Generates documentation for DataFusion examples using the Rust-based +# documentation generator and verifies that the committed README.md +# is up to date. +# +# The README is generated from documentation comments in: +# datafusion-examples/examples//main.rs +# +# This script is intended to be run in CI to ensure that example +# documentation stays in sync with the code. +# +# To update the README locally, run this script and replace README.md +# with the generated output. -# ffi examples are skipped because they were not part of the recent example -# consolidation work and do not follow the new grouping and execution pattern. -# They are not documented in the README using the new structure, so including -# them here would cause false CI failures. -SKIP_LIST=("ffi") +set -euo pipefail -missing=0 +ROOT_DIR="$(git rev-parse --show-toplevel)" -skip() { - local value="$1" - for item in "${SKIP_LIST[@]}"; do - if [[ "$item" == "$value" ]]; then - return 0 - fi - done - return 1 -} +# Load centralized tool versions +source "${ROOT_DIR}/ci/scripts/utils/tool_versions.sh" -# collect folder names -folders=$(find "$EXAMPLES_DIR" -mindepth 1 -maxdepth 1 -type d -exec basename {} \;) +EXAMPLES_DIR="$ROOT_DIR/datafusion-examples" +README="$EXAMPLES_DIR/README.md" +README_NEW="$EXAMPLES_DIR/README-NEW.md" -# collect group names from README headers -groups=$(grep "^### Group:" "$README" | sed -E 's/^### Group: `([^`]+)`.*/\1/') +echo "▶ Generating examples README (Rust generator)…" +cargo run --quiet \ + --manifest-path "$EXAMPLES_DIR/Cargo.toml" \ + --bin examples-docs \ + > "$README_NEW" -for folder in $folders; do - if skip "$folder"; then - echo "Skipped group: $folder" - continue - fi +echo "▶ Formatting generated README with prettier ${PRETTIER_VERSION}…" +npx "prettier@${PRETTIER_VERSION}" \ + --parser markdown \ + --write "$README_NEW" - if ! echo "$groups" | grep -qx "$folder"; then - echo "Missing README entry for example group: $folder" - missing=1 - fi -done +echo "▶ Comparing generated README with committed version…" -if [[ $missing -eq 1 ]]; then - echo "README is out of sync with examples" - exit 1 +if ! diff -u "$README" "$README_NEW" > /tmp/examples-readme.diff; then + echo "" + echo "❌ Examples README is out of date." + echo "" + echo "The examples documentation is generated automatically from:" + echo " - datafusion-examples/examples//main.rs" + echo "" + echo "To update the README locally, run:" + echo "" + echo " cargo run --bin examples-docs \\" + echo " | npx prettier@${PRETTIER_VERSION} --parser markdown --write \\" + echo " > datafusion-examples/README.md" + echo "" + echo "Diff:" + echo "------------------------------------------------------------" + cat /tmp/examples-readme.diff + echo "------------------------------------------------------------" + exit 1 fi + +echo "✅ Examples README is up-to-date." diff --git a/ci/scripts/doc_prettier_check.sh b/ci/scripts/doc_prettier_check.sh index d94a0d1c9617..95332eb65aaf 100755 --- a/ci/scripts/doc_prettier_check.sh +++ b/ci/scripts/doc_prettier_check.sh @@ -17,41 +17,70 @@ # specific language governing permissions and limitations # under the License. -SCRIPT_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/$(basename "${BASH_SOURCE[0]}")" - -MODE="--check" -ACTION="Checking" -if [ $# -gt 0 ]; then - if [ "$1" = "--write" ]; then - MODE="--write" - ACTION="Formatting" - else - echo "Usage: $0 [--write]" >&2 - exit 1 - fi +set -euo pipefail + +ROOT_DIR="$(git rev-parse --show-toplevel)" +SCRIPT_NAME="$(basename "${BASH_SOURCE[0]}")" + +# Load shared utilities and tool versions +source "${ROOT_DIR}/ci/scripts/utils/tool_versions.sh" +source "${ROOT_DIR}/ci/scripts/utils/git.sh" + +PRETTIER_TARGETS=( + '{datafusion,datafusion-cli,datafusion-examples,dev,docs}/**/*.md' + '!datafusion/CHANGELOG.md' + README.md + CONTRIBUTING.md +) + +MODE="check" +ALLOW_DIRTY=0 + +usage() { + cat >&2 </dev/null 2>&1; then echo "npx is required to run the prettier check. Install Node.js (e.g., brew install node) and re-run." >&2 exit 1 fi - -# Ignore subproject CHANGELOG.md because it is machine generated -npx prettier@2.7.1 $MODE \ - '{datafusion,datafusion-cli,datafusion-examples,dev,docs}/**/*.md' \ - '!datafusion/CHANGELOG.md' \ - README.md \ - CONTRIBUTING.md -status=$? - -if [ $status -ne 0 ]; then - if [ "$MODE" = "--check" ]; then - echo "Prettier check failed. Re-run with --write (e.g., ./ci/scripts/doc_prettier_check.sh --write) to format files, commit the changes, and re-run the check." >&2 - else - echo "Prettier format failed. Files may have been modified; commit any changes and re-run." >&2 - fi - exit $status + +PRETTIER_MODE=(--check) +if [[ "$MODE" == "write" ]]; then + PRETTIER_MODE=(--write) fi + +# Ignore subproject CHANGELOG.md because it is machine generated +npx "prettier@${PRETTIER_VERSION}" "${PRETTIER_MODE[@]}" "${PRETTIER_TARGETS[@]}" diff --git a/ci/scripts/license_header.sh b/ci/scripts/license_header.sh index 5345728f9cdf..7ab8c9637598 100755 --- a/ci/scripts/license_header.sh +++ b/ci/scripts/license_header.sh @@ -17,6 +17,62 @@ # specific language governing permissions and limitations # under the License. -# Check Apache license header -set -ex -hawkeye check --config licenserc.toml +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +SCRIPT_NAME="$(basename "${BASH_SOURCE[0]}")" + +source "${SCRIPT_DIR}/utils/git.sh" + +MODE="check" +ALLOW_DIRTY=0 +HAWKEYE_CONFIG="licenserc.toml" + +usage() { + cat >&2 <&2 <&2 <&2 <&2 <&2 + return 1 + fi +} diff --git a/benchmarks/requirements.txt b/ci/scripts/utils/tool_versions.sh similarity index 78% rename from benchmarks/requirements.txt rename to ci/scripts/utils/tool_versions.sh index 20a5a2bddbf2..ac731ed0d534 100644 --- a/benchmarks/requirements.txt +++ b/ci/scripts/utils/tool_versions.sh @@ -1,3 +1,5 @@ +#!/usr/bin/env bash +# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,4 +17,7 @@ # specific language governing permissions and limitations # under the License. -rich +# This file defines centralized tool versions used by CI and development scripts. +# It is intended to be sourced by other scripts and should not be executed directly. + +PRETTIER_VERSION="2.7.1" diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 31941d87165a..3fe6be964c3f 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -37,10 +37,10 @@ backtrace = ["datafusion/backtrace"] [dependencies] arrow = { workspace = true } async-trait = { workspace = true } -aws-config = "1.8.12" -aws-credential-types = "1.2.7" +aws-config = "1.8.14" +aws-credential-types = "1.2.13" chrono = { workspace = true } -clap = { version = "4.5.53", features = ["cargo", "derive"] } +clap = { version = "4.5.60", features = ["cargo", "derive"] } datafusion = { workspace = true, features = [ "avro", "compression", @@ -69,6 +69,9 @@ rustyline = "17.0" tokio = { workspace = true, features = ["macros", "parking_lot", "rt", "rt-multi-thread", "signal", "sync"] } url = { workspace = true } +[lints] +workspace = true + [dev-dependencies] ctor = { workspace = true } insta = { workspace = true } diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index 2b8385ac2d89..09347d6d7dc2 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -196,6 +196,7 @@ pub async fn exec_from_repl( } Err(ReadlineError::Interrupted) => { println!("^C"); + rl.helper().unwrap().reset_hint(); continue; } Err(ReadlineError::Eof) => { @@ -269,7 +270,7 @@ impl StatementExecutor { let options = task_ctx.session_config().options(); // Track memory usage for the query result if it's bounded - let mut reservation = + let reservation = MemoryConsumer::new("DataFusion-Cli").register(task_ctx.memory_pool()); if physical_plan.boundedness().is_unbounded() { @@ -300,7 +301,7 @@ impl StatementExecutor { let curr_num_rows = batch.num_rows(); // Stop collecting results if the number of rows exceeds the limit // results batch should include the last batch that exceeds the limit - if row_count < max_rows + curr_num_rows { + if row_count < max_rows.saturating_add(curr_num_rows) { // Try to grow the reservation to accommodate the batch in memory reservation.try_grow(get_record_batch_memory_size(&batch))?; results.push(batch); @@ -521,6 +522,7 @@ mod tests { use datafusion::common::plan_err; use datafusion::prelude::SessionContext; + use datafusion_common::assert_contains; use url::Url; async fn create_external_table_test(location: &str, sql: &str) -> Result<()> { @@ -714,7 +716,7 @@ mod tests { let err = create_external_table_test(location, &sql) .await .unwrap_err(); - assert!(err.to_string().contains("os error 2")); + assert_contains!(err.to_string(), "os error 2"); // for service_account_key let sql = format!( @@ -722,9 +724,8 @@ mod tests { ); let err = create_external_table_test(location, &sql) .await - .unwrap_err() - .to_string(); - assert!(err.contains("No RSA key found in pem file"), "{err}"); + .unwrap_err(); + assert_contains!(err.to_string(), "Error reading pem file: no items found"); // for application_credentials_path let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET @@ -732,7 +733,7 @@ mod tests { let err = create_external_table_test(location, &sql) .await .unwrap_err(); - assert!(err.to_string().contains("os error 2")); + assert_contains!(err.to_string(), "os error 2"); Ok(()) } diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index a45d57e8e952..67f3dc28269e 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -17,13 +17,18 @@ //! Functions that are query-able and searchable via the `\h` command +use datafusion_common::instant::Instant; use std::fmt; use std::fs::File; use std::str::FromStr; use std::sync::Arc; -use arrow::array::{Int64Array, StringArray, TimestampMillisecondArray, UInt64Array}; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; +use arrow::array::{ + DurationMillisecondArray, GenericListArray, Int64Array, StringArray, StructArray, + TimestampMillisecondArray, UInt64Array, +}; +use arrow::buffer::{Buffer, OffsetBuffer, ScalarBuffer}; +use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef, TimeUnit}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; use datafusion::catalog::{Session, TableFunctionImpl}; @@ -228,7 +233,7 @@ impl TableProvider for ParquetMetadataTable { self } - fn schema(&self) -> arrow::datatypes::SchemaRef { + fn schema(&self) -> SchemaRef { self.schema.clone() } @@ -421,7 +426,7 @@ impl TableFunctionImpl for ParquetMetadataFunc { compression_arr.push(format!("{:?}", column.compression())); // need to collect into Vec to format let encodings: Vec<_> = column.encodings().collect(); - encodings_arr.push(format!("{:?}", encodings)); + encodings_arr.push(format!("{encodings:?}")); index_page_offset_arr.push(column.index_page_offset()); dictionary_page_offset_arr.push(column.dictionary_page_offset()); data_page_offset_arr.push(column.data_page_offset()); @@ -477,7 +482,7 @@ impl TableProvider for MetadataCacheTable { self } - fn schema(&self) -> arrow::datatypes::SchemaRef { + fn schema(&self) -> SchemaRef { self.schema.clone() } @@ -595,7 +600,7 @@ impl TableProvider for StatisticsCacheTable { self } - fn schema(&self) -> arrow::datatypes::SchemaRef { + fn schema(&self) -> SchemaRef { self.schema.clone() } @@ -697,3 +702,182 @@ impl TableFunctionImpl for StatisticsCacheFunc { Ok(Arc::new(statistics_cache)) } } + +/// Implementation of the `list_files_cache` table function in datafusion-cli. +/// +/// This function returns the cached results of running a LIST command on a +/// particular object store path for a table. The object metadata is returned as +/// a List of Structs, with one Struct for each object. DataFusion uses these +/// cached results to plan queries against external tables. +/// +/// # Schema +/// ```sql +/// > describe select * from list_files_cache(); +/// +---------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-------------+ +/// | column_name | data_type | is_nullable | +/// +---------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-------------+ +/// | table | Utf8 | NO | +/// | path | Utf8 | NO | +/// | metadata_size_bytes | UInt64 | NO | +/// | expires_in | Duration(ms) | YES | +/// | metadata_list | List(Struct("file_path": non-null Utf8, "file_modified": non-null Timestamp(ms), "file_size_bytes": non-null UInt64, "e_tag": Utf8, "version": Utf8), field: 'metadata') | YES | +/// +---------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-------------+ +/// ``` +#[derive(Debug)] +struct ListFilesCacheTable { + schema: SchemaRef, + batch: RecordBatch, +} + +#[async_trait] +impl TableProvider for ListFilesCacheTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> datafusion::logical_expr::TableType { + datafusion::logical_expr::TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + Ok(MemorySourceConfig::try_new_exec( + &[vec![self.batch.clone()]], + TableProvider::schema(self), + projection.cloned(), + )?) + } +} + +#[derive(Debug)] +pub struct ListFilesCacheFunc { + cache_manager: Arc, +} + +impl ListFilesCacheFunc { + pub fn new(cache_manager: Arc) -> Self { + Self { cache_manager } + } +} + +impl TableFunctionImpl for ListFilesCacheFunc { + fn call(&self, exprs: &[Expr]) -> Result> { + if !exprs.is_empty() { + return plan_err!("list_files_cache should have no arguments"); + } + + let nested_fields = Fields::from(vec![ + Field::new("file_path", DataType::Utf8, false), + Field::new( + "file_modified", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + Field::new("file_size_bytes", DataType::UInt64, false), + Field::new("e_tag", DataType::Utf8, true), + Field::new("version", DataType::Utf8, true), + ]); + + let metadata_field = + Field::new("metadata", DataType::Struct(nested_fields.clone()), true); + + let schema = Arc::new(Schema::new(vec![ + Field::new("table", DataType::Utf8, true), + Field::new("path", DataType::Utf8, false), + Field::new("metadata_size_bytes", DataType::UInt64, false), + // expires field in ListFilesEntry has type Instant when set, from which we cannot get "the number of seconds", hence using Duration instead of Timestamp as data type. + Field::new( + "expires_in", + DataType::Duration(TimeUnit::Millisecond), + true, + ), + Field::new( + "metadata_list", + DataType::List(Arc::new(metadata_field.clone())), + true, + ), + ])); + + let mut table_arr = vec![]; + let mut path_arr = vec![]; + let mut metadata_size_bytes_arr = vec![]; + let mut expires_arr = vec![]; + + let mut file_path_arr = vec![]; + let mut file_modified_arr = vec![]; + let mut file_size_bytes_arr = vec![]; + let mut etag_arr = vec![]; + let mut version_arr = vec![]; + let mut offsets: Vec = vec![0]; + + if let Some(list_files_cache) = self.cache_manager.get_list_files_cache() { + let now = Instant::now(); + let mut current_offset: i32 = 0; + + for (path, entry) in list_files_cache.list_entries() { + table_arr.push(path.table.map(|t| t.to_string())); + path_arr.push(path.path.to_string()); + metadata_size_bytes_arr.push(entry.size_bytes as u64); + // calculates time left before entry expires + expires_arr.push( + entry + .expires + .map(|t| t.duration_since(now).as_millis() as i64), + ); + + for meta in entry.metas.files.iter() { + file_path_arr.push(meta.location.to_string()); + file_modified_arr.push(meta.last_modified.timestamp_millis()); + file_size_bytes_arr.push(meta.size); + etag_arr.push(meta.e_tag.clone()); + version_arr.push(meta.version.clone()); + } + current_offset += entry.metas.files.len() as i32; + offsets.push(current_offset); + } + } + + let struct_arr = StructArray::new( + nested_fields, + vec![ + Arc::new(StringArray::from(file_path_arr)), + Arc::new(TimestampMillisecondArray::from(file_modified_arr)), + Arc::new(UInt64Array::from(file_size_bytes_arr)), + Arc::new(StringArray::from(etag_arr)), + Arc::new(StringArray::from(version_arr)), + ], + None, + ); + + let offsets_buffer: OffsetBuffer = + OffsetBuffer::new(ScalarBuffer::from(Buffer::from_vec(offsets))); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(table_arr)), + Arc::new(StringArray::from(path_arr)), + Arc::new(UInt64Array::from(metadata_size_bytes_arr)), + Arc::new(DurationMillisecondArray::from(expires_arr)), + Arc::new(GenericListArray::new( + Arc::new(metadata_field), + offsets_buffer, + Arc::new(struct_arr), + None, + )), + ], + )?; + + let list_files_cache = ListFilesCacheTable { schema, batch }; + Ok(Arc::new(list_files_cache)) + } +} diff --git a/datafusion-cli/src/helper.rs b/datafusion-cli/src/helper.rs index df7afc14048b..f01d0891b964 100644 --- a/datafusion-cli/src/helper.rs +++ b/datafusion-cli/src/helper.rs @@ -19,8 +19,9 @@ //! and auto-completion for file name during creating external table. use std::borrow::Cow; +use std::cell::Cell; -use crate::highlighter::{NoSyntaxHighlighter, SyntaxHighlighter}; +use crate::highlighter::{Color, NoSyntaxHighlighter, SyntaxHighlighter}; use datafusion::sql::parser::{DFParser, Statement}; use datafusion::sql::sqlparser::dialect::dialect_from_str; @@ -33,10 +34,17 @@ use rustyline::hint::Hinter; use rustyline::validate::{ValidationContext, ValidationResult, Validator}; use rustyline::{Context, Helper, Result}; +/// Default suggestion shown when the input line is empty. +const DEFAULT_HINT_SUGGESTION: &str = " \\? for help, \\q to quit"; + pub struct CliHelper { completer: FilenameCompleter, dialect: Dialect, highlighter: Box, + /// Tracks whether to show the default hint. Set to `false` once the user + /// types anything, so the hint doesn't reappear after deleting back to + /// an empty line. Reset to `true` when the line is submitted. + show_hint: Cell, } impl CliHelper { @@ -50,6 +58,7 @@ impl CliHelper { completer: FilenameCompleter::new(), dialect: *dialect, highlighter, + show_hint: Cell::new(true), } } @@ -59,6 +68,11 @@ impl CliHelper { } } + /// Re-enable the default hint for the next prompt. + pub fn reset_hint(&self) { + self.show_hint.set(true); + } + fn validate_input(&self, input: &str) -> Result { if let Some(sql) = input.strip_suffix(';') { let dialect = match dialect_from_str(self.dialect) { @@ -114,6 +128,14 @@ impl Highlighter for CliHelper { impl Hinter for CliHelper { type Hint = String; + + fn hint(&self, line: &str, _pos: usize, _ctx: &Context<'_>) -> Option { + if !line.is_empty() { + self.show_hint.set(false); + } + (self.show_hint.get() && line.trim().is_empty()) + .then(|| Color::gray(DEFAULT_HINT_SUGGESTION)) + } } /// returns true if the current position is after the open quote for @@ -121,12 +143,9 @@ impl Hinter for CliHelper { fn is_open_quote_for_location(line: &str, pos: usize) -> bool { let mut sql = line[..pos].to_string(); sql.push('\''); - if let Ok(stmts) = DFParser::parse_sql(&sql) - && let Some(Statement::CreateExternalTable(_)) = stmts.back() - { - return true; - } - false + DFParser::parse_sql(&sql).is_ok_and(|stmts| { + matches!(stmts.back(), Some(Statement::CreateExternalTable(_))) + }) } impl Completer for CliHelper { @@ -149,7 +168,9 @@ impl Completer for CliHelper { impl Validator for CliHelper { fn validate(&self, ctx: &mut ValidationContext<'_>) -> Result { let input = ctx.input().trim_end(); - self.validate_input(input) + let result = self.validate_input(input); + self.reset_hint(); + result } } diff --git a/datafusion-cli/src/highlighter.rs b/datafusion-cli/src/highlighter.rs index 912a13916a5b..adcb135bb401 100644 --- a/datafusion-cli/src/highlighter.rs +++ b/datafusion-cli/src/highlighter.rs @@ -38,7 +38,8 @@ pub struct SyntaxHighlighter { impl SyntaxHighlighter { pub fn new(dialect: &config::Dialect) -> Self { - let dialect = dialect_from_str(dialect).unwrap_or(Box::new(GenericDialect {})); + let dialect = + dialect_from_str(dialect).unwrap_or_else(|| Box::new(GenericDialect {})); Self { dialect } } } @@ -80,16 +81,20 @@ impl Highlighter for SyntaxHighlighter { } /// Convenient utility to return strings with [ANSI color](https://gist.github.com/JBlond/2fea43a3049b38287e5e9cefc87b2124). -struct Color {} +pub(crate) struct Color {} impl Color { - fn green(s: impl Display) -> String { + pub(crate) fn green(s: impl Display) -> String { format!("\x1b[92m{s}\x1b[0m") } - fn red(s: impl Display) -> String { + pub(crate) fn red(s: impl Display) -> String { format!("\x1b[91m{s}\x1b[0m") } + + pub(crate) fn gray(s: impl Display) -> String { + format!("\x1b[90m{s}\x1b[0m") + } } #[cfg(test)] diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index 8f69ae477904..6bfe1160ecdd 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -32,7 +32,7 @@ use datafusion::logical_expr::ExplainFormat; use datafusion::prelude::SessionContext; use datafusion_cli::catalog::DynamicObjectStoreCatalog; use datafusion_cli::functions::{ - MetadataCacheFunc, ParquetMetadataFunc, StatisticsCacheFunc, + ListFilesCacheFunc, MetadataCacheFunc, ParquetMetadataFunc, StatisticsCacheFunc, }; use datafusion_cli::object_storage::instrumented::{ InstrumentedObjectStoreMode, InstrumentedObjectStoreRegistry, @@ -253,6 +253,13 @@ async fn main_inner() -> Result<()> { )), ); + ctx.register_udtf( + "list_files_cache", + Arc::new(ListFilesCacheFunc::new( + ctx.task_ctx().runtime_env().cache_manager.clone(), + )), + ); + let mut print_options = PrintOptions { format: args.format, quiet: args.quiet, @@ -431,15 +438,20 @@ pub fn extract_disk_limit(size: &str) -> Result { #[cfg(test)] mod tests { + use std::time::Duration; + use super::*; use datafusion::{ common::test_util::batches_to_string, execution::cache::{ - cache_manager::CacheManagerConfig, cache_unit::DefaultFileStatisticsCache, + DefaultListFilesCache, cache_manager::CacheManagerConfig, + cache_unit::DefaultFileStatisticsCache, }, - prelude::ParquetReadOptions, + prelude::{ParquetReadOptions, col, lit, split_part}, }; use insta::assert_snapshot; + use object_store::memory::InMemory; + use url::Url; fn assert_conversion(input: &str, expected: Result) { let result = extract_memory_pool_size(input); @@ -605,8 +617,8 @@ mod tests { | filename | file_size_bytes | metadata_size_bytes | hits | extra | +-----------------------------------+-----------------+---------------------+------+------------------+ | alltypes_plain.parquet | 1851 | 8882 | 2 | page_index=false | - | alltypes_tiny_pages.parquet | 454233 | 269266 | 2 | page_index=true | - | lz4_raw_compressed_larger.parquet | 380836 | 1347 | 2 | page_index=false | + | alltypes_tiny_pages.parquet | 454233 | 269074 | 2 | page_index=true | + | lz4_raw_compressed_larger.parquet | 380836 | 1339 | 2 | page_index=false | +-----------------------------------+-----------------+---------------------+------+------------------+ "); @@ -636,8 +648,8 @@ mod tests { | filename | file_size_bytes | metadata_size_bytes | hits | extra | +-----------------------------------+-----------------+---------------------+------+------------------+ | alltypes_plain.parquet | 1851 | 8882 | 5 | page_index=false | - | alltypes_tiny_pages.parquet | 454233 | 269266 | 2 | page_index=true | - | lz4_raw_compressed_larger.parquet | 380836 | 1347 | 3 | page_index=false | + | alltypes_tiny_pages.parquet | 454233 | 269074 | 2 | page_index=true | + | lz4_raw_compressed_larger.parquet | 380836 | 1339 | 3 | page_index=false | +-----------------------------------+-----------------+---------------------+------+------------------+ "); @@ -741,4 +753,99 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_list_files_cache() -> Result<(), DataFusionError> { + let list_files_cache = Arc::new(DefaultListFilesCache::new( + 1024, + Some(Duration::from_secs(1)), + )); + + let rt = RuntimeEnvBuilder::new() + .with_cache_manager( + CacheManagerConfig::default() + .with_list_files_cache(Some(list_files_cache)), + ) + .build_arc() + .unwrap(); + + let ctx = SessionContext::new_with_config_rt(SessionConfig::default(), rt); + + ctx.register_object_store( + &Url::parse("mem://test_table").unwrap(), + Arc::new(InMemory::new()), + ); + + ctx.register_udtf( + "list_files_cache", + Arc::new(ListFilesCacheFunc::new( + ctx.task_ctx().runtime_env().cache_manager.clone(), + )), + ); + + ctx.sql( + "CREATE EXTERNAL TABLE src_table + STORED AS PARQUET + LOCATION '../parquet-testing/data/alltypes_plain.parquet'", + ) + .await? + .collect() + .await?; + + ctx.sql("COPY (SELECT * FROM src_table) TO 'mem://test_table/0.parquet' STORED AS PARQUET").await?.collect().await?; + + ctx.sql("COPY (SELECT * FROM src_table) TO 'mem://test_table/1.parquet' STORED AS PARQUET").await?.collect().await?; + + ctx.sql( + "CREATE EXTERNAL TABLE test_table + STORED AS PARQUET + LOCATION 'mem://test_table/' + ", + ) + .await? + .collect() + .await?; + + let sql = "SELECT metadata_size_bytes, expires_in, metadata_list FROM list_files_cache()"; + let df = ctx + .sql(sql) + .await? + .unnest_columns(&["metadata_list"])? + .with_column_renamed("metadata_list", "metadata")? + .unnest_columns(&["metadata"])?; + + assert_eq!( + 2, + df.clone() + .filter(col("expires_in").is_not_null())? + .count() + .await? + ); + + let df = df + .with_column_renamed(r#""metadata.file_size_bytes""#, "file_size_bytes")? + .with_column_renamed(r#""metadata.e_tag""#, "etag")? + .with_column( + "filename", + split_part(col(r#""metadata.file_path""#), lit("/"), lit(-1)), + )? + .select_columns(&[ + "metadata_size_bytes", + "filename", + "file_size_bytes", + "etag", + ])? + .sort(vec![col("filename").sort(true, false)])?; + let rbs = df.collect().await?; + assert_snapshot!(batches_to_string(&rbs),@r" + +---------------------+-----------+-----------------+------+ + | metadata_size_bytes | filename | file_size_bytes | etag | + +---------------------+-----------+-----------------+------+ + | 212 | 0.parquet | 3642 | 0 | + | 212 | 1.parquet | 3642 | 1 | + +---------------------+-----------+-----------------+------+ + "); + + Ok(()) + } } diff --git a/datafusion-cli/src/object_storage.rs b/datafusion-cli/src/object_storage.rs index 3cee78a5b17c..34787838929f 100644 --- a/datafusion-cli/src/object_storage.rs +++ b/datafusion-cli/src/object_storage.rs @@ -64,6 +64,21 @@ pub async fn get_s3_object_store_builder( url: &Url, aws_options: &AwsOptions, resolve_region: bool, +) -> Result { + // Box the inner future to reduce the future size of this async function, + // which is deeply nested in the CLI's async call chain. + Box::pin(get_s3_object_store_builder_inner( + url, + aws_options, + resolve_region, + )) + .await +} + +async fn get_s3_object_store_builder_inner( + url: &Url, + aws_options: &AwsOptions, + resolve_region: bool, ) -> Result { let AwsOptions { access_key_id, @@ -209,7 +224,7 @@ impl CredentialsFromConfig { #[derive(Debug)] struct S3CredentialProvider { - credentials: aws_credential_types::provider::SharedCredentialsProvider, + credentials: SharedCredentialsProvider, } #[async_trait] @@ -749,7 +764,6 @@ mod tests { eprintln!("{e}"); return Ok(()); } - let expected_region = "eu-central-1"; let location = "s3://test-bucket/path/file.parquet"; // Set it to a non-existent file to avoid reading the default configuration file unsafe { @@ -766,9 +780,10 @@ mod tests { get_s3_object_store_builder(table_url.as_ref(), &aws_options, false).await?; // Verify that the region was auto-detected in test environment - assert_eq!( - builder.get_config_value(&AmazonS3ConfigKey::Region), - Some(expected_region.to_string()) + assert!( + builder + .get_config_value(&AmazonS3ConfigKey::Region) + .is_some() ); Ok(()) diff --git a/datafusion-cli/src/object_storage/instrumented.rs b/datafusion-cli/src/object_storage/instrumented.rs index 0d5e9dc2c5a8..b4f1a043ac8d 100644 --- a/datafusion-cli/src/object_storage/instrumented.rs +++ b/datafusion-cli/src/object_storage/instrumented.rs @@ -36,10 +36,11 @@ use datafusion::{ execution::object_store::{DefaultObjectStoreRegistry, ObjectStoreRegistry}, }; use futures::stream::{BoxStream, Stream}; +use futures::{StreamExt, TryStreamExt}; use object_store::{ - GetOptions, GetRange, GetResult, ListResult, MultipartUpload, ObjectMeta, - ObjectStore, PutMultipartOptions, PutOptions, PutPayload, PutResult, Result, - path::Path, + CopyOptions, GetOptions, GetRange, GetResult, ListResult, MultipartUpload, + ObjectMeta, ObjectStore, ObjectStoreExt, PutMultipartOptions, PutOptions, PutPayload, + PutResult, Result, path::Path, }; use parking_lot::{Mutex, RwLock}; use url::Url; @@ -110,7 +111,7 @@ pub enum InstrumentedObjectStoreMode { } impl fmt::Display for InstrumentedObjectStoreMode { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{self:?}") } } @@ -230,16 +231,26 @@ impl InstrumentedObjectStore { let timestamp = Utc::now(); let range = options.range.clone(); + let head = options.head; let start = Instant::now(); let ret = self.inner.get_opts(location, options).await?; let elapsed = start.elapsed(); + let (op, size) = if head { + (Operation::Head, None) + } else { + ( + Operation::Get, + Some((ret.range.end - ret.range.start) as usize), + ) + }; + self.requests.lock().push(RequestDetails { - op: Operation::Get, + op, path: location.clone(), timestamp, duration: Some(elapsed), - size: Some((ret.range.end - ret.range.start) as usize), + size, range, extra_display: None, }); @@ -247,23 +258,30 @@ impl InstrumentedObjectStore { Ok(ret) } - async fn instrumented_delete(&self, location: &Path) -> Result<()> { + fn instrumented_delete_stream( + &self, + locations: BoxStream<'static, Result>, + ) -> BoxStream<'static, Result> { + let requests_captured = Arc::clone(&self.requests); + let timestamp = Utc::now(); let start = Instant::now(); - self.inner.delete(location).await?; - let elapsed = start.elapsed(); - - self.requests.lock().push(RequestDetails { - op: Operation::Delete, - path: location.clone(), - timestamp, - duration: Some(elapsed), - size: None, - range: None, - extra_display: None, - }); - - Ok(()) + self.inner + .delete_stream(locations) + .and_then(move |location| { + let elapsed = start.elapsed(); + requests_captured.lock().push(RequestDetails { + op: Operation::Delete, + path: location.clone(), + timestamp, + duration: Some(elapsed), + size: None, + range: None, + extra_display: None, + }); + futures::future::ok(location) + }) + .boxed() } fn instrumented_list( @@ -361,29 +379,10 @@ impl InstrumentedObjectStore { Ok(()) } - - async fn instrumented_head(&self, location: &Path) -> Result { - let timestamp = Utc::now(); - let start = Instant::now(); - let ret = self.inner.head(location).await?; - let elapsed = start.elapsed(); - - self.requests.lock().push(RequestDetails { - op: Operation::Head, - path: location.clone(), - timestamp, - duration: Some(elapsed), - size: None, - range: None, - extra_display: None, - }); - - Ok(ret) - } } impl fmt::Display for InstrumentedObjectStore { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mode: InstrumentedObjectStoreMode = self.instrument_mode.load(Ordering::Relaxed).into(); write!( @@ -429,12 +428,15 @@ impl ObjectStore for InstrumentedObjectStore { self.inner.get_opts(location, options).await } - async fn delete(&self, location: &Path) -> Result<()> { + fn delete_stream( + &self, + locations: BoxStream<'static, Result>, + ) -> BoxStream<'static, Result> { if self.enabled() { - return self.instrumented_delete(location).await; + return self.instrumented_delete_stream(locations); } - self.inner.delete(location).await + self.inner.delete_stream(locations) } fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, Result> { @@ -453,28 +455,24 @@ impl ObjectStore for InstrumentedObjectStore { self.inner.list_with_delimiter(prefix).await } - async fn copy(&self, from: &Path, to: &Path) -> Result<()> { - if self.enabled() { - return self.instrumented_copy(from, to).await; - } - - self.inner.copy(from, to).await - } - - async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { - if self.enabled() { - return self.instrumented_copy_if_not_exists(from, to).await; - } - - self.inner.copy_if_not_exists(from, to).await - } - - async fn head(&self, location: &Path) -> Result { + async fn copy_opts( + &self, + from: &Path, + to: &Path, + options: CopyOptions, + ) -> Result<()> { if self.enabled() { - return self.instrumented_head(location).await; + return match options.mode { + object_store::CopyMode::Create => { + self.instrumented_copy_if_not_exists(from, to).await + } + object_store::CopyMode::Overwrite => { + self.instrumented_copy(from, to).await + } + }; } - self.inner.head(location).await + self.inner.copy_opts(from, to, options).await } } @@ -490,7 +488,7 @@ pub enum Operation { } impl fmt::Display for Operation { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{self:?}") } } @@ -508,7 +506,7 @@ pub struct RequestDetails { } impl fmt::Display for RequestDetails { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mut output_parts = vec![format!( "{} operation={:?}", self.timestamp.to_rfc3339(), diff --git a/datafusion-cli/src/print_format.rs b/datafusion-cli/src/print_format.rs index cfb8a32ffcfe..6a6a0370b08a 100644 --- a/datafusion-cli/src/print_format.rs +++ b/datafusion-cli/src/print_format.rs @@ -259,7 +259,7 @@ mod tests { fn print_csv_no_header() { let output = PrintBatchesTest::new() .with_format(PrintFormat::Csv) - .with_batches(split_batch(three_column_batch())) + .with_batches(split_batch(&three_column_batch())) .with_header(WithHeader::No) .run(); assert_snapshot!(output, @r" @@ -273,7 +273,7 @@ mod tests { fn print_csv_with_header() { let output = PrintBatchesTest::new() .with_format(PrintFormat::Csv) - .with_batches(split_batch(three_column_batch())) + .with_batches(split_batch(&three_column_batch())) .with_header(WithHeader::Yes) .run(); assert_snapshot!(output, @r" @@ -288,7 +288,7 @@ mod tests { fn print_tsv_no_header() { let output = PrintBatchesTest::new() .with_format(PrintFormat::Tsv) - .with_batches(split_batch(three_column_batch())) + .with_batches(split_batch(&three_column_batch())) .with_header(WithHeader::No) .run(); assert_snapshot!(output, @r" @@ -302,7 +302,7 @@ mod tests { fn print_tsv_with_header() { let output = PrintBatchesTest::new() .with_format(PrintFormat::Tsv) - .with_batches(split_batch(three_column_batch())) + .with_batches(split_batch(&three_column_batch())) .with_header(WithHeader::Yes) .run(); assert_snapshot!(output, @r" @@ -317,7 +317,7 @@ mod tests { fn print_table() { let output = PrintBatchesTest::new() .with_format(PrintFormat::Table) - .with_batches(split_batch(three_column_batch())) + .with_batches(split_batch(&three_column_batch())) .with_header(WithHeader::Ignored) .run(); assert_snapshot!(output, @r" @@ -334,7 +334,7 @@ mod tests { fn print_json() { let output = PrintBatchesTest::new() .with_format(PrintFormat::Json) - .with_batches(split_batch(three_column_batch())) + .with_batches(split_batch(&three_column_batch())) .with_header(WithHeader::Ignored) .run(); assert_snapshot!(output, @r#"[{"a":1,"b":4,"c":7},{"a":2,"b":5,"c":8},{"a":3,"b":6,"c":9}]"#); @@ -344,7 +344,7 @@ mod tests { fn print_ndjson() { let output = PrintBatchesTest::new() .with_format(PrintFormat::NdJson) - .with_batches(split_batch(three_column_batch())) + .with_batches(split_batch(&three_column_batch())) .with_header(WithHeader::Ignored) .run(); assert_snapshot!(output, @r#" @@ -358,7 +358,7 @@ mod tests { fn print_automatic_no_header() { let output = PrintBatchesTest::new() .with_format(PrintFormat::Automatic) - .with_batches(split_batch(three_column_batch())) + .with_batches(split_batch(&three_column_batch())) .with_header(WithHeader::No) .run(); assert_snapshot!(output, @r" @@ -371,7 +371,7 @@ mod tests { fn print_automatic_with_header() { let output = PrintBatchesTest::new() .with_format(PrintFormat::Automatic) - .with_batches(split_batch(three_column_batch())) + .with_batches(split_batch(&three_column_batch())) .with_header(WithHeader::Yes) .run(); assert_snapshot!(output, @r" @@ -633,7 +633,7 @@ mod tests { } /// Slice the record batch into 2 batches - fn split_batch(batch: RecordBatch) -> Vec { + fn split_batch(batch: &RecordBatch) -> Vec { assert!(batch.num_rows() > 1); let split = batch.num_rows() / 2; vec![ diff --git a/datafusion-cli/src/print_options.rs b/datafusion-cli/src/print_options.rs index 5fbe27d805db..d0810cb034df 100644 --- a/datafusion-cli/src/print_options.rs +++ b/datafusion-cli/src/print_options.rs @@ -115,7 +115,7 @@ impl PrintOptions { row_count: usize, format_options: &FormatOptions, ) -> Result<()> { - let stdout = std::io::stdout(); + let stdout = io::stdout(); let mut writer = stdout.lock(); self.format.print_batches( @@ -137,7 +137,7 @@ impl PrintOptions { query_start_time, ); - self.write_output(&mut writer, formatted_exec_details) + self.write_output(&mut writer, &formatted_exec_details) } /// Print the stream to stdout using the specified format @@ -153,7 +153,7 @@ impl PrintOptions { )); }; - let stdout = std::io::stdout(); + let stdout = io::stdout(); let mut writer = stdout.lock(); let mut row_count = 0_usize; @@ -179,13 +179,13 @@ impl PrintOptions { query_start_time, ); - self.write_output(&mut writer, formatted_exec_details) + self.write_output(&mut writer, &formatted_exec_details) } fn write_output( &self, writer: &mut W, - formatted_exec_details: String, + formatted_exec_details: &str, ) -> Result<()> { if !self.quiet { writeln!(writer, "{formatted_exec_details}")?; @@ -237,11 +237,11 @@ mod tests { let mut print_output: Vec = Vec::new(); let exec_out = String::from("Formatted Exec Output"); - print_options.write_output(&mut print_output, exec_out.clone())?; + print_options.write_output(&mut print_output, &exec_out)?; assert!(print_output.is_empty()); print_options.quiet = false; - print_options.write_output(&mut print_output, exec_out.clone())?; + print_options.write_output(&mut print_output, &exec_out)?; let out_str: String = print_output .clone() .try_into() @@ -253,7 +253,7 @@ mod tests { print_options .instrumented_registry .set_instrument_mode(InstrumentedObjectStoreMode::Trace); - print_options.write_output(&mut print_output, exec_out.clone())?; + print_options.write_output(&mut print_output, &exec_out)?; let out_str: String = print_output .clone() .try_into() diff --git a/datafusion-cli/tests/cli_integration.rs b/datafusion-cli/tests/cli_integration.rs index 8b8b786d652e..99fc2d527eea 100644 --- a/datafusion-cli/tests/cli_integration.rs +++ b/datafusion-cli/tests/cli_integration.rs @@ -44,7 +44,7 @@ fn make_settings() -> Settings { settings } -async fn setup_minio_container() -> ContainerAsync { +async fn setup_minio_container() -> Result, String> { const MINIO_ROOT_USER: &str = "TEST-DataFusionLogin"; const MINIO_ROOT_PASSWORD: &str = "TEST-DataFusionPassword"; @@ -99,27 +99,23 @@ async fn setup_minio_container() -> ContainerAsync { let stdout = container.stdout_to_vec().await.unwrap_or_default(); let stderr = container.stderr_to_vec().await.unwrap_or_default(); - panic!( + return Err(format!( "Failed to execute command: {}\nError: {}\nStdout: {:?}\nStderr: {:?}", cmd_ref, e, String::from_utf8_lossy(&stdout), String::from_utf8_lossy(&stderr) - ); + )); } } - container + Ok(container) } - Err(TestcontainersError::Client(e)) => { - panic!( - "Failed to start MinIO container. Ensure Docker is running and accessible: {e}" - ); - } - Err(e) => { - panic!("Failed to start MinIO container: {e}"); - } + Err(TestcontainersError::Client(e)) => Err(format!( + "Failed to start MinIO container. Ensure Docker is running and accessible: {e}" + )), + Err(e) => Err(format!("Failed to start MinIO container: {e}")), } } @@ -253,7 +249,14 @@ async fn test_cli() { return; } - let container = setup_minio_container().await; + let container = match setup_minio_container().await { + Ok(c) => c, + Err(e) if e.contains("toomanyrequests") => { + eprintln!("Skipping test: Docker pull rate limit reached: {e}"); + return; + } + e @ Err(_) => e.unwrap(), + }; let settings = make_settings(); let _bound = settings.bind_to_scope(); @@ -286,7 +289,14 @@ async fn test_aws_options() { let settings = make_settings(); let _bound = settings.bind_to_scope(); - let container = setup_minio_container().await; + let container = match setup_minio_container().await { + Ok(c) => c, + Err(e) if e.contains("toomanyrequests") => { + eprintln!("Skipping test: Docker pull rate limit reached: {e}"); + return; + } + e @ Err(_) => e.unwrap(), + }; let port = container.get_host_port_ipv4(9000).await.unwrap(); let input = format!( @@ -377,7 +387,14 @@ async fn test_s3_url_fallback() { return; } - let container = setup_minio_container().await; + let container = match setup_minio_container().await { + Ok(c) => c, + Err(e) if e.contains("toomanyrequests") => { + eprintln!("Skipping test: Docker pull rate limit reached: {e}"); + return; + } + e @ Err(_) => e.unwrap(), + }; let mut settings = make_settings(); settings.set_snapshot_suffix("s3_url_fallback"); @@ -407,8 +424,14 @@ async fn test_object_store_profiling() { return; } - let container = setup_minio_container().await; - + let container = match setup_minio_container().await { + Ok(c) => c, + Err(e) if e.contains("toomanyrequests") => { + eprintln!("Skipping test: Docker pull rate limit reached: {e}"); + return; + } + e @ Err(_) => e.unwrap(), + }; let mut settings = make_settings(); // as the object store profiling contains timestamps and durations, we must diff --git a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@no_track.snap b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@no_track.snap index 89b646a531f8..fe454595eb4b 100644 --- a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@no_track.snap +++ b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@no_track.snap @@ -14,7 +14,7 @@ success: false exit_code: 1 ----- stdout ----- [CLI_VERSION] -Error: Not enough memory to continue external sort. Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes +Error: Not enough memory to continue external sort. Consider increasing the memory limit config: 'datafusion.runtime.memory_limit', or decreasing the config: 'datafusion.execution.sort_spill_reservation_bytes'. caused by Resources exhausted: Failed to allocate diff --git a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top2.snap b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top2.snap index 62f864b3adb6..bb30e387166b 100644 --- a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top2.snap +++ b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top2.snap @@ -14,7 +14,7 @@ success: false exit_code: 1 ----- stdout ----- [CLI_VERSION] -Error: Not enough memory to continue external sort. Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes +Error: Not enough memory to continue external sort. Consider increasing the memory limit config: 'datafusion.runtime.memory_limit', or decreasing the config: 'datafusion.execution.sort_spill_reservation_bytes'. caused by Resources exhausted: Additional allocation failed for ExternalSorter[0] with top memory consumers (across reservations) as: Consumer(can spill: bool) consumed XB, peak XB, diff --git a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top3_default.snap b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top3_default.snap index 9845d095c918..891d72e3cc63 100644 --- a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top3_default.snap +++ b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top3_default.snap @@ -12,7 +12,7 @@ success: false exit_code: 1 ----- stdout ----- [CLI_VERSION] -Error: Not enough memory to continue external sort. Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes +Error: Not enough memory to continue external sort. Consider increasing the memory limit config: 'datafusion.runtime.memory_limit', or decreasing the config: 'datafusion.execution.sort_spill_reservation_bytes'. caused by Resources exhausted: Additional allocation failed for ExternalSorter[0] with top memory consumers (across reservations) as: Consumer(can spill: bool) consumed XB, peak XB, diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index b0190dadf3c3..e56f5ad6b8ca 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -35,18 +35,22 @@ rust-version = { workspace = true } [lints] workspace = true -[dev-dependencies] +[dependencies] arrow = { workspace = true } -# arrow_schema is required for record_batch! macro :sad: -arrow-flight = { workspace = true } arrow-schema = { workspace = true } +datafusion = { workspace = true, default-features = true, features = ["parquet_encryption"] } +datafusion-common = { workspace = true } +nom = "8.0.0" +tempfile = { workspace = true } +tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot", "fs"] } + +[dev-dependencies] +arrow-flight = { workspace = true } async-trait = { workspace = true } bytes = { workspace = true } dashmap = { workspace = true } # note only use main datafusion crate for examples base64 = "0.22.1" -datafusion = { workspace = true, default-features = true, features = ["parquet_encryption"] } -datafusion-common = { workspace = true } datafusion-expr = { workspace = true } datafusion-physical-expr-adapter = { workspace = true } datafusion-proto = { workspace = true } @@ -59,17 +63,16 @@ mimalloc = { version = "0.1", default-features = false } object_store = { workspace = true, features = ["aws", "http"] } prost = { workspace = true } rand = { workspace = true } +serde = { version = "1", features = ["derive"] } serde_json = { workspace = true } strum = { workspace = true } strum_macros = { workspace = true } -tempfile = { workspace = true } test-utils = { path = "../test-utils" } -tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] } tonic = "0.14" tracing = { version = "0.1" } tracing-subscriber = { version = "0.3" } url = { workspace = true } -uuid = "1.19" +uuid = { workspace = true } [target.'cfg(not(target_os = "windows"))'.dev-dependencies] -nix = { version = "0.30.1", features = ["fs"] } +nix = { version = "0.31.1", features = ["fs"] } diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 8f38b3899036..2cf0ec52409f 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -71,15 +71,16 @@ cargo run --example dataframe -- dataframe #### Category: Single Process -| Subcommand | File Path | Description | -| --------------------- | ----------------------------------------------------------------------------------------------------- | --------------------------------------------- | -| csv_sql_streaming | [`custom_data_source/csv_sql_streaming.rs`](examples/custom_data_source/csv_sql_streaming.rs) | Run a streaming SQL query against CSV data | -| csv_json_opener | [`custom_data_source/csv_json_opener.rs`](examples/custom_data_source/csv_json_opener.rs) | Use low-level FileOpener APIs for CSV/JSON | -| custom_datasource | [`custom_data_source/custom_datasource.rs`](examples/custom_data_source/custom_datasource.rs) | Query a custom TableProvider | -| custom_file_casts | [`custom_data_source/custom_file_casts.rs`](examples/custom_data_source/custom_file_casts.rs) | Implement custom casting rules | -| custom_file_format | [`custom_data_source/custom_file_format.rs`](examples/custom_data_source/custom_file_format.rs) | Write to a custom file format | -| default_column_values | [`custom_data_source/default_column_values.rs`](examples/custom_data_source/default_column_values.rs) | Custom default values using metadata | -| file_stream_provider | [`custom_data_source/file_stream_provider.rs`](examples/custom_data_source/file_stream_provider.rs) | Read/write via FileStreamProvider for streams | +| Subcommand | File Path | Description | +| --------------------- | ----------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------- | +| adapter_serialization | [`custom_data_source/adapter_serialization.rs`](examples/custom_data_source/adapter_serialization.rs) | Preserve custom PhysicalExprAdapter information during plan serialization using PhysicalExtensionCodec interception | +| csv_json_opener | [`custom_data_source/csv_json_opener.rs`](examples/custom_data_source/csv_json_opener.rs) | Use low-level FileOpener APIs for CSV/JSON | +| csv_sql_streaming | [`custom_data_source/csv_sql_streaming.rs`](examples/custom_data_source/csv_sql_streaming.rs) | Run a streaming SQL query against CSV data | +| custom_datasource | [`custom_data_source/custom_datasource.rs`](examples/custom_data_source/custom_datasource.rs) | Query a custom TableProvider | +| custom_file_casts | [`custom_data_source/custom_file_casts.rs`](examples/custom_data_source/custom_file_casts.rs) | Implement custom casting rules | +| custom_file_format | [`custom_data_source/custom_file_format.rs`](examples/custom_data_source/custom_file_format.rs) | Write to a custom file format | +| default_column_values | [`custom_data_source/default_column_values.rs`](examples/custom_data_source/default_column_values.rs) | Custom default values using metadata | +| file_stream_provider | [`custom_data_source/file_stream_provider.rs`](examples/custom_data_source/file_stream_provider.rs) | Read/write via FileStreamProvider for streams | ## Data IO Examples @@ -106,10 +107,11 @@ cargo run --example dataframe -- dataframe #### Category: Single Process -| Subcommand | File Path | Description | -| --------------------- | ----------------------------------------------------------------------------------- | ------------------------------------------------------ | -| dataframe | [`dataframe/dataframe.rs`](examples/dataframe/dataframe.rs) | Query DataFrames from various sources and write output | -| deserialize_to_struct | [`dataframe/deserialize_to_struct.rs`](examples/dataframe/deserialize_to_struct.rs) | Convert Arrow arrays into Rust structs | +| Subcommand | File Path | Description | +| --------------------- | ----------------------------------------------------------------------------------- | ------------------------------------------------------- | +| cache_factory | [`dataframe/cache_factory.rs`](examples/dataframe/cache_factory.rs) | Custom lazy caching for DataFrames using `CacheFactory` | +| dataframe | [`dataframe/dataframe.rs`](examples/dataframe/dataframe.rs) | Query DataFrames from various sources and write output | +| deserialize_to_struct | [`dataframe/deserialize_to_struct.rs`](examples/dataframe/deserialize_to_struct.rs) | Convert Arrow arrays into Rust structs | ## Execution Monitoring Examples @@ -142,8 +144,8 @@ cargo run --example dataframe -- dataframe | Subcommand | File Path | Description | | ---------- | ------------------------------------------------------- | ------------------------------------------------------ | -| server | [`flight/server.rs`](examples/flight/server.rs) | Run DataFusion server accepting FlightSQL/JDBC queries | | client | [`flight/client.rs`](examples/flight/client.rs) | Execute SQL queries via Arrow Flight protocol | +| server | [`flight/server.rs`](examples/flight/server.rs) | Run DataFusion server accepting FlightSQL/JDBC queries | | sql_server | [`flight/sql_server.rs`](examples/flight/sql_server.rs) | Standalone SQL server for JDBC clients | ## Proto Examples @@ -152,9 +154,10 @@ cargo run --example dataframe -- dataframe #### Category: Single Process -| Subcommand | File Path | Description | -| ------------------------ | --------------------------------------------------------------------------------- | --------------------------------------------------------------- | -| composed_extension_codec | [`proto/composed_extension_codec.rs`](examples/proto/composed_extension_codec.rs) | Use multiple extension codecs for serialization/deserialization | +| Subcommand | File Path | Description | +| ------------------------ | --------------------------------------------------------------------------------- | ----------------------------------------------------------------------------- | +| composed_extension_codec | [`proto/composed_extension_codec.rs`](examples/proto/composed_extension_codec.rs) | Use multiple extension codecs for serialization/deserialization | +| expression_deduplication | [`proto/expression_deduplication.rs`](examples/proto/expression_deduplication.rs) | Example of expression caching/deduplication using the codec decorator pattern | ## Query Planning Examples diff --git a/datafusion-examples/data/README.md b/datafusion-examples/data/README.md new file mode 100644 index 000000000000..e8296a8856e6 --- /dev/null +++ b/datafusion-examples/data/README.md @@ -0,0 +1,25 @@ + + +## Example datasets + +| Filename | Path | Description | +| ----------- | --------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `cars.csv` | [`data/csv/cars.csv`](./csv/cars.csv) | Time-series–like dataset containing car identifiers, speed values, and timestamps. Used in window function and time-based query examples (e.g. ordering, window frames). | +| `regex.csv` | [`data/csv/regex.csv`](./csv/regex.csv) | Dataset for regular expression examples. Contains input values, regex patterns, replacement strings, and optional flags. Covers ASCII, Unicode, and locale-specific text processing. | diff --git a/datafusion-examples/data/csv/cars.csv b/datafusion-examples/data/csv/cars.csv new file mode 100644 index 000000000000..bc40f3b01e7a --- /dev/null +++ b/datafusion-examples/data/csv/cars.csv @@ -0,0 +1,26 @@ +car,speed,time +red,20.0,1996-04-12T12:05:03.000000000 +red,20.3,1996-04-12T12:05:04.000000000 +red,21.4,1996-04-12T12:05:05.000000000 +red,21.5,1996-04-12T12:05:06.000000000 +red,19.0,1996-04-12T12:05:07.000000000 +red,18.0,1996-04-12T12:05:08.000000000 +red,17.0,1996-04-12T12:05:09.000000000 +red,7.0,1996-04-12T12:05:10.000000000 +red,7.1,1996-04-12T12:05:11.000000000 +red,7.2,1996-04-12T12:05:12.000000000 +red,3.0,1996-04-12T12:05:13.000000000 +red,1.0,1996-04-12T12:05:14.000000000 +red,0.0,1996-04-12T12:05:15.000000000 +green,10.0,1996-04-12T12:05:03.000000000 +green,10.3,1996-04-12T12:05:04.000000000 +green,10.4,1996-04-12T12:05:05.000000000 +green,10.5,1996-04-12T12:05:06.000000000 +green,11.0,1996-04-12T12:05:07.000000000 +green,12.0,1996-04-12T12:05:08.000000000 +green,14.0,1996-04-12T12:05:09.000000000 +green,15.0,1996-04-12T12:05:10.000000000 +green,15.1,1996-04-12T12:05:11.000000000 +green,15.2,1996-04-12T12:05:12.000000000 +green,8.0,1996-04-12T12:05:13.000000000 +green,2.0,1996-04-12T12:05:14.000000000 diff --git a/datafusion-examples/data/csv/regex.csv b/datafusion-examples/data/csv/regex.csv new file mode 100644 index 000000000000..b249c39522b6 --- /dev/null +++ b/datafusion-examples/data/csv/regex.csv @@ -0,0 +1,12 @@ +values,patterns,replacement,flags +abc,^(a),bb\1bb,i +ABC,^(A).*,B,i +aBc,(b|d),e,i +AbC,(B|D),e, +aBC,^(b|c),d, +4000,\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b,xyz, +4010,\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b,xyz, +Düsseldorf,[\p{Letter}-]+,München, +Москва,[\p{L}-]+,Moscow, +Köln,[a-zA-Z]ö[a-zA-Z]{2},Koln, +اليوم,^\p{Arabic}+$,Today, \ No newline at end of file diff --git a/datafusion-examples/examples/builtin_functions/function_factory.rs b/datafusion-examples/examples/builtin_functions/function_factory.rs index 7eff0d0b5c48..106c53cdf7f1 100644 --- a/datafusion-examples/examples/builtin_functions/function_factory.rs +++ b/datafusion-examples/examples/builtin_functions/function_factory.rs @@ -24,7 +24,7 @@ use datafusion::error::Result; use datafusion::execution::context::{ FunctionFactory, RegisterFunction, SessionContext, SessionState, }; -use datafusion::logical_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion::logical_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion::logical_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion::logical_expr::{ ColumnarValue, CreateFunction, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, @@ -145,7 +145,7 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { fn simplify( &self, args: Vec, - _info: &dyn SimplifyInfo, + _info: &SimplifyContext, ) -> Result { let replacement = Self::replacement(&self.expr, &args)?; diff --git a/datafusion-examples/examples/builtin_functions/main.rs b/datafusion-examples/examples/builtin_functions/main.rs index 638f56dfbe46..42ca15f91935 100644 --- a/datafusion-examples/examples/builtin_functions/main.rs +++ b/datafusion-examples/examples/builtin_functions/main.rs @@ -26,9 +26,15 @@ //! //! Each subcommand runs a corresponding example: //! - `all` — run all examples included in this module -//! - `date_time` — examples of date-time related functions and queries -//! - `function_factory` — register `CREATE FUNCTION` handler to implement SQL macros -//! - `regexp` — examples of using regular expression functions +//! +//! - `date_time` +//! (file: date_time.rs, desc: Examples of date-time related functions and queries) +//! +//! - `function_factory` +//! (file: function_factory.rs, desc: Register `CREATE FUNCTION` handler to implement SQL macros) +//! +//! - `regexp` +//! (file: regexp.rs, desc: Examples of using regular expression functions) mod date_time; mod function_factory; diff --git a/datafusion-examples/examples/builtin_functions/regexp.rs b/datafusion-examples/examples/builtin_functions/regexp.rs index e8376cd0c94e..97dc71b94e93 100644 --- a/datafusion-examples/examples/builtin_functions/regexp.rs +++ b/datafusion-examples/examples/builtin_functions/regexp.rs @@ -1,5 +1,4 @@ // Licensed to the Apache Software Foundation (ASF) under one -// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file @@ -18,12 +17,10 @@ //! See `main.rs` for how to run it. -use std::{fs::File, io::Write}; - use datafusion::common::{assert_batches_eq, assert_contains}; use datafusion::error::Result; use datafusion::prelude::*; -use tempfile::tempdir; +use datafusion_examples::utils::datasets::ExampleDataset; /// This example demonstrates how to use the regexp_* functions /// @@ -35,29 +32,9 @@ use tempfile::tempdir; /// https://docs.rs/regex/latest/regex/#grouping-and-flags pub async fn regexp() -> Result<()> { let ctx = SessionContext::new(); - // content from file 'datafusion/physical-expr/tests/data/regex.csv' - let csv_data = r#"values,patterns,replacement,flags -abc,^(a),bb\1bb,i -ABC,^(A).*,B,i -aBc,(b|d),e,i -AbC,(B|D),e, -aBC,^(b|c),d, -4000,\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b,xyz, -4010,\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b,xyz, -Düsseldorf,[\p{Letter}-]+,München, -Москва,[\p{L}-]+,Moscow, -Köln,[a-zA-Z]ö[a-zA-Z]{2},Koln, -اليوم,^\p{Arabic}+$,Today,"#; - let dir = tempdir()?; - let file_path = dir.path().join("regex.csv"); - { - let mut file = File::create(&file_path)?; - // write CSV data - file.write_all(csv_data.as_bytes())?; - } // scope closes the file - let file_path = file_path.to_str().unwrap(); - - ctx.register_csv("examples", file_path, CsvReadOptions::new()) + let dataset = ExampleDataset::Regex; + + ctx.register_csv("examples", dataset.path_str()?, CsvReadOptions::new()) .await?; // diff --git a/datafusion-examples/examples/custom_data_source/adapter_serialization.rs b/datafusion-examples/examples/custom_data_source/adapter_serialization.rs new file mode 100644 index 000000000000..a2cd187fee06 --- /dev/null +++ b/datafusion-examples/examples/custom_data_source/adapter_serialization.rs @@ -0,0 +1,519 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. +//! +//! This example demonstrates how to use the `PhysicalExtensionCodec` trait's +//! interception methods (`serialize_physical_plan` and `deserialize_physical_plan`) +//! to implement custom serialization logic. +//! +//! The key insight is that `FileScanConfig::expr_adapter_factory` is NOT serialized by +//! default. This example shows how to: +//! 1. Detect plans with custom adapters during serialization +//! 2. Wrap them as Extension nodes with JSON-serialized adapter metadata +//! 3. Store the inner DataSourceExec (without adapter) as a child in the extension's inputs field +//! 4. Unwrap and restore the adapter during deserialization +//! +//! This demonstrates nested serialization (protobuf outer, JSON inner) and the power +//! of the `PhysicalExtensionCodec` interception pattern. Both plan and expression +//! serialization route through the codec, enabling interception at every node in the tree. + +use std::fmt::Debug; +use std::sync::Arc; + +use arrow::array::record_batch; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::assert_batches_eq; +use datafusion::common::{Result, not_impl_err}; +use datafusion::datasource::listing::{ + ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl, +}; +use datafusion::datasource::physical_plan::{FileScanConfig, FileScanConfigBuilder}; +use datafusion::datasource::source::DataSourceExec; +use datafusion::execution::TaskContext; +use datafusion::execution::context::SessionContext; +use datafusion::execution::object_store::ObjectStoreUrl; +use datafusion::parquet::arrow::ArrowWriter; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionConfig; +use datafusion_physical_expr_adapter::{ + DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, +}; +use datafusion_proto::bytes::{ + physical_plan_from_bytes_with_proto_converter, + physical_plan_to_bytes_with_proto_converter, +}; +use datafusion_proto::physical_plan::from_proto::parse_physical_expr_with_converter; +use datafusion_proto::physical_plan::to_proto::serialize_physical_expr_with_converter; +use datafusion_proto::physical_plan::{ + PhysicalExtensionCodec, PhysicalProtoConverterExtension, +}; +use datafusion_proto::protobuf::physical_plan_node::PhysicalPlanType; +use datafusion_proto::protobuf::{ + PhysicalExprNode, PhysicalExtensionNode, PhysicalPlanNode, +}; +use object_store::memory::InMemory; +use object_store::path::Path; +use object_store::{ObjectStore, ObjectStoreExt, PutPayload}; +use serde::{Deserialize, Serialize}; + +/// Example showing how to preserve custom adapter information during plan serialization. +/// +/// This demonstrates: +/// 1. Creating a custom PhysicalExprAdapter with metadata +/// 2. Using PhysicalExtensionCodec to intercept serialization +/// 3. Wrapping adapter info as Extension nodes +/// 4. Restoring adapters during deserialization +pub async fn adapter_serialization() -> Result<()> { + println!("=== PhysicalExprAdapter Serialization Example ===\n"); + + // Step 1: Create sample Parquet data in memory + println!("Step 1: Creating sample Parquet data..."); + let store = Arc::new(InMemory::new()) as Arc; + let batch = record_batch!(("id", Int32, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))?; + let path = Path::from("data.parquet"); + write_parquet(&store, &path, &batch).await?; + + // Step 2: Set up session with custom adapter + println!("Step 2: Setting up session with custom adapter..."); + let logical_schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + + let mut cfg = SessionConfig::new(); + cfg.options_mut().execution.parquet.pushdown_filters = true; + let ctx = SessionContext::new_with_config(cfg); + ctx.runtime_env().register_object_store( + ObjectStoreUrl::parse("memory://")?.as_ref(), + Arc::clone(&store), + ); + + // Create a table with our custom MetadataAdapterFactory + let adapter_factory = Arc::new(MetadataAdapterFactory::new("v1")); + let listing_config = + ListingTableConfig::new(ListingTableUrl::parse("memory:///data.parquet")?) + .infer_options(&ctx.state()) + .await? + .with_schema(logical_schema) + .with_expr_adapter_factory( + Arc::clone(&adapter_factory) as Arc + ); + let table = ListingTable::try_new(listing_config)?; + ctx.register_table("my_table", Arc::new(table))?; + + // Step 3: Create physical plan with filter + println!("Step 3: Creating physical plan with filter..."); + let df = ctx.sql("SELECT * FROM my_table WHERE id > 5").await?; + let original_plan = df.create_physical_plan().await?; + + // Verify adapter is present in original plan + let has_adapter_before = verify_adapter_in_plan(&original_plan, "original"); + println!(" Original plan has adapter: {has_adapter_before}"); + + // Step 4: Serialize with our custom codec + println!("\nStep 4: Serializing plan with AdapterPreservingCodec..."); + let codec = AdapterPreservingCodec; + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&original_plan), + &codec, + &codec, + )?; + println!(" Serialized {} bytes", bytes.len()); + println!(" (DataSourceExec with adapter was wrapped as PhysicalExtensionNode)"); + + // Step 5: Deserialize with our custom codec + println!("\nStep 5: Deserializing plan with AdapterPreservingCodec..."); + let task_ctx = ctx.task_ctx(); + let restored_plan = + physical_plan_from_bytes_with_proto_converter(&bytes, &task_ctx, &codec, &codec)?; + + // Verify adapter is restored + let has_adapter_after = verify_adapter_in_plan(&restored_plan, "restored"); + println!(" Restored plan has adapter: {has_adapter_after}"); + + // Step 6: Execute and compare results + println!("\nStep 6: Executing plans and comparing results..."); + let original_results = + datafusion::physical_plan::collect(Arc::clone(&original_plan), task_ctx.clone()) + .await?; + let restored_results = + datafusion::physical_plan::collect(restored_plan, task_ctx).await?; + + #[rustfmt::skip] + let expected = [ + "+----+", + "| id |", + "+----+", + "| 6 |", + "| 7 |", + "| 8 |", + "| 9 |", + "| 10 |", + "+----+", + ]; + + println!("\n Original plan results:"); + arrow::util::pretty::print_batches(&original_results)?; + assert_batches_eq!(expected, &original_results); + + println!("\n Restored plan results:"); + arrow::util::pretty::print_batches(&restored_results)?; + assert_batches_eq!(expected, &restored_results); + + println!("\n=== Example Complete! ==="); + println!("Key takeaways:"); + println!( + " 1. PhysicalExtensionCodec provides serialize_physical_plan/deserialize_physical_plan hooks" + ); + println!(" 2. Custom metadata can be wrapped as PhysicalExtensionNode"); + println!(" 3. Nested serialization (protobuf + JSON) works seamlessly"); + println!( + " 4. Both plans produce identical results despite serialization round-trip" + ); + println!(" 5. Adapters are fully preserved through the serialization round-trip"); + + Ok(()) +} + +// ============================================================================ +// MetadataAdapter - A simple custom adapter with a tag +// ============================================================================ + +/// A custom PhysicalExprAdapter that wraps another adapter. +/// The tag metadata is stored in the factory, not the adapter itself. +#[derive(Debug)] +struct MetadataAdapter { + inner: Arc, +} + +impl PhysicalExprAdapter for MetadataAdapter { + fn rewrite(&self, expr: Arc) -> Result> { + // Simply delegate to inner adapter + self.inner.rewrite(expr) + } +} + +// ============================================================================ +// MetadataAdapterFactory - Factory for creating MetadataAdapter instances +// ============================================================================ + +/// Factory for creating MetadataAdapter instances. +/// The tag is stored in the factory and extracted via Debug formatting in `extract_adapter_tag`. +#[derive(Debug)] +struct MetadataAdapterFactory { + // Note: This field is read via Debug formatting in `extract_adapter_tag`. + // Rust's dead code analysis doesn't recognize Debug-based field access. + // In PR #19234, this field is used by `with_partition_values`, but that method + // doesn't exist in upstream DataFusion's PhysicalExprAdapter trait. + #[expect(dead_code)] + tag: String, +} + +impl MetadataAdapterFactory { + fn new(tag: impl Into) -> Self { + Self { tag: tag.into() } + } +} + +impl PhysicalExprAdapterFactory for MetadataAdapterFactory { + fn create( + &self, + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + ) -> Result> { + let inner = DefaultPhysicalExprAdapterFactory + .create(logical_file_schema, physical_file_schema)?; + Ok(Arc::new(MetadataAdapter { inner })) + } +} + +// ============================================================================ +// AdapterPreservingCodec - Custom codec that preserves adapters +// ============================================================================ + +/// Extension payload structure for serializing adapter info +#[derive(Serialize, Deserialize)] +struct ExtensionPayload { + /// Marker to identify this is our custom extension + marker: String, + /// JSON-serialized adapter metadata + adapter_metadata: AdapterMetadata, +} + +/// Metadata about the adapter to recreate it during deserialization +#[derive(Serialize, Deserialize)] +struct AdapterMetadata { + /// The adapter tag (e.g., "v1") + tag: String, +} + +const EXTENSION_MARKER: &str = "adapter_preserving_extension_v1"; + +/// A codec that intercepts serialization to preserve adapter information. +#[derive(Debug)] +struct AdapterPreservingCodec; + +impl PhysicalExtensionCodec for AdapterPreservingCodec { + // Required method: decode custom extension nodes + fn try_decode( + &self, + buf: &[u8], + inputs: &[Arc], + _ctx: &TaskContext, + ) -> Result> { + // Try to parse as our extension payload + if let Ok(payload) = serde_json::from_slice::(buf) + && payload.marker == EXTENSION_MARKER + { + if inputs.len() != 1 { + return Err(datafusion::error::DataFusionError::Plan(format!( + "Extension node expected exactly 1 child, got {}", + inputs.len() + ))); + } + let inner_plan = inputs[0].clone(); + + // Recreate the adapter factory + let adapter_factory = create_adapter_factory(&payload.adapter_metadata.tag); + + // Inject adapter into the plan + return inject_adapter_into_plan(inner_plan, adapter_factory); + } + + not_impl_err!("Unknown extension type") + } + + // Required method: encode custom execution plans + fn try_encode( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> Result<()> { + // We don't need this for the example - we use serialize_physical_plan instead + not_impl_err!( + "try_encode not used - adapter wrapping happens in serialize_physical_plan" + ) + } +} + +impl PhysicalProtoConverterExtension for AdapterPreservingCodec { + fn execution_plan_to_proto( + &self, + plan: &Arc, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + // Check if this is a DataSourceExec with adapter + if let Some(exec) = plan.as_any().downcast_ref::() + && let Some(config) = + exec.data_source().as_any().downcast_ref::() + && let Some(adapter_factory) = &config.expr_adapter_factory + && let Some(tag) = extract_adapter_tag(adapter_factory.as_ref()) + { + // Try to extract our MetadataAdapterFactory's tag + println!(" [Serialize] Found DataSourceExec with adapter tag: {tag}"); + + // 1. Create adapter metadata + let adapter_metadata = AdapterMetadata { tag }; + + // 2. Serialize the inner plan to protobuf + // Note that this will drop the custom adapter since the default serialization cannot handle it + let inner_proto = PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + extension_codec, + self, + )?; + + // 3. Create extension payload to wrap the plan + // so that the custom adapter gets re-attached during deserialization + // The choice of JSON is arbitrary; other formats could be used. + let payload = ExtensionPayload { + marker: EXTENSION_MARKER.to_string(), + adapter_metadata, + }; + let payload_bytes = serde_json::to_vec(&payload).map_err(|e| { + datafusion::error::DataFusionError::Plan(format!( + "Failed to serialize payload: {e}" + )) + })?; + + // 4. Return as PhysicalExtensionNode with child plan in inputs + return Ok(PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Extension( + PhysicalExtensionNode { + node: payload_bytes, + inputs: vec![inner_proto], + }, + )), + }); + } + + // No adapter found, not a DataSourceExec, etc. - use default serialization + PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + extension_codec, + self, + ) + } + + // Interception point: override deserialization to unwrap adapters + fn proto_to_execution_plan( + &self, + ctx: &TaskContext, + extension_codec: &dyn PhysicalExtensionCodec, + proto: &PhysicalPlanNode, + ) -> Result> { + // Check if this is our custom extension wrapper + if let Some(PhysicalPlanType::Extension(extension)) = &proto.physical_plan_type + && let Ok(payload) = + serde_json::from_slice::(&extension.node) + && payload.marker == EXTENSION_MARKER + { + println!( + " [Deserialize] Found adapter extension with tag: {}", + payload.adapter_metadata.tag + ); + + // Get the inner plan proto from inputs field + if extension.inputs.is_empty() { + return Err(datafusion::error::DataFusionError::Plan( + "Extension node missing child plan in inputs".to_string(), + )); + } + let inner_proto = &extension.inputs[0]; + + // Deserialize the inner plan + let inner_plan = inner_proto.try_into_physical_plan_with_converter( + ctx, + extension_codec, + self, + )?; + + // Recreate the adapter factory + let adapter_factory = create_adapter_factory(&payload.adapter_metadata.tag); + + // Inject adapter into the plan + return inject_adapter_into_plan(inner_plan, adapter_factory); + } + + // Not our extension - use default deserialization + proto.try_into_physical_plan_with_converter(ctx, extension_codec, self) + } + + fn proto_to_physical_expr( + &self, + proto: &PhysicalExprNode, + ctx: &TaskContext, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + parse_physical_expr_with_converter(proto, ctx, input_schema, codec, self) + } + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result { + serialize_physical_expr_with_converter(expr, codec, self) + } +} + +// ============================================================================ +// Helper functions +// ============================================================================ + +/// Write a RecordBatch to Parquet in the object store +async fn write_parquet( + store: &dyn ObjectStore, + path: &Path, + batch: &arrow::record_batch::RecordBatch, +) -> Result<()> { + let mut buf = vec![]; + let mut writer = ArrowWriter::try_new(&mut buf, batch.schema(), None)?; + writer.write(batch)?; + writer.close()?; + + let payload = PutPayload::from_bytes(buf.into()); + store.put(path, payload).await?; + Ok(()) +} + +/// Extract the tag from a MetadataAdapterFactory. +/// +/// Note: Since `PhysicalExprAdapterFactory` doesn't provide `as_any()` for downcasting, +/// we parse the Debug output. In a production system, you might add a dedicated trait +/// method for metadata extraction. +fn extract_adapter_tag(factory: &dyn PhysicalExprAdapterFactory) -> Option { + let debug_str = format!("{factory:?}"); + if debug_str.contains("MetadataAdapterFactory") { + // Extract tag from debug output: MetadataAdapterFactory { tag: "v1" } + if let Some(start) = debug_str.find("tag: \"") { + let after_tag = &debug_str[start + 6..]; + if let Some(end) = after_tag.find('"') { + return Some(after_tag[..end].to_string()); + } + } + } + None +} + +/// Create an adapter factory from a tag +fn create_adapter_factory(tag: &str) -> Arc { + Arc::new(MetadataAdapterFactory::new(tag)) +} + +/// Inject an adapter into a plan (assumes plan is a DataSourceExec with FileScanConfig) +fn inject_adapter_into_plan( + plan: Arc, + adapter_factory: Arc, +) -> Result> { + if let Some(exec) = plan.as_any().downcast_ref::() + && let Some(config) = exec.data_source().as_any().downcast_ref::() + { + let new_config = FileScanConfigBuilder::from(config.clone()) + .with_expr_adapter(Some(adapter_factory)) + .build(); + return Ok(DataSourceExec::from_data_source(new_config)); + } + // If not a DataSourceExec with FileScanConfig, return as-is + Ok(plan) +} + +/// Helper to verify if a plan has an adapter (for testing/validation) +fn verify_adapter_in_plan(plan: &Arc, label: &str) -> bool { + // Walk the plan tree to find DataSourceExec with adapter + fn check_plan(plan: &dyn ExecutionPlan) -> bool { + if let Some(exec) = plan.as_any().downcast_ref::() + && let Some(config) = + exec.data_source().as_any().downcast_ref::() + && config.expr_adapter_factory.is_some() + { + return true; + } + // Check children + for child in plan.children() { + if check_plan(child.as_ref()) { + return true; + } + } + false + } + + let has_adapter = check_plan(plan.as_ref()); + println!(" [Verify] {label} plan adapter check: {has_adapter}"); + has_adapter +} diff --git a/datafusion-examples/examples/custom_data_source/csv_json_opener.rs b/datafusion-examples/examples/custom_data_source/csv_json_opener.rs index 7b2e32136263..35f36ea8bc0c 100644 --- a/datafusion-examples/examples/custom_data_source/csv_json_opener.rs +++ b/datafusion-examples/examples/custom_data_source/csv_json_opener.rs @@ -31,12 +31,12 @@ use datafusion::{ }, error::Result, physical_plan::metrics::ExecutionPlanMetricsSet, - test_util::aggr_test_schema, }; use datafusion::datasource::physical_plan::FileScanConfigBuilder; +use datafusion_examples::utils::datasets::ExampleDataset; use futures::StreamExt; -use object_store::{ObjectStore, local::LocalFileSystem, memory::InMemory}; +use object_store::{ObjectStoreExt, local::LocalFileSystem, memory::InMemory}; /// This example demonstrates using the low level [`FileStream`] / [`FileOpener`] APIs to directly /// read data from (CSV/JSON) into Arrow RecordBatches. @@ -50,12 +50,10 @@ pub async fn csv_json_opener() -> Result<()> { async fn csv_opener() -> Result<()> { let object_store = Arc::new(LocalFileSystem::new()); - let schema = aggr_test_schema(); - let testdata = datafusion::test_util::arrow_test_data(); - let path = format!("{testdata}/csv/aggregate_test_100.csv"); - - let path = std::path::Path::new(&path).canonicalize()?; + let dataset = ExampleDataset::Cars; + let csv_path = dataset.path(); + let schema = dataset.schema(); let options = CsvOptions { has_header: Some(true), @@ -71,9 +69,9 @@ async fn csv_opener() -> Result<()> { let scan_config = FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), source) - .with_projection_indices(Some(vec![12, 0]))? + .with_projection_indices(Some(vec![0, 1]))? .with_limit(Some(5)) - .with_file(PartitionedFile::new(path.display().to_string(), 10)) + .with_file(PartitionedFile::new(csv_path.display().to_string(), 10)) .build(); let opener = @@ -89,15 +87,15 @@ async fn csv_opener() -> Result<()> { } assert_batches_eq!( &[ - "+--------------------------------+----+", - "| c13 | c1 |", - "+--------------------------------+----+", - "| 6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW | c |", - "| C2GT5KVyOPZpgKVl110TyZO0NcJ434 | d |", - "| AyYVExXK6AR2qUTxNZ7qRHQOVGMLcz | b |", - "| 0keZ5G8BffGwgF2RwQD59TFzMStxCB | a |", - "| Ig1QcuKsjHXkproePdERo2w0mYzIqd | b |", - "+--------------------------------+----+", + "+-----+-------+", + "| car | speed |", + "+-----+-------+", + "| red | 20.0 |", + "| red | 20.3 |", + "| red | 21.4 |", + "| red | 21.5 |", + "| red | 19.0 |", + "+-----+-------+", ], &result ); @@ -127,6 +125,7 @@ async fn json_opener() -> Result<()> { projected, FileCompressionType::UNCOMPRESSED, Arc::new(object_store), + true, ); let scan_config = FileScanConfigBuilder::new( diff --git a/datafusion-examples/examples/custom_data_source/csv_sql_streaming.rs b/datafusion-examples/examples/custom_data_source/csv_sql_streaming.rs index 554382ea9549..4692086a10b2 100644 --- a/datafusion-examples/examples/custom_data_source/csv_sql_streaming.rs +++ b/datafusion-examples/examples/custom_data_source/csv_sql_streaming.rs @@ -17,9 +17,9 @@ //! See `main.rs` for how to run it. -use datafusion::common::test_util::datafusion_test_data; use datafusion::error::Result; use datafusion::prelude::*; +use datafusion_examples::utils::datasets::ExampleDataset; /// This example demonstrates executing a simple query against an Arrow data source (CSV) and /// fetching results with streaming aggregation and streaming window @@ -27,33 +27,34 @@ pub async fn csv_sql_streaming() -> Result<()> { // create local execution context let ctx = SessionContext::new(); - let testdata = datafusion_test_data(); + let dataset = ExampleDataset::Cars; + let csv_path = dataset.path(); - // Register a table source and tell DataFusion the file is ordered by `ts ASC`. + // Register a table source and tell DataFusion the file is ordered by `car ASC`. // Note it is the responsibility of the user to make sure // that file indeed satisfies this condition or else incorrect answers may be produced. let asc = true; let nulls_first = true; - let sort_expr = vec![col("ts").sort(asc, nulls_first)]; + let sort_expr = vec![col("car").sort(asc, nulls_first)]; // register csv file with the execution context ctx.register_csv( "ordered_table", - &format!("{testdata}/window_1.csv"), + csv_path.to_str().unwrap(), CsvReadOptions::new().file_sort_order(vec![sort_expr]), ) .await?; // execute the query - // Following query can be executed with unbounded sources because group by expressions (e.g ts) is + // Following query can be executed with unbounded sources because group by expressions (e.g car) is // already ordered at the source. // // Unbounded sources means that if the input came from a "never ending" source (such as a FIFO // file on unix) the query could produce results incrementally as data was read. let df = ctx .sql( - "SELECT ts, MIN(inc_col), MAX(inc_col) \ + "SELECT car, MIN(speed), MAX(speed) \ FROM ordered_table \ - GROUP BY ts", + GROUP BY car", ) .await?; @@ -64,7 +65,7 @@ pub async fn csv_sql_streaming() -> Result<()> { // its result in streaming fashion, because its required ordering is already satisfied at the source. let df = ctx .sql( - "SELECT ts, SUM(inc_col) OVER(ORDER BY ts ASC) \ + "SELECT car, SUM(speed) OVER(ORDER BY car ASC) \ FROM ordered_table", ) .await?; diff --git a/datafusion-examples/examples/custom_data_source/custom_datasource.rs b/datafusion-examples/examples/custom_data_source/custom_datasource.rs index b276ae32cf24..7abb39e1a713 100644 --- a/datafusion-examples/examples/custom_data_source/custom_datasource.rs +++ b/datafusion-examples/examples/custom_data_source/custom_datasource.rs @@ -192,7 +192,7 @@ impl TableProvider for CustomDataSource { struct CustomExec { db: CustomDataSource, projected_schema: SchemaRef, - cache: PlanProperties, + cache: Arc, } impl CustomExec { @@ -207,7 +207,7 @@ impl CustomExec { Self { db, projected_schema, - cache, + cache: Arc::new(cache), } } @@ -238,7 +238,7 @@ impl ExecutionPlan for CustomExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } diff --git a/datafusion-examples/examples/custom_data_source/custom_file_casts.rs b/datafusion-examples/examples/custom_data_source/custom_file_casts.rs index 895b6f52b6e1..6b37db653e35 100644 --- a/datafusion-examples/examples/custom_data_source/custom_file_casts.rs +++ b/datafusion-examples/examples/custom_data_source/custom_file_casts.rs @@ -40,7 +40,7 @@ use datafusion_physical_expr_adapter::{ }; use object_store::memory::InMemory; use object_store::path::Path; -use object_store::{ObjectStore, PutPayload}; +use object_store::{ObjectStore, ObjectStoreExt, PutPayload}; // Example showing how to implement custom casting rules to adapt file schemas. // This example enforces that casts must be strictly widening: if the file type is Int64 and the table type is Int32, it will error @@ -49,9 +49,9 @@ use object_store::{ObjectStore, PutPayload}; pub async fn custom_file_casts() -> Result<()> { println!("=== Creating example data ==="); - // Create a logical / table schema with an Int32 column + // Create a logical / table schema with an Int32 column (nullable) let logical_schema = - Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, true)])); // Create some data that can be cast (Int16 -> Int32 is widening) and some that cannot (Int64 -> Int32 is narrowing) let store = Arc::new(InMemory::new()) as Arc; @@ -156,14 +156,14 @@ impl PhysicalExprAdapterFactory for CustomCastPhysicalExprAdapterFactory { &self, logical_file_schema: SchemaRef, physical_file_schema: SchemaRef, - ) -> Arc { + ) -> Result> { let inner = self .inner - .create(logical_file_schema, Arc::clone(&physical_file_schema)); - Arc::new(CustomCastsPhysicalExprAdapter { + .create(logical_file_schema, Arc::clone(&physical_file_schema))?; + Ok(Arc::new(CustomCastsPhysicalExprAdapter { physical_file_schema, inner, - }) + })) } } diff --git a/datafusion-examples/examples/custom_data_source/default_column_values.rs b/datafusion-examples/examples/custom_data_source/default_column_values.rs index 81d74cfbecab..40c8836c1f82 100644 --- a/datafusion-examples/examples/custom_data_source/default_column_values.rs +++ b/datafusion-examples/examples/custom_data_source/default_column_values.rs @@ -48,7 +48,7 @@ use datafusion_physical_expr_adapter::{ use futures::StreamExt; use object_store::memory::InMemory; use object_store::path::Path; -use object_store::{ObjectStore, PutPayload}; +use object_store::{ObjectStore, ObjectStoreExt, PutPayload}; // Metadata key for storing default values in field metadata const DEFAULT_VALUE_METADATA_KEY: &str = "example.default_value"; @@ -79,7 +79,7 @@ pub async fn default_column_values() -> Result<()> { let mut buf = vec![]; let props = WriterProperties::builder() - .set_max_row_group_size(2) + .set_max_row_group_row_count(Some(2)) .build(); let mut writer = @@ -278,18 +278,18 @@ impl PhysicalExprAdapterFactory for DefaultValuePhysicalExprAdapterFactory { &self, logical_file_schema: SchemaRef, physical_file_schema: SchemaRef, - ) -> Arc { + ) -> Result> { let default_factory = DefaultPhysicalExprAdapterFactory; let default_adapter = default_factory.create( Arc::clone(&logical_file_schema), Arc::clone(&physical_file_schema), - ); + )?; - Arc::new(DefaultValuePhysicalExprAdapter { + Ok(Arc::new(DefaultValuePhysicalExprAdapter { logical_file_schema, physical_file_schema, default_adapter, - }) + })) } } diff --git a/datafusion-examples/examples/custom_data_source/file_stream_provider.rs b/datafusion-examples/examples/custom_data_source/file_stream_provider.rs index 936da0a33d47..5b43072d43f8 100644 --- a/datafusion-examples/examples/custom_data_source/file_stream_provider.rs +++ b/datafusion-examples/examples/custom_data_source/file_stream_provider.rs @@ -22,7 +22,7 @@ /// /// On non-Windows systems, this example creates a named pipe (FIFO) and /// writes rows into it asynchronously while DataFusion reads the data -/// through a `FileStreamProvider`. +/// through a `FileStreamProvider`. /// /// This illustrates how to integrate dynamically updated data sources /// with DataFusion without needing to reload the entire dataset each time. @@ -126,7 +126,6 @@ mod non_windows { let broken_pipe_timeout = Duration::from_secs(10); let sa = file_path; // Spawn a new thread to write to the FIFO file - #[allow(clippy::disallowed_methods)] // spawn allowed only in tests tasks.spawn_blocking(move || { let file = OpenOptions::new().write(true).open(sa).unwrap(); // Reference time to use when deciding to fail the test diff --git a/datafusion-examples/examples/custom_data_source/main.rs b/datafusion-examples/examples/custom_data_source/main.rs index 5846626d8138..0d21a6259112 100644 --- a/datafusion-examples/examples/custom_data_source/main.rs +++ b/datafusion-examples/examples/custom_data_source/main.rs @@ -26,14 +26,32 @@ //! //! Each subcommand runs a corresponding example: //! - `all` — run all examples included in this module -//! - `csv_json_opener` — use low level FileOpener APIs to read CSV/JSON into Arrow RecordBatches -//! - `csv_sql_streaming` — build and run a streaming query plan from a SQL statement against a local CSV file -//! - `custom_datasource` — run queries against a custom datasource (TableProvider) -//! - `custom_file_casts` — implement custom casting rules to adapt file schemas -//! - `custom_file_format` — write data to a custom file format -//! - `default_column_values` — implement custom default value handling for missing columns using field metadata and PhysicalExprAdapter -//! - `file_stream_provider` — run a query on FileStreamProvider which implements StreamProvider for reading and writing to arbitrary stream sources/sinks +//! +//! - `adapter_serialization` +//! (file: adapter_serialization.rs, desc: Preserve custom PhysicalExprAdapter information during plan serialization using PhysicalExtensionCodec interception) +//! +//! - `csv_json_opener` +//! (file: csv_json_opener.rs, desc: Use low-level FileOpener APIs for CSV/JSON) +//! +//! - `csv_sql_streaming` +//! (file: csv_sql_streaming.rs, desc: Run a streaming SQL query against CSV data) +//! +//! - `custom_datasource` +//! (file: custom_datasource.rs, desc: Query a custom TableProvider) +//! +//! - `custom_file_casts` +//! (file: custom_file_casts.rs, desc: Implement custom casting rules) +//! +//! - `custom_file_format` +//! (file: custom_file_format.rs, desc: Write to a custom file format) +//! +//! - `default_column_values` +//! (file: default_column_values.rs, desc: Custom default values using metadata) +//! +//! - `file_stream_provider` +//! (file: file_stream_provider.rs, desc: Read/write via FileStreamProvider for streams) +mod adapter_serialization; mod csv_json_opener; mod csv_sql_streaming; mod custom_datasource; @@ -50,6 +68,7 @@ use strum_macros::{Display, EnumIter, EnumString, VariantNames}; #[strum(serialize_all = "snake_case")] enum ExampleKind { All, + AdapterSerialization, CsvJsonOpener, CsvSqlStreaming, CustomDatasource, @@ -74,6 +93,9 @@ impl ExampleKind { Box::pin(example.run()).await?; } } + ExampleKind::AdapterSerialization => { + adapter_serialization::adapter_serialization().await? + } ExampleKind::CsvJsonOpener => csv_json_opener::csv_json_opener().await?, ExampleKind::CsvSqlStreaming => { csv_sql_streaming::csv_sql_streaming().await? diff --git a/datafusion-examples/examples/data_io/catalog.rs b/datafusion-examples/examples/data_io/catalog.rs index d2ddff82e32d..9781a93374ea 100644 --- a/datafusion-examples/examples/data_io/catalog.rs +++ b/datafusion-examples/examples/data_io/catalog.rs @@ -140,7 +140,6 @@ struct DirSchemaOpts<'a> { /// Schema where every file with extension `ext` in a given `dir` is a table. #[derive(Debug)] struct DirSchema { - ext: String, tables: RwLock>>, } @@ -173,14 +172,8 @@ impl DirSchema { } Ok(Arc::new(Self { tables: RwLock::new(tables), - ext: ext.to_string(), })) } - - #[allow(unused)] - fn name(&self) -> &str { - &self.ext - } } #[async_trait] @@ -217,7 +210,6 @@ impl SchemaProvider for DirSchema { /// If supported by the implementation, removes an existing table from this schema and returns it. /// If no table of that name exists, returns Ok(None). - #[allow(unused_variables)] fn deregister_table(&self, name: &str) -> Result>> { let mut tables = self.tables.write().unwrap(); log::info!("dropping table {name}"); diff --git a/datafusion-examples/examples/data_io/json_shredding.rs b/datafusion-examples/examples/data_io/json_shredding.rs index d2ffacc9464c..ca1513f62624 100644 --- a/datafusion-examples/examples/data_io/json_shredding.rs +++ b/datafusion-examples/examples/data_io/json_shredding.rs @@ -47,7 +47,7 @@ use datafusion_physical_expr_adapter::{ }; use object_store::memory::InMemory; use object_store::path::Path; -use object_store::{ObjectStore, PutPayload}; +use object_store::{ObjectStoreExt, PutPayload}; // Example showing how to implement custom filter rewriting for JSON shredding. // @@ -76,7 +76,7 @@ pub async fn json_shredding() -> Result<()> { let mut buf = vec![]; let props = WriterProperties::builder() - .set_max_row_group_size(2) + .set_max_row_group_row_count(Some(2)) .build(); let mut writer = ArrowWriter::try_new(&mut buf, batch.schema(), Some(props)) @@ -275,17 +275,17 @@ impl PhysicalExprAdapterFactory for ShreddedJsonRewriterFactory { &self, logical_file_schema: SchemaRef, physical_file_schema: SchemaRef, - ) -> Arc { + ) -> Result> { let default_factory = DefaultPhysicalExprAdapterFactory; let default_adapter = default_factory.create( Arc::clone(&logical_file_schema), Arc::clone(&physical_file_schema), - ); + )?; - Arc::new(ShreddedJsonRewriter { + Ok(Arc::new(ShreddedJsonRewriter { physical_file_schema, default_adapter, - }) + })) } } diff --git a/datafusion-examples/examples/data_io/main.rs b/datafusion-examples/examples/data_io/main.rs index 0b2bd03f7ea9..0039585d15b6 100644 --- a/datafusion-examples/examples/data_io/main.rs +++ b/datafusion-examples/examples/data_io/main.rs @@ -26,16 +26,36 @@ //! //! Each subcommand runs a corresponding example: //! - `all` — run all examples included in this module -//! - `catalog` — register the table into a custom catalog -//! - `json_shredding` — shows how to implement custom filter rewriting for JSON shredding -//! - `parquet_adv_idx` — create a detailed secondary index that covers the contents of several parquet files -//! - `parquet_emb_idx` — store a custom index inside a Parquet file and use it to speed up queries -//! - `parquet_enc_with_kms` — read and write encrypted Parquet files using an encryption factory -//! - `parquet_enc` — read and write encrypted Parquet files using DataFusion -//! - `parquet_exec_visitor` — extract statistics by visiting an ExecutionPlan after execution -//! - `parquet_idx` — create an secondary index over several parquet files and use it to speed up queries -//! - `query_http_csv` — configure `object_store` and run a query against files via HTTP -//! - `remote_catalog` — interfacing with a remote catalog (e.g. over a network) +//! +//! - `catalog` +//! (file: catalog.rs, desc: Register tables into a custom catalog) +//! +//! - `json_shredding` +//! (file: json_shredding.rs, desc: Implement filter rewriting for JSON shredding) +//! +//! - `parquet_adv_idx` +//! (file: parquet_advanced_index.rs, desc: Create a secondary index across multiple parquet files) +//! +//! - `parquet_emb_idx` +//! (file: parquet_embedded_index.rs, desc: Store a custom index inside Parquet files) +//! +//! - `parquet_enc` +//! (file: parquet_encrypted.rs, desc: Read & write encrypted Parquet files) +//! +//! - `parquet_enc_with_kms` +//! (file: parquet_encrypted_with_kms.rs, desc: Encrypted Parquet I/O using a KMS-backed factory) +//! +//! - `parquet_exec_visitor` +//! (file: parquet_exec_visitor.rs, desc: Extract statistics by visiting an ExecutionPlan) +//! +//! - `parquet_idx` +//! (file: parquet_index.rs, desc: Create a secondary index) +//! +//! - `query_http_csv` +//! (file: query_http_csv.rs, desc: Query CSV files via HTTP) +//! +//! - `remote_catalog` +//! (file: remote_catalog.rs, desc: Interact with a remote catalog) mod catalog; mod json_shredding; diff --git a/datafusion-examples/examples/data_io/parquet_advanced_index.rs b/datafusion-examples/examples/data_io/parquet_advanced_index.rs index 3f4ebe7a9205..f02b01354b78 100644 --- a/datafusion-examples/examples/data_io/parquet_advanced_index.rs +++ b/datafusion-examples/examples/data_io/parquet_advanced_index.rs @@ -43,7 +43,7 @@ use datafusion::parquet::arrow::arrow_reader::{ ArrowReaderOptions, ParquetRecordBatchReaderBuilder, RowSelection, RowSelector, }; use datafusion::parquet::arrow::async_reader::{AsyncFileReader, ParquetObjectReader}; -use datafusion::parquet::file::metadata::ParquetMetaData; +use datafusion::parquet::file::metadata::{PageIndexPolicy, ParquetMetaData}; use datafusion::parquet::file::properties::{EnabledStatistics, WriterProperties}; use datafusion::parquet::schema::types::ColumnPath; use datafusion::physical_expr::PhysicalExpr; @@ -410,7 +410,7 @@ impl IndexedFile { let options = ArrowReaderOptions::new() // Load the page index when reading metadata to cache // so it is available to interpret row selections - .with_page_index(true); + .with_page_index_policy(PageIndexPolicy::Required); let reader = ParquetRecordBatchReaderBuilder::try_new_with_options(file, options)?; let metadata = reader.metadata().clone(); @@ -567,7 +567,7 @@ impl ParquetFileReaderFactory for CachedParquetFileReaderFactory { .object_meta .location .parts() - .last() + .next_back() .expect("No path in location") .as_ref() .to_string(); @@ -659,7 +659,7 @@ fn make_demo_file(path: impl AsRef, value_range: Range) -> Result<()> // enable page statistics for the tag column, // for everything else. let props = WriterProperties::builder() - .set_max_row_group_size(100) + .set_max_row_group_row_count(Some(100)) // compute column chunk (per row group) statistics by default .set_statistics_enabled(EnabledStatistics::Chunk) // compute column page statistics for the tag column diff --git a/datafusion-examples/examples/data_io/parquet_encrypted.rs b/datafusion-examples/examples/data_io/parquet_encrypted.rs index f88ab91321e9..26361e9b52be 100644 --- a/datafusion-examples/examples/data_io/parquet_encrypted.rs +++ b/datafusion-examples/examples/data_io/parquet_encrypted.rs @@ -17,6 +17,8 @@ //! See `main.rs` for how to run it. +use std::sync::Arc; + use datafusion::common::DataFusionError; use datafusion::config::{ConfigFileEncryptionProperties, TableParquetOptions}; use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; @@ -24,7 +26,7 @@ use datafusion::logical_expr::{col, lit}; use datafusion::parquet::encryption::decrypt::FileDecryptionProperties; use datafusion::parquet::encryption::encrypt::FileEncryptionProperties; use datafusion::prelude::{ParquetReadOptions, SessionContext}; -use std::sync::Arc; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; use tempfile::TempDir; /// Read and write encrypted Parquet files using DataFusion @@ -32,13 +34,13 @@ pub async fn parquet_encrypted() -> datafusion::common::Result<()> { // The SessionContext is the main high level API for interacting with DataFusion let ctx = SessionContext::new(); - // Find the local path of "alltypes_plain.parquet" - let testdata = datafusion::test_util::parquet_test_data(); - let filename = &format!("{testdata}/alltypes_plain.parquet"); + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; // Read the sample parquet file let parquet_df = ctx - .read_parquet(filename, ParquetReadOptions::default()) + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) .await?; // Show information from the dataframe @@ -52,27 +54,28 @@ pub async fn parquet_encrypted() -> datafusion::common::Result<()> { let (encrypt, decrypt) = setup_encryption(&parquet_df)?; // Create a temporary file location for the encrypted parquet file - let tmp_dir = TempDir::new()?; - let tempfile = tmp_dir.path().join("alltypes_plain-encrypted.parquet"); - let tempfile_str = tempfile.into_os_string().into_string().unwrap(); + let tmp_source = TempDir::new()?; + let tempfile = tmp_source.path().join("cars_encrypted.parquet"); // Write encrypted parquet let mut options = TableParquetOptions::default(); options.crypto.file_encryption = Some(ConfigFileEncryptionProperties::from(&encrypt)); parquet_df .write_parquet( - tempfile_str.as_str(), + tempfile.to_str().unwrap(), DataFrameWriteOptions::new().with_single_file_output(true), Some(options), ) .await?; - // Read encrypted parquet + // Read encrypted parquet back as a DataFrame using matching decryption config let ctx: SessionContext = SessionContext::new(); let read_options = ParquetReadOptions::default().file_decryption_properties((&decrypt).into()); - let encrypted_parquet_df = ctx.read_parquet(tempfile_str, read_options).await?; + let encrypted_parquet_df = ctx + .read_parquet(tempfile.to_str().unwrap(), read_options) + .await?; // Show information from the dataframe println!( @@ -91,11 +94,12 @@ async fn query_dataframe(df: &DataFrame) -> Result<(), DataFusionError> { df.clone().describe().await?.show().await?; // Select three columns and filter the results - // so that only rows where id > 1 are returned + // so that only rows where speed > 5 are returned + // select car, speed, time from t where speed > 5 println!("\nSelected rows and columns:"); df.clone() - .select_columns(&["id", "bool_col", "timestamp_col"])? - .filter(col("id").gt(lit(5)))? + .select_columns(&["car", "speed", "time"])? + .filter(col("speed").gt(lit(5)))? .show() .await?; diff --git a/datafusion-examples/examples/data_io/parquet_exec_visitor.rs b/datafusion-examples/examples/data_io/parquet_exec_visitor.rs index d38fe9e17120..47caf9480df9 100644 --- a/datafusion-examples/examples/data_io/parquet_exec_visitor.rs +++ b/datafusion-examples/examples/data_io/parquet_exec_visitor.rs @@ -29,28 +29,32 @@ use datafusion::physical_plan::metrics::MetricValue; use datafusion::physical_plan::{ ExecutionPlan, ExecutionPlanVisitor, execute_stream, visit_execution_plan, }; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; use futures::StreamExt; /// Example of collecting metrics after execution by visiting the `ExecutionPlan` pub async fn parquet_exec_visitor() -> datafusion::common::Result<()> { let ctx = SessionContext::new(); - let test_data = datafusion::test_util::parquet_test_data(); + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; // Configure listing options let file_format = ParquetFormat::default().with_enable_pruning(true); let listing_options = ListingOptions::new(Arc::new(file_format)); + let table_path = parquet_temp.file_uri()?; + // First example were we use an absolute path, which requires no additional setup. - let _ = ctx - .register_listing_table( - "my_table", - &format!("file://{test_data}/alltypes_plain.parquet"), - listing_options.clone(), - None, - None, - ) - .await; + ctx.register_listing_table( + "my_table", + &table_path, + listing_options.clone(), + None, + None, + ) + .await?; let df = ctx.sql("SELECT * FROM my_table").await?; let plan = df.create_physical_plan().await?; diff --git a/datafusion-examples/examples/dataframe/cache_factory.rs b/datafusion-examples/examples/dataframe/cache_factory.rs index a6c465720c62..a92c3dc4ce26 100644 --- a/datafusion-examples/examples/dataframe/cache_factory.rs +++ b/datafusion-examples/examples/dataframe/cache_factory.rs @@ -19,31 +19,26 @@ use std::fmt::Debug; use std::hash::Hash; -use std::sync::Arc; -use std::sync::RwLock; +use std::sync::{Arc, RwLock}; use arrow::array::RecordBatch; use async_trait::async_trait; use datafusion::catalog::memory::MemorySourceConfig; use datafusion::common::DFSchemaRef; use datafusion::error::Result; -use datafusion::execution::SessionState; -use datafusion::execution::SessionStateBuilder; use datafusion::execution::context::QueryPlanner; use datafusion::execution::session_state::CacheFactory; -use datafusion::logical_expr::Extension; -use datafusion::logical_expr::LogicalPlan; -use datafusion::logical_expr::UserDefinedLogicalNode; -use datafusion::logical_expr::UserDefinedLogicalNodeCore; -use datafusion::physical_plan::ExecutionPlan; -use datafusion::physical_plan::collect_partitioned; -use datafusion::physical_planner::DefaultPhysicalPlanner; -use datafusion::physical_planner::ExtensionPlanner; -use datafusion::physical_planner::PhysicalPlanner; -use datafusion::prelude::ParquetReadOptions; -use datafusion::prelude::SessionContext; +use datafusion::execution::{SessionState, SessionStateBuilder}; +use datafusion::logical_expr::{ + Extension, LogicalPlan, UserDefinedLogicalNode, UserDefinedLogicalNodeCore, +}; +use datafusion::physical_plan::{ExecutionPlan, collect_partitioned}; +use datafusion::physical_planner::{ + DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner, +}; use datafusion::prelude::*; use datafusion_common::HashMap; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; /// This example demonstrates how to leverage [CacheFactory] to implement custom caching strategies for dataframes in DataFusion. /// By default, [DataFrame::cache] in Datafusion is eager and creates an in-memory table. This example shows a basic alternative implementation for lazy caching. @@ -53,28 +48,29 @@ use datafusion_common::HashMap; /// - A [CacheNodeQueryPlanner] that installs [CacheNodePlanner]. /// - A simple in-memory [CacheManager] that stores cached [RecordBatch]es. Note that the implementation for this example is very naive and only implements put, but for real production use cases cache eviction and drop should also be implemented. pub async fn cache_dataframe_with_custom_logic() -> Result<()> { - let testdata = datafusion::test_util::parquet_test_data(); - let filename = &format!("{testdata}/alltypes_plain.parquet"); - let session_state = SessionStateBuilder::new() .with_cache_factory(Some(Arc::new(CustomCacheFactory {}))) .with_query_planner(Arc::new(CacheNodeQueryPlanner::default())) .build(); let ctx = SessionContext::new_with_state(session_state); + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + // Read the parquet files and show its schema using 'describe' let parquet_df = ctx - .read_parquet(filename, ParquetReadOptions::default()) + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) .await?; let df_cached = parquet_df - .select_columns(&["id", "bool_col", "timestamp_col"])? - .filter(col("id").gt(lit(1)))? + .select_columns(&["car", "speed", "time"])? + .filter(col("speed").gt(lit(1.0)))? .cache() .await?; - let df1 = df_cached.clone().filter(col("bool_col").is_true())?; - let df2 = df1.clone().sort(vec![col("id").sort(true, false)])?; + let df1 = df_cached.clone().filter(col("car").eq(lit("red")))?; + let df2 = df1.clone().sort(vec![col("car").sort(true, false)])?; // should see log for caching only once df_cached.show().await?; diff --git a/datafusion-examples/examples/dataframe/dataframe.rs b/datafusion-examples/examples/dataframe/dataframe.rs index 94653e80c869..dde19cb476f1 100644 --- a/datafusion-examples/examples/dataframe/dataframe.rs +++ b/datafusion-examples/examples/dataframe/dataframe.rs @@ -17,6 +17,10 @@ //! See `main.rs` for how to run it. +use std::fs::File; +use std::io::Write; +use std::sync::Arc; + use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray, StringViewArray}; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::catalog::MemTable; @@ -28,10 +32,9 @@ use datafusion::error::Result; use datafusion::functions_aggregate::average::avg; use datafusion::functions_aggregate::min_max::max; use datafusion::prelude::*; -use std::fs::{File, create_dir_all}; -use std::io::Write; -use std::sync::Arc; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; use tempfile::{TempDir, tempdir}; +use tokio::fs::create_dir_all; /// This example demonstrates using DataFusion's DataFrame API /// @@ -64,8 +67,8 @@ pub async fn dataframe_example() -> Result<()> { read_memory(&ctx).await?; read_memory_macro().await?; write_out(&ctx).await?; - register_aggregate_test_data("t1", &ctx).await?; - register_aggregate_test_data("t2", &ctx).await?; + register_cars_test_data("t1", &ctx).await?; + register_cars_test_data("t2", &ctx).await?; where_scalar_subquery(&ctx).await?; where_in_subquery(&ctx).await?; where_exist_subquery(&ctx).await?; @@ -77,23 +80,24 @@ pub async fn dataframe_example() -> Result<()> { /// 2. Show the schema /// 3. Select columns and rows async fn read_parquet(ctx: &SessionContext) -> Result<()> { - // Find the local path of "alltypes_plain.parquet" - let testdata = datafusion::test_util::parquet_test_data(); - let filename = &format!("{testdata}/alltypes_plain.parquet"); + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(ctx, &dataset.path()).await?; // Read the parquet files and show its schema using 'describe' let parquet_df = ctx - .read_parquet(filename, ParquetReadOptions::default()) + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) .await?; // show its schema using 'describe' parquet_df.clone().describe().await?.show().await?; // Select three columns and filter the results - // so that only rows where id > 1 are returned + // so that only rows where speed > 1 are returned + // select car, speed, time from t where speed > 1 parquet_df - .select_columns(&["id", "bool_col", "timestamp_col"])? - .filter(col("id").gt(lit(1)))? + .select_columns(&["car", "speed", "time"])? + .filter(col("speed").gt(lit(1)))? .show() .await?; @@ -211,15 +215,15 @@ async fn write_out(ctx: &SessionContext) -> Result<()> { // Create a single temp root with subdirectories let tmp_root = TempDir::new()?; let examples_root = tmp_root.path().join("datafusion-examples"); - create_dir_all(&examples_root)?; + create_dir_all(&examples_root).await?; let table_dir = examples_root.join("test_table"); let parquet_dir = examples_root.join("test_parquet"); let csv_dir = examples_root.join("test_csv"); let json_dir = examples_root.join("test_json"); - create_dir_all(&table_dir)?; - create_dir_all(&parquet_dir)?; - create_dir_all(&csv_dir)?; - create_dir_all(&json_dir)?; + create_dir_all(&table_dir).await?; + create_dir_all(&parquet_dir).await?; + create_dir_all(&csv_dir).await?; + create_dir_all(&json_dir).await?; let create_sql = format!( "CREATE EXTERNAL TABLE test(tablecol1 varchar) @@ -266,7 +270,7 @@ async fn write_out(ctx: &SessionContext) -> Result<()> { } /// Use the DataFrame API to execute the following subquery: -/// select c1,c2 from t1 where (select avg(t2.c2) from t2 where t1.c1 = t2.c1)>0 limit 3; +/// select car, speed from t1 where (select avg(t2.speed) from t2 where t1.car = t2.car) > 0 limit 3; async fn where_scalar_subquery(ctx: &SessionContext) -> Result<()> { ctx.table("t1") .await? @@ -274,14 +278,14 @@ async fn where_scalar_subquery(ctx: &SessionContext) -> Result<()> { scalar_subquery(Arc::new( ctx.table("t2") .await? - .filter(out_ref_col(DataType::Utf8, "t1.c1").eq(col("t2.c1")))? - .aggregate(vec![], vec![avg(col("t2.c2"))])? - .select(vec![avg(col("t2.c2"))])? + .filter(out_ref_col(DataType::Utf8, "t1.car").eq(col("t2.car")))? + .aggregate(vec![], vec![avg(col("t2.speed"))])? + .select(vec![avg(col("t2.speed"))])? .into_unoptimized_plan(), )) - .gt(lit(0u8)), + .gt(lit(0.0)), )? - .select(vec![col("t1.c1"), col("t1.c2")])? + .select(vec![col("t1.car"), col("t1.speed")])? .limit(0, Some(3))? .show() .await?; @@ -289,22 +293,24 @@ async fn where_scalar_subquery(ctx: &SessionContext) -> Result<()> { } /// Use the DataFrame API to execute the following subquery: -/// select t1.c1, t1.c2 from t1 where t1.c2 in (select max(t2.c2) from t2 where t2.c1 > 0 ) limit 3; +/// select t1.car, t1.speed from t1 where t1.speed in (select max(t2.speed) from t2 where t2.car = 'red') limit 3; async fn where_in_subquery(ctx: &SessionContext) -> Result<()> { ctx.table("t1") .await? .filter(in_subquery( - col("t1.c2"), + col("t1.speed"), Arc::new( ctx.table("t2") .await? - .filter(col("t2.c1").gt(lit(ScalarValue::UInt8(Some(0)))))? - .aggregate(vec![], vec![max(col("t2.c2"))])? - .select(vec![max(col("t2.c2"))])? + .filter( + col("t2.car").eq(lit(ScalarValue::Utf8(Some("red".to_string())))), + )? + .aggregate(vec![], vec![max(col("t2.speed"))])? + .select(vec![max(col("t2.speed"))])? .into_unoptimized_plan(), ), ))? - .select(vec![col("t1.c1"), col("t1.c2")])? + .select(vec![col("t1.car"), col("t1.speed")])? .limit(0, Some(3))? .show() .await?; @@ -312,31 +318,27 @@ async fn where_in_subquery(ctx: &SessionContext) -> Result<()> { } /// Use the DataFrame API to execute the following subquery: -/// select t1.c1, t1.c2 from t1 where exists (select t2.c2 from t2 where t1.c1 = t2.c1) limit 3; +/// select t1.car, t1.speed from t1 where exists (select t2.speed from t2 where t1.car = t2.car) limit 3; async fn where_exist_subquery(ctx: &SessionContext) -> Result<()> { ctx.table("t1") .await? .filter(exists(Arc::new( ctx.table("t2") .await? - .filter(out_ref_col(DataType::Utf8, "t1.c1").eq(col("t2.c1")))? - .select(vec![col("t2.c2")])? + .filter(out_ref_col(DataType::Utf8, "t1.car").eq(col("t2.car")))? + .select(vec![col("t2.speed")])? .into_unoptimized_plan(), )))? - .select(vec![col("t1.c1"), col("t1.c2")])? + .select(vec![col("t1.car"), col("t1.speed")])? .limit(0, Some(3))? .show() .await?; Ok(()) } -async fn register_aggregate_test_data(name: &str, ctx: &SessionContext) -> Result<()> { - let testdata = datafusion::test_util::arrow_test_data(); - ctx.register_csv( - name, - &format!("{testdata}/csv/aggregate_test_100.csv"), - CsvReadOptions::default(), - ) - .await?; +async fn register_cars_test_data(name: &str, ctx: &SessionContext) -> Result<()> { + let dataset = ExampleDataset::Cars; + ctx.register_csv(name, dataset.path_str()?, CsvReadOptions::default()) + .await?; Ok(()) } diff --git a/datafusion-examples/examples/dataframe/deserialize_to_struct.rs b/datafusion-examples/examples/dataframe/deserialize_to_struct.rs index e19d45554131..b031225dc9b6 100644 --- a/datafusion-examples/examples/dataframe/deserialize_to_struct.rs +++ b/datafusion-examples/examples/dataframe/deserialize_to_struct.rs @@ -17,11 +17,11 @@ //! See `main.rs` for how to run it. -use arrow::array::{AsArray, PrimitiveArray}; -use arrow::datatypes::{Float64Type, Int32Type}; +use arrow::array::{Array, Float64Array, StringViewArray}; use datafusion::common::assert_batches_eq; use datafusion::error::Result; use datafusion::prelude::*; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; use futures::StreamExt; /// This example shows how to convert query results into Rust structs by using @@ -34,63 +34,103 @@ use futures::StreamExt; pub async fn deserialize_to_struct() -> Result<()> { // Run a query that returns two columns of data let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + ctx.register_parquet( - "alltypes_plain", - &format!("{testdata}/alltypes_plain.parquet"), + "cars", + parquet_temp.path_str()?, ParquetReadOptions::default(), ) .await?; + let df = ctx - .sql("SELECT int_col, double_col FROM alltypes_plain") + .sql("SELECT car, speed FROM cars ORDER BY speed LIMIT 50") .await?; - // print out the results showing we have an int32 and a float64 column + // print out the results showing we have car and speed columns and a deterministic ordering let results = df.clone().collect().await?; assert_batches_eq!( [ - "+---------+------------+", - "| int_col | double_col |", - "+---------+------------+", - "| 0 | 0.0 |", - "| 1 | 10.1 |", - "| 0 | 0.0 |", - "| 1 | 10.1 |", - "| 0 | 0.0 |", - "| 1 | 10.1 |", - "| 0 | 0.0 |", - "| 1 | 10.1 |", - "+---------+------------+", + "+-------+-------+", + "| car | speed |", + "+-------+-------+", + "| red | 0.0 |", + "| red | 1.0 |", + "| green | 2.0 |", + "| red | 3.0 |", + "| red | 7.0 |", + "| red | 7.1 |", + "| red | 7.2 |", + "| green | 8.0 |", + "| green | 10.0 |", + "| green | 10.3 |", + "| green | 10.4 |", + "| green | 10.5 |", + "| green | 11.0 |", + "| green | 12.0 |", + "| green | 14.0 |", + "| green | 15.0 |", + "| green | 15.1 |", + "| green | 15.2 |", + "| red | 17.0 |", + "| red | 18.0 |", + "| red | 19.0 |", + "| red | 20.0 |", + "| red | 20.3 |", + "| red | 21.4 |", + "| red | 21.5 |", + "+-------+-------+", ], &results ); // We will now convert the query results into a Rust struct let mut stream = df.execute_stream().await?; - let mut list = vec![]; + let mut list: Vec = vec![]; // DataFusion produces data in chunks called `RecordBatch`es which are // typically 8000 rows each. This loop processes each `RecordBatch` as it is // produced by the query plan and adds it to the list - while let Some(b) = stream.next().await.transpose()? { + while let Some(batch) = stream.next().await.transpose()? { // Each `RecordBatch` has one or more columns. Each column is stored as // an `ArrayRef`. To interact with data using Rust native types we need to // convert these `ArrayRef`s into concrete array types using APIs from // the arrow crate. // In this case, we know that each batch has two columns of the Arrow - // types Int32 and Float64, so first we cast the two columns to the + // types StringView and Float64, so first we cast the two columns to the // appropriate Arrow PrimitiveArray (this is a fast / zero-copy cast).: - let int_col: &PrimitiveArray = b.column(0).as_primitive(); - let float_col: &PrimitiveArray = b.column(1).as_primitive(); + let car_col = batch + .column(0) + .as_any() + .downcast_ref::() + .expect("car column must be Utf8View"); + + let speed_col = batch + .column(1) + .as_any() + .downcast_ref::() + .expect("speed column must be Float64"); // With PrimitiveArrays, we can access to the values as native Rust - // types i32 and f64, and forming the desired `Data` structs - for (i, f) in int_col.values().iter().zip(float_col.values()) { - list.push(Data { - int_col: *i, - double_col: *f, - }) + // types String and f64, and forming the desired `Data` structs + for i in 0..batch.num_rows() { + let car = if car_col.is_null(i) { + None + } else { + Some(car_col.value(i).to_string()) + }; + + let speed = if speed_col.is_null(i) { + None + } else { + Some(speed_col.value(i)) + }; + + list.push(Data { car, speed }); } } @@ -100,45 +140,220 @@ pub async fn deserialize_to_struct() -> Result<()> { res, r#"[ Data { - int_col: 0, - double_col: 0.0, + car: Some( + "red", + ), + speed: Some( + 0.0, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 1.0, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 2.0, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 3.0, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 7.0, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 7.1, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 7.2, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 8.0, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 10.0, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 10.3, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 10.4, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 10.5, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 11.0, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 12.0, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 14.0, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 15.0, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 15.1, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 15.2, + ), }, Data { - int_col: 1, - double_col: 10.1, + car: Some( + "red", + ), + speed: Some( + 17.0, + ), }, Data { - int_col: 0, - double_col: 0.0, + car: Some( + "red", + ), + speed: Some( + 18.0, + ), }, Data { - int_col: 1, - double_col: 10.1, + car: Some( + "red", + ), + speed: Some( + 19.0, + ), }, Data { - int_col: 0, - double_col: 0.0, + car: Some( + "red", + ), + speed: Some( + 20.0, + ), }, Data { - int_col: 1, - double_col: 10.1, + car: Some( + "red", + ), + speed: Some( + 20.3, + ), }, Data { - int_col: 0, - double_col: 0.0, + car: Some( + "red", + ), + speed: Some( + 21.4, + ), }, Data { - int_col: 1, - double_col: 10.1, + car: Some( + "red", + ), + speed: Some( + 21.5, + ), }, ]"# ); - // Use the fields in the struct to avoid clippy complaints - let int_sum = list.iter().fold(0, |acc, x| acc + x.int_col); - let double_sum = list.iter().fold(0.0, |acc, x| acc + x.double_col); - assert_eq!(int_sum, 4); - assert_eq!(double_sum, 40.4); + let speed_green_sum: f64 = list + .iter() + .filter(|data| data.car.as_deref() == Some("green")) + .filter_map(|data| data.speed) + .sum(); + let speed_red_sum: f64 = list + .iter() + .filter(|data| data.car.as_deref() == Some("red")) + .filter_map(|data| data.speed) + .sum(); + assert_eq!(speed_green_sum, 133.5); + assert_eq!(speed_red_sum, 162.5); Ok(()) } @@ -146,6 +361,6 @@ pub async fn deserialize_to_struct() -> Result<()> { /// This is target struct where we want the query results. #[derive(Debug)] struct Data { - int_col: i32, - double_col: f64, + car: Option, + speed: Option, } diff --git a/datafusion-examples/examples/dataframe/main.rs b/datafusion-examples/examples/dataframe/main.rs index 7f2b2d02aeff..25b5377d3823 100644 --- a/datafusion-examples/examples/dataframe/main.rs +++ b/datafusion-examples/examples/dataframe/main.rs @@ -21,13 +21,20 @@ //! //! ## Usage //! ```bash -//! cargo run --example dataframe -- [all|dataframe|deserialize_to_struct] +//! cargo run --example dataframe -- [all|dataframe|deserialize_to_struct|cache_factory] //! ``` //! //! Each subcommand runs a corresponding example: //! - `all` — run all examples included in this module -//! - `dataframe` — run a query using a DataFrame API against parquet files, csv files, and in-memory data, including multiple subqueries -//! - `deserialize_to_struct` — convert query results (Arrow ArrayRefs) into Rust structs +//! +//! - `cache_factory` +//! (file: cache_factory.rs, desc: Custom lazy caching for DataFrames using `CacheFactory`) +// +//! - `dataframe` +//! (file: dataframe.rs, desc: Query DataFrames from various sources and write output) +//! +//! - `deserialize_to_struct` +//! (file: deserialize_to_struct.rs, desc: Convert Arrow arrays into Rust structs) mod cache_factory; mod dataframe; diff --git a/datafusion-examples/examples/execution_monitoring/main.rs b/datafusion-examples/examples/execution_monitoring/main.rs index 07de57f6b80e..8f80c36929ca 100644 --- a/datafusion-examples/examples/execution_monitoring/main.rs +++ b/datafusion-examples/examples/execution_monitoring/main.rs @@ -26,9 +26,15 @@ //! //! Each subcommand runs a corresponding example: //! - `all` — run all examples included in this module -//! - `mem_pool_exec_plan` — shows how to implement memory-aware ExecutionPlan with memory reservation and spilling -//! - `mem_pool_tracking` — demonstrates TrackConsumersPool for memory tracking and debugging with enhanced error messages -//! - `tracing` — demonstrates the tracing injection feature for the DataFusion runtime +//! +//! - `mem_pool_exec_plan` +//! (file: memory_pool_execution_plan.rs, desc: Memory-aware ExecutionPlan with spilling) +//! +//! - `mem_pool_tracking` +//! (file: memory_pool_tracking.rs, desc: Demonstrates memory tracking) +//! +//! - `tracing` +//! (file: tracing.rs, desc: Demonstrates tracing integration) mod memory_pool_execution_plan; mod memory_pool_tracking; diff --git a/datafusion-examples/examples/execution_monitoring/memory_pool_execution_plan.rs b/datafusion-examples/examples/execution_monitoring/memory_pool_execution_plan.rs index 48475acbb154..4c05cd2fb1fb 100644 --- a/datafusion-examples/examples/execution_monitoring/memory_pool_execution_plan.rs +++ b/datafusion-examples/examples/execution_monitoring/memory_pool_execution_plan.rs @@ -38,7 +38,7 @@ use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::logical_expr::LogicalPlanBuilder; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, Statistics, + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, }; use datafusion::prelude::*; use futures::stream::{StreamExt, TryStreamExt}; @@ -199,7 +199,7 @@ impl ExternalBatchBufferer { struct BufferingExecutionPlan { schema: SchemaRef, input: Arc, - properties: PlanProperties, + properties: Arc, } impl BufferingExecutionPlan { @@ -233,7 +233,7 @@ impl ExecutionPlan for BufferingExecutionPlan { self.schema.clone() } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.properties } @@ -296,8 +296,4 @@ impl ExecutionPlan for BufferingExecutionPlan { }), ))) } - - fn statistics(&self) -> Result { - Ok(Statistics::new_unknown(&self.schema)) - } } diff --git a/datafusion-examples/examples/execution_monitoring/memory_pool_tracking.rs b/datafusion-examples/examples/execution_monitoring/memory_pool_tracking.rs index 8d6e5dd7e444..af3031c690fa 100644 --- a/datafusion-examples/examples/execution_monitoring/memory_pool_tracking.rs +++ b/datafusion-examples/examples/execution_monitoring/memory_pool_tracking.rs @@ -110,7 +110,8 @@ async fn automatic_usage_example() -> Result<()> { println!("✓ Expected memory limit error during data processing:"); println!("Error: {e}"); /* Example error message: - Error: Not enough memory to continue external sort. Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes + Error: Not enough memory to continue external sort. Consider increasing the memory limit config: 'datafusion.runtime.memory_limit', + or decreasing the config: 'datafusion.execution.sort_spill_reservation_bytes'. caused by Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: ExternalSorterMerge[3]#112(can spill: false) consumed 10.0 MB, peak 10.0 MB, diff --git a/datafusion-examples/examples/execution_monitoring/tracing.rs b/datafusion-examples/examples/execution_monitoring/tracing.rs index 5fa759f2d541..172c1ca83b3b 100644 --- a/datafusion-examples/examples/execution_monitoring/tracing.rs +++ b/datafusion-examples/examples/execution_monitoring/tracing.rs @@ -51,16 +51,17 @@ //! 10:29:40.809 INFO main ThreadId(01) tracing: ***** WITH tracer: Non-main tasks DID inherit the `run_instrumented_query` span ***** //! ``` +use std::any::Any; +use std::sync::Arc; + use datafusion::common::runtime::{JoinSetTracer, set_join_set_tracer}; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::ListingOptions; use datafusion::error::Result; use datafusion::prelude::*; -use datafusion::test_util::parquet_test_data; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; use futures::FutureExt; use futures::future::BoxFuture; -use std::any::Any; -use std::sync::Arc; use tracing::{Instrument, Level, Span, info, instrument}; /// Demonstrates the tracing injection feature for the DataFusion runtime @@ -126,18 +127,27 @@ async fn run_instrumented_query() -> Result<()> { info!("Starting query execution"); let ctx = SessionContext::new(); - let test_data = parquet_test_data(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + let file_format = ParquetFormat::default().with_enable_pruning(true); - let listing_options = ListingOptions::new(Arc::new(file_format)) - .with_file_extension("alltypes_tiny_pages_plain.parquet"); + let listing_options = + ListingOptions::new(Arc::new(file_format)).with_file_extension(".parquet"); - let table_path = format!("file://{test_data}/"); - info!("Registering table 'alltypes' from {}", table_path); - ctx.register_listing_table("alltypes", &table_path, listing_options, None, None) - .await - .expect("Failed to register table"); + info!("Registering table 'cars' from {}", parquet_temp.path_str()?); + ctx.register_listing_table( + "cars", + parquet_temp.path_str()?, + listing_options, + None, + None, + ) + .await + .expect("Failed to register table"); - let sql = "SELECT COUNT(*), string_col FROM alltypes GROUP BY string_col"; + let sql = "SELECT COUNT(*), car, sum(speed) FROM cars GROUP BY car"; info!(sql, "Executing SQL query"); let result = ctx.sql(sql).await?.collect().await?; info!("Query complete: {} batches returned", result.len()); diff --git a/datafusion-examples/examples/external_dependency/main.rs b/datafusion-examples/examples/external_dependency/main.rs index 0a9a2cd2372d..447e7d38bdd5 100644 --- a/datafusion-examples/examples/external_dependency/main.rs +++ b/datafusion-examples/examples/external_dependency/main.rs @@ -26,8 +26,12 @@ //! //! Each subcommand runs a corresponding example: //! - `all` — run all examples included in this module -//! - `dataframe_to_s3` — run a query using a DataFrame against a parquet file from AWS S3 and writing back to AWS S3 -//! - `query_aws_s3` — configure `object_store` and run a query against files stored in AWS S3 +//! +//! - `dataframe_to_s3` +//! (file: dataframe_to_s3.rs, desc: Query DataFrames and write results to S3) +//! +//! - `query_aws_s3` +//! (file: query_aws_s3.rs, desc: Query S3-backed data using object_store) mod dataframe_to_s3; mod query_aws_s3; diff --git a/datafusion-examples/examples/ffi/ffi_example_table_provider/Cargo.toml b/datafusion-examples/examples/ffi/ffi_example_table_provider/Cargo.toml index e9c0c5b43d68..e2d0e3fa6744 100644 --- a/datafusion-examples/examples/ffi/ffi_example_table_provider/Cargo.toml +++ b/datafusion-examples/examples/ffi/ffi_example_table_provider/Cargo.toml @@ -28,6 +28,9 @@ datafusion = { workspace = true } datafusion-ffi = { workspace = true } ffi_module_interface = { path = "../ffi_module_interface" } +[lints] +workspace = true + [lib] name = "ffi_example_table_provider" crate-type = ["cdylib", 'rlib'] diff --git a/datafusion-examples/examples/ffi/ffi_module_interface/Cargo.toml b/datafusion-examples/examples/ffi/ffi_module_interface/Cargo.toml index f393b2971e45..fe4902711241 100644 --- a/datafusion-examples/examples/ffi/ffi_module_interface/Cargo.toml +++ b/datafusion-examples/examples/ffi/ffi_module_interface/Cargo.toml @@ -21,6 +21,9 @@ version = "0.1.0" edition = "2024" publish = false +[lints] +workspace = true + [dependencies] abi_stable = "0.11.3" datafusion-ffi = { workspace = true } diff --git a/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml b/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml index 823c9afddee2..8d7434dca211 100644 --- a/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml +++ b/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml @@ -21,6 +21,9 @@ version = "0.1.0" edition = "2024" publish = false +[lints] +workspace = true + [dependencies] abi_stable = "0.11.3" datafusion = { workspace = true } diff --git a/datafusion-examples/examples/flight/client.rs b/datafusion-examples/examples/flight/client.rs index 484576975a6f..8f6856a4e484 100644 --- a/datafusion-examples/examples/flight/client.rs +++ b/datafusion-examples/examples/flight/client.rs @@ -19,21 +19,26 @@ use std::collections::HashMap; use std::sync::Arc; -use tonic::transport::Endpoint; - -use datafusion::arrow::datatypes::Schema; use arrow_flight::flight_descriptor; use arrow_flight::flight_service_client::FlightServiceClient; use arrow_flight::utils::flight_data_to_arrow_batch; use arrow_flight::{FlightDescriptor, Ticket}; +use datafusion::arrow::datatypes::Schema; use datafusion::arrow::util::pretty; +use datafusion::prelude::SessionContext; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; +use tonic::transport::Endpoint; /// This example shows how to wrap DataFusion with `FlightService` to support looking up schema information for /// Parquet files and executing SQL queries against them on a remote server. /// This example is run along-side the example `flight_server`. pub async fn client() -> Result<(), Box> { - let testdata = datafusion::test_util::parquet_test_data(); + let ctx = SessionContext::new(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; // Create Flight client let endpoint = Endpoint::new("http://localhost:50051")?; @@ -44,7 +49,7 @@ pub async fn client() -> Result<(), Box> { let request = tonic::Request::new(FlightDescriptor { r#type: flight_descriptor::DescriptorType::Path as i32, cmd: Default::default(), - path: vec![format!("{testdata}/alltypes_plain.parquet")], + path: vec![format!("{}", parquet_temp.path_str()?)], }); let schema_result = client.get_schema(request).await?.into_inner(); @@ -53,7 +58,7 @@ pub async fn client() -> Result<(), Box> { // Call do_get to execute a SQL query and receive results let request = tonic::Request::new(Ticket { - ticket: "SELECT id FROM alltypes_plain".into(), + ticket: "SELECT car FROM cars".into(), }); let mut stream = client.do_get(request).await?.into_inner(); diff --git a/datafusion-examples/examples/flight/main.rs b/datafusion-examples/examples/flight/main.rs index 6f20f576d3a7..426e806486f7 100644 --- a/datafusion-examples/examples/flight/main.rs +++ b/datafusion-examples/examples/flight/main.rs @@ -29,9 +29,15 @@ //! Note: The Flight server must be started in a separate process //! before running the `client` example. Therefore, running `all` will //! not produce a full server+client workflow automatically. -//! - `client` — run DataFusion as a standalone process and execute SQL queries from a client using the Flight protocol -//! - `server` — run DataFusion as a standalone process and execute SQL queries from a client using the Flight protocol -//! - `sql_server` — run DataFusion as a standalone process and execute SQL queries from JDBC clients +//! +//! - `client` +//! (file: client.rs, desc: Execute SQL queries via Arrow Flight protocol) +//! +//! - `server` +//! (file: server.rs, desc: Run DataFusion server accepting FlightSQL/JDBC queries) +//! +//! - `sql_server` +//! (file: sql_server.rs, desc: Standalone SQL server for JDBC clients) mod client; mod server; diff --git a/datafusion-examples/examples/flight/server.rs b/datafusion-examples/examples/flight/server.rs index aad82e28b15e..b73c81dd7d2c 100644 --- a/datafusion-examples/examples/flight/server.rs +++ b/datafusion-examples/examples/flight/server.rs @@ -17,25 +17,24 @@ //! See `main.rs` for how to run it. -use arrow::ipc::writer::{CompressionContext, DictionaryTracker, IpcDataGenerator}; use std::sync::Arc; +use arrow::ipc::writer::{CompressionContext, DictionaryTracker, IpcDataGenerator}; +use arrow_flight::{ + Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, + HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket, + flight_service_server::FlightService, flight_service_server::FlightServiceServer, +}; use arrow_flight::{PollInfo, SchemaAsIpc}; use datafusion::arrow::error::ArrowError; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::{ListingOptions, ListingTableUrl}; +use datafusion::prelude::*; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; use futures::stream::BoxStream; use tonic::transport::Server; use tonic::{Request, Response, Status, Streaming}; -use datafusion::prelude::*; - -use arrow_flight::{ - Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, - HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket, - flight_service_server::FlightService, flight_service_server::FlightServiceServer, -}; - #[derive(Clone)] pub struct FlightServiceImpl {} @@ -85,16 +84,21 @@ impl FlightService for FlightServiceImpl { // create local execution context let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()) + .await + .map_err(|e| { + Status::internal(format!("Error writing csv to parquet: {e}")) + })?; + let parquet_path = parquet_temp.path_str().map_err(|e| { + Status::internal(format!("Error getting parquet path: {e}")) + })?; // register parquet file with the execution context - ctx.register_parquet( - "alltypes_plain", - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) - .await - .map_err(to_tonic_err)?; + ctx.register_parquet("cars", parquet_path, ParquetReadOptions::default()) + .await + .map_err(to_tonic_err)?; // create the DataFrame let df = ctx.sql(sql).await.map_err(to_tonic_err)?; diff --git a/datafusion-examples/examples/flight/sql_server.rs b/datafusion-examples/examples/flight/sql_server.rs index 435e05ffc0ce..e55aaa7250ea 100644 --- a/datafusion-examples/examples/flight/sql_server.rs +++ b/datafusion-examples/examples/flight/sql_server.rs @@ -17,6 +17,9 @@ //! See `main.rs` for how to run it. +use std::pin::Pin; +use std::sync::Arc; + use arrow::array::{ArrayRef, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::ipc::writer::IpcWriteOptions; @@ -38,12 +41,11 @@ use arrow_flight::{ use dashmap::DashMap; use datafusion::logical_expr::LogicalPlan; use datafusion::prelude::{DataFrame, ParquetReadOptions, SessionConfig, SessionContext}; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; use futures::{Stream, StreamExt, TryStreamExt}; use log::info; use mimalloc::MiMalloc; use prost::Message; -use std::pin::Pin; -use std::sync::Arc; use tonic::metadata::MetadataValue; use tonic::transport::Server; use tonic::{Request, Response, Status, Streaming}; @@ -100,22 +102,24 @@ impl FlightSqlServiceImpl { .with_information_schema(true); let ctx = Arc::new(SessionContext::new_with_config(session_config)); - let testdata = datafusion::test_util::parquet_test_data(); + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()) + .await + .map_err(|e| status!("Error writing csv to parquet", e))?; + let parquet_path = parquet_temp + .path_str() + .map_err(|e| status!("Error getting parquet path", e))?; // register parquet file with the execution context - ctx.register_parquet( - "alltypes_plain", - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) - .await - .map_err(|e| status!("Error registering table", e))?; + ctx.register_parquet("cars", parquet_path, ParquetReadOptions::default()) + .await + .map_err(|e| status!("Error registering table", e))?; self.contexts.insert(uuid.clone(), ctx); Ok(uuid) } - #[allow(clippy::result_large_err)] fn get_ctx(&self, req: &Request) -> Result, Status> { // get the token from the authorization header on Request let auth = req @@ -141,7 +145,6 @@ impl FlightSqlServiceImpl { } } - #[allow(clippy::result_large_err)] fn get_plan(&self, handle: &str) -> Result { if let Some(plan) = self.statements.get(handle) { Ok(plan.clone()) @@ -150,7 +153,6 @@ impl FlightSqlServiceImpl { } } - #[allow(clippy::result_large_err)] fn get_result(&self, handle: &str) -> Result, Status> { if let Some(result) = self.results.get(handle) { Ok(result.clone()) @@ -198,13 +200,11 @@ impl FlightSqlServiceImpl { .unwrap() } - #[allow(clippy::result_large_err)] fn remove_plan(&self, handle: &str) -> Result<(), Status> { self.statements.remove(&handle.to_string()); Ok(()) } - #[allow(clippy::result_large_err)] fn remove_result(&self, handle: &str) -> Result<(), Status> { self.results.remove(&handle.to_string()); Ok(()) diff --git a/datafusion-examples/examples/proto/composed_extension_codec.rs b/datafusion-examples/examples/proto/composed_extension_codec.rs index f3910d461b6a..b4f3d4f09899 100644 --- a/datafusion-examples/examples/proto/composed_extension_codec.rs +++ b/datafusion-examples/examples/proto/composed_extension_codec.rs @@ -106,7 +106,7 @@ impl ExecutionPlan for ParentExec { self } - fn properties(&self) -> &datafusion::physical_plan::PlanProperties { + fn properties(&self) -> &Arc { unreachable!() } @@ -182,7 +182,7 @@ impl ExecutionPlan for ChildExec { self } - fn properties(&self) -> &datafusion::physical_plan::PlanProperties { + fn properties(&self) -> &Arc { unreachable!() } diff --git a/datafusion-examples/examples/proto/expression_deduplication.rs b/datafusion-examples/examples/proto/expression_deduplication.rs new file mode 100644 index 000000000000..0dec807f8043 --- /dev/null +++ b/datafusion-examples/examples/proto/expression_deduplication.rs @@ -0,0 +1,275 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. +//! +//! This example demonstrates how to use the `PhysicalExtensionCodec` trait's +//! interception methods to implement expression deduplication during deserialization. +//! +//! This pattern is inspired by PR #18192, which introduces expression caching +//! to reduce memory usage when deserializing plans with duplicate expressions. +//! +//! The key insight is that identical expressions serialize to identical protobuf bytes. +//! By caching deserialized expressions keyed by their protobuf bytes, we can: +//! 1. Return the same Arc for duplicate expressions +//! 2. Reduce memory allocation during deserialization +//! 3. Enable downstream optimizations that rely on Arc pointer equality +//! +//! This demonstrates the decorator pattern enabled by the `PhysicalExtensionCodec` trait, +//! where all expression serialization/deserialization routes through the codec methods. + +use std::collections::HashMap; +use std::fmt::Debug; +use std::sync::{Arc, RwLock}; + +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::common::Result; +use datafusion::execution::TaskContext; +use datafusion::logical_expr::Operator; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::expressions::{BinaryExpr, col}; +use datafusion::physical_plan::filter::FilterExec; +use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; +use datafusion::prelude::SessionContext; +use datafusion_proto::physical_plan::from_proto::parse_physical_expr_with_converter; +use datafusion_proto::physical_plan::to_proto::serialize_physical_expr_with_converter; +use datafusion_proto::physical_plan::{ + DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, + PhysicalProtoConverterExtension, +}; +use datafusion_proto::protobuf::{PhysicalExprNode, PhysicalPlanNode}; +use prost::Message; + +/// Example showing how to implement expression deduplication using the codec decorator pattern. +/// +/// This demonstrates: +/// 1. Creating a CachingCodec that caches expressions by their protobuf bytes +/// 2. Intercepting deserialization to return cached Arcs for duplicate expressions +/// 3. Verifying that duplicate expressions share the same Arc after deserialization +/// +/// Deduplication is keyed by the protobuf bytes representing the expression, +/// in reality deduplication could be done based on e.g. the pointer address of the +/// serialized expression in memory, but this is simpler to demonstrate. +/// +/// In this case our expression is trivial and just for demonstration purposes. +/// In real scenarios, expressions can be much more complex, e.g. a large InList +/// expression could be megabytes in size, so deduplication can save significant memory +/// in addition to more correctly representing the original plan structure. +pub async fn expression_deduplication() -> Result<()> { + println!("=== Expression Deduplication Example ===\n"); + + // Create a schema for our test expressions + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, false)])); + + // Step 1: Create expressions with duplicates + println!("Step 1: Creating expressions with duplicates..."); + + // Create expression: col("a") + let a = col("a", &schema)?; + + // Create a clone to show duplicates + let a_clone = Arc::clone(&a); + + // Combine: a OR a_clone + let combined_expr = + Arc::new(BinaryExpr::new(a, Operator::Or, a_clone)) as Arc; + println!(" Created expression: a OR a with duplicates"); + println!(" Note: a appears twice in the expression tree\n"); + // Step 2: Create a filter plan with this expression + println!("Step 2: Creating physical plan with the expression..."); + + let input = Arc::new(PlaceholderRowExec::new(Arc::clone(&schema))); + let filter_plan: Arc = + Arc::new(FilterExec::try_new(combined_expr, input)?); + + println!(" Created FilterExec with duplicate sub-expressions\n"); + + // Step 3: Serialize with the caching codec + println!("Step 3: Serializing plan..."); + + let extension_codec = DefaultPhysicalExtensionCodec {}; + let caching_converter = CachingCodec::new(); + let proto = + caching_converter.execution_plan_to_proto(&filter_plan, &extension_codec)?; + + // Serialize to bytes + let mut bytes = Vec::new(); + proto.encode(&mut bytes).unwrap(); + println!(" Serialized plan to {} bytes\n", bytes.len()); + + // Step 4: Deserialize with the caching codec + println!("Step 4: Deserializing plan with CachingCodec..."); + + let ctx = SessionContext::new(); + let deserialized_plan = proto.try_into_physical_plan_with_converter( + &ctx.task_ctx(), + &extension_codec, + &caching_converter, + )?; + + // Step 5: check that we deduplicated expressions + println!("Step 5: Checking for deduplicated expressions..."); + let Some(filter_exec) = deserialized_plan.as_any().downcast_ref::() + else { + panic!("Deserialized plan is not a FilterExec"); + }; + let predicate = Arc::clone(filter_exec.predicate()); + let binary_expr = predicate + .as_any() + .downcast_ref::() + .expect("Predicate is not a BinaryExpr"); + let left = &binary_expr.left(); + let right = &binary_expr.right(); + // Check if left and right point to the same Arc + let deduplicated = Arc::ptr_eq(left, right); + if deduplicated { + println!(" Success: Duplicate expressions were deduplicated!"); + println!( + " Cache Stats: hits={}, misses={}", + caching_converter.stats.read().unwrap().cache_hits, + caching_converter.stats.read().unwrap().cache_misses, + ); + } else { + println!(" Failure: Duplicate expressions were NOT deduplicated."); + } + + Ok(()) +} + +// ============================================================================ +// CachingCodec - Implements expression deduplication +// ============================================================================ + +/// Statistics for cache performance monitoring +#[derive(Debug, Default)] +struct CacheStats { + cache_hits: usize, + cache_misses: usize, +} + +/// A codec that caches deserialized expressions to enable deduplication. +/// +/// When deserializing, if we've already seen the same protobuf bytes, +/// we return the cached Arc instead of creating a new allocation. +#[derive(Debug, Default)] +struct CachingCodec { + /// Cache mapping protobuf bytes -> deserialized expression + expr_cache: RwLock, Arc>>, + /// Statistics for demonstration + stats: RwLock, +} + +impl CachingCodec { + fn new() -> Self { + Self::default() + } +} + +impl PhysicalExtensionCodec for CachingCodec { + // Required: decode custom extension nodes + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[Arc], + _ctx: &TaskContext, + ) -> Result> { + datafusion::common::not_impl_err!("No custom extension nodes") + } + + // Required: encode custom execution plans + fn try_encode( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> Result<()> { + datafusion::common::not_impl_err!("No custom extension nodes") + } +} + +impl PhysicalProtoConverterExtension for CachingCodec { + fn proto_to_execution_plan( + &self, + ctx: &TaskContext, + extension_codec: &dyn PhysicalExtensionCodec, + proto: &PhysicalPlanNode, + ) -> Result> { + proto.try_into_physical_plan_with_converter(ctx, extension_codec, self) + } + + fn execution_plan_to_proto( + &self, + plan: &Arc, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + extension_codec, + self, + ) + } + + // CACHING IMPLEMENTATION: Intercept expression deserialization + fn proto_to_physical_expr( + &self, + proto: &PhysicalExprNode, + ctx: &TaskContext, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + // Create cache key from protobuf bytes + let mut key = Vec::new(); + proto.encode(&mut key).map_err(|e| { + datafusion::error::DataFusionError::Internal(format!( + "Failed to encode proto for cache key: {e}" + )) + })?; + + // Check cache first + { + let cache = self.expr_cache.read().unwrap(); + if let Some(cached) = cache.get(&key) { + // Cache hit! Update stats and return cached Arc + let mut stats = self.stats.write().unwrap(); + stats.cache_hits += 1; + return Ok(Arc::clone(cached)); + } + } + + // Cache miss - deserialize and store + let expr = + parse_physical_expr_with_converter(proto, ctx, input_schema, codec, self)?; + + // Store in cache + { + let mut cache = self.expr_cache.write().unwrap(); + cache.insert(key, Arc::clone(&expr)); + let mut stats = self.stats.write().unwrap(); + stats.cache_misses += 1; + } + + Ok(expr) + } + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result { + serialize_physical_expr_with_converter(expr, codec, self) + } +} diff --git a/datafusion-examples/examples/proto/main.rs b/datafusion-examples/examples/proto/main.rs index f56078b31997..3f525b5d46af 100644 --- a/datafusion-examples/examples/proto/main.rs +++ b/datafusion-examples/examples/proto/main.rs @@ -21,14 +21,20 @@ //! //! ## Usage //! ```bash -//! cargo run --example proto -- [all|composed_extension_codec] +//! cargo run --example proto -- [all|composed_extension_codec|expression_deduplication] //! ``` //! //! Each subcommand runs a corresponding example: //! - `all` — run all examples included in this module -//! - `composed_extension_codec` — example of using multiple extension codecs for serialization / deserialization +//! +//! - `composed_extension_codec` +//! (file: composed_extension_codec.rs, desc: Use multiple extension codecs for serialization/deserialization) +//! +//! - `expression_deduplication` +//! (file: expression_deduplication.rs, desc: Example of expression caching/deduplication using the codec decorator pattern) mod composed_extension_codec; +mod expression_deduplication; use datafusion::error::{DataFusionError, Result}; use strum::{IntoEnumIterator, VariantNames}; @@ -39,6 +45,7 @@ use strum_macros::{Display, EnumIter, EnumString, VariantNames}; enum ExampleKind { All, ComposedExtensionCodec, + ExpressionDeduplication, } impl ExampleKind { @@ -59,6 +66,9 @@ impl ExampleKind { ExampleKind::ComposedExtensionCodec => { composed_extension_codec::composed_extension_codec().await? } + ExampleKind::ExpressionDeduplication => { + expression_deduplication::expression_deduplication().await? + } } Ok(()) } diff --git a/datafusion-examples/examples/query_planning/expr_api.rs b/datafusion-examples/examples/query_planning/expr_api.rs index 47de669023f7..386273c72817 100644 --- a/datafusion-examples/examples/query_planning/expr_api.rs +++ b/datafusion-examples/examples/query_planning/expr_api.rs @@ -175,8 +175,9 @@ fn simplify_demo() -> Result<()> { // the ExecutionProps carries information needed to simplify // expressions, such as the current time (to evaluate `now()` // correctly) - let props = ExecutionProps::new(); - let context = SimplifyContext::new(&props).with_schema(schema); + let context = SimplifyContext::default() + .with_schema(schema) + .with_current_time(); let simplifier = ExprSimplifier::new(context); // And then call the simplify_expr function: @@ -191,7 +192,9 @@ fn simplify_demo() -> Result<()> { // here are some other examples of what DataFusion is capable of let schema = Schema::new(vec![make_field("i", DataType::Int64)]).to_dfschema_ref()?; - let context = SimplifyContext::new(&props).with_schema(schema.clone()); + let context = SimplifyContext::default() + .with_schema(Arc::clone(&schema)) + .with_current_time(); let simplifier = ExprSimplifier::new(context); // basic arithmetic simplification @@ -551,7 +554,9 @@ fn type_coercion_demo() -> Result<()> { assert!(physical_expr.evaluate(&batch).is_ok()); // 2. Type coercion with `ExprSimplifier::coerce`. - let context = SimplifyContext::new(&props).with_schema(Arc::new(df_schema.clone())); + let context = SimplifyContext::default() + .with_schema(Arc::new(df_schema.clone())) + .with_current_time(); let simplifier = ExprSimplifier::new(context); let coerced_expr = simplifier.coerce(expr.clone(), &df_schema)?; let physical_expr = datafusion::physical_expr::create_physical_expr( diff --git a/datafusion-examples/examples/query_planning/main.rs b/datafusion-examples/examples/query_planning/main.rs index ec21c3ea5a76..d3f99aedceb3 100644 --- a/datafusion-examples/examples/query_planning/main.rs +++ b/datafusion-examples/examples/query_planning/main.rs @@ -26,14 +26,30 @@ //! //! Each subcommand runs a corresponding example: //! - `all` — run all examples included in this module -//! - `analyzer_rule` — use a custom AnalyzerRule to change a query's semantics (row level access control) -//! - `expr_api` — create, execute, simplify, analyze and coerce `Expr`s -//! - `optimizer_rule` — use a custom OptimizerRule to replace certain predicates -//! - `parse_sql_expr` — parse SQL text into DataFusion `Expr` -//! - `plan_to_sql` — generate SQL from DataFusion `Expr` and `LogicalPlan` -//! - `planner_api` — APIs to manipulate logical and physical plans -//! - `pruning` — APIs to manipulate logical and physical plans -//! - `thread_pools` — demonstrate TrackConsumersPool for memory tracking and debugging with enhanced error messages and shows how to implement memory-aware ExecutionPlan with memory reservation and spilling +//! +//! - `analyzer_rule` +//! (file: analyzer_rule.rs, desc: Custom AnalyzerRule to change query semantics) +//! +//! - `expr_api` +//! (file: expr_api.rs, desc: Create, execute, analyze, and coerce Exprs) +//! +//! - `optimizer_rule` +//! (file: optimizer_rule.rs, desc: Replace predicates via a custom OptimizerRule) +//! +//! - `parse_sql_expr` +//! (file: parse_sql_expr.rs, desc: Parse SQL into DataFusion Expr) +//! +//! - `plan_to_sql` +//! (file: plan_to_sql.rs, desc: Generate SQL from expressions or plans) +//! +//! - `planner_api` +//! (file: planner_api.rs, desc: APIs for logical and physical plan manipulation) +//! +//! - `pruning` +//! (file: pruning.rs, desc: Use pruning to skip irrelevant files) +//! +//! - `thread_pools` +//! (file: thread_pools.rs, desc: Configure custom thread pools for DataFusion execution) mod analyzer_rule; mod expr_api; diff --git a/datafusion-examples/examples/query_planning/parse_sql_expr.rs b/datafusion-examples/examples/query_planning/parse_sql_expr.rs index 376120de9d49..74072b8480f9 100644 --- a/datafusion-examples/examples/query_planning/parse_sql_expr.rs +++ b/datafusion-examples/examples/query_planning/parse_sql_expr.rs @@ -19,6 +19,7 @@ use arrow::datatypes::{DataType, Field, Schema}; use datafusion::common::DFSchema; +use datafusion::common::ScalarValue; use datafusion::logical_expr::{col, lit}; use datafusion::sql::unparser::Unparser; use datafusion::{ @@ -26,6 +27,7 @@ use datafusion::{ error::Result, prelude::{ParquetReadOptions, SessionContext}, }; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; /// This example demonstrates the programmatic parsing of SQL expressions using /// the DataFusion [`SessionContext::parse_sql_expr`] API or the [`DataFrame::parse_sql_expr`] API. @@ -70,18 +72,19 @@ fn simple_session_context_parse_sql_expr_demo() -> Result<()> { /// DataFusion can parse a SQL text to an logical expression using schema at [`DataFrame`]. async fn simple_dataframe_parse_sql_expr_demo() -> Result<()> { - let sql = "int_col < 5 OR double_col = 8.0"; - let expr = col("int_col") - .lt(lit(5_i64)) - .or(col("double_col").eq(lit(8.0_f64))); + let sql = "car = 'red' OR speed > 1.0"; + let expr = col("car") + .eq(lit(ScalarValue::Utf8(Some("red".to_string())))) + .or(col("speed").gt(lit(1.0_f64))); let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + let df = ctx - .read_parquet( - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) .await?; let parsed_expr = df.parse_sql_expr(sql)?; @@ -93,39 +96,37 @@ async fn simple_dataframe_parse_sql_expr_demo() -> Result<()> { async fn query_parquet_demo() -> Result<()> { let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + let df = ctx - .read_parquet( - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) .await?; let df = df .clone() - .select(vec![ - df.parse_sql_expr("int_col")?, - df.parse_sql_expr("double_col")?, - ])? - .filter(df.parse_sql_expr("int_col < 5 OR double_col = 8.0")?)? + .select(vec![df.parse_sql_expr("car")?, df.parse_sql_expr("speed")?])? + .filter(df.parse_sql_expr("car = 'red' OR speed > 1.0")?)? .aggregate( - vec![df.parse_sql_expr("double_col")?], - vec![df.parse_sql_expr("SUM(int_col) as sum_int_col")?], + vec![df.parse_sql_expr("car")?], + vec![df.parse_sql_expr("SUM(speed) as sum_speed")?], )? // Directly parsing the SQL text into a sort expression is not supported yet, so // construct it programmatically - .sort(vec![col("double_col").sort(false, false)])? + .sort(vec![col("car").sort(false, false)])? .limit(0, Some(1))?; let result = df.collect().await?; assert_batches_eq!( &[ - "+------------+-------------+", - "| double_col | sum_int_col |", - "+------------+-------------+", - "| 10.1 | 4 |", - "+------------+-------------+", + "+-----+--------------------+", + "| car | sum_speed |", + "+-----+--------------------+", + "| red | 162.49999999999997 |", + "+-----+--------------------+" ], &result ); @@ -135,15 +136,16 @@ async fn query_parquet_demo() -> Result<()> { /// DataFusion can parse a SQL text and convert it back to SQL using [`Unparser`]. async fn round_trip_parse_sql_expr_demo() -> Result<()> { - let sql = "((int_col < 5) OR (double_col = 8))"; + let sql = "((car = 'red') OR (speed > 1.0))"; let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + let df = ctx - .read_parquet( - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) .await?; let parsed_expr = df.parse_sql_expr(sql)?; @@ -158,7 +160,7 @@ async fn round_trip_parse_sql_expr_demo() -> Result<()> { // difference in precedence rules between DataFusion and target engines. let unparser = Unparser::default().with_pretty(true); - let pretty = "int_col < 5 OR double_col = 8"; + let pretty = "car = 'red' OR speed > 1.0"; let pretty_round_trip_sql = unparser.expr_to_sql(&parsed_expr)?.to_string(); assert_eq!(pretty, pretty_round_trip_sql); diff --git a/datafusion-examples/examples/query_planning/plan_to_sql.rs b/datafusion-examples/examples/query_planning/plan_to_sql.rs index 756cc80b8f3c..86aebbc0b2c3 100644 --- a/datafusion-examples/examples/query_planning/plan_to_sql.rs +++ b/datafusion-examples/examples/query_planning/plan_to_sql.rs @@ -17,7 +17,11 @@ //! See `main.rs` for how to run it. +use std::fmt; +use std::sync::Arc; + use datafusion::common::DFSchemaRef; +use datafusion::common::ScalarValue; use datafusion::error::Result; use datafusion::logical_expr::sqlparser::ast::Statement; use datafusion::logical_expr::{ @@ -35,8 +39,7 @@ use datafusion::sql::unparser::extension_unparser::{ UnparseToStatementResult, UnparseWithinStatementResult, }; use datafusion::sql::unparser::{Unparser, plan_to_sql}; -use std::fmt; -use std::sync::Arc; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; /// This example demonstrates the programmatic construction of SQL strings using /// the DataFusion Expr [`Expr`] and LogicalPlan [`LogicalPlan`] API. @@ -114,21 +117,21 @@ fn simple_expr_to_sql_demo_escape_mysql_style() -> Result<()> { async fn simple_plan_to_sql_demo() -> Result<()> { let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + let df = ctx - .read_parquet( - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) .await? - .select_columns(&["id", "int_col", "double_col", "date_string_col"])?; + .select_columns(&["car", "speed", "time"])?; // Convert the data frame to a SQL string let sql = plan_to_sql(df.logical_plan())?.to_string(); assert_eq!( sql, - r#"SELECT "?table?".id, "?table?".int_col, "?table?".double_col, "?table?".date_string_col FROM "?table?""# + r#"SELECT "?table?".car, "?table?".speed, "?table?"."time" FROM "?table?""# ); Ok(()) @@ -139,35 +142,35 @@ async fn simple_plan_to_sql_demo() -> Result<()> { async fn round_trip_plan_to_sql_demo() -> Result<()> { let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; // register parquet file with the execution context ctx.register_parquet( - "alltypes_plain", - &format!("{testdata}/alltypes_plain.parquet"), + "cars", + parquet_temp.path_str()?, ParquetReadOptions::default(), ) .await?; // create a logical plan from a SQL string and then programmatically add new filters + // select car, speed, time from cars where speed > 1 and car = 'red' let df = ctx // Use SQL to read some data from the parquet file - .sql( - "SELECT int_col, double_col, CAST(date_string_col as VARCHAR) \ - FROM alltypes_plain", - ) + .sql("SELECT car, speed, time FROM cars") .await? - // Add id > 1 and tinyint_col < double_col filter + // Add speed > 1 and car = 'red' filter .filter( - col("id") + col("speed") .gt(lit(1)) - .and(col("tinyint_col").lt(col("double_col"))), + .and(col("car").eq(lit(ScalarValue::Utf8(Some("red".to_string()))))), )?; let sql = plan_to_sql(df.logical_plan())?.to_string(); assert_eq!( sql, - r#"SELECT alltypes_plain.int_col, alltypes_plain.double_col, CAST(alltypes_plain.date_string_col AS VARCHAR) FROM alltypes_plain WHERE ((alltypes_plain.id > 1) AND (alltypes_plain.tinyint_col < alltypes_plain.double_col))"# + r#"SELECT cars.car, cars.speed, cars."time" FROM cars WHERE ((cars.speed > 1) AND (cars.car = 'red'))"# ); Ok(()) @@ -211,6 +214,7 @@ impl UserDefinedLogicalNodeCore for MyLogicalPlan { } struct PlanToStatement {} + impl UserDefinedLogicalNodeUnparser for PlanToStatement { fn unparse_to_statement( &self, @@ -231,14 +235,15 @@ impl UserDefinedLogicalNodeUnparser for PlanToStatement { /// It can be unparse as a statement that reads from the same parquet file. async fn unparse_my_logical_plan_as_statement() -> Result<()> { let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + let inner_plan = ctx - .read_parquet( - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) .await? - .select_columns(&["id", "int_col", "double_col", "date_string_col"])? + .select_columns(&["car", "speed", "time"])? .into_unoptimized_plan(); let node = Arc::new(MyLogicalPlan { input: inner_plan }); @@ -249,7 +254,7 @@ async fn unparse_my_logical_plan_as_statement() -> Result<()> { let sql = unparser.plan_to_sql(&my_plan)?.to_string(); assert_eq!( sql, - r#"SELECT "?table?".id, "?table?".int_col, "?table?".double_col, "?table?".date_string_col FROM "?table?""# + r#"SELECT "?table?".car, "?table?".speed, "?table?"."time" FROM "?table?""# ); Ok(()) } @@ -284,14 +289,15 @@ impl UserDefinedLogicalNodeUnparser for PlanToSubquery { /// It can be unparse as a subquery that reads from the same parquet file, with some columns projected. async fn unparse_my_logical_plan_as_subquery() -> Result<()> { let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + let inner_plan = ctx - .read_parquet( - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) .await? - .select_columns(&["id", "int_col", "double_col", "date_string_col"])? + .select_columns(&["car", "speed", "time"])? .into_unoptimized_plan(); let node = Arc::new(MyLogicalPlan { input: inner_plan }); @@ -299,8 +305,8 @@ async fn unparse_my_logical_plan_as_subquery() -> Result<()> { let my_plan = LogicalPlan::Extension(Extension { node }); let plan = LogicalPlanBuilder::from(my_plan) .project(vec![ - col("id").alias("my_id"), - col("int_col").alias("my_int"), + col("car").alias("my_car"), + col("speed").alias("my_speed"), ])? .build()?; let unparser = @@ -308,8 +314,8 @@ async fn unparse_my_logical_plan_as_subquery() -> Result<()> { let sql = unparser.plan_to_sql(&plan)?.to_string(); assert_eq!( sql, - "SELECT \"?table?\".id AS my_id, \"?table?\".int_col AS my_int FROM \ - (SELECT \"?table?\".id, \"?table?\".int_col, \"?table?\".double_col, \"?table?\".date_string_col FROM \"?table?\")", + "SELECT \"?table?\".car AS my_car, \"?table?\".speed AS my_speed FROM \ + (SELECT \"?table?\".car, \"?table?\".speed, \"?table?\".\"time\" FROM \"?table?\")", ); Ok(()) } diff --git a/datafusion-examples/examples/query_planning/planner_api.rs b/datafusion-examples/examples/query_planning/planner_api.rs index 9b8aa1c2fe64..8b2c09f4aecb 100644 --- a/datafusion-examples/examples/query_planning/planner_api.rs +++ b/datafusion-examples/examples/query_planning/planner_api.rs @@ -22,6 +22,7 @@ use datafusion::logical_expr::LogicalPlan; use datafusion::physical_plan::displayable; use datafusion::physical_planner::DefaultPhysicalPlanner; use datafusion::prelude::*; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; /// This example demonstrates the process of converting logical plan /// into physical execution plans using DataFusion. @@ -37,25 +38,23 @@ use datafusion::prelude::*; pub async fn planner_api() -> Result<()> { // Set up a DataFusion context and load a Parquet file let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + let df = ctx - .read_parquet( - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) .await?; // Construct the input logical plan using DataFrame API let df = df .clone() - .select(vec![ - df.parse_sql_expr("int_col")?, - df.parse_sql_expr("double_col")?, - ])? - .filter(df.parse_sql_expr("int_col < 5 OR double_col = 8.0")?)? + .select(vec![df.parse_sql_expr("car")?, df.parse_sql_expr("speed")?])? + .filter(df.parse_sql_expr("car = 'red' OR speed > 1.0")?)? .aggregate( - vec![df.parse_sql_expr("double_col")?], - vec![df.parse_sql_expr("SUM(int_col) as sum_int_col")?], + vec![df.parse_sql_expr("car")?], + vec![df.parse_sql_expr("SUM(speed) as sum_speed")?], )? .limit(0, Some(1))?; let logical_plan = df.logical_plan().clone(); diff --git a/datafusion-examples/examples/query_planning/thread_pools.rs b/datafusion-examples/examples/query_planning/thread_pools.rs index 6fc7d51e91c1..2ff73a77c402 100644 --- a/datafusion-examples/examples/query_planning/thread_pools.rs +++ b/datafusion-examples/examples/query_planning/thread_pools.rs @@ -37,15 +37,17 @@ //! //! [Architecture section]: https://docs.rs/datafusion/latest/datafusion/index.html#thread-scheduling-cpu--io-thread-pools-and-tokio-runtimes +use std::sync::Arc; + use arrow::util::pretty::pretty_format_batches; use datafusion::common::runtime::JoinSet; use datafusion::error::Result; use datafusion::execution::SendableRecordBatchStream; use datafusion::prelude::*; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; use futures::stream::StreamExt; use object_store::client::SpawnedReqwestConnector; use object_store::http::HttpBuilder; -use std::sync::Arc; use tokio::runtime::Handle; use tokio::sync::Notify; use url::Url; @@ -70,10 +72,12 @@ pub async fn thread_pools() -> Result<()> { // The first two examples read local files. Enabling the URL table feature // lets us treat filenames as tables in SQL. let ctx = SessionContext::new().enable_url_table(); - let sql = format!( - "SELECT * FROM '{}/alltypes_plain.parquet'", - datafusion::test_util::parquet_test_data() - ); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + + let sql = format!("SELECT * FROM '{}'", parquet_temp.path_str()?); // Run a query on the current runtime. Calling `await` means the future // (in this case the `async` function and all spawned work in DataFusion diff --git a/datafusion-examples/examples/relation_planner/main.rs b/datafusion-examples/examples/relation_planner/main.rs index 15079f644612..babc0d3714f7 100644 --- a/datafusion-examples/examples/relation_planner/main.rs +++ b/datafusion-examples/examples/relation_planner/main.rs @@ -27,9 +27,15 @@ //! //! Each subcommand runs a corresponding example: //! - `all` — run all examples included in this module -//! - `match_recognize` — MATCH_RECOGNIZE pattern matching on event streams -//! - `pivot_unpivot` — PIVOT and UNPIVOT operations for reshaping data -//! - `table_sample` — TABLESAMPLE clause for sampling rows from tables +//! +//! - `match_recognize` +//! (file: match_recognize.rs, desc: Implement MATCH_RECOGNIZE pattern matching) +//! +//! - `pivot_unpivot` +//! (file: pivot_unpivot.rs, desc: Implement PIVOT / UNPIVOT) +//! +//! - `table_sample` +//! (file: table_sample.rs, desc: Implement TABLESAMPLE) //! //! ## Snapshot Testing //! diff --git a/datafusion-examples/examples/relation_planner/match_recognize.rs b/datafusion-examples/examples/relation_planner/match_recognize.rs index 60baf9bd61a6..c4b3d522efc1 100644 --- a/datafusion-examples/examples/relation_planner/match_recognize.rs +++ b/datafusion-examples/examples/relation_planner/match_recognize.rs @@ -362,7 +362,7 @@ impl RelationPlanner for MatchRecognizePlanner { .. } = relation else { - return Ok(RelationPlanning::Original(relation)); + return Ok(RelationPlanning::Original(Box::new(relation))); }; // Plan the input table @@ -401,6 +401,8 @@ impl RelationPlanner for MatchRecognizePlanner { node: Arc::new(node), }); - Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) } } diff --git a/datafusion-examples/examples/relation_planner/pivot_unpivot.rs b/datafusion-examples/examples/relation_planner/pivot_unpivot.rs index 86a6cb955500..2e1696956bf6 100644 --- a/datafusion-examples/examples/relation_planner/pivot_unpivot.rs +++ b/datafusion-examples/examples/relation_planner/pivot_unpivot.rs @@ -339,7 +339,7 @@ impl RelationPlanner for PivotUnpivotPlanner { alias, ), - other => Ok(RelationPlanning::Original(other)), + other => Ok(RelationPlanning::Original(Box::new(other))), } } } @@ -459,7 +459,9 @@ fn plan_pivot( .aggregate(group_by_cols, pivot_exprs)? .build()?; - Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) } // ============================================================================ @@ -540,7 +542,9 @@ fn plan_unpivot( .build()?; } - Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) } // ============================================================================ diff --git a/datafusion-examples/examples/relation_planner/table_sample.rs b/datafusion-examples/examples/relation_planner/table_sample.rs index 362d35dcf4ca..895f2fdd4ff3 100644 --- a/datafusion-examples/examples/relation_planner/table_sample.rs +++ b/datafusion-examples/examples/relation_planner/table_sample.rs @@ -331,7 +331,7 @@ impl RelationPlanner for TableSamplePlanner { index_hints, } = relation else { - return Ok(RelationPlanning::Original(relation)); + return Ok(RelationPlanning::Original(Box::new(relation))); }; // Extract sample spec (handles both before/after alias positions) @@ -401,7 +401,9 @@ impl RelationPlanner for TableSamplePlanner { let fraction = bucket_num as f64 / total as f64; let plan = TableSamplePlanNode::new(input, fraction, seed).into_plan(); - return Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))); + return Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))); } // Handle quantity-based sampling @@ -422,7 +424,9 @@ impl RelationPlanner for TableSamplePlanner { let plan = LogicalPlanBuilder::from(input) .limit(0, Some(rows as usize))? .build()?; - Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) } // TABLESAMPLE (N PERCENT) - percentage sampling @@ -430,7 +434,9 @@ impl RelationPlanner for TableSamplePlanner { let percent: f64 = parse_literal::(&quantity_value_expr)?; let fraction = percent / 100.0; let plan = TableSamplePlanNode::new(input, fraction, seed).into_plan(); - Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) } // TABLESAMPLE (N) - fraction if <1.0, row limit if >=1.0 @@ -448,7 +454,9 @@ impl RelationPlanner for TableSamplePlanner { // Interpret as fraction TableSamplePlanNode::new(input, value, seed).into_plan() }; - Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) } } } @@ -610,7 +618,7 @@ pub struct SampleExec { upper_bound: f64, seed: u64, metrics: ExecutionPlanMetricsSet, - cache: PlanProperties, + cache: Arc, } impl SampleExec { @@ -648,7 +656,7 @@ impl SampleExec { upper_bound, seed, metrics: ExecutionPlanMetricsSet::new(), - cache, + cache: Arc::new(cache), }) } @@ -678,7 +686,7 @@ impl ExecutionPlan for SampleExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } diff --git a/datafusion-examples/examples/sql_ops/main.rs b/datafusion-examples/examples/sql_ops/main.rs index aaab7778be0e..ce7be8fa2bad 100644 --- a/datafusion-examples/examples/sql_ops/main.rs +++ b/datafusion-examples/examples/sql_ops/main.rs @@ -26,10 +26,18 @@ //! //! Each subcommand runs a corresponding example: //! - `all` — run all examples included in this module -//! - `analysis` — analyse SQL queries with DataFusion structures -//! - `custom_sql_parser` — implementing a custom SQL parser to extend DataFusion -//! - `frontend` — create LogicalPlans (only) from sql strings -//! - `query` — query data using SQL (in memory RecordBatches, local Parquet files) +//! +//! - `analysis` +//! (file: analysis.rs, desc: Analyze SQL queries) +//! +//! - `custom_sql_parser` +//! (file: custom_sql_parser.rs, desc: Implement a custom SQL parser to extend DataFusion) +//! +//! - `frontend` +//! (file: frontend.rs, desc: Build LogicalPlans from SQL) +//! +//! - `query` +//! (file: query.rs, desc: Query data using SQL) mod analysis; mod custom_sql_parser; diff --git a/datafusion-examples/examples/sql_ops/query.rs b/datafusion-examples/examples/sql_ops/query.rs index 90d0c3ca34a0..60b47c36b9ae 100644 --- a/datafusion-examples/examples/sql_ops/query.rs +++ b/datafusion-examples/examples/sql_ops/query.rs @@ -17,18 +17,19 @@ //! See `main.rs` for how to run it. +use std::sync::Arc; + use datafusion::arrow::array::{UInt8Array, UInt64Array}; use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::arrow::record_batch::RecordBatch; +use datafusion::catalog::MemTable; use datafusion::common::{assert_batches_eq, exec_datafusion_err}; -use datafusion::datasource::MemTable; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::ListingOptions; use datafusion::error::{DataFusionError, Result}; use datafusion::prelude::*; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; use object_store::local::LocalFileSystem; -use std::path::Path; -use std::sync::Arc; /// Examples of various ways to execute queries using SQL /// @@ -113,32 +114,33 @@ async fn query_parquet() -> Result<()> { // create local execution context let ctx = SessionContext::new(); - let test_data = datafusion::test_util::parquet_test_data(); + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; // Configure listing options let file_format = ParquetFormat::default().with_enable_pruning(true); - let listing_options = ListingOptions::new(Arc::new(file_format)) - // This is a workaround for this example since `test_data` contains - // many different parquet different files, - // in practice use FileType::PARQUET.get_ext(). - .with_file_extension("alltypes_plain.parquet"); + let listing_options = + ListingOptions::new(Arc::new(file_format)).with_file_extension(".parquet"); + + let table_path = parquet_temp.file_uri()?; // First example were we use an absolute path, which requires no additional setup. ctx.register_listing_table( "my_table", - &format!("file://{test_data}/"), + &table_path, listing_options.clone(), None, None, ) - .await - .unwrap(); + .await?; // execute the query let df = ctx .sql( "SELECT * \ FROM my_table \ + ORDER BY speed \ LIMIT 1", ) .await?; @@ -147,21 +149,22 @@ async fn query_parquet() -> Result<()> { let results = df.collect().await?; assert_batches_eq!( [ - "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", - "| id | bool_col | tinyint_col | smallint_col | int_col | bigint_col | float_col | double_col | date_string_col | string_col | timestamp_col |", - "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", - "| 4 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30332f30312f3039 | 30 | 2009-03-01T00:00:00 |", - "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", + "+-----+-------+---------------------+", + "| car | speed | time |", + "+-----+-------+---------------------+", + "| red | 0.0 | 1996-04-12T12:05:15 |", + "+-----+-------+---------------------+", ], &results ); - // Second example were we temporarily move into the test data's parent directory and - // simulate a relative path, this requires registering an ObjectStore. + // Second example where we change the current working directory and explicitly + // register a local filesystem object store. This demonstrates how listing tables + // resolve paths via an ObjectStore, even when using filesystem-backed data. let cur_dir = std::env::current_dir()?; - - let test_data_path = Path::new(&test_data); - let test_data_path_parent = test_data_path + let test_data_path_parent = parquet_temp + .tmp_dir + .path() .parent() .ok_or(exec_datafusion_err!("test_data path needs a parent"))?; @@ -169,15 +172,15 @@ async fn query_parquet() -> Result<()> { let local_fs = Arc::new(LocalFileSystem::default()); - let u = url::Url::parse("file://./") + let url = url::Url::parse("file://./") .map_err(|e| DataFusionError::External(Box::new(e)))?; - ctx.register_object_store(&u, local_fs); + ctx.register_object_store(&url, local_fs); // Register a listing table - this will use all files in the directory as data sources // for the query ctx.register_listing_table( "relative_table", - "./data", + parquet_temp.path_str()?, listing_options.clone(), None, None, @@ -189,6 +192,7 @@ async fn query_parquet() -> Result<()> { .sql( "SELECT * \ FROM relative_table \ + ORDER BY speed \ LIMIT 1", ) .await?; @@ -197,11 +201,11 @@ async fn query_parquet() -> Result<()> { let results = df.collect().await?; assert_batches_eq!( [ - "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", - "| id | bool_col | tinyint_col | smallint_col | int_col | bigint_col | float_col | double_col | date_string_col | string_col | timestamp_col |", - "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", - "| 4 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30332f30312f3039 | 30 | 2009-03-01T00:00:00 |", - "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", + "+-----+-------+---------------------+", + "| car | speed | time |", + "+-----+-------+---------------------+", + "| red | 0.0 | 1996-04-12T12:05:15 |", + "+-----+-------+---------------------+", ], &results ); diff --git a/datafusion-examples/examples/udf/advanced_udaf.rs b/datafusion-examples/examples/udf/advanced_udaf.rs index fbb9e652486c..89f621d30e18 100644 --- a/datafusion-examples/examples/udf/advanced_udaf.rs +++ b/datafusion-examples/examples/udf/advanced_udaf.rs @@ -34,7 +34,7 @@ use datafusion::logical_expr::{ Accumulator, AggregateUDF, AggregateUDFImpl, EmitTo, GroupsAccumulator, Signature, expr::AggregateFunction, function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs}, - simplify::SimplifyInfo, + simplify::SimplifyContext, }; use datafusion::prelude::*; @@ -314,12 +314,16 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator { let prods = emit_to.take_needed(&mut self.prods); let nulls = self.null_state.build(emit_to); - assert_eq!(nulls.len(), prods.len()); + if let Some(nulls) = &nulls { + assert_eq!(nulls.len(), counts.len()); + } assert_eq!(counts.len(), prods.len()); // don't evaluate geometric mean with null inputs to avoid errors on null values - let array: PrimitiveArray = if nulls.null_count() > 0 { + let array: PrimitiveArray = if let Some(nulls) = &nulls + && nulls.null_count() > 0 + { let mut builder = PrimitiveBuilder::::with_capacity(nulls.len()); let iter = prods.into_iter().zip(counts).zip(nulls.iter()); @@ -337,7 +341,7 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator { .zip(counts) .map(|(prod, count)| prod.powf(1.0 / count as f64)) .collect::>(); - PrimitiveArray::new(geo_mean.into(), Some(nulls)) // no copy + PrimitiveArray::new(geo_mean.into(), nulls) // no copy .with_data_type(self.return_data_type.clone()) }; @@ -347,7 +351,6 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator { // return arrays for counts and prods fn state(&mut self, emit_to: EmitTo) -> Result> { let nulls = self.null_state.build(emit_to); - let nulls = Some(nulls); let counts = emit_to.take_needed(&mut self.counts); let counts = UInt32Array::new(counts.into(), nulls.clone()); // zero copy @@ -421,7 +424,7 @@ impl AggregateUDFImpl for SimplifiedGeoMeanUdaf { /// Optionally replaces a UDAF with another expression during query optimization. fn simplify(&self) -> Option { - let simplify = |aggregate_function: AggregateFunction, _: &dyn SimplifyInfo| { + let simplify = |aggregate_function: AggregateFunction, _: &SimplifyContext| { // Replaces the UDAF with `GeoMeanUdaf` as a placeholder example to demonstrate the `simplify` method. // In real-world scenarios, you might create UDFs from built-in expressions. Ok(Expr::AggregateFunction(AggregateFunction::new_udf( diff --git a/datafusion-examples/examples/udf/advanced_udwf.rs b/datafusion-examples/examples/udf/advanced_udwf.rs index e8d3a75b29de..615d099c2854 100644 --- a/datafusion-examples/examples/udf/advanced_udwf.rs +++ b/datafusion-examples/examples/udf/advanced_udwf.rs @@ -17,7 +17,7 @@ //! See `main.rs` for how to run it. -use std::{any::Any, fs::File, io::Write, sync::Arc}; +use std::{any::Any, sync::Arc}; use arrow::datatypes::Field; use arrow::{ @@ -32,7 +32,7 @@ use datafusion::logical_expr::expr::{WindowFunction, WindowFunctionParams}; use datafusion::logical_expr::function::{ PartitionEvaluatorArgs, WindowFunctionSimplification, WindowUDFFieldArgs, }; -use datafusion::logical_expr::simplify::SimplifyInfo; +use datafusion::logical_expr::simplify::SimplifyContext; use datafusion::logical_expr::{ Expr, LimitEffect, PartitionEvaluator, Signature, WindowFrame, WindowFunctionDefinition, WindowUDF, WindowUDFImpl, @@ -40,7 +40,7 @@ use datafusion::logical_expr::{ use datafusion::physical_expr::PhysicalExpr; use datafusion::prelude::*; use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; -use tempfile::tempdir; +use datafusion_examples::utils::datasets::ExampleDataset; /// This example shows how to use the full WindowUDFImpl API to implement a user /// defined window function. As in the `simple_udwf.rs` example, this struct implements @@ -198,7 +198,7 @@ impl WindowUDFImpl for SimplifySmoothItUdf { /// this function will simplify `SimplifySmoothItUdf` to `AggregateUDF` for `Avg` /// default implementation will not be called (left as `todo!()`) fn simplify(&self) -> Option { - let simplify = |window_function: WindowFunction, _: &dyn SimplifyInfo| { + let simplify = |window_function: WindowFunction, _: &SimplifyContext| { Ok(Expr::from(WindowFunction { fun: WindowFunctionDefinition::AggregateUDF(avg_udaf()), params: WindowFunctionParams { @@ -230,44 +230,9 @@ async fn create_context() -> Result { // declare a new context. In spark API, this corresponds to a new spark SQL session let ctx = SessionContext::new(); - // content from file 'datafusion/core/tests/data/cars.csv' - let csv_data = r#"car,speed,time -red,20.0,1996-04-12T12:05:03.000000000 -red,20.3,1996-04-12T12:05:04.000000000 -red,21.4,1996-04-12T12:05:05.000000000 -red,21.5,1996-04-12T12:05:06.000000000 -red,19.0,1996-04-12T12:05:07.000000000 -red,18.0,1996-04-12T12:05:08.000000000 -red,17.0,1996-04-12T12:05:09.000000000 -red,7.0,1996-04-12T12:05:10.000000000 -red,7.1,1996-04-12T12:05:11.000000000 -red,7.2,1996-04-12T12:05:12.000000000 -red,3.0,1996-04-12T12:05:13.000000000 -red,1.0,1996-04-12T12:05:14.000000000 -red,0.0,1996-04-12T12:05:15.000000000 -green,10.0,1996-04-12T12:05:03.000000000 -green,10.3,1996-04-12T12:05:04.000000000 -green,10.4,1996-04-12T12:05:05.000000000 -green,10.5,1996-04-12T12:05:06.000000000 -green,11.0,1996-04-12T12:05:07.000000000 -green,12.0,1996-04-12T12:05:08.000000000 -green,14.0,1996-04-12T12:05:09.000000000 -green,15.0,1996-04-12T12:05:10.000000000 -green,15.1,1996-04-12T12:05:11.000000000 -green,15.2,1996-04-12T12:05:12.000000000 -green,8.0,1996-04-12T12:05:13.000000000 -green,2.0,1996-04-12T12:05:14.000000000 -"#; - let dir = tempdir()?; - let file_path = dir.path().join("cars.csv"); - { - let mut file = File::create(&file_path)?; - // write CSV data - file.write_all(csv_data.as_bytes())?; - } // scope closes the file - let file_path = file_path.to_str().unwrap(); - - ctx.register_csv("cars", file_path, CsvReadOptions::new()) + let dataset = ExampleDataset::Cars; + + ctx.register_csv("cars", dataset.path_str()?, CsvReadOptions::new()) .await?; Ok(ctx) diff --git a/datafusion-examples/examples/udf/async_udf.rs b/datafusion-examples/examples/udf/async_udf.rs index c31e8290ccce..3d8faf623d43 100644 --- a/datafusion-examples/examples/udf/async_udf.rs +++ b/datafusion-examples/examples/udf/async_udf.rs @@ -102,8 +102,7 @@ pub async fn async_udf() -> Result<()> { "| physical_plan | FilterExec: __async_fn_0@2, projection=[id@0, name@1] |", "| | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |", "| | AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=ask_llm(CAST(name@1 AS Utf8View), Is this animal furry?))] |", - "| | CoalesceBatchesExec: target_batch_size=8192 |", - "| | DataSourceExec: partitions=1, partition_sizes=[1] |", + "| | DataSourceExec: partitions=1, partition_sizes=[1] |", "| | |", "+---------------+------------------------------------------------------------------------------------------------------------------------------+", ], diff --git a/datafusion-examples/examples/udf/main.rs b/datafusion-examples/examples/udf/main.rs index aff20e775429..e024e466ab07 100644 --- a/datafusion-examples/examples/udf/main.rs +++ b/datafusion-examples/examples/udf/main.rs @@ -26,14 +26,30 @@ //! //! Each subcommand runs a corresponding example: //! - `all` — run all examples included in this module -//! - `adv_udaf` — user defined aggregate function example -//! - `adv_udf` — user defined scalar function example -//! - `adv_udwf` — user defined window function example -//! - `async_udf` — asynchronous user defined function example -//! - `udaf` — simple user defined aggregate function example -//! - `udf` — simple user defined scalar function example -//! - `udtf` — simple user defined table function example -//! - `udwf` — simple user defined window function example +//! +//! - `adv_udaf` +//! (file: advanced_udaf.rs, desc: Advanced User Defined Aggregate Function (UDAF)) +//! +//! - `adv_udf` +//! (file: advanced_udf.rs, desc: Advanced User Defined Scalar Function (UDF)) +//! +//! - `adv_udwf` +//! (file: advanced_udwf.rs, desc: Advanced User Defined Window Function (UDWF)) +//! +//! - `async_udf` +//! (file: async_udf.rs, desc: Asynchronous User Defined Scalar Function) +//! +//! - `udaf` +//! (file: simple_udaf.rs, desc: Simple UDAF example) +//! +//! - `udf` +//! (file: simple_udf.rs, desc: Simple UDF example) +//! +//! - `udtf` +//! (file: simple_udtf.rs, desc: Simple UDTF example) +//! +//! - `udwf` +//! (file: simple_udwf.rs, desc: Simple UDWF example) mod advanced_udaf; mod advanced_udf; diff --git a/datafusion-examples/examples/udf/simple_udtf.rs b/datafusion-examples/examples/udf/simple_udtf.rs index 087b8ba73af5..ee2615c4a5ac 100644 --- a/datafusion-examples/examples/udf/simple_udtf.rs +++ b/datafusion-examples/examples/udf/simple_udtf.rs @@ -17,27 +17,28 @@ //! See `main.rs` for how to run it. +use std::fs::File; +use std::io::Seek; +use std::path::Path; +use std::sync::Arc; + use arrow::csv::ReaderBuilder; use arrow::csv::reader::Format; use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef; use datafusion::arrow::record_batch::RecordBatch; -use datafusion::catalog::Session; -use datafusion::catalog::TableFunctionImpl; +use datafusion::catalog::{Session, TableFunctionImpl}; use datafusion::common::{ScalarValue, plan_err}; use datafusion::datasource::TableProvider; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::error::Result; -use datafusion::execution::context::ExecutionProps; use datafusion::logical_expr::simplify::SimplifyContext; use datafusion::logical_expr::{Expr, TableType}; use datafusion::optimizer::simplify_expressions::ExprSimplifier; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::*; -use std::fs::File; -use std::io::Seek; -use std::path::Path; -use std::sync::Arc; +use datafusion_examples::utils::datasets::ExampleDataset; + // To define your own table function, you only need to do the following 3 things: // 1. Implement your own [`TableProvider`] // 2. Implement your own [`TableFunctionImpl`] and return your [`TableProvider`] @@ -51,18 +52,19 @@ pub async fn simple_udtf() -> Result<()> { // register the table function that will be called in SQL statements by `read_csv` ctx.register_udtf("read_csv", Arc::new(LocalCsvTableFunc {})); - let testdata = datafusion::test_util::arrow_test_data(); - let csv_file = format!("{testdata}/csv/aggregate_test_100.csv"); + let dataset = ExampleDataset::Cars; // Pass 2 arguments, read csv with at most 2 rows (simplify logic makes 1+1 --> 2) let df = ctx - .sql(format!("SELECT * FROM read_csv('{csv_file}', 1 + 1);").as_str()) + .sql( + format!("SELECT * FROM read_csv('{}', 1 + 1);", dataset.path_str()?).as_str(), + ) .await?; df.show().await?; // just run, return all rows let df = ctx - .sql(format!("SELECT * FROM read_csv('{csv_file}');").as_str()) + .sql(format!("SELECT * FROM read_csv('{}');", dataset.path_str()?).as_str()) .await?; df.show().await?; @@ -142,8 +144,7 @@ impl TableFunctionImpl for LocalCsvTableFunc { .get(1) .map(|expr| { // try to simplify the expression, so 1+2 becomes 3, for example - let execution_props = ExecutionProps::new(); - let info = SimplifyContext::new(&execution_props); + let info = SimplifyContext::default(); let expr = ExprSimplifier::new(info).simplify(expr.clone())?; if let Expr::Literal(ScalarValue::Int64(Some(limit)), _) = expr { diff --git a/datafusion-examples/src/bin/examples-docs.rs b/datafusion-examples/src/bin/examples-docs.rs new file mode 100644 index 000000000000..7efcf4da15d2 --- /dev/null +++ b/datafusion-examples/src/bin/examples-docs.rs @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Generates Markdown documentation for DataFusion example groups. +//! +//! This binary scans `datafusion-examples/examples`, extracts structured +//! documentation from each group's `main.rs` file, and renders a README-style +//! Markdown document. +//! +//! By default, documentation is generated for all example groups. If a group +//! name is provided as the first CLI argument, only that group is rendered. +//! +//! ## Usage +//! +//! ```bash +//! # Generate docs for all example groups +//! cargo run --bin examples-docs +//! +//! # Generate docs for a single group +//! cargo run --bin examples-docs -- dataframe +//! ``` + +use datafusion_examples::utils::example_metadata::{ + RepoLayout, generate_examples_readme, +}; + +fn main() -> Result<(), Box> { + let layout = RepoLayout::detect()?; + let group = std::env::args().nth(1); + let markdown = generate_examples_readme(&layout, group.as_deref())?; + print!("{markdown}"); + Ok(()) +} diff --git a/datafusion-examples/src/lib.rs b/datafusion-examples/src/lib.rs new file mode 100644 index 000000000000..7f334aedaafe --- /dev/null +++ b/datafusion-examples/src/lib.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Internal utilities shared by the DataFusion examples. + +pub mod utils; diff --git a/datafusion-examples/src/utils/csv_to_parquet.rs b/datafusion-examples/src/utils/csv_to_parquet.rs new file mode 100644 index 000000000000..1fbf2930e904 --- /dev/null +++ b/datafusion-examples/src/utils/csv_to_parquet.rs @@ -0,0 +1,244 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::path::{Path, PathBuf}; + +use datafusion::dataframe::DataFrameWriteOptions; +use datafusion::error::{DataFusionError, Result}; +use datafusion::prelude::{CsvReadOptions, SessionContext}; +use tempfile::TempDir; +use tokio::fs::create_dir_all; + +/// Temporary Parquet directory that is deleted when dropped. +#[derive(Debug)] +pub struct ParquetTemp { + pub tmp_dir: TempDir, + pub parquet_dir: PathBuf, +} + +impl ParquetTemp { + pub fn path(&self) -> &Path { + &self.parquet_dir + } + + pub fn path_str(&self) -> Result<&str> { + self.parquet_dir.to_str().ok_or_else(|| { + DataFusionError::Execution(format!( + "Parquet directory path is not valid UTF-8: {}", + self.parquet_dir.display() + )) + }) + } + + pub fn file_uri(&self) -> Result { + Ok(format!("file://{}", self.path_str()?)) + } +} + +/// Helper for examples: load a CSV file and materialize it as Parquet +/// in a temporary directory. +/// +/// # Example +/// ``` +/// use std::path::PathBuf; +/// use datafusion::prelude::*; +/// use datafusion_examples::utils::write_csv_to_parquet; +/// # use datafusion::assert_batches_eq; +/// # use datafusion::error::Result; +/// # #[tokio::main] +/// # async fn main() -> Result<()> { +/// let ctx = SessionContext::new(); +/// let csv_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) +/// .join("data") +/// .join("csv") +/// .join("cars.csv"); +/// let parquet_dir = write_csv_to_parquet(&ctx, &csv_path).await?; +/// let df = ctx.read_parquet(parquet_dir.path_str()?, ParquetReadOptions::default()).await?; +/// let rows = df +/// .sort(vec![col("speed").sort(true, true)])? +/// .limit(0, Some(5))?; +/// assert_batches_eq!( +/// &[ +/// "+-------+-------+---------------------+", +/// "| car | speed | time |", +/// "+-------+-------+---------------------+", +/// "| red | 0.0 | 1996-04-12T12:05:15 |", +/// "| red | 1.0 | 1996-04-12T12:05:14 |", +/// "| green | 2.0 | 1996-04-12T12:05:14 |", +/// "| red | 3.0 | 1996-04-12T12:05:13 |", +/// "| red | 7.0 | 1996-04-12T12:05:10 |", +/// "+-------+-------+---------------------+", +/// ], +/// &rows.collect().await? +/// ); +/// # Ok(()) +/// # } +/// ``` +pub async fn write_csv_to_parquet( + ctx: &SessionContext, + csv_path: &Path, +) -> Result { + if !csv_path.is_file() { + return Err(DataFusionError::Execution(format!( + "CSV file does not exist: {}", + csv_path.display() + ))); + } + + let csv_path = csv_path.to_str().ok_or_else(|| { + DataFusionError::Execution("CSV path is not valid UTF-8".to_string()) + })?; + + let csv_df = ctx.read_csv(csv_path, CsvReadOptions::default()).await?; + + let tmp_dir = TempDir::new()?; + let parquet_dir = tmp_dir.path().join("parquet_source"); + create_dir_all(&parquet_dir).await?; + + let path = parquet_dir.to_str().ok_or_else(|| { + DataFusionError::Execution("Failed processing tmp directory path".to_string()) + })?; + + csv_df + .write_parquet(path, DataFrameWriteOptions::default(), None) + .await?; + + Ok(ParquetTemp { + tmp_dir, + parquet_dir, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::path::PathBuf; + + use datafusion::assert_batches_eq; + use datafusion::prelude::*; + + #[tokio::test] + async fn test_write_csv_to_parquet_with_cars_data() -> Result<()> { + let ctx = SessionContext::new(); + let csv_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("data") + .join("csv") + .join("cars.csv"); + + let parquet_dir = write_csv_to_parquet(&ctx, &csv_path).await?; + let df = ctx + .read_parquet(parquet_dir.path_str()?, ParquetReadOptions::default()) + .await?; + + let rows = df.sort(vec![col("speed").sort(true, true)])?; + assert_batches_eq!( + &[ + "+-------+-------+---------------------+", + "| car | speed | time |", + "+-------+-------+---------------------+", + "| red | 0.0 | 1996-04-12T12:05:15 |", + "| red | 1.0 | 1996-04-12T12:05:14 |", + "| green | 2.0 | 1996-04-12T12:05:14 |", + "| red | 3.0 | 1996-04-12T12:05:13 |", + "| red | 7.0 | 1996-04-12T12:05:10 |", + "| red | 7.1 | 1996-04-12T12:05:11 |", + "| red | 7.2 | 1996-04-12T12:05:12 |", + "| green | 8.0 | 1996-04-12T12:05:13 |", + "| green | 10.0 | 1996-04-12T12:05:03 |", + "| green | 10.3 | 1996-04-12T12:05:04 |", + "| green | 10.4 | 1996-04-12T12:05:05 |", + "| green | 10.5 | 1996-04-12T12:05:06 |", + "| green | 11.0 | 1996-04-12T12:05:07 |", + "| green | 12.0 | 1996-04-12T12:05:08 |", + "| green | 14.0 | 1996-04-12T12:05:09 |", + "| green | 15.0 | 1996-04-12T12:05:10 |", + "| green | 15.1 | 1996-04-12T12:05:11 |", + "| green | 15.2 | 1996-04-12T12:05:12 |", + "| red | 17.0 | 1996-04-12T12:05:09 |", + "| red | 18.0 | 1996-04-12T12:05:08 |", + "| red | 19.0 | 1996-04-12T12:05:07 |", + "| red | 20.0 | 1996-04-12T12:05:03 |", + "| red | 20.3 | 1996-04-12T12:05:04 |", + "| red | 21.4 | 1996-04-12T12:05:05 |", + "| red | 21.5 | 1996-04-12T12:05:06 |", + "+-------+-------+---------------------+", + ], + &rows.collect().await? + ); + + Ok(()) + } + + #[tokio::test] + async fn test_write_csv_to_parquet_with_regex_data() -> Result<()> { + let ctx = SessionContext::new(); + let csv_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("data") + .join("csv") + .join("regex.csv"); + + let parquet_dir = write_csv_to_parquet(&ctx, &csv_path).await?; + let df = ctx + .read_parquet(parquet_dir.path_str()?, ParquetReadOptions::default()) + .await?; + + let rows = df.sort(vec![col("values").sort(true, true)])?; + assert_batches_eq!( + &[ + "+------------+--------------------------------------+-------------+-------+", + "| values | patterns | replacement | flags |", + "+------------+--------------------------------------+-------------+-------+", + "| 4000 | \\b4([1-9]\\d\\d|\\d[1-9]\\d|\\d\\d[1-9])\\b | xyz | |", + "| 4010 | \\b4([1-9]\\d\\d|\\d[1-9]\\d|\\d\\d[1-9])\\b | xyz | |", + "| ABC | ^(A).* | B | i |", + "| AbC | (B|D) | e | |", + "| Düsseldorf | [\\p{Letter}-]+ | München | |", + "| Köln | [a-zA-Z]ö[a-zA-Z]{2} | Koln | |", + "| aBC | ^(b|c) | d | |", + "| aBc | (b|d) | e | i |", + "| abc | ^(a) | bb\\1bb | i |", + "| Москва | [\\p{L}-]+ | Moscow | |", + "| اليوم | ^\\p{Arabic}+$ | Today | |", + "+------------+--------------------------------------+-------------+-------+", + ], + &rows.collect().await? + ); + + Ok(()) + } + + #[tokio::test] + async fn test_write_csv_to_parquet_error() { + let ctx = SessionContext::new(); + let csv_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("data") + .join("csv") + .join("file-does-not-exist.csv"); + + let err = write_csv_to_parquet(&ctx, &csv_path).await.unwrap_err(); + match err { + DataFusionError::Execution(msg) => { + assert!( + msg.contains("CSV file does not exist"), + "unexpected error message: {msg}" + ); + } + other => panic!("unexpected error variant: {other:?}"), + } + } +} diff --git a/datafusion/sqllogictest/src/engines/postgres_engine/types.rs b/datafusion-examples/src/utils/datasets/cars.rs similarity index 53% rename from datafusion/sqllogictest/src/engines/postgres_engine/types.rs rename to datafusion-examples/src/utils/datasets/cars.rs index 510462befb08..2d8547c16d68 100644 --- a/datafusion/sqllogictest/src/engines/postgres_engine/types.rs +++ b/datafusion-examples/src/utils/datasets/cars.rs @@ -15,31 +15,19 @@ // specific language governing permissions and limitations // under the License. -use postgres_types::Type; -use std::fmt::Display; -use tokio_postgres::types::FromSql; +use std::sync::Arc; -pub struct PgRegtype { - value: String, -} - -impl<'a> FromSql<'a> for PgRegtype { - fn from_sql( - _: &Type, - buf: &'a [u8], - ) -> Result> { - let oid = postgres_protocol::types::oid_from_sql(buf)?; - let value = Type::from_oid(oid).ok_or("bad type")?.to_string(); - Ok(PgRegtype { value }) - } - - fn accepts(ty: &Type) -> bool { - matches!(*ty, Type::REGTYPE) - } -} +use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; -impl Display for PgRegtype { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.value) - } +/// Schema for the `data/csv/cars.csv` example dataset. +pub fn schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("car", DataType::Utf8, false), + Field::new("speed", DataType::Float64, false), + Field::new( + "time", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + ])) } diff --git a/datafusion-examples/src/utils/datasets/mod.rs b/datafusion-examples/src/utils/datasets/mod.rs new file mode 100644 index 000000000000..1857e6af9b55 --- /dev/null +++ b/datafusion-examples/src/utils/datasets/mod.rs @@ -0,0 +1,139 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::path::PathBuf; + +use arrow_schema::SchemaRef; +use datafusion::error::{DataFusionError, Result}; + +pub mod cars; +pub mod regex; + +/// Describes example datasets used across DataFusion examples. +/// +/// This enum provides a single, discoverable place to define +/// dataset-specific metadata such as file paths and schemas. +#[derive(Debug)] +pub enum ExampleDataset { + Cars, + Regex, +} + +impl ExampleDataset { + pub fn file_stem(&self) -> &'static str { + match self { + Self::Cars => "cars", + Self::Regex => "regex", + } + } + + pub fn path(&self) -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("data") + .join("csv") + .join(format!("{}.csv", self.file_stem())) + } + + pub fn path_str(&self) -> Result { + let path = self.path(); + path.to_str().map(String::from).ok_or_else(|| { + DataFusionError::Execution(format!( + "CSV directory path is not valid UTF-8: {}", + path.display() + )) + }) + } + + pub fn schema(&self) -> SchemaRef { + match self { + Self::Cars => cars::schema(), + Self::Regex => regex::schema(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use arrow::datatypes::{DataType, TimeUnit}; + + #[test] + fn example_dataset_file_stem() { + assert_eq!(ExampleDataset::Cars.file_stem(), "cars"); + assert_eq!(ExampleDataset::Regex.file_stem(), "regex"); + } + + #[test] + fn example_dataset_path_points_to_csv() { + let path = ExampleDataset::Cars.path(); + assert!(path.ends_with("data/csv/cars.csv")); + + let path = ExampleDataset::Regex.path(); + assert!(path.ends_with("data/csv/regex.csv")); + } + + #[test] + fn example_dataset_path_str_is_valid_utf8() { + let path = ExampleDataset::Cars.path_str().unwrap(); + assert!(path.ends_with("cars.csv")); + + let path = ExampleDataset::Regex.path_str().unwrap(); + assert!(path.ends_with("regex.csv")); + } + + #[test] + fn cars_schema_is_stable() { + let schema = ExampleDataset::Cars.schema(); + + let fields: Vec<_> = schema + .fields() + .iter() + .map(|f| (f.name().as_str(), f.data_type().clone())) + .collect(); + + assert_eq!( + fields, + vec![ + ("car", DataType::Utf8), + ("speed", DataType::Float64), + ("time", DataType::Timestamp(TimeUnit::Nanosecond, None)), + ] + ); + } + + #[test] + fn regex_schema_is_stable() { + let schema = ExampleDataset::Regex.schema(); + + let fields: Vec<_> = schema + .fields() + .iter() + .map(|f| (f.name().as_str(), f.data_type().clone())) + .collect(); + + assert_eq!( + fields, + vec![ + ("values", DataType::Utf8), + ("patterns", DataType::Utf8), + ("replacement", DataType::Utf8), + ("flags", DataType::Utf8), + ] + ); + } +} diff --git a/datafusion-examples/src/utils/datasets/regex.rs b/datafusion-examples/src/utils/datasets/regex.rs new file mode 100644 index 000000000000..d44582126a05 --- /dev/null +++ b/datafusion-examples/src/utils/datasets/regex.rs @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, Schema}; + +/// Schema for the `data/csv/regex.csv` example dataset. +pub fn schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("values", DataType::Utf8, false), + Field::new("patterns", DataType::Utf8, false), + Field::new("replacement", DataType::Utf8, false), + Field::new("flags", DataType::Utf8, true), + ])) +} diff --git a/datafusion-examples/src/utils/example_metadata/discover.rs b/datafusion-examples/src/utils/example_metadata/discover.rs new file mode 100644 index 000000000000..1ba5f6d29a14 --- /dev/null +++ b/datafusion-examples/src/utils/example_metadata/discover.rs @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utilities for discovering example groups in the repository filesystem. +//! +//! An example group is defined as a directory containing a `main.rs` file +//! under the examples root. This module is intentionally filesystem-focused +//! and does not perform any parsing or rendering. +//! Discovery fails if no valid example groups are found. + +use std::fs; +use std::path::{Path, PathBuf}; + +use datafusion::common::exec_err; +use datafusion::error::Result; + +/// Discovers all example group directories under the given root. +/// +/// A directory is considered an example group if it contains a `main.rs` file. +pub fn discover_example_groups(root: &Path) -> Result> { + let mut groups = Vec::new(); + for entry in fs::read_dir(root)? { + let entry = entry?; + let path = entry.path(); + + if path.is_dir() && path.join("main.rs").is_file() { + groups.push(path); + } + } + + if groups.is_empty() { + return exec_err!("No example groups found under: {}", root.display()); + } + + groups.sort(); + Ok(groups) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::utils::example_metadata::test_utils::assert_exec_err_contains; + + use std::fs::{self, File}; + + use tempfile::TempDir; + + #[test] + fn discover_example_groups_finds_dirs_with_main_rs() -> Result<()> { + let tmp = TempDir::new()?; + let root = tmp.path(); + + // valid example group + let group1 = root.join("group1"); + fs::create_dir(&group1)?; + File::create(group1.join("main.rs"))?; + + // not an example group + let group2 = root.join("group2"); + fs::create_dir(&group2)?; + + let groups = discover_example_groups(root)?; + assert_eq!(groups.len(), 1); + assert_eq!(groups[0], group1); + Ok(()) + } + + #[test] + fn discover_example_groups_errors_if_main_rs_is_a_directory() -> Result<()> { + let tmp = TempDir::new()?; + let root = tmp.path(); + let group = root.join("group"); + fs::create_dir(&group)?; + fs::create_dir(group.join("main.rs"))?; + + let err = discover_example_groups(root).unwrap_err(); + assert_exec_err_contains(err, "No example groups found"); + Ok(()) + } + + #[test] + fn discover_example_groups_errors_if_none_found() -> Result<()> { + let tmp = TempDir::new()?; + let err = discover_example_groups(tmp.path()).unwrap_err(); + assert_exec_err_contains(err, "No example groups found"); + Ok(()) + } +} diff --git a/datafusion-examples/src/utils/example_metadata/layout.rs b/datafusion-examples/src/utils/example_metadata/layout.rs new file mode 100644 index 000000000000..ee6fad89855f --- /dev/null +++ b/datafusion-examples/src/utils/example_metadata/layout.rs @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Repository layout utilities. +//! +//! This module provides a small helper (`RepoLayout`) that encapsulates +//! knowledge about the DataFusion repository structure, in particular +//! where example groups are located relative to the repository root. + +use std::path::{Path, PathBuf}; + +use datafusion::error::{DataFusionError, Result}; + +/// Describes the layout of a DataFusion repository. +/// +/// This type centralizes knowledge about where example-related +/// directories live relative to the repository root. +#[derive(Debug, Clone)] +pub struct RepoLayout { + root: PathBuf, +} + +impl From<&Path> for RepoLayout { + fn from(path: &Path) -> Self { + Self { + root: path.to_path_buf(), + } + } +} + +impl RepoLayout { + /// Creates a layout from an explicit repository root. + pub fn from_root(root: PathBuf) -> Self { + Self { root } + } + + /// Detects the repository root based on `CARGO_MANIFEST_DIR`. + /// + /// This is intended for use from binaries inside the workspace. + pub fn detect() -> Result { + let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + + let root = manifest_dir.parent().ok_or_else(|| { + DataFusionError::Execution( + "CARGO_MANIFEST_DIR does not have a parent".to_string(), + ) + })?; + + Ok(Self { + root: root.to_path_buf(), + }) + } + + /// Returns the repository root directory. + pub fn root(&self) -> &Path { + &self.root + } + + /// Returns the `datafusion-examples/examples` directory. + pub fn examples_root(&self) -> PathBuf { + self.root.join("datafusion-examples").join("examples") + } + + /// Returns the directory for a single example group. + /// + /// Example: `examples/udf` + pub fn example_group_dir(&self, group: &str) -> PathBuf { + self.examples_root().join(group) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn detect_sets_non_empty_root() -> Result<()> { + let layout = RepoLayout::detect()?; + assert!(!layout.root().as_os_str().is_empty()); + Ok(()) + } + + #[test] + fn examples_root_is_under_repo_root() -> Result<()> { + let layout = RepoLayout::detect()?; + let examples_root = layout.examples_root(); + assert!(examples_root.starts_with(layout.root())); + assert!(examples_root.ends_with("datafusion-examples/examples")); + Ok(()) + } + + #[test] + fn example_group_dir_appends_group_name() -> Result<()> { + let layout = RepoLayout::detect()?; + let group_dir = layout.example_group_dir("foo"); + assert!(group_dir.ends_with("datafusion-examples/examples/foo")); + Ok(()) + } +} diff --git a/datafusion-examples/src/utils/example_metadata/mod.rs b/datafusion-examples/src/utils/example_metadata/mod.rs new file mode 100644 index 000000000000..ab4c8e4a8e4c --- /dev/null +++ b/datafusion-examples/src/utils/example_metadata/mod.rs @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Documentation generator for DataFusion examples. +//! +//! # Design goals +//! +//! - Keep README.md in sync with runnable examples +//! - Fail fast on malformed documentation +//! +//! # Overview +//! +//! Each example group corresponds to a directory under +//! `datafusion-examples/examples/` containing a `main.rs` file. +//! Documentation is extracted from structured `//!` comments in that file. +//! +//! For each example group, the generator produces: +//! +//! ```text +//! ## Examples +//! ### Group: `` +//! #### Category: Single Process | Distributed +//! +//! | Subcommand | File Path | Description | +//! ``` +//! +//! # Usage +//! +//! Generate documentation for a single group only: +//! +//! ```bash +//! cargo run --bin examples-docs -- dataframe +//! ``` +//! +//! Generate documentation for all examples: +//! +//! ```bash +//! cargo run --bin examples-docs +//! ``` + +pub mod discover; +pub mod layout; +pub mod model; +pub mod parser; +pub mod render; + +#[cfg(test)] +pub mod test_utils; + +pub use layout::RepoLayout; +pub use model::{Category, ExampleEntry, ExampleGroup, GroupName}; +pub use parser::parse_main_rs_docs; +pub use render::generate_examples_readme; diff --git a/datafusion-examples/src/utils/example_metadata/model.rs b/datafusion-examples/src/utils/example_metadata/model.rs new file mode 100644 index 000000000000..11416d141eb7 --- /dev/null +++ b/datafusion-examples/src/utils/example_metadata/model.rs @@ -0,0 +1,418 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Domain model for DataFusion example documentation. +//! +//! This module defines the core data structures used to represent +//! example groups, individual examples, and their categorization +//! as parsed from `main.rs` documentation comments. + +use std::path::Path; + +use datafusion::error::{DataFusionError, Result}; + +use crate::utils::example_metadata::parse_main_rs_docs; + +/// Well-known abbreviations used to preserve correct capitalization +/// when generating human-readable documentation titles. +const ABBREVIATIONS: &[(&str, &str)] = &[ + ("dataframe", "DataFrame"), + ("io", "IO"), + ("sql", "SQL"), + ("udf", "UDF"), +]; + +/// A group of related examples (e.g. `builtin_functions`, `udf`). +/// +/// Each group corresponds to a directory containing a `main.rs` file +/// with structured documentation comments. +#[derive(Debug)] +pub struct ExampleGroup { + pub name: GroupName, + pub examples: Vec, + pub category: Category, +} + +impl ExampleGroup { + /// Parses an example group from its directory. + /// + /// The group name is derived from the directory name, and example + /// entries are extracted from `main.rs`. + pub fn from_dir(dir: &Path, category: Category) -> Result { + let raw_name = dir + .file_name() + .and_then(|s| s.to_str()) + .ok_or_else(|| { + DataFusionError::Execution("Invalid example group dir".to_string()) + })? + .to_string(); + + let name = GroupName::from_dir_name(raw_name); + let main_rs = dir.join("main.rs"); + let examples = parse_main_rs_docs(&main_rs)?; + + Ok(Self { + name, + examples, + category, + }) + } +} + +/// Represents an example group name in both raw and human-readable forms. +/// +/// For example: +/// - raw: `builtin_functions` +/// - title: `Builtin Functions` +#[derive(Debug)] +pub struct GroupName { + raw: String, + title: String, +} + +impl GroupName { + /// Creates a group name from a directory name. + pub fn from_dir_name(raw: String) -> Self { + let title = raw + .split('_') + .map(format_part) + .collect::>() + .join(" "); + + Self { raw, title } + } + + /// Returns the raw group name (directory name). + pub fn raw(&self) -> &str { + &self.raw + } + + /// Returns a title-cased name for documentation. + pub fn title(&self) -> &str { + &self.title + } +} + +/// A single runnable example within a group. +/// +/// Each entry corresponds to a subcommand documented in `main.rs`. +#[derive(Debug)] +pub struct ExampleEntry { + /// CLI subcommand name. + pub subcommand: String, + /// Rust source file name. + pub file: String, + /// Human-readable description. + pub desc: String, +} + +/// Execution category of an example group. +#[derive(Debug, Default)] +pub enum Category { + /// Runs in a single process. + #[default] + SingleProcess, + /// Requires a distributed setup. + Distributed, +} + +impl Category { + /// Returns the display name used in documentation. + pub fn name(&self) -> &str { + match self { + Self::SingleProcess => "Single Process", + Self::Distributed => "Distributed", + } + } + + /// Determines the category for a group by name. + pub fn for_group(name: &str) -> Self { + match name { + "flight" => Category::Distributed, + _ => Category::SingleProcess, + } + } +} + +/// Formats a single group-name segment for display. +/// +/// This function applies DataFusion-specific capitalization rules: +/// - Known abbreviations (e.g. `sql`, `io`, `udf`) are rendered in all caps +/// - All other segments fall back to standard Title Case +fn format_part(part: &str) -> String { + let lower = part.to_ascii_lowercase(); + + if let Some((_, replacement)) = ABBREVIATIONS.iter().find(|(k, _)| *k == lower) { + return replacement.to_string(); + } + + let mut chars = part.chars(); + match chars.next() { + Some(first) => first.to_uppercase().collect::() + chars.as_str(), + None => String::new(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::utils::example_metadata::test_utils::{ + assert_exec_err_contains, example_group_from_docs, + }; + + use std::fs; + + use tempfile::TempDir; + + #[test] + fn category_for_group_works() { + assert!(matches!( + Category::for_group("flight"), + Category::Distributed + )); + assert!(matches!( + Category::for_group("anything_else"), + Category::SingleProcess + )); + } + + #[test] + fn all_subcommand_is_ignored() -> Result<()> { + let group = example_group_from_docs( + r#" + //! - `all` — run all examples included in this module + //! + //! - `foo` + //! (file: foo.rs, desc: foo example) + "#, + )?; + assert_eq!(group.examples.len(), 1); + assert_eq!(group.examples[0].subcommand, "foo"); + Ok(()) + } + + #[test] + fn metadata_without_subcommand_fails() { + let err = example_group_from_docs("//! (file: foo.rs, desc: missing subcommand)") + .unwrap_err(); + assert_exec_err_contains(err, "Metadata without preceding subcommand"); + } + + #[test] + fn group_name_handles_abbreviations() { + assert_eq!( + GroupName::from_dir_name("dataframe".to_string()).title(), + "DataFrame" + ); + assert_eq!( + GroupName::from_dir_name("data_io".to_string()).title(), + "Data IO" + ); + assert_eq!( + GroupName::from_dir_name("sql_ops".to_string()).title(), + "SQL Ops" + ); + assert_eq!(GroupName::from_dir_name("udf".to_string()).title(), "UDF"); + } + + #[test] + fn group_name_title_cases() { + let cases = [ + ("very_long_group_name", "Very Long Group Name"), + ("foo", "Foo"), + ("dataframe", "DataFrame"), + ("data_io", "Data IO"), + ("sql_ops", "SQL Ops"), + ("udf", "UDF"), + ]; + for (input, expected) in cases { + let name = GroupName::from_dir_name(input.to_string()); + assert_eq!(name.title(), expected); + } + } + + #[test] + fn parse_group_example_works() -> Result<()> { + let tmp = TempDir::new().unwrap(); + + // Simulate: examples/builtin_functions/ + let group_dir = tmp.path().join("builtin_functions"); + fs::create_dir(&group_dir)?; + + // Write a fake main.rs with docs + let main_rs = group_dir.join("main.rs"); + fs::write( + &main_rs, + r#" + // Licensed to the Apache Software Foundation (ASF) under one + // or more contributor license agreements. See the NOTICE file + // distributed with this work for additional information + // regarding copyright ownership. The ASF licenses this file + // to you under the Apache License, Version 2.0 (the + // "License"); you may not use this file except in compliance + // with the License. You may obtain a copy of the License at + // + // http://www.apache.org/licenses/LICENSE-2.0 + // + // Unless required by applicable law or agreed to in writing, + // software distributed under the License is distributed on an + // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + // KIND, either express or implied. See the License for the + // specific language governing permissions and limitations + // under the License. + // + //! # These are miscellaneous function-related examples + //! + //! These examples demonstrate miscellaneous function-related features. + //! + //! ## Usage + //! ```bash + //! cargo run --example builtin_functions -- [all|date_time|function_factory|regexp] + //! ``` + //! + //! Each subcommand runs a corresponding example: + //! - `all` — run all examples included in this module + //! + //! - `date_time` + //! (file: date_time.rs, desc: Examples of date-time related functions and queries) + //! + //! - `function_factory` + //! (file: function_factory.rs, desc: Register `CREATE FUNCTION` handler to implement SQL macros) + //! + //! - `regexp` + //! (file: regexp.rs, desc: Examples of using regular expression functions) + "#, + )?; + + let group = ExampleGroup::from_dir(&group_dir, Category::SingleProcess)?; + + // Assert group-level data + assert_eq!(group.name.title(), "Builtin Functions"); + assert_eq!(group.examples.len(), 3); + + // Assert 1 example + assert_eq!(group.examples[0].subcommand, "date_time"); + assert_eq!(group.examples[0].file, "date_time.rs"); + assert_eq!( + group.examples[0].desc, + "Examples of date-time related functions and queries" + ); + + // Assert 2 example + assert_eq!(group.examples[1].subcommand, "function_factory"); + assert_eq!(group.examples[1].file, "function_factory.rs"); + assert_eq!( + group.examples[1].desc, + "Register `CREATE FUNCTION` handler to implement SQL macros" + ); + + // Assert 3 example + assert_eq!(group.examples[2].subcommand, "regexp"); + assert_eq!(group.examples[2].file, "regexp.rs"); + assert_eq!( + group.examples[2].desc, + "Examples of using regular expression functions" + ); + + Ok(()) + } + + #[test] + fn duplicate_metadata_without_repeating_subcommand_fails() { + let err = example_group_from_docs( + r#" + //! - `foo` + //! (file: a.rs, desc: first) + //! (file: b.rs, desc: second) + "#, + ) + .unwrap_err(); + assert_exec_err_contains(err, "Metadata without preceding subcommand"); + } + + #[test] + fn duplicate_metadata_for_same_subcommand_fails() { + let err = example_group_from_docs( + r#" + //! - `foo` + //! (file: a.rs, desc: first) + //! + //! - `foo` + //! (file: b.rs, desc: second) + "#, + ) + .unwrap_err(); + assert_exec_err_contains(err, "Duplicate metadata for subcommand `foo`"); + } + + #[test] + fn metadata_must_follow_subcommand() { + let err = example_group_from_docs( + r#" + //! - `foo` + //! some unrelated comment + //! (file: foo.rs, desc: test) + "#, + ) + .unwrap_err(); + assert_exec_err_contains(err, "Metadata without preceding subcommand"); + } + + #[test] + fn preserves_example_order_from_main_rs() -> Result<()> { + let group = example_group_from_docs( + r#" + //! - `second` + //! (file: second.rs, desc: second example) + //! + //! - `first` + //! (file: first.rs, desc: first example) + //! + //! - `third` + //! (file: third.rs, desc: third example) + "#, + )?; + + let subcommands: Vec<&str> = group + .examples + .iter() + .map(|e| e.subcommand.as_str()) + .collect(); + + assert_eq!( + subcommands, + vec!["second", "first", "third"], + "examples must preserve the order defined in main.rs" + ); + + Ok(()) + } + + #[test] + fn metadata_can_follow_blank_doc_line() -> Result<()> { + let group = example_group_from_docs( + r#" + //! - `foo` + //! + //! (file: foo.rs, desc: test) + "#, + )?; + assert_eq!(group.examples.len(), 1); + Ok(()) + } +} diff --git a/datafusion-examples/src/utils/example_metadata/parser.rs b/datafusion-examples/src/utils/example_metadata/parser.rs new file mode 100644 index 000000000000..4ead3e5a2ae9 --- /dev/null +++ b/datafusion-examples/src/utils/example_metadata/parser.rs @@ -0,0 +1,267 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Parser for example metadata embedded in `main.rs` documentation comments. +//! +//! This module scans `//!` doc comments to extract example subcommands +//! and their associated metadata (file name and description), enforcing +//! a strict ordering and structure to avoid ambiguous documentation. + +use std::{collections::HashSet, fs, path::Path}; + +use datafusion::common::exec_err; +use datafusion::error::Result; +use nom::{ + Err, IResult, Parser, + bytes::complete::{tag, take_until, take_while}, + character::complete::multispace0, + combinator::all_consuming, + error::{Error, ErrorKind}, + sequence::{delimited, preceded}, +}; + +use crate::utils::example_metadata::ExampleEntry; + +/// Parsing state machine used while scanning `main.rs` docs. +/// +/// This makes the "subcommand - metadata" relationship explicit: +/// metadata is only valid immediately after a subcommand has been seen. +enum ParserState<'a> { + /// Not currently expecting metadata. + Idle, + /// A subcommand was just parsed; the next valid metadata (if any) + /// must belong to this subcommand. + SeenSubcommand(&'a str), +} + +/// Parses a subcommand declaration line from `main.rs` docs. +/// +/// Expected format: +/// ```text +/// //! - `` +/// ``` +fn parse_subcommand_line(input: &str) -> IResult<&str, &str> { + let parser = preceded( + multispace0, + delimited(tag("//! - `"), take_until("`"), tag("`")), + ); + all_consuming(parser).parse(input) +} + +/// Parses example metadata (file name and description) from `main.rs` docs. +/// +/// Expected format: +/// ```text +/// //! (file: .rs, desc: ) +/// ``` +fn parse_metadata_line(input: &str) -> IResult<&str, (&str, &str)> { + let parser = preceded( + multispace0, + preceded(tag("//!"), preceded(multispace0, take_while(|_| true))), + ); + let (rest, payload) = all_consuming(parser).parse(input)?; + + let content = payload + .strip_prefix("(") + .and_then(|s| s.strip_suffix(")")) + .ok_or_else(|| Err::Error(Error::new(payload, ErrorKind::Tag)))?; + + let (file, desc) = content + .strip_prefix("file:") + .ok_or_else(|| Err::Error(Error::new(payload, ErrorKind::Tag)))? + .split_once(", desc:") + .ok_or_else(|| Err::Error(Error::new(payload, ErrorKind::Tag)))?; + + Ok((rest, (file.trim(), desc.trim()))) +} + +/// Parses example entries from a group's `main.rs` file. +pub fn parse_main_rs_docs(path: &Path) -> Result> { + let content = fs::read_to_string(path)?; + let mut entries = vec![]; + let mut state = ParserState::Idle; + let mut seen_subcommands = HashSet::new(); + + for (line_no, raw_line) in content.lines().enumerate() { + let line = raw_line.trim(); + + // Try parsing subcommand, excluding `all` because it's not used in README + if let Ok((_, sub)) = parse_subcommand_line(line) { + state = if sub == "all" { + ParserState::Idle + } else { + ParserState::SeenSubcommand(sub) + }; + continue; + } + + // Try parsing metadata + if let Ok((_, (file, desc))) = parse_metadata_line(line) { + let subcommand = match state { + ParserState::SeenSubcommand(s) => s, + ParserState::Idle => { + return exec_err!( + "Metadata without preceding subcommand at {}:{}", + path.display(), + line_no + 1 + ); + } + }; + + if !seen_subcommands.insert(subcommand) { + return exec_err!("Duplicate metadata for subcommand `{subcommand}`"); + } + + entries.push(ExampleEntry { + subcommand: subcommand.to_string(), + file: file.to_string(), + desc: desc.to_string(), + }); + + state = ParserState::Idle; + continue; + } + + // If a non-blank doc line interrupts a pending subcommand, reset the state + if let ParserState::SeenSubcommand(_) = state + && is_non_blank_doc_line(line) + { + state = ParserState::Idle; + } + } + + Ok(entries) +} + +/// Returns `true` for non-blank Rust doc comment lines (`//!`). +/// +/// Used to detect when a subcommand is interrupted by unrelated documentation, +/// so metadata is only accepted immediately after a subcommand (blank doc lines +/// are allowed in between). +fn is_non_blank_doc_line(line: &str) -> bool { + line.starts_with("//!") && !line.trim_start_matches("//!").trim().is_empty() +} + +#[cfg(test)] +mod tests { + use super::*; + + use tempfile::TempDir; + + #[test] + fn parse_subcommand_line_accepts_valid_input() { + let line = "//! - `date_time`"; + let sub = parse_subcommand_line(line); + assert_eq!(sub, Ok(("", "date_time"))); + } + + #[test] + fn parse_subcommand_line_invalid_inputs() { + let err_lines = [ + "//! - ", + "//! - foo", + "//! - `foo` bar", + "//! --", + "//!-", + "//!--", + "//!", + "//", + "/", + "", + ]; + for line in err_lines { + assert!( + parse_subcommand_line(line).is_err(), + "expected error for input: {line}" + ); + } + } + + #[test] + fn parse_metadata_line_accepts_valid_input() { + let line = + "//! (file: date_time.rs, desc: Examples of date-time related functions)"; + let res = parse_metadata_line(line); + assert_eq!( + res, + Ok(( + "", + ("date_time.rs", "Examples of date-time related functions") + )) + ); + + let line = "//! (file: foo.rs, desc: Foo, bar, baz)"; + let res = parse_metadata_line(line); + assert_eq!(res, Ok(("", ("foo.rs", "Foo, bar, baz")))); + + let line = "//! (file: foo.rs, desc: Foo(FOO))"; + let res = parse_metadata_line(line); + assert_eq!(res, Ok(("", ("foo.rs", "Foo(FOO)")))); + } + + #[test] + fn parse_metadata_line_invalid_inputs() { + let bad_lines = [ + "//! (file: foo.rs)", + "//! (desc: missing file)", + "//! file: foo.rs, desc: test", + "//! file: foo.rs,desc: test", + "//! (file: foo.rs desc: test)", + "//! (file: foo.rs,desc: test)", + "//! (desc: test, file: foo.rs)", + "//! ()", + "//! (file: foo.rs, desc: test) extra", + "", + ]; + for line in bad_lines { + assert!( + parse_metadata_line(line).is_err(), + "expected error for input: {line}" + ); + } + } + + #[test] + fn parse_main_rs_docs_extracts_entries() -> Result<()> { + let tmp = TempDir::new().unwrap(); + let main_rs = tmp.path().join("main.rs"); + + fs::write( + &main_rs, + r#" + //! - `foo` + //! (file: foo.rs, desc: first example) + //! + //! - `bar` + //! (file: bar.rs, desc: second example) + "#, + )?; + + let entries = parse_main_rs_docs(&main_rs)?; + + assert_eq!(entries.len(), 2); + + assert_eq!(entries[0].subcommand, "foo"); + assert_eq!(entries[0].file, "foo.rs"); + assert_eq!(entries[0].desc, "first example"); + + assert_eq!(entries[1].subcommand, "bar"); + assert_eq!(entries[1].file, "bar.rs"); + assert_eq!(entries[1].desc, "second example"); + Ok(()) + } +} diff --git a/datafusion-examples/src/utils/example_metadata/render.rs b/datafusion-examples/src/utils/example_metadata/render.rs new file mode 100644 index 000000000000..a4ea620e7835 --- /dev/null +++ b/datafusion-examples/src/utils/example_metadata/render.rs @@ -0,0 +1,203 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Markdown renderer for DataFusion example documentation. +//! +//! This module takes parsed example metadata and generates the +//! `README.md` content for `datafusion-examples`, including group +//! sections and example tables. + +use std::path::PathBuf; + +use datafusion::error::{DataFusionError, Result}; + +use crate::utils::example_metadata::discover::discover_example_groups; +use crate::utils::example_metadata::model::ExampleGroup; +use crate::utils::example_metadata::{Category, RepoLayout}; + +const STATIC_HEADER: &str = r#" + +# DataFusion Examples + +This crate includes end to end, highly commented examples of how to use +various DataFusion APIs to help you get started. + +## Prerequisites + +Run `git submodule update --init` to init test files. + +## Running Examples + +To run an example, use the `cargo run` command, such as: + +```bash +git clone https://github.com/apache/datafusion +cd datafusion +# Download test data +git submodule update --init + +# Change to the examples directory +cd datafusion-examples/examples + +# Run all examples in a group +cargo run --example -- all + +# Run a specific example within a group +cargo run --example -- + +# Run all examples in the `dataframe` group +cargo run --example dataframe -- all + +# Run a single example from the `dataframe` group +# (apply the same pattern for any other group) +cargo run --example dataframe -- dataframe +``` +"#; + +/// Generates Markdown documentation for DataFusion examples. +/// +/// If `group` is `None`, documentation is generated for all example groups. +/// If `group` is `Some`, only that group is rendered. +/// +/// # Errors +/// +/// Returns an error if: +/// - the requested group does not exist +/// - a `main.rs` file is missing +/// - documentation comments are malformed +pub fn generate_examples_readme( + layout: &RepoLayout, + group: Option<&str>, +) -> Result { + let examples_root = layout.examples_root(); + + let mut out = String::new(); + out.push_str(STATIC_HEADER); + + let group_dirs: Vec = match group { + Some(name) => { + let dir = examples_root.join(name); + if !dir.is_dir() { + return Err(DataFusionError::Execution(format!( + "Example group `{name}` does not exist" + ))); + } + vec![dir] + } + None => discover_example_groups(&examples_root)?, + }; + + for group_dir in group_dirs { + let raw_name = + group_dir + .file_name() + .and_then(|s| s.to_str()) + .ok_or_else(|| { + DataFusionError::Execution("Invalid example group dir".to_string()) + })?; + + let category = Category::for_group(raw_name); + let group = ExampleGroup::from_dir(&group_dir, category)?; + + out.push_str(&group.render_markdown()); + } + + Ok(out) +} + +impl ExampleGroup { + /// Renders this example group as a Markdown section for the README. + pub fn render_markdown(&self) -> String { + let mut out = String::new(); + out.push_str(&format!("\n## {} Examples\n\n", self.name.title())); + out.push_str(&format!("### Group: `{}`\n\n", self.name.raw())); + out.push_str(&format!("#### Category: {}\n\n", self.category.name())); + out.push_str("| Subcommand | File Path | Description |\n"); + out.push_str("| --- | --- | --- |\n"); + + for example in &self.examples { + out.push_str(&format!( + "| {} | [`{}/{}`](examples/{}/{}) | {} |\n", + example.subcommand, + self.name.raw(), + example.file, + self.name.raw(), + example.file, + example.desc + )); + } + + out + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::utils::example_metadata::test_utils::assert_exec_err_contains; + + use std::fs; + + use tempfile::TempDir; + + #[test] + fn single_group_generation_works() { + let tmp = TempDir::new().unwrap(); + // Fake repo root + let layout = RepoLayout::from_root(tmp.path().to_path_buf()); + + // Create: datafusion-examples/examples/builtin_functions + let examples_dir = layout.example_group_dir("builtin_functions"); + fs::create_dir_all(&examples_dir).unwrap(); + + fs::write( + examples_dir.join("main.rs"), + "//! - `x`\n//! (file: foo.rs, desc: test)", + ) + .unwrap(); + + let out = generate_examples_readme(&layout, Some("builtin_functions")).unwrap(); + assert!(out.contains("Builtin Functions")); + assert!(out.contains("| x | [`builtin_functions/foo.rs`]")); + } + + #[test] + fn single_group_generation_fails_if_group_missing() { + let tmp = TempDir::new().unwrap(); + let layout = RepoLayout::from_root(tmp.path().to_path_buf()); + let err = generate_examples_readme(&layout, Some("missing_group")).unwrap_err(); + assert_exec_err_contains(err, "Example group `missing_group` does not exist"); + } +} diff --git a/datafusion-examples/src/utils/example_metadata/test_utils.rs b/datafusion-examples/src/utils/example_metadata/test_utils.rs new file mode 100644 index 000000000000..d6ab3b06ba06 --- /dev/null +++ b/datafusion-examples/src/utils/example_metadata/test_utils.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Test helpers for example metadata parsing and validation. +//! +//! This module provides small, focused utilities to reduce duplication +//! and keep tests readable across the example metadata submodules. + +use std::fs; + +use datafusion::error::{DataFusionError, Result}; +use tempfile::TempDir; + +use crate::utils::example_metadata::{Category, ExampleGroup}; + +/// Asserts that an `Execution` error contains the expected message fragment. +/// +/// Keeps tests focused on semantic error causes without coupling them +/// to full error string formatting. +pub fn assert_exec_err_contains(err: DataFusionError, needle: &str) { + match err { + DataFusionError::Execution(msg) => { + assert!( + msg.contains(needle), + "expected '{needle}' in error message, got: {msg}" + ); + } + other => panic!("expected Execution error, got: {other:?}"), + } +} + +/// Helper for grammar-focused tests. +/// +/// Creates a minimal temporary example group with a single `main.rs` +/// containing the provided docs. Intended for testing parsing and +/// validation rules, not full integration behavior. +pub fn example_group_from_docs(docs: &str) -> Result { + let tmp = TempDir::new().map_err(|e| { + DataFusionError::Execution(format!("Failed initializing temp dir: {e}")) + })?; + let dir = tmp.path().join("group"); + fs::create_dir(&dir).map_err(|e| { + DataFusionError::Execution(format!("Failed creating temp dir: {e}")) + })?; + fs::write(dir.join("main.rs"), docs).map_err(|e| { + DataFusionError::Execution(format!("Failed writing to temp file: {e}")) + })?; + ExampleGroup::from_dir(&dir, Category::SingleProcess) +} diff --git a/datafusion-examples/src/utils/mod.rs b/datafusion-examples/src/utils/mod.rs new file mode 100644 index 000000000000..da96724a49cb --- /dev/null +++ b/datafusion-examples/src/utils/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod csv_to_parquet; +pub mod datasets; +pub mod example_metadata; + +pub use csv_to_parquet::write_csv_to_parquet; diff --git a/datafusion/catalog-listing/src/helpers.rs b/datafusion/catalog-listing/src/helpers.rs index ea016015cebd..031b2ebfb810 100644 --- a/datafusion/catalog-listing/src/helpers.rs +++ b/datafusion/catalog-listing/src/helpers.rs @@ -83,6 +83,7 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { | Expr::Exists(_) | Expr::InSubquery(_) | Expr::ScalarSubquery(_) + | Expr::SetComparison(_) | Expr::GroupingSet(_) | Expr::Case(_) => Ok(TreeNodeRecursion::Continue), diff --git a/datafusion/catalog-listing/src/mod.rs b/datafusion/catalog-listing/src/mod.rs index 28bd880ea01f..9efb5aa96267 100644 --- a/datafusion/catalog-listing/src/mod.rs +++ b/datafusion/catalog-listing/src/mod.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -#![deny(clippy::allow_attributes)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] #![doc( html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", diff --git a/datafusion/catalog-listing/src/table.rs b/datafusion/catalog-listing/src/table.rs index 9fb2dd2dce29..a5de79b052a4 100644 --- a/datafusion/catalog-listing/src/table.rs +++ b/datafusion/catalog-listing/src/table.rs @@ -28,12 +28,13 @@ use datafusion_common::{ use datafusion_datasource::file::FileSource; use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; -use datafusion_datasource::file_sink_config::FileSinkConfig; +use datafusion_datasource::file_sink_config::{FileOutputMode, FileSinkConfig}; #[expect(deprecated)] use datafusion_datasource::schema_adapter::SchemaAdapterFactory; use datafusion_datasource::{ ListingTableUrl, PartitionedFile, TableSchema, compute_all_files_statistics, }; +use datafusion_execution::cache::TableScopedPath; use datafusion_execution::cache::cache_manager::FileStatisticsCache; use datafusion_execution::cache::cache_unit::DefaultFileStatisticsCache; use datafusion_expr::dml::InsertOp; @@ -336,16 +337,103 @@ impl ListingTable { self.options.format.file_source(table_schema) } - /// If file_sort_order is specified, creates the appropriate physical expressions + /// Creates output ordering from user-specified file_sort_order or derives + /// from file orderings when user doesn't specify. + /// + /// If user specified `file_sort_order`, that takes precedence. + /// Otherwise, attempts to derive common ordering from file orderings in + /// the provided file groups. pub fn try_create_output_ordering( &self, execution_props: &ExecutionProps, + file_groups: &[FileGroup], ) -> datafusion_common::Result> { - create_lex_ordering( - &self.table_schema, - &self.options.file_sort_order, - execution_props, - ) + // If user specified sort order, use that + if !self.options.file_sort_order.is_empty() { + return create_lex_ordering( + &self.table_schema, + &self.options.file_sort_order, + execution_props, + ); + } + if let Some(ordering) = derive_common_ordering_from_files(file_groups) { + return Ok(vec![ordering]); + } + Ok(vec![]) + } +} + +/// Derives a common ordering from file orderings across all file groups. +/// +/// Returns the common ordering if all files have compatible orderings, +/// otherwise returns None. +/// +/// The function finds the longest common prefix among all file orderings. +/// For example, if files have orderings `[a, b, c]` and `[a, b]`, the common +/// ordering is `[a, b]`. +fn derive_common_ordering_from_files(file_groups: &[FileGroup]) -> Option { + enum CurrentOrderingState { + /// Initial state before processing any files + FirstFile, + /// Some common ordering found so far + SomeOrdering(LexOrdering), + /// No files have ordering + NoOrdering, + } + let mut state = CurrentOrderingState::FirstFile; + + // Collect file orderings and track counts + for group in file_groups { + for file in group.iter() { + state = match (&state, &file.ordering) { + // If this is the first file with ordering, set it as current + (CurrentOrderingState::FirstFile, Some(ordering)) => { + CurrentOrderingState::SomeOrdering(ordering.clone()) + } + (CurrentOrderingState::FirstFile, None) => { + CurrentOrderingState::NoOrdering + } + // If we have an existing ordering, find common prefix with new ordering + (CurrentOrderingState::SomeOrdering(current), Some(ordering)) => { + // Find common prefix between current and new ordering + let prefix_len = current + .as_ref() + .iter() + .zip(ordering.as_ref().iter()) + .take_while(|(a, b)| a == b) + .count(); + if prefix_len == 0 { + log::trace!( + "Cannot derive common ordering: no common prefix between orderings {current:?} and {ordering:?}" + ); + return None; + } else { + let ordering = + LexOrdering::new(current.as_ref()[..prefix_len].to_vec()) + .expect("prefix_len > 0, so ordering must be valid"); + CurrentOrderingState::SomeOrdering(ordering) + } + } + // If one file has ordering and another doesn't, no common ordering + // Return None and log a trace message explaining why + (CurrentOrderingState::SomeOrdering(ordering), None) + | (CurrentOrderingState::NoOrdering, Some(ordering)) => { + log::trace!( + "Cannot derive common ordering: some files have ordering {ordering:?}, others don't" + ); + return None; + } + // Both have no ordering, remain in NoOrdering state + (CurrentOrderingState::NoOrdering, None) => { + CurrentOrderingState::NoOrdering + } + }; + } + } + + match state { + CurrentOrderingState::SomeOrdering(ordering) => Some(ordering), + _ => None, } } @@ -438,7 +526,10 @@ impl TableProvider for ListingTable { return Ok(ScanResult::new(Arc::new(EmptyExec::new(projected_schema)))); } - let output_ordering = self.try_create_output_ordering(state.execution_props())?; + let output_ordering = self.try_create_output_ordering( + state.execution_props(), + &partitioned_file_lists, + )?; match state .config_options() .execution @@ -565,7 +656,11 @@ impl TableProvider for ListingTable { // Invalidate cache entries for this table if they exist if let Some(lfc) = state.runtime_env().cache_manager.get_list_files_cache() { - let _ = lfc.remove(table_path.prefix()); + let key = TableScopedPath { + table: table_path.get_table_ref().clone(), + path: table_path.prefix().clone(), + }; + let _ = lfc.remove(&key); } // Sink related option, apart from format @@ -579,9 +674,11 @@ impl TableProvider for ListingTable { insert_op, keep_partition_by_columns, file_extension: self.options().format.get_ext(), + file_output_mode: FileOutputMode::Automatic, }; - let orderings = self.try_create_output_ordering(state.execution_props())?; + // For writes, we only use user-specified ordering (no file groups to derive from) + let orderings = self.try_create_output_ordering(state.execution_props(), &[])?; // It is sufficient to pass only one of the equivalent orderings: let order_requirements = orderings.into_iter().next().map(Into::into); @@ -630,16 +727,19 @@ impl ListingTable { let meta_fetch_concurrency = ctx.config_options().execution.meta_fetch_concurrency; let file_list = stream::iter(file_list).flatten_unordered(meta_fetch_concurrency); - // collect the statistics if required by the config + // collect the statistics and ordering if required by the config let files = file_list .map(|part_file| async { let part_file = part_file?; - let statistics = if self.options.collect_stat { - self.do_collect_statistics(ctx, &store, &part_file).await? + let (statistics, ordering) = if self.options.collect_stat { + self.do_collect_statistics_and_ordering(ctx, &store, &part_file) + .await? } else { - Arc::new(Statistics::new_unknown(&self.file_schema)) + (Arc::new(Statistics::new_unknown(&self.file_schema)), None) }; - Ok(part_file.with_statistics(statistics)) + Ok(part_file + .with_statistics(statistics) + .with_ordering(ordering)) }) .boxed() .buffer_unordered(ctx.config_options().execution.meta_fetch_concurrency); @@ -694,42 +794,50 @@ impl ListingTable { }) } - /// Collects statistics for a given partitioned file. + /// Collects statistics and ordering for a given partitioned file. /// - /// This method first checks if the statistics for the given file are already cached. - /// If they are, it returns the cached statistics. - /// If they are not, it infers the statistics from the file and stores them in the cache. - async fn do_collect_statistics( + /// This method checks if statistics are cached. If cached, it returns the + /// cached statistics and infers ordering separately. If not cached, it infers + /// both statistics and ordering in a single metadata read for efficiency. + async fn do_collect_statistics_and_ordering( &self, ctx: &dyn Session, store: &Arc, part_file: &PartitionedFile, - ) -> datafusion_common::Result> { - match self - .collected_statistics - .get_with_extra(&part_file.object_meta.location, &part_file.object_meta) + ) -> datafusion_common::Result<(Arc, Option)> { + use datafusion_execution::cache::cache_manager::CachedFileMetadata; + + let path = &part_file.object_meta.location; + let meta = &part_file.object_meta; + + // Check cache first - if we have valid cached statistics and ordering + if let Some(cached) = self.collected_statistics.get(path) + && cached.is_valid_for(meta) { - Some(statistics) => Ok(statistics), - None => { - let statistics = self - .options - .format - .infer_stats( - ctx, - store, - Arc::clone(&self.file_schema), - &part_file.object_meta, - ) - .await?; - let statistics = Arc::new(statistics); - self.collected_statistics.put_with_extra( - &part_file.object_meta.location, - Arc::clone(&statistics), - &part_file.object_meta, - ); - Ok(statistics) - } + // Return cached statistics and ordering + return Ok((Arc::clone(&cached.statistics), cached.ordering.clone())); } + + // Cache miss or invalid: fetch both statistics and ordering in a single metadata read + let file_meta = self + .options + .format + .infer_stats_and_ordering(ctx, store, Arc::clone(&self.file_schema), meta) + .await?; + + let statistics = Arc::new(file_meta.statistics); + + // Store in cache + self.collected_statistics.put( + path, + CachedFileMetadata::new( + meta.clone(), + Arc::clone(&statistics), + file_meta.ordering.clone(), + ), + ); + + Ok((statistics, file_meta.ordering)) } } @@ -805,3 +913,146 @@ async fn get_files_with_limit( let inexact_stats = all_files.next().await.is_some(); Ok((file_group, inexact_stats)) } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::compute::SortOptions; + use datafusion_physical_expr::expressions::Column; + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; + use std::sync::Arc; + + /// Helper to create a PhysicalSortExpr + fn sort_expr( + name: &str, + idx: usize, + descending: bool, + nulls_first: bool, + ) -> PhysicalSortExpr { + PhysicalSortExpr::new( + Arc::new(Column::new(name, idx)), + SortOptions { + descending, + nulls_first, + }, + ) + } + + /// Helper to create a LexOrdering (unwraps the Option) + fn lex_ordering(exprs: Vec) -> LexOrdering { + LexOrdering::new(exprs).expect("expected non-empty ordering") + } + + /// Helper to create a PartitionedFile with optional ordering + fn create_file(name: &str, ordering: Option) -> PartitionedFile { + PartitionedFile::new(name.to_string(), 1024).with_ordering(ordering) + } + + #[test] + fn test_derive_common_ordering_all_files_same_ordering() { + // All files have the same ordering -> returns that ordering + let ordering = lex_ordering(vec![ + sort_expr("a", 0, false, true), + sort_expr("b", 1, true, false), + ]); + + let file_groups = vec![ + FileGroup::new(vec![ + create_file("f1.parquet", Some(ordering.clone())), + create_file("f2.parquet", Some(ordering.clone())), + ]), + FileGroup::new(vec![create_file("f3.parquet", Some(ordering.clone()))]), + ]; + + let result = derive_common_ordering_from_files(&file_groups); + assert_eq!(result, Some(ordering)); + } + + #[test] + fn test_derive_common_ordering_common_prefix() { + // Files have different orderings but share a common prefix + let ordering_abc = lex_ordering(vec![ + sort_expr("a", 0, false, true), + sort_expr("b", 1, false, true), + sort_expr("c", 2, false, true), + ]); + let ordering_ab = lex_ordering(vec![ + sort_expr("a", 0, false, true), + sort_expr("b", 1, false, true), + ]); + + let file_groups = vec![FileGroup::new(vec![ + create_file("f1.parquet", Some(ordering_abc)), + create_file("f2.parquet", Some(ordering_ab.clone())), + ])]; + + let result = derive_common_ordering_from_files(&file_groups); + assert_eq!(result, Some(ordering_ab)); + } + + #[test] + fn test_derive_common_ordering_no_common_prefix() { + // Files have completely different orderings -> returns None + let ordering_a = lex_ordering(vec![sort_expr("a", 0, false, true)]); + let ordering_b = lex_ordering(vec![sort_expr("b", 1, false, true)]); + + let file_groups = vec![FileGroup::new(vec![ + create_file("f1.parquet", Some(ordering_a)), + create_file("f2.parquet", Some(ordering_b)), + ])]; + + let result = derive_common_ordering_from_files(&file_groups); + assert_eq!(result, None); + } + + #[test] + fn test_derive_common_ordering_mixed_with_none() { + // Some files have ordering, some don't -> returns None + let ordering = lex_ordering(vec![sort_expr("a", 0, false, true)]); + + let file_groups = vec![FileGroup::new(vec![ + create_file("f1.parquet", Some(ordering)), + create_file("f2.parquet", None), + ])]; + + let result = derive_common_ordering_from_files(&file_groups); + assert_eq!(result, None); + } + + #[test] + fn test_derive_common_ordering_all_none() { + // No files have ordering -> returns None + let file_groups = vec![FileGroup::new(vec![ + create_file("f1.parquet", None), + create_file("f2.parquet", None), + ])]; + + let result = derive_common_ordering_from_files(&file_groups); + assert_eq!(result, None); + } + + #[test] + fn test_derive_common_ordering_empty_groups() { + // Empty file groups -> returns None + let file_groups: Vec = vec![]; + let result = derive_common_ordering_from_files(&file_groups); + assert_eq!(result, None); + } + + #[test] + fn test_derive_common_ordering_single_file() { + // Single file with ordering -> returns that ordering + let ordering = lex_ordering(vec![ + sort_expr("a", 0, false, true), + sort_expr("b", 1, true, false), + ]); + + let file_groups = vec![FileGroup::new(vec![create_file( + "f1.parquet", + Some(ordering.clone()), + )])]; + + let result = derive_common_ordering_from_files(&file_groups); + assert_eq!(result, Some(ordering)); + } +} diff --git a/datafusion/catalog/src/information_schema.rs b/datafusion/catalog/src/information_schema.rs index 52bfeca3d428..ea93dc21a3f5 100644 --- a/datafusion/catalog/src/information_schema.rs +++ b/datafusion/catalog/src/information_schema.rs @@ -24,7 +24,7 @@ use crate::{CatalogProviderList, SchemaProvider, TableProvider}; use arrow::array::builder::{BooleanBuilder, UInt8Builder}; use arrow::{ array::{StringBuilder, UInt64Builder}, - datatypes::{DataType, Field, Schema, SchemaRef}, + datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}, record_batch::RecordBatch, }; use async_trait::async_trait; @@ -34,7 +34,10 @@ use datafusion_common::error::Result; use datafusion_common::types::NativeType; use datafusion_execution::TaskContext; use datafusion_execution::runtime_env::RuntimeEnv; -use datafusion_expr::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF}; +use datafusion_expr::function::WindowUDFFieldArgs; +use datafusion_expr::{ + AggregateUDF, ReturnFieldArgs, ScalarUDF, Signature, TypeSignature, WindowUDF, +}; use datafusion_expr::{TableType, Volatility}; use datafusion_physical_plan::SendableRecordBatchStream; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; @@ -421,10 +424,24 @@ fn get_udf_args_and_return_types( Ok(arg_types .into_iter() .map(|arg_types| { - // only handle the function which implemented [`ScalarUDFImpl::return_type`] method + let arg_fields: Vec = arg_types + .iter() + .enumerate() + .map(|(i, t)| { + Arc::new(Field::new(format!("arg_{i}"), t.clone(), true)) + }) + .collect(); + let scalar_arguments = vec![None; arg_fields.len()]; let return_type = udf - .return_type(&arg_types) - .map(|t| remove_native_type_prefix(&NativeType::from(t))) + .return_field_from_args(ReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &scalar_arguments, + }) + .map(|f| { + remove_native_type_prefix(&NativeType::from( + f.data_type().clone(), + )) + }) .ok(); let arg_types = arg_types .into_iter() @@ -447,11 +464,21 @@ fn get_udaf_args_and_return_types( Ok(arg_types .into_iter() .map(|arg_types| { - // only handle the function which implemented [`ScalarUDFImpl::return_type`] method + let arg_fields: Vec = arg_types + .iter() + .enumerate() + .map(|(i, t)| { + Arc::new(Field::new(format!("arg_{i}"), t.clone(), true)) + }) + .collect(); let return_type = udaf - .return_type(&arg_types) - .ok() - .map(|t| remove_native_type_prefix(&NativeType::from(t))); + .return_field(&arg_fields) + .map(|f| { + remove_native_type_prefix(&NativeType::from( + f.data_type().clone(), + )) + }) + .ok(); let arg_types = arg_types .into_iter() .map(|t| remove_native_type_prefix(&NativeType::from(t))) @@ -473,12 +500,26 @@ fn get_udwf_args_and_return_types( Ok(arg_types .into_iter() .map(|arg_types| { - // only handle the function which implemented [`ScalarUDFImpl::return_type`] method + let arg_fields: Vec = arg_types + .iter() + .enumerate() + .map(|(i, t)| { + Arc::new(Field::new(format!("arg_{i}"), t.clone(), true)) + }) + .collect(); + let return_type = udwf + .field(WindowUDFFieldArgs::new(&arg_fields, udwf.name())) + .map(|f| { + remove_native_type_prefix(&NativeType::from( + f.data_type().clone(), + )) + }) + .ok(); let arg_types = arg_types .into_iter() .map(|t| remove_native_type_prefix(&NativeType::from(t))) .collect::>(); - (arg_types, None) + (arg_types, return_type) }) .collect::>()) } diff --git a/datafusion/catalog/src/lib.rs b/datafusion/catalog/src/lib.rs index d1cd3998fecf..931941e8fdfa 100644 --- a/datafusion/catalog/src/lib.rs +++ b/datafusion/catalog/src/lib.rs @@ -24,7 +24,6 @@ // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -#![deny(clippy::allow_attributes)] //! Interfaces and default implementations of catalogs and schemas. //! diff --git a/datafusion/catalog/src/memory/table.rs b/datafusion/catalog/src/memory/table.rs index 7865eb016bee..484b5f805e54 100644 --- a/datafusion/catalog/src/memory/table.rs +++ b/datafusion/catalog/src/memory/table.rs @@ -549,7 +549,7 @@ fn evaluate_filters_to_mask( struct DmlResultExec { rows_affected: u64, schema: SchemaRef, - properties: PlanProperties, + properties: Arc, } impl DmlResultExec { @@ -570,7 +570,7 @@ impl DmlResultExec { Self { rows_affected, schema, - properties, + properties: Arc::new(properties), } } } @@ -604,7 +604,7 @@ impl ExecutionPlan for DmlResultExec { Arc::clone(&self.schema) } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.properties } diff --git a/datafusion/catalog/src/streaming.rs b/datafusion/catalog/src/streaming.rs index 31669171b291..db9596b420b7 100644 --- a/datafusion/catalog/src/streaming.rs +++ b/datafusion/catalog/src/streaming.rs @@ -20,19 +20,18 @@ use std::any::Any; use std::sync::Arc; -use crate::Session; -use crate::TableProvider; - use arrow::datatypes::SchemaRef; +use async_trait::async_trait; use datafusion_common::{DFSchema, Result, plan_err}; use datafusion_expr::{Expr, SortExpr, TableType}; +use datafusion_physical_expr::equivalence::project_ordering; use datafusion_physical_expr::{LexOrdering, create_physical_sort_exprs}; use datafusion_physical_plan::ExecutionPlan; use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; - -use async_trait::async_trait; use log::debug; +use crate::{Session, TableProvider}; + /// A [`TableProvider`] that streams a set of [`PartitionStream`] #[derive(Debug)] pub struct StreamingTable { @@ -105,7 +104,22 @@ impl TableProvider for StreamingTable { let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?; let eqp = state.execution_props(); - create_physical_sort_exprs(&self.sort_order, &df_schema, eqp)? + let original_sort_exprs = + create_physical_sort_exprs(&self.sort_order, &df_schema, eqp)?; + + if let Some(p) = projection { + // When performing a projection, the output columns will not match + // the original physical sort expression indices. Also the sort columns + // may not be in the output projection. To correct for these issues + // we need to project the ordering based on the output schema. + let schema = Arc::new(self.schema.project(p)?); + LexOrdering::new(original_sort_exprs) + .and_then(|lex_ordering| project_ordering(&lex_ordering, &schema)) + .map(|lex_ordering| lex_ordering.to_vec()) + .unwrap_or_default() + } else { + original_sort_exprs + } } else { vec![] }; diff --git a/datafusion/catalog/src/table.rs b/datafusion/catalog/src/table.rs index 1f223852c2b9..f31d4d52ce88 100644 --- a/datafusion/catalog/src/table.rs +++ b/datafusion/catalog/src/table.rs @@ -353,6 +353,14 @@ pub trait TableProvider: Debug + Sync + Send { ) -> Result> { not_impl_err!("UPDATE not supported for {} table", self.table_type()) } + + /// Remove all rows from the table. + /// + /// Should return an [ExecutionPlan] producing a single row with count (UInt64), + /// representing the number of rows removed. + async fn truncate(&self, _state: &dyn Session) -> Result> { + not_impl_err!("TRUNCATE not supported for {} table", self.table_type()) + } } /// Arguments for scanning a table with [`TableProvider::scan_with_args`]. diff --git a/datafusion/common-runtime/src/lib.rs b/datafusion/common-runtime/src/lib.rs index fdbfe7f2390c..cf45ccf3ef63 100644 --- a/datafusion/common-runtime/src/lib.rs +++ b/datafusion/common-runtime/src/lib.rs @@ -16,7 +16,6 @@ // under the License. #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -#![deny(clippy::allow_attributes)] #![doc( html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 262f50839563..e4ba71e45c66 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -57,6 +57,10 @@ sql = ["sqlparser"] harness = false name = "with_hashes" +[[bench]] +harness = false +name = "scalar_to_array" + [dependencies] ahash = { workspace = true } apache-avro = { workspace = true, features = [ @@ -72,7 +76,8 @@ half = { workspace = true } hashbrown = { workspace = true } hex = { workspace = true, optional = true } indexmap = { workspace = true } -libc = "0.2.177" +itertools = { workspace = true } +libc = "0.2.180" log = { workspace = true } object_store = { workspace = true, optional = true } parquet = { workspace = true, optional = true, default-features = true } diff --git a/datafusion/common/benches/scalar_to_array.rs b/datafusion/common/benches/scalar_to_array.rs new file mode 100644 index 000000000000..90a152e515fe --- /dev/null +++ b/datafusion/common/benches/scalar_to_array.rs @@ -0,0 +1,107 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmarks for `ScalarValue::to_array_of_size`, focusing on List +//! scalars. + +use arrow::array::{Array, ArrayRef, AsArray, StringViewBuilder}; +use arrow::datatypes::{DataType, Field}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::utils::SingleRowListArrayBuilder; +use std::sync::Arc; + +/// Build a `ScalarValue::List` of `num_elements` Utf8View strings whose +/// inner StringViewArray has `num_buffers` data buffers. +fn make_list_scalar(num_elements: usize, num_buffers: usize) -> ScalarValue { + let elements_per_buffer = num_elements.div_ceil(num_buffers); + + let mut small_arrays: Vec = Vec::new(); + let mut remaining = num_elements; + for buf_idx in 0..num_buffers { + let count = remaining.min(elements_per_buffer); + if count == 0 { + break; + } + let start = buf_idx * elements_per_buffer; + let mut builder = StringViewBuilder::with_capacity(count); + for i in start..start + count { + builder.append_value(format!("{i:024x}")); + } + small_arrays.push(Arc::new(builder.finish()) as ArrayRef); + remaining -= count; + } + + let refs: Vec<&dyn Array> = small_arrays.iter().map(|a| a.as_ref()).collect(); + let concated = arrow::compute::concat(&refs).unwrap(); + + let list_array = SingleRowListArrayBuilder::new(concated) + .with_field(&Field::new_list_field(DataType::Utf8View, true)) + .build_list_array(); + ScalarValue::List(Arc::new(list_array)) +} + +/// We want to measure the cost of doing the conversion and then also accessing +/// the results, to model what would happen during query evaluation. +fn consume_list_array(arr: &ArrayRef) { + let list_arr = arr.as_list::(); + let mut total_len: usize = 0; + for i in 0..list_arr.len() { + let inner = list_arr.value(i); + let sv = inner.as_string_view(); + for j in 0..sv.len() { + total_len += sv.value(j).len(); + } + } + std::hint::black_box(total_len); +} + +fn bench_list_to_array_of_size(c: &mut Criterion) { + let mut group = c.benchmark_group("list_to_array_of_size"); + + let num_elements = 1245; + let scalar_1buf = make_list_scalar(num_elements, 1); + let scalar_50buf = make_list_scalar(num_elements, 50); + + for batch_size in [256, 1024] { + group.bench_with_input( + BenchmarkId::new("1_buffer", batch_size), + &batch_size, + |b, &sz| { + b.iter(|| { + let arr = scalar_1buf.to_array_of_size(sz).unwrap(); + consume_list_array(&arr); + }); + }, + ); + group.bench_with_input( + BenchmarkId::new("50_buffers", batch_size), + &batch_size, + |b, &sz| { + b.iter(|| { + let arr = scalar_50buf.to_array_of_size(sz).unwrap(); + consume_list_array(&arr); + }); + }, + ); + } + + group.finish(); +} + +criterion_group!(benches, bench_list_to_array_of_size); +criterion_main!(benches); diff --git a/datafusion/common/benches/with_hashes.rs b/datafusion/common/benches/with_hashes.rs index 8154c20df88f..9ee31d9c4bef 100644 --- a/datafusion/common/benches/with_hashes.rs +++ b/datafusion/common/benches/with_hashes.rs @@ -19,11 +19,14 @@ use ahash::RandomState; use arrow::array::{ - Array, ArrayRef, ArrowPrimitiveType, DictionaryArray, GenericStringArray, - NullBufferBuilder, OffsetSizeTrait, PrimitiveArray, StringViewArray, make_array, + Array, ArrayRef, ArrowPrimitiveType, DictionaryArray, GenericStringArray, Int32Array, + Int64Array, ListArray, MapArray, NullBufferBuilder, OffsetSizeTrait, PrimitiveArray, + RunArray, StringViewArray, StructArray, UnionArray, make_array, +}; +use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; +use arrow::datatypes::{ + ArrowDictionaryKeyType, DataType, Field, Fields, Int32Type, Int64Type, UnionFields, }; -use arrow::buffer::NullBuffer; -use arrow::datatypes::{ArrowDictionaryKeyType, Int32Type, Int64Type}; use criterion::{Bencher, Criterion, criterion_group, criterion_main}; use datafusion_common::hash_utils::with_hashes; use rand::Rng; @@ -37,6 +40,8 @@ const BATCH_SIZE: usize = 8192; struct BenchData { name: &'static str, array: ArrayRef, + /// Union arrays can't have null bitmasks added + supports_nulls: bool, } fn criterion_benchmark(c: &mut Criterion) { @@ -47,50 +52,93 @@ fn criterion_benchmark(c: &mut Criterion) { BenchData { name: "int64", array: primitive_array::(BATCH_SIZE), + supports_nulls: true, }, BenchData { name: "utf8", array: pool.string_array::(BATCH_SIZE), + supports_nulls: true, }, BenchData { name: "large_utf8", array: pool.string_array::(BATCH_SIZE), + supports_nulls: true, }, BenchData { name: "utf8_view", array: pool.string_view_array(BATCH_SIZE), + supports_nulls: true, }, BenchData { name: "utf8_view (small)", array: small_pool.string_view_array(BATCH_SIZE), + supports_nulls: true, }, BenchData { name: "dictionary_utf8_int32", array: pool.dictionary_array::(BATCH_SIZE), + supports_nulls: true, + }, + BenchData { + name: "list_array", + array: list_array(BATCH_SIZE), + supports_nulls: true, + }, + BenchData { + name: "map_array", + array: map_array(BATCH_SIZE), + supports_nulls: true, + }, + BenchData { + name: "sparse_union", + array: sparse_union_array(BATCH_SIZE), + supports_nulls: false, + }, + BenchData { + name: "dense_union", + array: dense_union_array(BATCH_SIZE), + supports_nulls: false, + }, + BenchData { + name: "struct_array", + array: create_struct_array(&pool, BATCH_SIZE), + supports_nulls: true, + }, + BenchData { + name: "run_array_int32", + array: create_run_array::(BATCH_SIZE), + supports_nulls: true, }, ]; - for BenchData { name, array } in cases { - // with_hash has different code paths for single vs multiple arrays and nulls vs no nulls - let nullable_array = add_nulls(&array); + for BenchData { + name, + array, + supports_nulls, + } in cases + { c.bench_function(&format!("{name}: single, no nulls"), |b| { do_hash_test(b, std::slice::from_ref(&array)); }); - c.bench_function(&format!("{name}: single, nulls"), |b| { - do_hash_test(b, std::slice::from_ref(&nullable_array)); - }); c.bench_function(&format!("{name}: multiple, no nulls"), |b| { let arrays = vec![array.clone(), array.clone(), array.clone()]; do_hash_test(b, &arrays); }); - c.bench_function(&format!("{name}: multiple, nulls"), |b| { - let arrays = vec![ - nullable_array.clone(), - nullable_array.clone(), - nullable_array.clone(), - ]; - do_hash_test(b, &arrays); - }); + // Union arrays can't have null bitmasks + if supports_nulls { + let nullable_array = add_nulls(&array); + c.bench_function(&format!("{name}: single, nulls"), |b| { + do_hash_test(b, std::slice::from_ref(&nullable_array)); + }); + c.bench_function(&format!("{name}: multiple, nulls"), |b| { + let arrays = vec![ + nullable_array.clone(), + nullable_array.clone(), + nullable_array.clone(), + ]; + do_hash_test(b, &arrays); + }); + } } } @@ -122,16 +170,51 @@ where builder.finish().expect("should be nulls in buffer") } -// Returns an new array that is the same as array, but with nulls +// Returns a new array that is the same as array, but with nulls +// Handles the special case of RunArray where nulls must be in the values array fn add_nulls(array: &ArrayRef) -> ArrayRef { - let array_data = array - .clone() - .into_data() - .into_builder() - .nulls(Some(create_null_mask(array.len()))) - .build() - .unwrap(); - make_array(array_data) + use arrow::datatypes::DataType; + + match array.data_type() { + DataType::RunEndEncoded(_, _) => { + // RunArray can't have top-level nulls, so apply nulls to the values array + let run_array = array + .as_any() + .downcast_ref::>() + .expect("Expected RunArray"); + + let run_ends_buffer = run_array.run_ends().inner().clone(); + let run_ends_array = PrimitiveArray::::new(run_ends_buffer, None); + let values = run_array.values().clone(); + + // Add nulls to the values array + let values_with_nulls = { + let array_data = values + .clone() + .into_data() + .into_builder() + .nulls(Some(create_null_mask(values.len()))) + .build() + .unwrap(); + make_array(array_data) + }; + + Arc::new( + RunArray::try_new(&run_ends_array, values_with_nulls.as_ref()) + .expect("Failed to create RunArray with null values"), + ) + } + _ => { + let array_data = array + .clone() + .into_data() + .into_builder() + .nulls(Some(create_null_mask(array.len()))) + .build() + .unwrap(); + make_array(array_data) + } + } } pub fn make_rng() -> StdRng { @@ -205,5 +288,282 @@ where Arc::new(array) } -criterion_group!(benches, criterion_benchmark); +/// Benchmark sliced arrays to demonstrate the optimization for when an array is +/// sliced, the underlying buffer may be much larger than what's referenced by +/// the slice. The optimization avoids hashing unreferenced elements. +fn sliced_array_benchmark(c: &mut Criterion) { + // Test with different slice ratios: slice_size / total_size + // Smaller ratio = more potential savings from the optimization + let slice_ratios = [10, 5, 2]; // 1/10, 1/5, 1/2 of total + + for ratio in slice_ratios { + let total_rows = BATCH_SIZE * ratio; + let slice_offset = BATCH_SIZE * (ratio / 2); // Take from middle + let slice_len = BATCH_SIZE; + + // Sliced ListArray + { + let full_array = list_array(total_rows); + let sliced: ArrayRef = Arc::new( + full_array + .as_any() + .downcast_ref::() + .unwrap() + .slice(slice_offset, slice_len), + ); + c.bench_function( + &format!("list_array_sliced: 1/{ratio} of {total_rows} rows"), + |b| { + do_hash_test_with_len(b, std::slice::from_ref(&sliced), slice_len); + }, + ); + } + + // Sliced MapArray + { + let full_array = map_array(total_rows); + let sliced: ArrayRef = Arc::new( + full_array + .as_any() + .downcast_ref::() + .unwrap() + .slice(slice_offset, slice_len), + ); + c.bench_function( + &format!("map_array_sliced: 1/{ratio} of {total_rows} rows"), + |b| { + do_hash_test_with_len(b, std::slice::from_ref(&sliced), slice_len); + }, + ); + } + + // Sliced Sparse UnionArray + { + let full_array = sparse_union_array(total_rows); + let sliced: ArrayRef = Arc::new( + full_array + .as_any() + .downcast_ref::() + .unwrap() + .slice(slice_offset, slice_len), + ); + c.bench_function( + &format!("sparse_union_sliced: 1/{ratio} of {total_rows} rows"), + |b| { + do_hash_test_with_len(b, std::slice::from_ref(&sliced), slice_len); + }, + ); + } + } +} + +fn do_hash_test_with_len(b: &mut Bencher, arrays: &[ArrayRef], expected_len: usize) { + let state = RandomState::new(); + b.iter(|| { + with_hashes(arrays, &state, |hashes| { + assert_eq!(hashes.len(), expected_len); + Ok(()) + }) + .unwrap(); + }); +} + +fn list_array(num_rows: usize) -> ArrayRef { + let mut rng = make_rng(); + let elements_per_row = 5; + let total_elements = num_rows * elements_per_row; + + let values: Int64Array = (0..total_elements) + .map(|_| Some(rng.random::())) + .collect(); + let offsets: Vec = (0..=num_rows) + .map(|i| (i * elements_per_row) as i32) + .collect(); + + Arc::new(ListArray::new( + Arc::new(Field::new("item", DataType::Int64, true)), + OffsetBuffer::new(ScalarBuffer::from(offsets)), + Arc::new(values), + None, + )) +} + +fn map_array(num_rows: usize) -> ArrayRef { + let mut rng = make_rng(); + let entries_per_row = 5; + let total_entries = num_rows * entries_per_row; + + let keys: Int32Array = (0..total_entries) + .map(|_| Some(rng.random::())) + .collect(); + let values: Int64Array = (0..total_entries) + .map(|_| Some(rng.random::())) + .collect(); + let offsets: Vec = (0..=num_rows) + .map(|i| (i * entries_per_row) as i32) + .collect(); + + let entries = StructArray::try_new( + Fields::from(vec![ + Field::new("keys", DataType::Int32, false), + Field::new("values", DataType::Int64, true), + ]), + vec![Arc::new(keys), Arc::new(values)], + None, + ) + .unwrap(); + + Arc::new(MapArray::new( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("keys", DataType::Int32, false), + Field::new("values", DataType::Int64, true), + ])), + false, + )), + OffsetBuffer::new(ScalarBuffer::from(offsets)), + entries, + None, + false, + )) +} + +fn sparse_union_array(num_rows: usize) -> ArrayRef { + let mut rng = make_rng(); + let num_types = 5; + + let type_ids: Vec = (0..num_rows) + .map(|_| rng.random_range(0..num_types) as i8) + .collect(); + let (fields, children): (Vec<_>, Vec<_>) = (0..num_types) + .map(|i| { + ( + ( + i as i8, + Arc::new(Field::new(format!("f{i}"), DataType::Int64, true)), + ), + primitive_array::(num_rows), + ) + }) + .unzip(); + + Arc::new( + UnionArray::try_new( + UnionFields::from_iter(fields), + ScalarBuffer::from(type_ids), + None, + children, + ) + .unwrap(), + ) +} + +fn dense_union_array(num_rows: usize) -> ArrayRef { + let mut rng = make_rng(); + let num_types = 5; + let type_ids: Vec = (0..num_rows) + .map(|_| rng.random_range(0..num_types) as i8) + .collect(); + + let mut type_counts = vec![0i32; num_types]; + for &tid in &type_ids { + type_counts[tid as usize] += 1; + } + + let mut current_offsets = vec![0i32; num_types]; + let offsets: Vec = type_ids + .iter() + .map(|&tid| { + let offset = current_offsets[tid as usize]; + current_offsets[tid as usize] += 1; + offset + }) + .collect(); + + let (fields, children): (Vec<_>, Vec<_>) = (0..num_types) + .map(|i| { + ( + ( + i as i8, + Arc::new(Field::new(format!("f{i}"), DataType::Int64, true)), + ), + primitive_array::(type_counts[i] as usize), + ) + }) + .unzip(); + + Arc::new( + UnionArray::try_new( + UnionFields::from_iter(fields), + ScalarBuffer::from(type_ids), + Some(ScalarBuffer::from(offsets)), + children, + ) + .unwrap(), + ) +} + +fn boolean_array(array_len: usize) -> ArrayRef { + let mut rng = make_rng(); + Arc::new( + (0..array_len) + .map(|_| Some(rng.random::())) + .collect::(), + ) +} + +/// Create a StructArray with multiple columns +fn create_struct_array(pool: &StringPool, array_len: usize) -> ArrayRef { + let bool_array = boolean_array(array_len); + let int32_array = primitive_array::(array_len); + let int64_array = primitive_array::(array_len); + let str_array = pool.string_array::(array_len); + + let fields = Fields::from(vec![ + Field::new("bool_col", DataType::Boolean, false), + Field::new("int32_col", DataType::Int32, false), + Field::new("int64_col", DataType::Int64, false), + Field::new("string_col", DataType::Utf8, false), + ]); + + Arc::new(StructArray::new( + fields, + vec![bool_array, int32_array, int64_array, str_array], + None, + )) +} + +/// Create a RunArray to test run array hashing. +fn create_run_array(array_len: usize) -> ArrayRef +where + T: ArrowPrimitiveType, + StandardUniform: Distribution, +{ + let mut rng = make_rng(); + + // Create runs of varying lengths + let mut run_ends = Vec::new(); + let mut values = Vec::new(); + let mut current_end = 0; + + while current_end < array_len { + // Random run length between 1 and 50 + let run_length = rng.random_range(1..=50).min(array_len - current_end); + current_end += run_length; + run_ends.push(current_end as i32); + values.push(Some(rng.random::())); + } + + let run_ends_array = Arc::new(PrimitiveArray::::from(run_ends)); + let values_array: Arc = + Arc::new(values.into_iter().collect::>()); + + Arc::new( + RunArray::try_new(&run_ends_array, values_array.as_ref()) + .expect("Failed to create RunArray"), + ) +} + +criterion_group!(benches, criterion_benchmark, sliced_array_benchmark); criterion_main!(benches); diff --git a/datafusion/common/src/alias.rs b/datafusion/common/src/alias.rs index 2ee2cb4dc7ad..99f6447a6acd 100644 --- a/datafusion/common/src/alias.rs +++ b/datafusion/common/src/alias.rs @@ -37,6 +37,16 @@ impl AliasGenerator { Self::default() } + /// Advance the counter to at least `min_id`, ensuring future aliases + /// won't collide with already-existing ones. + /// + /// For example, if the query already contains an alias `alias_42`, then calling + /// `update_min_id(42)` will ensure that future aliases generated by this + /// [`AliasGenerator`] will start from `alias_43`. + pub fn update_min_id(&self, min_id: usize) { + self.next_id.fetch_max(min_id + 1, Ordering::Relaxed); + } + /// Return a unique alias with the provided prefix pub fn next(&self, prefix: &str) -> String { let id = self.next_id.fetch_add(1, Ordering::Relaxed); diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs index 29082cc303a7..bc4313ed9566 100644 --- a/datafusion/common/src/cast.rs +++ b/datafusion/common/src/cast.rs @@ -25,8 +25,9 @@ use arrow::array::{ BinaryViewArray, Decimal32Array, Decimal64Array, DurationMicrosecondArray, DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, Float16Array, Int8Array, Int16Array, LargeBinaryArray, LargeListViewArray, LargeStringArray, - ListViewArray, StringViewArray, UInt16Array, + ListViewArray, RunArray, StringViewArray, UInt16Array, }; +use arrow::datatypes::RunEndIndexType; use arrow::{ array::{ Array, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, @@ -334,3 +335,8 @@ pub fn as_list_view_array(array: &dyn Array) -> Result<&ListViewArray> { pub fn as_large_list_view_array(array: &dyn Array) -> Result<&LargeListViewArray> { Ok(downcast_value!(array, LargeListViewArray)) } + +// Downcast Array to RunArray +pub fn as_run_array(array: &dyn Array) -> Result<&RunArray> { + Ok(downcast_value!(array, RunArray, T)) +} diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 95a02147438b..d71af206c78d 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -469,6 +469,25 @@ config_namespace! { /// metadata memory consumption pub batch_size: usize, default = 8192 + /// A perfect hash join (see `HashJoinExec` for more details) will be considered + /// if the range of keys (max - min) on the build side is < this threshold. + /// This provides a fast path for joins with very small key ranges, + /// bypassing the density check. + /// + /// Currently only supports cases where build_side.num_rows() < u32::MAX. + /// Support for build_side.num_rows() >= u32::MAX will be added in the future. + pub perfect_hash_join_small_build_threshold: usize, default = 1024 + + /// The minimum required density of join keys on the build side to consider a + /// perfect hash join (see `HashJoinExec` for more details). Density is calculated as: + /// `(number of rows) / (max_key - min_key + 1)`. + /// A perfect hash join may be used if the actual key density > this + /// value. + /// + /// Currently only supports cases where build_side.num_rows() < u32::MAX. + /// Support for build_side.num_rows() >= u32::MAX will be added in the future. + pub perfect_hash_join_min_key_density: f64, default = 0.15 + /// When set to true, record batches will be examined between each operator and /// small batches will be coalesced into larger batches. This is helpful when there /// are highly selective filters or joins that could produce tiny output batches. The @@ -738,7 +757,7 @@ config_namespace! { /// (writing) Sets best effort maximum size of data page in bytes pub data_pagesize_limit: usize, default = 1024 * 1024 - /// (writing) Sets write_batch_size in bytes + /// (writing) Sets write_batch_size in rows pub write_batch_size: usize, default = 1024 /// (writing) Sets parquet writer version @@ -753,7 +772,7 @@ config_namespace! { /// (writing) Sets default parquet compression codec. /// Valid values are: uncompressed, snappy, gzip(level), - /// lzo, brotli(level), lz4, zstd(level), and lz4_raw. + /// brotli(level), lz4, zstd(level), and lz4_raw. /// These values are not case sensitive. If NULL, uses /// default parquet writer setting /// @@ -1123,6 +1142,12 @@ config_namespace! { /// /// Default: true pub enable_sort_pushdown: bool, default = true + + /// When set to true, the optimizer will extract leaf expressions + /// (such as `get_field`) from filter/sort/join nodes into projections + /// closer to the leaf table scans, and push those projections down + /// towards the leaf nodes. + pub enable_leaf_expression_pushdown: bool, default = true } } @@ -1231,7 +1256,7 @@ impl<'a> TryInto> for &'a FormatOptions } /// A key value pair, with a corresponding description -#[derive(Debug, Hash, PartialEq, Eq)] +#[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct ConfigEntry { /// A unique string to identify this config value pub key: String, @@ -1327,6 +1352,10 @@ impl ConfigField for ConfigOptions { } } +/// This namespace is reserved for interacting with Foreign Function Interface +/// (FFI) based configuration extensions. +pub const DATAFUSION_FFI_CONFIG_NAMESPACE: &str = "datafusion_ffi"; + impl ConfigOptions { /// Creates a new [`ConfigOptions`] with default values pub fn new() -> Self { @@ -1341,12 +1370,12 @@ impl ConfigOptions { /// Set a configuration option pub fn set(&mut self, key: &str, value: &str) -> Result<()> { - let Some((prefix, key)) = key.split_once('.') else { + let Some((mut prefix, mut inner_key)) = key.split_once('.') else { return _config_err!("could not find config namespace for key \"{key}\""); }; if prefix == "datafusion" { - if key == "optimizer.enable_dynamic_filter_pushdown" { + if inner_key == "optimizer.enable_dynamic_filter_pushdown" { let bool_value = value.parse::().map_err(|e| { DataFusionError::Configuration(format!( "Failed to parse '{value}' as bool: {e}", @@ -1361,13 +1390,23 @@ impl ConfigOptions { } return Ok(()); } - return ConfigField::set(self, key, value); + return ConfigField::set(self, inner_key, value); + } + + if !self.extensions.0.contains_key(prefix) + && self + .extensions + .0 + .contains_key(DATAFUSION_FFI_CONFIG_NAMESPACE) + { + inner_key = key; + prefix = DATAFUSION_FFI_CONFIG_NAMESPACE; } let Some(e) = self.extensions.0.get_mut(prefix) else { return _config_err!("Could not find config namespace \"{prefix}\""); }; - e.0.set(key, value) + e.0.set(inner_key, value) } /// Create new [`ConfigOptions`], taking values from environment variables @@ -2132,7 +2171,7 @@ impl TableOptions { /// /// A result indicating success or failure in setting the configuration option. pub fn set(&mut self, key: &str, value: &str) -> Result<()> { - let Some((prefix, _)) = key.split_once('.') else { + let Some((mut prefix, _)) = key.split_once('.') else { return _config_err!("could not find config namespace for key \"{key}\""); }; @@ -2144,6 +2183,15 @@ impl TableOptions { return Ok(()); } + if !self.extensions.0.contains_key(prefix) + && self + .extensions + .0 + .contains_key(DATAFUSION_FFI_CONFIG_NAMESPACE) + { + prefix = DATAFUSION_FFI_CONFIG_NAMESPACE; + } + let Some(e) = self.extensions.0.get_mut(prefix) else { return _config_err!("Could not find config namespace \"{prefix}\""); }; @@ -2229,7 +2277,7 @@ impl TableOptions { /// Options that control how Parquet files are read, including global options /// that apply to all columns and optional column-specific overrides /// -/// Closely tied to [`ParquetWriterOptions`](crate::file_options::parquet_writer::ParquetWriterOptions). +/// Closely tied to `ParquetWriterOptions` (see `crate::file_options::parquet_writer::ParquetWriterOptions` when the "parquet" feature is enabled). /// Properties not included in [`TableParquetOptions`] may not be configurable at the external API /// (e.g. sorting_columns). #[derive(Clone, Default, Debug, PartialEq)] @@ -2480,7 +2528,7 @@ config_namespace_with_hashmap! { /// Sets default parquet compression codec for the column path. /// Valid values are: uncompressed, snappy, gzip(level), - /// lzo, brotli(level), lz4, zstd(level), and lz4_raw. + /// brotli(level), lz4, zstd(level), and lz4_raw. /// These values are not case-sensitive. If NULL, uses /// default parquet options pub compression: Option, transform = str::to_lowercase, default = None @@ -3046,6 +3094,22 @@ config_namespace! { /// If not specified, the default level for the compression algorithm is used. pub compression_level: Option, default = None pub schema_infer_max_rec: Option, default = None + /// The JSON format to use when reading files. + /// + /// When `true` (default), expects newline-delimited JSON (NDJSON): + /// ```text + /// {"key1": 1, "key2": "val"} + /// {"key1": 2, "key2": "vals"} + /// ``` + /// + /// When `false`, expects JSON array format: + /// ```text + /// [ + /// {"key1": 1, "key2": "val"}, + /// {"key1": 2, "key2": "vals"} + /// ] + /// ``` + pub newline_delimited: bool, default = true } } diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 55a031d87012..de0aacf9e8bc 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -698,10 +698,12 @@ impl DFSchema { // check nested fields match (dt1, dt2) { (DataType::Dictionary(_, v1), DataType::Dictionary(_, v2)) => { - v1.as_ref() == v2.as_ref() + Self::datatype_is_logically_equal(v1.as_ref(), v2.as_ref()) + } + (DataType::Dictionary(_, v1), othertype) + | (othertype, DataType::Dictionary(_, v1)) => { + Self::datatype_is_logically_equal(v1.as_ref(), othertype) } - (DataType::Dictionary(_, v1), othertype) => v1.as_ref() == othertype, - (othertype, DataType::Dictionary(_, v1)) => v1.as_ref() == othertype, (DataType::List(f1), DataType::List(f2)) | (DataType::LargeList(f1), DataType::LargeList(f2)) | (DataType::FixedSizeList(f1, _), DataType::FixedSizeList(f2, _)) => { @@ -1134,6 +1136,12 @@ impl TryFrom for DFSchema { } } +impl From for SchemaRef { + fn from(dfschema: DFSchema) -> Self { + Arc::clone(&dfschema.inner) + } +} + // Hashing refers to a subset of fields considered in PartialEq. impl Hash for DFSchema { fn hash(&self, state: &mut H) { @@ -1792,6 +1800,27 @@ mod tests { &DataType::Utf8, &DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) )); + + // Dictionary is logically equal to the logically equivalent value type + assert!(DFSchema::datatype_is_logically_equal( + &DataType::Utf8View, + &DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) + )); + + assert!(DFSchema::datatype_is_logically_equal( + &DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::List( + Field::new("element", DataType::Utf8, false).into() + )) + ), + &DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::List( + Field::new("element", DataType::Utf8View, false).into() + )) + ) + )); } #[test] diff --git a/datafusion/common/src/file_options/mod.rs b/datafusion/common/src/file_options/mod.rs index c7374949ecef..5d2abd23172e 100644 --- a/datafusion/common/src/file_options/mod.rs +++ b/datafusion/common/src/file_options/mod.rs @@ -84,7 +84,7 @@ mod tests { .build(); // Verify the expected options propagated down to parquet crate WriterProperties struct - assert_eq!(properties.max_row_group_size(), 123); + assert_eq!(properties.max_row_group_row_count(), Some(123)); assert_eq!(properties.data_page_size_limit(), 123); assert_eq!(properties.write_batch_size(), 123); assert_eq!(properties.writer_version(), WriterVersion::PARQUET_2_0); diff --git a/datafusion/common/src/file_options/parquet_writer.rs b/datafusion/common/src/file_options/parquet_writer.rs index 196cb96f3832..a7a1fc6d0bb6 100644 --- a/datafusion/common/src/file_options/parquet_writer.rs +++ b/datafusion/common/src/file_options/parquet_writer.rs @@ -222,7 +222,7 @@ impl ParquetOptions { .and_then(|s| parse_statistics_string(s).ok()) .unwrap_or(DEFAULT_STATISTICS_ENABLED), ) - .set_max_row_group_size(*max_row_group_size) + .set_max_row_group_row_count(Some(*max_row_group_size)) .set_created_by(created_by.clone()) .set_column_index_truncate_length(*column_index_truncate_length) .set_statistics_truncate_length(*statistics_truncate_length) @@ -341,10 +341,6 @@ pub fn parse_compression_string( level, )?)) } - "lzo" => { - check_level_is_none(codec, &level)?; - Ok(parquet::basic::Compression::LZO) - } "brotli" => { let level = require_level(codec, level)?; Ok(parquet::basic::Compression::BROTLI(BrotliLevel::try_new( @@ -368,7 +364,7 @@ pub fn parse_compression_string( _ => Err(DataFusionError::Configuration(format!( "Unknown or unsupported parquet compression: \ {str_setting}. Valid values are: uncompressed, snappy, gzip(level), \ - lzo, brotli(level), lz4, zstd(level), and lz4_raw." + brotli(level), lz4, zstd(level), and lz4_raw." ))), } } @@ -397,7 +393,7 @@ mod tests { use parquet::basic::Compression; use parquet::file::properties::{ BloomFilterProperties, DEFAULT_BLOOM_FILTER_FPP, DEFAULT_BLOOM_FILTER_NDV, - EnabledStatistics, + DEFAULT_MAX_ROW_GROUP_ROW_COUNT, EnabledStatistics, }; use std::collections::HashMap; @@ -540,7 +536,9 @@ mod tests { write_batch_size: props.write_batch_size(), writer_version: props.writer_version().into(), dictionary_page_size_limit: props.dictionary_page_size_limit(), - max_row_group_size: props.max_row_group_size(), + max_row_group_size: props + .max_row_group_row_count() + .unwrap_or(DEFAULT_MAX_ROW_GROUP_ROW_COUNT), created_by: props.created_by().to_string(), column_index_truncate_length: props.column_index_truncate_length(), statistics_truncate_length: props.statistics_truncate_length(), diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index 98dd1f235aee..3be6118c55ff 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -20,15 +20,19 @@ use ahash::RandomState; use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; use arrow::array::*; +use arrow::compute::take; use arrow::datatypes::*; #[cfg(not(feature = "force_hash_collisions"))] use arrow::{downcast_dictionary_array, downcast_primitive_array}; +use itertools::Itertools; +use std::collections::HashMap; #[cfg(not(feature = "force_hash_collisions"))] use crate::cast::{ as_binary_view_array, as_boolean_array, as_fixed_size_list_array, - as_generic_binary_array, as_large_list_array, as_list_array, as_map_array, - as_string_array, as_string_view_array, as_struct_array, as_union_array, + as_generic_binary_array, as_large_list_array, as_large_list_view_array, + as_list_array, as_list_view_array, as_map_array, as_string_array, + as_string_view_array, as_struct_array, as_union_array, }; use crate::error::Result; use crate::error::{_internal_datafusion_err, _internal_err}; @@ -390,33 +394,22 @@ fn hash_generic_byte_view_array( } } -/// Helper function to update hash for a dictionary key if the value is valid -#[cfg(not(feature = "force_hash_collisions"))] -#[inline] -fn update_hash_for_dict_key( - hash: &mut u64, - dict_hashes: &[u64], - dict_values: &dyn Array, - idx: usize, - multi_col: bool, -) { - if dict_values.is_valid(idx) { - if multi_col { - *hash = combine_hashes(dict_hashes[idx], *hash); - } else { - *hash = dict_hashes[idx]; - } - } - // no update for invalid dictionary value -} - -/// Hash the values in a dictionary array -#[cfg(not(feature = "force_hash_collisions"))] -fn hash_dictionary( +/// Hash dictionary array with compile-time specialization for null handling. +/// +/// Uses const generics to eliminate runtim branching in the hot loop: +/// - `HAS_NULL_KEYS`: Whether to check for null dictionary keys +/// - `HAS_NULL_VALUES`: Whether to check for null dictionary values +/// - `MULTI_COL`: Whether to combine with existing hash (true) or initialize (false) +#[inline(never)] +fn hash_dictionary_inner< + K: ArrowDictionaryKeyType, + const HAS_NULL_KEYS: bool, + const HAS_NULL_VALUES: bool, + const MULTI_COL: bool, +>( array: &DictionaryArray, random_state: &RandomState, hashes_buffer: &mut [u64], - multi_col: bool, ) -> Result<()> { // Hash each dictionary value once, and then use that computed // hash for each key value to avoid a potentially expensive @@ -425,22 +418,91 @@ fn hash_dictionary( let mut dict_hashes = vec![0; dict_values.len()]; create_hashes([dict_values], random_state, &mut dict_hashes)?; - // combine hash for each index in values - for (hash, key) in hashes_buffer.iter_mut().zip(array.keys().iter()) { - if let Some(key) = key { + if HAS_NULL_KEYS { + for (hash, key) in hashes_buffer.iter_mut().zip(array.keys().iter()) { + if let Some(key) = key { + let idx = key.as_usize(); + if !HAS_NULL_VALUES || dict_values.is_valid(idx) { + if MULTI_COL { + *hash = combine_hashes(dict_hashes[idx], *hash); + } else { + *hash = dict_hashes[idx]; + } + } + } + } + } else { + for (hash, key) in hashes_buffer.iter_mut().zip(array.keys().values()) { let idx = key.as_usize(); - update_hash_for_dict_key( - hash, - &dict_hashes, - dict_values.as_ref(), - idx, - multi_col, - ); - } // no update for Null key + if !HAS_NULL_VALUES || dict_values.is_valid(idx) { + if MULTI_COL { + *hash = combine_hashes(dict_hashes[idx], *hash); + } else { + *hash = dict_hashes[idx]; + } + } + } } Ok(()) } +/// Hash the values in a dictionary array +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_dictionary( + array: &DictionaryArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], + multi_col: bool, +) -> Result<()> { + let has_null_keys = array.keys().null_count() != 0; + let has_null_values = array.values().null_count() != 0; + + // Dispatcher based on null presence and multi-column mode + // Should reduce branching within hot loops + match (has_null_keys, has_null_values, multi_col) { + (false, false, false) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + (false, false, true) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + (false, true, false) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + (false, true, true) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + (true, false, false) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + (true, false, true) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + (true, true, false) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + (true, true, true) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + } +} + #[cfg(not(feature = "force_hash_collisions"))] fn hash_struct_array( array: &StructArray, @@ -450,19 +512,21 @@ fn hash_struct_array( let nulls = array.nulls(); let row_len = array.len(); - let valid_row_indices: Vec = if let Some(nulls) = nulls { - nulls.valid_indices().collect() - } else { - (0..row_len).collect() - }; - // Create hashes for each row that combines the hashes over all the column at that row. let mut values_hashes = vec![0u64; row_len]; create_hashes(array.columns(), random_state, &mut values_hashes)?; - for i in valid_row_indices { - let hash = &mut hashes_buffer[i]; - *hash = combine_hashes(*hash, values_hashes[i]); + // Separate paths to avoid allocating Vec when there are no nulls + if let Some(nulls) = nulls { + for i in nulls.valid_indices() { + let hash = &mut hashes_buffer[i]; + *hash = combine_hashes(*hash, values_hashes[i]); + } + } else { + for i in 0..row_len { + let hash = &mut hashes_buffer[i]; + *hash = combine_hashes(*hash, values_hashes[i]); + } } Ok(()) @@ -479,15 +543,29 @@ fn hash_map_array( let offsets = array.offsets(); // Create hashes for each entry in each row - let mut values_hashes = vec![0u64; array.entries().len()]; - create_hashes(array.entries().columns(), random_state, &mut values_hashes)?; + let first_offset = offsets.first().copied().unwrap_or_default() as usize; + let last_offset = offsets.last().copied().unwrap_or_default() as usize; + let entries_len = last_offset - first_offset; + + // Only hash the entries that are actually referenced + let mut values_hashes = vec![0u64; entries_len]; + let entries = array.entries(); + let sliced_columns: Vec = entries + .columns() + .iter() + .map(|col| col.slice(first_offset, entries_len)) + .collect(); + create_hashes(&sliced_columns, random_state, &mut values_hashes)?; // Combine the hashes for entries on each row with each other and previous hash for that row + // Adjust indices by first_offset since values_hashes is sliced starting from first_offset if let Some(nulls) = nulls { for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() { if nulls.is_valid(i) { let hash = &mut hashes_buffer[i]; - for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { + for values_hash in &values_hashes + [start.as_usize() - first_offset..stop.as_usize() - first_offset] + { *hash = combine_hashes(*hash, *values_hash); } } @@ -495,7 +573,9 @@ fn hash_map_array( } else { for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() { let hash = &mut hashes_buffer[i]; - for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { + for values_hash in &values_hashes + [start.as_usize() - first_offset..stop.as_usize() - first_offset] + { *hash = combine_hashes(*hash, *values_hash); } } @@ -510,27 +590,83 @@ fn hash_list_array( random_state: &RandomState, hashes_buffer: &mut [u64], ) -> Result<()> +where + OffsetSize: OffsetSizeTrait, +{ + // In case values is sliced, hash only the bytes used by the offsets of this ListArray + let first_offset = array.value_offsets().first().cloned().unwrap_or_default(); + let last_offset = array.value_offsets().last().cloned().unwrap_or_default(); + let value_bytes_len = (last_offset - first_offset).as_usize(); + let mut values_hashes = vec![0u64; value_bytes_len]; + create_hashes( + [array + .values() + .slice(first_offset.as_usize(), value_bytes_len)], + random_state, + &mut values_hashes, + )?; + + if array.null_count() > 0 { + for (i, (start, stop)) in array.value_offsets().iter().tuple_windows().enumerate() + { + if array.is_valid(i) { + let hash = &mut hashes_buffer[i]; + for values_hash in &values_hashes[(*start - first_offset).as_usize() + ..(*stop - first_offset).as_usize()] + { + *hash = combine_hashes(*hash, *values_hash); + } + } + } + } else { + for ((start, stop), hash) in array + .value_offsets() + .iter() + .tuple_windows() + .zip(hashes_buffer.iter_mut()) + { + for values_hash in &values_hashes + [(*start - first_offset).as_usize()..(*stop - first_offset).as_usize()] + { + *hash = combine_hashes(*hash, *values_hash); + } + } + } + Ok(()) +} + +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_list_view_array( + array: &GenericListViewArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> where OffsetSize: OffsetSizeTrait, { let values = array.values(); let offsets = array.value_offsets(); + let sizes = array.value_sizes(); let nulls = array.nulls(); let mut values_hashes = vec![0u64; values.len()]; create_hashes([values], random_state, &mut values_hashes)?; if let Some(nulls) = nulls { - for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() { + for (i, (offset, size)) in offsets.iter().zip(sizes.iter()).enumerate() { if nulls.is_valid(i) { let hash = &mut hashes_buffer[i]; - for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { + let start = offset.as_usize(); + let end = start + size.as_usize(); + for values_hash in &values_hashes[start..end] { *hash = combine_hashes(*hash, *values_hash); } } } } else { - for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() { + for (i, (offset, size)) in offsets.iter().zip(sizes.iter()).enumerate() { let hash = &mut hashes_buffer[i]; - for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { + let start = offset.as_usize(); + let end = start + size.as_usize(); + for values_hash in &values_hashes[start..end] { *hash = combine_hashes(*hash, *values_hash); } } @@ -544,14 +680,42 @@ fn hash_union_array( random_state: &RandomState, hashes_buffer: &mut [u64], ) -> Result<()> { - use std::collections::HashMap; - let DataType::Union(union_fields, _mode) = array.data_type() else { unreachable!() }; - let mut child_hashes = HashMap::with_capacity(union_fields.len()); + if array.is_dense() { + // Dense union: children only contain values of their type, so they're already compact. + // Use the default hashing approach which is efficient for dense unions. + hash_union_array_default(array, union_fields, random_state, hashes_buffer) + } else { + // Sparse union: each child has the same length as the union array. + // Optimization: only hash the elements that are actually referenced by type_ids, + // instead of hashing all K*N elements (where K = num types, N = array length). + hash_sparse_union_array(array, union_fields, random_state, hashes_buffer) + } +} + +/// Default hashing for union arrays - hashes all elements of each child array fully. +/// +/// This approach works for both dense and sparse union arrays: +/// - Dense unions: children are compact (each child only contains values of that type) +/// - Sparse unions: children have the same length as the union array +/// +/// For sparse unions with 3+ types, the optimized take/scatter approach in +/// `hash_sparse_union_array` is more efficient, but for 1-2 types or dense unions, +/// this simpler approach is preferred. +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_union_array_default( + array: &UnionArray, + union_fields: &UnionFields, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> { + let mut child_hashes: HashMap> = + HashMap::with_capacity(union_fields.len()); + // Hash each child array fully for (type_id, _field) in union_fields.iter() { let child = array.child(type_id); let mut child_hash_buffer = vec![0; child.len()]; @@ -560,6 +724,9 @@ fn hash_union_array( child_hashes.insert(type_id, child_hash_buffer); } + // Combine hashes for each row using the appropriate child offset + // For dense unions: value_offset points to the actual position in the child + // For sparse unions: value_offset equals the row index #[expect(clippy::needless_range_loop)] for i in 0..array.len() { let type_id = array.type_id(i); @@ -572,6 +739,69 @@ fn hash_union_array( Ok(()) } +/// Hash a sparse union array. +/// Sparse unions have child arrays with the same length as the union array. +/// For 3+ types, we optimize by only hashing the N elements that are actually used +/// (via take/scatter), instead of hashing all K*N elements. +/// +/// For 1-2 types, the overhead of take/scatter outweighs the benefit, so we use +/// the default approach of hashing all children (same as dense unions). +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_sparse_union_array( + array: &UnionArray, + union_fields: &UnionFields, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> { + use std::collections::HashMap; + + // For 1-2 types, the take/scatter overhead isn't worth it. + // Fall back to the default approach (same as dense union). + if union_fields.len() <= 2 { + return hash_union_array_default( + array, + union_fields, + random_state, + hashes_buffer, + ); + } + + let type_ids = array.type_ids(); + + // Group indices by type_id + let mut indices_by_type: HashMap> = HashMap::new(); + for (i, &type_id) in type_ids.iter().enumerate() { + indices_by_type.entry(type_id).or_default().push(i as u32); + } + + // For each type, extract only the needed elements, hash them, and scatter back + for (type_id, _field) in union_fields.iter() { + if let Some(indices) = indices_by_type.get(&type_id) { + if indices.is_empty() { + continue; + } + + let child = array.child(type_id); + let indices_array = UInt32Array::from(indices.clone()); + + // Extract only the elements we need using take() + let filtered = take(child.as_ref(), &indices_array, None)?; + + // Hash the filtered array + let mut filtered_hashes = vec![0u64; filtered.len()]; + create_hashes([&filtered], random_state, &mut filtered_hashes)?; + + // Scatter hashes back to correct positions + for (hash, &idx) in filtered_hashes.iter().zip(indices.iter()) { + hashes_buffer[idx as usize] = + combine_hashes(hashes_buffer[idx as usize], *hash); + } + } + } + + Ok(()) +} + #[cfg(not(feature = "force_hash_collisions"))] fn hash_fixed_list_array( array: &FixedSizeListArray, @@ -605,12 +835,17 @@ fn hash_fixed_list_array( Ok(()) } +/// Inner hash function for RunArray +#[inline(never)] #[cfg(not(feature = "force_hash_collisions"))] -fn hash_run_array( +fn hash_run_array_inner< + R: RunEndIndexType, + const HAS_NULL_VALUES: bool, + const REHASH: bool, +>( array: &RunArray, random_state: &RandomState, hashes_buffer: &mut [u64], - rehash: bool, ) -> Result<()> { // We find the relevant runs that cover potentially sliced arrays, so we can only hash those // values. Then we find the runs that refer to the original runs and ensure that we apply @@ -648,25 +883,23 @@ fn hash_run_array( .iter() .enumerate() { - let is_null_value = sliced_values.is_null(adjusted_physical_index); let absolute_run_end = absolute_run_end.as_usize(); - let end_in_slice = (absolute_run_end - array_offset).min(array_len); - if rehash { - if !is_null_value { - let value_hash = values_hashes[adjusted_physical_index]; - for hash in hashes_buffer - .iter_mut() - .take(end_in_slice) - .skip(start_in_slice) - { - *hash = combine_hashes(value_hash, *hash); - } + if HAS_NULL_VALUES && sliced_values.is_null(adjusted_physical_index) { + start_in_slice = end_in_slice; + continue; + } + + let value_hash = values_hashes[adjusted_physical_index]; + let run_slice = &mut hashes_buffer[start_in_slice..end_in_slice]; + + if REHASH { + for hash in run_slice.iter_mut() { + *hash = combine_hashes(value_hash, *hash); } } else { - let value_hash = values_hashes[adjusted_physical_index]; - hashes_buffer[start_in_slice..end_in_slice].fill(value_hash); + run_slice.fill(value_hash); } start_in_slice = end_in_slice; @@ -675,6 +908,31 @@ fn hash_run_array( Ok(()) } +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_run_array( + array: &RunArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], + rehash: bool, +) -> Result<()> { + let has_null_values = array.values().null_count() != 0; + + match (has_null_values, rehash) { + (false, false) => { + hash_run_array_inner::(array, random_state, hashes_buffer) + } + (false, true) => { + hash_run_array_inner::(array, random_state, hashes_buffer) + } + (true, false) => { + hash_run_array_inner::(array, random_state, hashes_buffer) + } + (true, true) => { + hash_run_array_inner::(array, random_state, hashes_buffer) + } + } +} + /// Internal helper function that hashes a single array and either initializes or combines /// the hash values in the buffer. #[cfg(not(feature = "force_hash_collisions"))] @@ -714,6 +972,14 @@ fn hash_single_array( let array = as_large_list_array(array)?; hash_list_array(array, random_state, hashes_buffer)?; } + DataType::ListView(_) => { + let array = as_list_view_array(array)?; + hash_list_view_array(array, random_state, hashes_buffer)?; + } + DataType::LargeListView(_) => { + let array = as_large_list_view_array(array)?; + hash_list_view_array(array, random_state, hashes_buffer)?; + } DataType::Map(_, _) => { let array = as_map_array(array)?; hash_map_array(array, random_state, hashes_buffer)?; @@ -1128,6 +1394,130 @@ mod tests { assert_eq!(hashes[1], hashes[6]); // null vs empty list } + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_sliced_list_arrays() { + let data = vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + // Slice from here + Some(vec![Some(3), None, Some(5)]), + Some(vec![Some(3), None, Some(5)]), + None, + // To here + Some(vec![Some(0), Some(1), Some(2)]), + Some(vec![]), + ]; + let list_array = + Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; + let list_array = list_array.slice(2, 3); + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut hashes = vec![0; list_array.len()]; + create_hashes(&[list_array], &random_state, &mut hashes).unwrap(); + assert_eq!(hashes[0], hashes[1]); + assert_ne!(hashes[1], hashes[2]); + } + + #[test] + // Tests actual values of hashes, which are different if forcing collisions + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_list_view_arrays() { + use arrow::buffer::{NullBuffer, ScalarBuffer}; + + // Create values array: [0, 1, 2, 3, null, 5] + let values = Arc::new(Int32Array::from(vec![ + Some(0), + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef; + let field = Arc::new(Field::new("item", DataType::Int32, true)); + + // Create ListView with the following logical structure: + // Row 0: [0, 1, 2] (offset=0, size=3) + // Row 1: null (null bit set) + // Row 2: [3, null, 5] (offset=3, size=3) + // Row 3: [3, null, 5] (offset=3, size=3) - same as row 2 + // Row 4: null (null bit set) + // Row 5: [0, 1, 2] (offset=0, size=3) - same as row 0 + // Row 6: [] (offset=0, size=0) - empty list + let offsets = ScalarBuffer::from(vec![0i32, 0, 3, 3, 0, 0, 0]); + let sizes = ScalarBuffer::from(vec![3i32, 0, 3, 3, 0, 3, 0]); + let nulls = Some(NullBuffer::from(vec![ + true, false, true, true, false, true, true, + ])); + + let list_view_array = + Arc::new(ListViewArray::new(field, offsets, sizes, values, nulls)) + as ArrayRef; + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut hashes = vec![0; list_view_array.len()]; + create_hashes(&[list_view_array], &random_state, &mut hashes).unwrap(); + + assert_eq!(hashes[0], hashes[5]); // same content [0, 1, 2] + assert_eq!(hashes[1], hashes[4]); // both null + assert_eq!(hashes[2], hashes[3]); // same content [3, null, 5] + assert_eq!(hashes[1], hashes[6]); // null vs empty list + + // Negative tests: different content should produce different hashes + assert_ne!(hashes[0], hashes[2]); // [0, 1, 2] vs [3, null, 5] + assert_ne!(hashes[0], hashes[6]); // [0, 1, 2] vs [] + assert_ne!(hashes[2], hashes[6]); // [3, null, 5] vs [] + } + + #[test] + // Tests actual values of hashes, which are different if forcing collisions + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_large_list_view_arrays() { + use arrow::buffer::{NullBuffer, ScalarBuffer}; + + // Create values array: [0, 1, 2, 3, null, 5] + let values = Arc::new(Int32Array::from(vec![ + Some(0), + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef; + let field = Arc::new(Field::new("item", DataType::Int32, true)); + + // Create LargeListView with the following logical structure: + // Row 0: [0, 1, 2] (offset=0, size=3) + // Row 1: null (null bit set) + // Row 2: [3, null, 5] (offset=3, size=3) + // Row 3: [3, null, 5] (offset=3, size=3) - same as row 2 + // Row 4: null (null bit set) + // Row 5: [0, 1, 2] (offset=0, size=3) - same as row 0 + // Row 6: [] (offset=0, size=0) - empty list + let offsets = ScalarBuffer::from(vec![0i64, 0, 3, 3, 0, 0, 0]); + let sizes = ScalarBuffer::from(vec![3i64, 0, 3, 3, 0, 3, 0]); + let nulls = Some(NullBuffer::from(vec![ + true, false, true, true, false, true, true, + ])); + + let large_list_view_array = Arc::new(LargeListViewArray::new( + field, offsets, sizes, values, nulls, + )) as ArrayRef; + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut hashes = vec![0; large_list_view_array.len()]; + create_hashes(&[large_list_view_array], &random_state, &mut hashes).unwrap(); + + assert_eq!(hashes[0], hashes[5]); // same content [0, 1, 2] + assert_eq!(hashes[1], hashes[4]); // both null + assert_eq!(hashes[2], hashes[3]); // same content [3, null, 5] + assert_eq!(hashes[1], hashes[6]); // null vs empty list + + // Negative tests: different content should produce different hashes + assert_ne!(hashes[0], hashes[2]); // [0, 1, 2] vs [3, null, 5] + assert_ne!(hashes[0], hashes[6]); // [0, 1, 2] vs [] + assert_ne!(hashes[2], hashes[6]); // [3, null, 5] vs [] + } + #[test] // Tests actual values of hashes, which are different if forcing collisions #[cfg(not(feature = "force_hash_collisions"))] diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index df6659c6f843..fdd04f752455 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -24,7 +24,6 @@ // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -#![deny(clippy::allow_attributes)] mod column; mod dfschema; diff --git a/datafusion/common/src/metadata.rs b/datafusion/common/src/metadata.rs index eb687bde07d0..d6d8fb7b0ed0 100644 --- a/datafusion/common/src/metadata.rs +++ b/datafusion/common/src/metadata.rs @@ -171,6 +171,10 @@ pub fn format_type_and_metadata( /// // Add any metadata from `FieldMetadata` to `Field` /// let updated_field = metadata.add_to_field(field); /// ``` +/// +/// For more background, please also see the [Implementing User Defined Types and Custom Metadata in DataFusion blog] +/// +/// [Implementing User Defined Types and Custom Metadata in DataFusion blog]: https://datafusion.apache.org/blog/2025/09/21/custom-types-using-metadata #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct FieldMetadata { /// The inner metadata of a literal expression, which is a map of string diff --git a/datafusion/common/src/nested_struct.rs b/datafusion/common/src/nested_struct.rs index 086d96e85230..bf2558f31306 100644 --- a/datafusion/common/src/nested_struct.rs +++ b/datafusion/common/src/nested_struct.rs @@ -19,9 +19,9 @@ use crate::error::{_plan_err, Result}; use arrow::{ array::{Array, ArrayRef, StructArray, new_null_array}, compute::{CastOptions, cast_with_options}, - datatypes::{DataType::Struct, Field, FieldRef}, + datatypes::{DataType, DataType::Struct, Field, FieldRef}, }; -use std::sync::Arc; +use std::{collections::HashSet, sync::Arc}; /// Cast a struct column to match target struct fields, handling nested structs recursively. /// @@ -31,6 +31,7 @@ use std::sync::Arc; /// /// ## Field Matching Strategy /// - **By Name**: Source struct fields are matched to target fields by name (case-sensitive) +/// - **No Positional Mapping**: Structs with no overlapping field names are rejected /// - **Type Adaptation**: When a matching field is found, it is recursively cast to the target field's type /// - **Missing Fields**: Target fields not present in the source are filled with null values /// - **Extra Fields**: Source fields not present in the target are ignored @@ -54,16 +55,30 @@ fn cast_struct_column( target_fields: &[Arc], cast_options: &CastOptions, ) -> Result { - if let Some(source_struct) = source_col.as_any().downcast_ref::() { - validate_struct_compatibility(source_struct.fields(), target_fields)?; + if source_col.data_type() == &DataType::Null + || (!source_col.is_empty() && source_col.null_count() == source_col.len()) + { + return Ok(new_null_array( + &Struct(target_fields.to_vec().into()), + source_col.len(), + )); + } + if let Some(source_struct) = source_col.as_any().downcast_ref::() { + let source_fields = source_struct.fields(); + validate_struct_compatibility(source_fields, target_fields)?; let mut fields: Vec> = Vec::with_capacity(target_fields.len()); let mut arrays: Vec = Vec::with_capacity(target_fields.len()); let num_rows = source_col.len(); - for target_child_field in target_fields { + // Iterate target fields and pick source child by name when present. + for target_child_field in target_fields.iter() { fields.push(Arc::clone(target_child_field)); - match source_struct.column_by_name(target_child_field.name()) { + + let source_child_opt = + source_struct.column_by_name(target_child_field.name()); + + match source_child_opt { Some(source_child_col) => { let adapted_child = cast_column(source_child_col, target_child_field, cast_options) @@ -200,10 +215,20 @@ pub fn cast_column( /// // Target: {a: binary} /// // Result: Err(...) - string cannot cast to binary /// ``` +/// pub fn validate_struct_compatibility( source_fields: &[FieldRef], target_fields: &[FieldRef], ) -> Result<()> { + let has_overlap = has_one_of_more_common_fields(source_fields, target_fields); + if !has_overlap { + return _plan_err!( + "Cannot cast struct with {} fields to {} fields because there is no field name overlap", + source_fields.len(), + target_fields.len() + ); + } + // Check compatibility for each target field for target_field in target_fields { // Look for matching field in source by name @@ -211,53 +236,102 @@ pub fn validate_struct_compatibility( .iter() .find(|f| f.name() == target_field.name()) { - // Ensure nullability is compatible. It is invalid to cast a nullable - // source field to a non-nullable target field as this may discard - // null values. - if source_field.is_nullable() && !target_field.is_nullable() { + validate_field_compatibility(source_field, target_field)?; + } else { + // Target field is missing from source + // If it's non-nullable, we cannot fill it with NULL + if !target_field.is_nullable() { return _plan_err!( - "Cannot cast nullable struct field '{}' to non-nullable field", + "Cannot cast struct: target field '{}' is non-nullable but missing from source. \ + Cannot fill with NULL.", target_field.name() ); } - // Check if the matching field types are compatible - match (source_field.data_type(), target_field.data_type()) { - // Recursively validate nested structs - (Struct(source_nested), Struct(target_nested)) => { - validate_struct_compatibility(source_nested, target_nested)?; - } - // For non-struct types, use the existing castability check - _ => { - if !arrow::compute::can_cast_types( - source_field.data_type(), - target_field.data_type(), - ) { - return _plan_err!( - "Cannot cast struct field '{}' from type {} to type {}", - target_field.name(), - source_field.data_type(), - target_field.data_type() - ); - } - } - } } - // Missing fields in source are OK - they'll be filled with nulls } // Extra fields in source are OK - they'll be ignored Ok(()) } +fn validate_field_compatibility( + source_field: &Field, + target_field: &Field, +) -> Result<()> { + if source_field.data_type() == &DataType::Null { + // Validate that target allows nulls before returning early. + // It is invalid to cast a NULL source field to a non-nullable target field. + if !target_field.is_nullable() { + return _plan_err!( + "Cannot cast NULL struct field '{}' to non-nullable field '{}'", + source_field.name(), + target_field.name() + ); + } + return Ok(()); + } + + // Ensure nullability is compatible. It is invalid to cast a nullable + // source field to a non-nullable target field as this may discard + // null values. + if source_field.is_nullable() && !target_field.is_nullable() { + return _plan_err!( + "Cannot cast nullable struct field '{}' to non-nullable field", + target_field.name() + ); + } + + // Check if the matching field types are compatible + match (source_field.data_type(), target_field.data_type()) { + // Recursively validate nested structs + (Struct(source_nested), Struct(target_nested)) => { + validate_struct_compatibility(source_nested, target_nested)?; + } + // For non-struct types, use the existing castability check + _ => { + if !arrow::compute::can_cast_types( + source_field.data_type(), + target_field.data_type(), + ) { + return _plan_err!( + "Cannot cast struct field '{}' from type {} to type {}", + target_field.name(), + source_field.data_type(), + target_field.data_type() + ); + } + } + } + + Ok(()) +} + +/// Check if two field lists have at least one common field by name. +/// +/// This is useful for validating struct compatibility when casting between structs, +/// ensuring that source and target fields have overlapping names. +pub fn has_one_of_more_common_fields( + source_fields: &[FieldRef], + target_fields: &[FieldRef], +) -> bool { + let source_names: HashSet<&str> = source_fields + .iter() + .map(|field| field.name().as_str()) + .collect(); + target_fields + .iter() + .any(|field| source_names.contains(field.name().as_str())) +} + #[cfg(test)] mod tests { use super::*; - use crate::format::DEFAULT_CAST_OPTIONS; + use crate::{assert_contains, format::DEFAULT_CAST_OPTIONS}; use arrow::{ array::{ BinaryArray, Int32Array, Int32Builder, Int64Array, ListArray, MapArray, - MapBuilder, StringArray, StringBuilder, + MapBuilder, NullArray, StringArray, StringBuilder, }, buffer::NullBuffer, datatypes::{DataType, Field, FieldRef, Int32Type}, @@ -428,11 +502,14 @@ mod tests { #[test] fn test_validate_struct_compatibility_missing_field_in_source() { - // Source struct: {field2: String} (missing field1) - let source_fields = vec![arc_field("field2", DataType::Utf8)]; + // Source struct: {field1: Int32} (missing field2) + let source_fields = vec![arc_field("field1", DataType::Int32)]; - // Target struct: {field1: Int32} - let target_fields = vec![arc_field("field1", DataType::Int32)]; + // Target struct: {field1: Int32, field2: Utf8} + let target_fields = vec![ + arc_field("field1", DataType::Int32), + arc_field("field2", DataType::Utf8), + ]; // Should be OK - missing fields will be filled with nulls let result = validate_struct_compatibility(&source_fields, &target_fields); @@ -455,6 +532,20 @@ mod tests { assert!(result.is_ok()); } + #[test] + fn test_validate_struct_compatibility_no_overlap_mismatch_len() { + let source_fields = vec![ + arc_field("left", DataType::Int32), + arc_field("right", DataType::Int32), + ]; + let target_fields = vec![arc_field("alpha", DataType::Int32)]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert_contains!(error_msg, "no field name overlap"); + } + #[test] fn test_cast_struct_parent_nulls_retained() { let a_array = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef; @@ -525,6 +616,117 @@ mod tests { assert!(error_msg.contains("non-nullable")); } + #[test] + fn test_validate_struct_compatibility_by_name() { + // Source struct: {field1: Int32, field2: String} + let source_fields = vec![ + arc_field("field1", DataType::Int32), + arc_field("field2", DataType::Utf8), + ]; + + // Target struct: {field2: String, field1: Int64} + let target_fields = vec![ + arc_field("field2", DataType::Utf8), + arc_field("field1", DataType::Int64), + ]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_ok()); + } + + #[test] + fn test_validate_struct_compatibility_by_name_with_type_mismatch() { + // Source struct: {field1: Binary} + let source_fields = vec![arc_field("field1", DataType::Binary)]; + + // Target struct: {field1: Int32} (incompatible type) + let target_fields = vec![arc_field("field1", DataType::Int32)]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert_contains!( + error_msg, + "Cannot cast struct field 'field1' from type Binary to type Int32" + ); + } + + #[test] + fn test_validate_struct_compatibility_no_overlap_equal_len() { + let source_fields = vec![ + arc_field("left", DataType::Int32), + arc_field("right", DataType::Utf8), + ]; + + let target_fields = vec![ + arc_field("alpha", DataType::Int32), + arc_field("beta", DataType::Utf8), + ]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert_contains!(error_msg, "no field name overlap"); + } + + #[test] + fn test_validate_struct_compatibility_mixed_name_overlap() { + // Source struct: {a: Int32, b: String, extra: Boolean} + let source_fields = vec![ + arc_field("a", DataType::Int32), + arc_field("b", DataType::Utf8), + arc_field("extra", DataType::Boolean), + ]; + + // Target struct: {b: String, a: Int64, c: Float32} + // Name overlap with a and b, missing c (nullable) + let target_fields = vec![ + arc_field("b", DataType::Utf8), + arc_field("a", DataType::Int64), + arc_field("c", DataType::Float32), + ]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_ok()); + } + + #[test] + fn test_validate_struct_compatibility_by_name_missing_required_field() { + // Source struct: {field1: Int32} (missing field2) + let source_fields = vec![arc_field("field1", DataType::Int32)]; + + // Target struct: {field1: Int32, field2: Int32 non-nullable} + let target_fields = vec![ + arc_field("field1", DataType::Int32), + Arc::new(non_null_field("field2", DataType::Int32)), + ]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert_contains!( + error_msg, + "Cannot cast struct: target field 'field2' is non-nullable but missing from source. Cannot fill with NULL." + ); + } + + #[test] + fn test_validate_struct_compatibility_partial_name_overlap_with_count_mismatch() { + // Source struct: {a: Int32} (only one field) + let source_fields = vec![arc_field("a", DataType::Int32)]; + + // Target struct: {a: Int32, b: String} (two fields, but 'a' overlaps) + let target_fields = vec![ + arc_field("a", DataType::Int32), + arc_field("b", DataType::Utf8), + ]; + + // This should succeed - partial overlap means by-name mapping + // and missing field 'b' is nullable + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_ok()); + } + #[test] fn test_cast_nested_struct_with_extra_and_missing_fields() { // Source inner struct has fields a, b, extra @@ -585,6 +787,33 @@ mod tests { assert!(missing.is_null(1)); } + #[test] + fn test_cast_null_struct_field_to_nested_struct() { + let null_inner = Arc::new(NullArray::new(2)) as ArrayRef; + let source_struct = StructArray::from(vec![( + arc_field("inner", DataType::Null), + Arc::clone(&null_inner), + )]); + let source_col = Arc::new(source_struct) as ArrayRef; + + let target_field = struct_field( + "outer", + vec![struct_field("inner", vec![field("a", DataType::Int32)])], + ); + + let result = + cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + let outer = result.as_any().downcast_ref::().unwrap(); + let inner = get_column_as!(&outer, "inner", StructArray); + assert_eq!(inner.len(), 2); + assert!(inner.is_null(0)); + assert!(inner.is_null(1)); + + let inner_a = get_column_as!(inner, "a", Int32Array); + assert!(inner_a.is_null(0)); + assert!(inner_a.is_null(1)); + } + #[test] fn test_cast_struct_with_array_and_map_fields() { // Array field with second row null @@ -704,4 +933,81 @@ mod tests { assert_eq!(a_col.value(0), 1); assert_eq!(a_col.value(1), 2); } + + #[test] + fn test_cast_struct_no_overlap_rejected() { + let first = Arc::new(Int32Array::from(vec![Some(10), Some(20)])) as ArrayRef; + let second = + Arc::new(StringArray::from(vec![Some("alpha"), Some("beta")])) as ArrayRef; + + let source_struct = StructArray::from(vec![ + (arc_field("left", DataType::Int32), first), + (arc_field("right", DataType::Utf8), second), + ]); + let source_col = Arc::new(source_struct) as ArrayRef; + + let target_field = struct_field( + "s", + vec![field("a", DataType::Int64), field("b", DataType::Utf8)], + ); + + let result = cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert_contains!(error_msg, "no field name overlap"); + } + + #[test] + fn test_cast_struct_missing_non_nullable_field_fails() { + // Source has only field 'a' + let a = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef; + let source_struct = StructArray::from(vec![(arc_field("a", DataType::Int32), a)]); + let source_col = Arc::new(source_struct) as ArrayRef; + + // Target has fields 'a' (nullable) and 'b' (non-nullable) + let target_field = struct_field( + "s", + vec![ + field("a", DataType::Int32), + non_null_field("b", DataType::Int32), + ], + ); + + // Should fail because 'b' is non-nullable but missing from source + let result = cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + err.to_string() + .contains("target field 'b' is non-nullable but missing from source"), + "Unexpected error: {err}" + ); + } + + #[test] + fn test_cast_struct_missing_nullable_field_succeeds() { + // Source has only field 'a' + let a = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef; + let source_struct = StructArray::from(vec![(arc_field("a", DataType::Int32), a)]); + let source_col = Arc::new(source_struct) as ArrayRef; + + // Target has fields 'a' and 'b' (both nullable) + let target_field = struct_field( + "s", + vec![field("a", DataType::Int32), field("b", DataType::Int32)], + ); + + // Should succeed - 'b' is nullable so can be filled with NULL + let result = + cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + let struct_array = result.as_any().downcast_ref::().unwrap(); + + let a_col = get_column_as!(&struct_array, "a", Int32Array); + assert_eq!(a_col.value(0), 1); + assert_eq!(a_col.value(1), 2); + + let b_col = get_column_as!(&struct_array, "b", Int32Array); + assert!(b_col.is_null(0)); + assert!(b_col.is_null(1)); + } } diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index e4e048ad3c0d..c21d3e21f007 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -43,7 +43,7 @@ use crate::cast::{ as_float16_array, as_float32_array, as_float64_array, as_int8_array, as_int16_array, as_int32_array, as_int64_array, as_interval_dt_array, as_interval_mdn_array, as_interval_ym_array, as_large_binary_array, as_large_list_array, - as_large_string_array, as_string_array, as_string_view_array, + as_large_string_array, as_run_array, as_string_array, as_string_view_array, as_time32_millisecond_array, as_time32_second_array, as_time64_microsecond_array, as_time64_nanosecond_array, as_timestamp_microsecond_array, as_timestamp_millisecond_array, as_timestamp_nanosecond_array, @@ -56,21 +56,20 @@ use crate::hash_utils::create_hashes; use crate::utils::SingleRowListArrayBuilder; use crate::{_internal_datafusion_err, arrow_datafusion_err}; use arrow::array::{ - Array, ArrayData, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, AsArray, - BinaryArray, BinaryViewArray, BinaryViewBuilder, BooleanArray, Date32Array, + Array, ArrayData, ArrayDataBuilder, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, + AsArray, BinaryArray, BinaryViewArray, BinaryViewBuilder, BooleanArray, Date32Array, Date64Array, Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array, DictionaryArray, DurationMicrosecondArray, DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, FixedSizeBinaryArray, - FixedSizeBinaryBuilder, FixedSizeListArray, Float16Array, Float32Array, Float64Array, - GenericListArray, Int8Array, Int16Array, Int32Array, Int64Array, - IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, - LargeBinaryArray, LargeListArray, LargeStringArray, ListArray, MapArray, - MutableArrayData, OffsetSizeTrait, PrimitiveArray, Scalar, StringArray, - StringViewArray, StringViewBuilder, StructArray, Time32MillisecondArray, - Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, UInt8Array, UInt16Array, UInt32Array, UInt64Array, UnionArray, - new_empty_array, new_null_array, + FixedSizeListArray, Float16Array, Float32Array, Float64Array, GenericListArray, + Int8Array, Int16Array, Int32Array, Int64Array, IntervalDayTimeArray, + IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, LargeListArray, + LargeStringArray, ListArray, MapArray, MutableArrayData, OffsetSizeTrait, + PrimitiveArray, RunArray, Scalar, StringArray, StringViewArray, StringViewBuilder, + StructArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, + Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt8Array, UInt16Array, UInt32Array, + UInt64Array, UnionArray, downcast_run_array, new_empty_array, new_null_array, }; use arrow::buffer::{BooleanBuffer, ScalarBuffer}; use arrow::compute::kernels::cast::{CastOptions, cast_with_options}; @@ -80,11 +79,12 @@ use arrow::compute::kernels::numeric::{ use arrow::datatypes::{ ArrowDictionaryKeyType, ArrowNativeType, ArrowTimestampType, DataType, Date32Type, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, DecimalType, Field, - Float32Type, Int8Type, Int16Type, Int32Type, Int64Type, IntervalDayTime, + FieldRef, Float32Type, Int8Type, Int16Type, Int32Type, Int64Type, IntervalDayTime, IntervalDayTimeType, IntervalMonthDayNano, IntervalMonthDayNanoType, IntervalUnit, - IntervalYearMonthType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, - TimestampNanosecondType, TimestampSecondType, UInt8Type, UInt16Type, UInt32Type, - UInt64Type, UnionFields, UnionMode, i256, validate_decimal_precision_and_scale, + IntervalYearMonthType, RunEndIndexType, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt8Type, + UInt16Type, UInt32Type, UInt64Type, UnionFields, UnionMode, i256, + validate_decimal_precision_and_scale, }; use arrow::util::display::{ArrayFormatter, FormatOptions, array_value_to_string}; use cache::{get_or_create_cached_key_array, get_or_create_cached_null_array}; @@ -429,6 +429,8 @@ pub enum ScalarValue { Union(Option<(i8, Box)>, UnionFields, UnionMode), /// Dictionary type: index type and value Dictionary(Box, Box), + /// (run-ends field, value field, value) + RunEndEncoded(FieldRef, FieldRef, Box), } impl Hash for Fl { @@ -558,6 +560,10 @@ impl PartialEq for ScalarValue { (Union(_, _, _), _) => false, (Dictionary(k1, v1), Dictionary(k2, v2)) => k1.eq(k2) && v1.eq(v2), (Dictionary(_, _), _) => false, + (RunEndEncoded(rf1, vf1, v1), RunEndEncoded(rf2, vf2, v2)) => { + rf1.eq(rf2) && vf1.eq(vf2) && v1.eq(v2) + } + (RunEndEncoded(_, _, _), _) => false, (Null, Null) => true, (Null, _) => false, } @@ -723,6 +729,15 @@ impl PartialOrd for ScalarValue { if k1 == k2 { v1.partial_cmp(v2) } else { None } } (Dictionary(_, _), _) => None, + (RunEndEncoded(rf1, vf1, v1), RunEndEncoded(rf2, vf2, v2)) => { + // Don't compare if the run ends fields don't match (it is effectively a different datatype) + if rf1 == rf2 && vf1 == vf2 { + v1.partial_cmp(v2) + } else { + None + } + } + (RunEndEncoded(_, _, _), _) => None, (Null, Null) => Some(Ordering::Equal), (Null, _) => None, } @@ -966,6 +981,11 @@ impl Hash for ScalarValue { k.hash(state); v.hash(state); } + RunEndEncoded(rf, vf, v) => { + rf.hash(state); + vf.hash(state); + v.hash(state); + } // stable hash for Null value Null => 1.hash(state), } @@ -1244,6 +1264,13 @@ impl ScalarValue { index_type.clone(), Box::new(value_type.as_ref().try_into()?), ), + DataType::RunEndEncoded(run_ends_field, value_field) => { + ScalarValue::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(value_field), + Box::new(value_field.data_type().try_into()?), + ) + } // `ScalarValue::List` contains single element `ListArray`. DataType::List(field_ref) => ScalarValue::List(Arc::new( GenericListArray::new_null(Arc::clone(field_ref), 1), @@ -1574,6 +1601,8 @@ impl ScalarValue { | DataType::Float16 | DataType::Float32 | DataType::Float64 + | DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) | DataType::Decimal128(_, _) | DataType::Decimal256(_, _) | DataType::Timestamp(_, _) @@ -1642,6 +1671,14 @@ impl ScalarValue { Box::new(ScalarValue::new_default(value_type)?), )), + DataType::RunEndEncoded(run_ends_field, value_field) => { + Ok(ScalarValue::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(value_field), + Box::new(ScalarValue::new_default(value_field.data_type())?), + )) + } + // Map types DataType::Map(field, _) => Ok(ScalarValue::Map(Arc::new(MapArray::from( ArrayData::new_empty(field.data_type()), @@ -1661,8 +1698,7 @@ impl ScalarValue { } } - // Unsupported types for now - _ => { + DataType::ListView(_) | DataType::LargeListView(_) => { _not_impl_err!( "Default value for data_type \"{datatype}\" is not implemented yet" ) @@ -1953,6 +1989,12 @@ impl ScalarValue { ScalarValue::Dictionary(k, v) => { DataType::Dictionary(k.clone(), Box::new(v.data_type())) } + ScalarValue::RunEndEncoded(run_ends_field, value_field, _) => { + DataType::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(value_field), + ) + } ScalarValue::Null => DataType::Null, } } @@ -2231,6 +2273,7 @@ impl ScalarValue { None => true, }, ScalarValue::Dictionary(_, v) => v.is_null(), + ScalarValue::RunEndEncoded(_, _, v) => v.is_null(), } } @@ -2598,6 +2641,94 @@ impl ScalarValue { _ => unreachable!("Invalid dictionary keys type: {}", key_type), } } + DataType::RunEndEncoded(run_ends_field, value_field) => { + fn make_run_array( + scalars: impl IntoIterator, + run_ends_field: &FieldRef, + values_field: &FieldRef, + ) -> Result { + let mut scalars = scalars.into_iter(); + + let mut run_ends = vec![]; + let mut value_scalars = vec![]; + + let mut len = R::Native::ONE; + let mut current = + if let Some(ScalarValue::RunEndEncoded(_, _, scalar)) = + scalars.next() + { + *scalar + } else { + // We are guaranteed to have one element of correct + // type because we peeked above + unreachable!() + }; + for scalar in scalars { + let scalar = match scalar { + ScalarValue::RunEndEncoded( + inner_run_ends_field, + inner_value_field, + scalar, + ) if &inner_run_ends_field == run_ends_field + && &inner_value_field == values_field => + { + *scalar + } + _ => { + return _exec_err!( + "Expected RunEndEncoded scalar with run-ends field {run_ends_field} but got: {scalar:?}" + ); + } + }; + + // new run + if scalar != current { + run_ends.push(len); + value_scalars.push(current); + current = scalar; + } + + len = len.add_checked(R::Native::ONE).map_err(|_| { + DataFusionError::Execution(format!( + "Cannot construct RunArray: Overflows run-ends type {}", + run_ends_field.data_type() + )) + })?; + } + + run_ends.push(len); + value_scalars.push(current); + + let run_ends = PrimitiveArray::::from_iter_values(run_ends); + let values = ScalarValue::iter_to_array(value_scalars)?; + + // Using ArrayDataBuilder so we can maintain the fields + let dt = DataType::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(values_field), + ); + let builder = ArrayDataBuilder::new(dt) + .len(RunArray::logical_len(&run_ends)) + .add_child_data(run_ends.to_data()) + .add_child_data(values.to_data()); + let run_array = RunArray::::from(builder.build()?); + + Ok(Arc::new(run_array)) + } + + match run_ends_field.data_type() { + DataType::Int16 => { + make_run_array::(scalars, run_ends_field, value_field)? + } + DataType::Int32 => { + make_run_array::(scalars, run_ends_field, value_field)? + } + DataType::Int64 => { + make_run_array::(scalars, run_ends_field, value_field)? + } + dt => unreachable!("Invalid run-ends type: {dt}"), + } + } DataType::FixedSizeBinary(size) => { let array = scalars .map(|sv| { @@ -2626,7 +2757,6 @@ impl ScalarValue { | DataType::Time32(TimeUnit::Nanosecond) | DataType::Time64(TimeUnit::Second) | DataType::Time64(TimeUnit::Millisecond) - | DataType::RunEndEncoded(_, _) | DataType::ListView(_) | DataType::LargeListView(_) => { return _not_impl_err!( @@ -2878,7 +3008,7 @@ impl ScalarValue { /// /// Errors if `self` is /// - a decimal that fails be converted to a decimal array of size - /// - a `FixedsizeList` that fails to be concatenated into an array of size + /// - a `FixedSizeList` that fails to be concatenated into an array of size /// - a `List` that fails to be concatenated into an array of size /// - a `Dictionary` that fails be converted to a dictionary array of size pub fn to_array_of_size(&self, size: usize) -> Result { @@ -2989,13 +3119,8 @@ impl ScalarValue { }, ScalarValue::Utf8View(e) => match e { Some(value) => { - let mut builder = - StringViewBuilder::with_capacity(size).with_deduplicate_strings(); - // Replace with upstream arrow-rs code when available: - // https://github.com/apache/arrow-rs/issues/9034 - for _ in 0..size { - builder.append_value(value); - } + let mut builder = StringViewBuilder::with_capacity(size); + builder.try_append_value_n(value, size)?; let array = builder.finish(); Arc::new(array) } @@ -3013,11 +3138,8 @@ impl ScalarValue { }, ScalarValue::BinaryView(e) => match e { Some(value) => { - let mut builder = - BinaryViewBuilder::with_capacity(size).with_deduplicate_strings(); - for _ in 0..size { - builder.append_value(value); - } + let mut builder = BinaryViewBuilder::with_capacity(size); + builder.try_append_value_n(value, size)?; let array = builder.finish(); Arc::new(array) } @@ -3031,14 +3153,7 @@ impl ScalarValue { ) .unwrap(), ), - None => { - // TODO: Replace with FixedSizeBinaryArray::new_null once a fix for - // https://github.com/apache/arrow-rs/issues/8900 is in the used arrow-rs - // version. - let mut builder = FixedSizeBinaryBuilder::new(*s); - builder.append_nulls(size); - Arc::new(builder.finish()) - } + None => Arc::new(FixedSizeBinaryArray::new_null(*s, size)), }, ScalarValue::LargeBinary(e) => match e { Some(value) => { @@ -3218,6 +3333,54 @@ impl ScalarValue { _ => unreachable!("Invalid dictionary keys type: {}", key_type), } } + ScalarValue::RunEndEncoded(run_ends_field, values_field, value) => { + fn make_run_array( + run_ends_field: &Arc, + values_field: &Arc, + value: &ScalarValue, + size: usize, + ) -> Result { + let size_native = R::Native::from_usize(size) + .ok_or_else(|| DataFusionError::Execution(format!("Cannot construct RunArray of size {size}: Overflows run-ends type {}", R::DATA_TYPE)))?; + let values = value.to_array_of_size(1)?; + let run_ends = + PrimitiveArray::::new(vec![size_native].into(), None); + + // Using ArrayDataBuilder so we can maintain the fields + let dt = DataType::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(values_field), + ); + let builder = ArrayDataBuilder::new(dt) + .len(size) + .add_child_data(run_ends.to_data()) + .add_child_data(values.to_data()); + let run_array = RunArray::::from(builder.build()?); + + Ok(Arc::new(run_array)) + } + match run_ends_field.data_type() { + DataType::Int16 => make_run_array::( + run_ends_field, + values_field, + value, + size, + )?, + DataType::Int32 => make_run_array::( + run_ends_field, + values_field, + value, + size, + )?, + DataType::Int64 => make_run_array::( + run_ends_field, + values_field, + value, + size, + )?, + dt => unreachable!("Invalid run-ends type: {dt}"), + } + } ScalarValue::Null => get_or_create_cached_null_array(size), }) } @@ -3271,13 +3434,22 @@ impl ScalarValue { } } + /// Repeats the rows of `arr` `size` times, producing an array with + /// `arr.len() * size` total rows. fn list_to_array_of_size(arr: &dyn Array, size: usize) -> Result { - let arrays = repeat_n(arr, size).collect::>(); - let ret = match !arrays.is_empty() { - true => arrow::compute::concat(arrays.as_slice())?, - false => arr.slice(0, 0), - }; - Ok(ret) + if size == 0 { + return Ok(arr.slice(0, 0)); + } + + // Examples: given `arr = [[A, B, C]]` and `size = 3`, `indices = [0, 0, 0]` and + // the result is `[[A, B, C], [A, B, C], [A, B, C]]`. + // + // Given `arr = [[A, B], [C]]` and `size = 2`, `indices = [0, 1, 0, 1]` and the + // result is `[[A, B], [C], [A, B], [C]]`. (But in practice, we are always called + // with `arr.len() == 1`.) + let n = arr.len() as u32; + let indices = UInt32Array::from_iter_values((0..size).flat_map(|_| 0..n)); + Ok(arrow::compute::take(arr, &indices, None)?) } /// Retrieve ScalarValue for each row in `array` @@ -3425,7 +3597,7 @@ impl ScalarValue { /// Converts a value in `array` at `index` into a ScalarValue pub fn try_from_array(array: &dyn Array, index: usize) -> Result { // handle NULL value - if !array.is_valid(index) { + if array.is_null(index) { return array.data_type().try_into(); } @@ -3568,6 +3740,28 @@ impl ScalarValue { Self::Dictionary(key_type.clone(), Box::new(value)) } + DataType::RunEndEncoded(run_ends_field, value_field) => { + // Explicitly check length here since get_physical_index() doesn't + // bound check for us + if index > array.len() { + return _exec_err!( + "Index {index} out of bounds for array of length {}", + array.len() + ); + } + let scalar = downcast_run_array!( + array => { + let index = array.get_physical_index(index); + ScalarValue::try_from_array(array.values(), index)? + }, + dt => unreachable!("Invalid run-ends type: {dt}") + ); + Self::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(value_field), + Box::new(scalar), + ) + } DataType::Struct(_) => { let a = array.slice(index, 1); Self::Struct(Arc::new(a.as_struct().to_owned())) @@ -3680,6 +3874,7 @@ impl ScalarValue { ScalarValue::LargeUtf8(v) => v, ScalarValue::Utf8View(v) => v, ScalarValue::Dictionary(_, v) => return v.try_as_str(), + ScalarValue::RunEndEncoded(_, _, v) => return v.try_as_str(), _ => return None, }; Some(v.as_ref().map(|v| v.as_str())) @@ -3704,7 +3899,23 @@ impl ScalarValue { } let scalar_array = self.to_array()?; - let cast_arr = cast_with_options(&scalar_array, target_type, cast_options)?; + + // For struct types, use name-based casting logic that matches fields by name + // and recursively casts nested structs. The field name wrapper is arbitrary + // since cast_column only uses the DataType::Struct field definitions inside. + let cast_arr = match target_type { + DataType::Struct(_) => { + // Field name is unused; only the struct's inner field names matter + let target_field = Field::new("_", target_type.clone(), true); + crate::nested_struct::cast_column( + &scalar_array, + &target_field, + cast_options, + )? + } + _ => cast_with_options(&scalar_array, target_type, cast_options)?, + }; + ScalarValue::try_from_array(&cast_arr, 0) } @@ -4008,6 +4219,34 @@ impl ScalarValue { None => v.is_null(), } } + ScalarValue::RunEndEncoded(run_ends_field, _, value) => { + // Explicitly check length here since get_physical_index() doesn't + // bound check for us + if index > array.len() { + return _exec_err!( + "Index {index} out of bounds for array of length {}", + array.len() + ); + } + match run_ends_field.data_type() { + DataType::Int16 => { + let array = as_run_array::(array)?; + let index = array.get_physical_index(index); + value.eq_array(array.values(), index)? + } + DataType::Int32 => { + let array = as_run_array::(array)?; + let index = array.get_physical_index(index); + value.eq_array(array.values(), index)? + } + DataType::Int64 => { + let array = as_run_array::(array)?; + let index = array.get_physical_index(index); + value.eq_array(array.values(), index)? + } + dt => unreachable!("Invalid run-ends type: {dt}"), + } + } ScalarValue::Null => array.is_null(index), }) } @@ -4097,6 +4336,7 @@ impl ScalarValue { // `dt` and `sv` are boxed, so they are NOT already included in `self` dt.size() + sv.size() } + ScalarValue::RunEndEncoded(rf, vf, v) => rf.size() + vf.size() + v.size(), } } @@ -4212,6 +4452,9 @@ impl ScalarValue { ScalarValue::Dictionary(_, value) => { value.compact(); } + ScalarValue::RunEndEncoded(_, _, value) => { + value.compact(); + } } } @@ -4843,6 +5086,7 @@ impl fmt::Display for ScalarValue { None => write!(f, "NULL")?, }, ScalarValue::Dictionary(_k, v) => write!(f, "{v}")?, + ScalarValue::RunEndEncoded(_, _, v) => write!(f, "{v}")?, ScalarValue::Null => write!(f, "NULL")?, }; Ok(()) @@ -5021,6 +5265,9 @@ impl fmt::Debug for ScalarValue { None => write!(f, "Union(NULL)"), }, ScalarValue::Dictionary(k, v) => write!(f, "Dictionary({k:?}, {v:?})"), + ScalarValue::RunEndEncoded(rf, vf, v) => { + write!(f, "RunEndEncoded({rf:?}, {vf:?}, {v:?})") + } ScalarValue::Null => write!(f, "NULL"), } } @@ -5294,6 +5541,79 @@ mod tests { assert_eq!(empty_array.len(), 0); } + #[test] + fn test_to_array_of_size_list_size_one() { + // size=1 takes the fast path (Arc::clone) + let arr = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(10), + Some(20), + ])]); + let sv = ScalarValue::List(Arc::new(arr.clone())); + let result = sv.to_array_of_size(1).unwrap(); + assert_eq!(result.as_list::(), &arr); + } + + #[test] + fn test_to_array_of_size_list_empty_inner() { + // A list scalar containing an empty list: [[]] + let arr = ListArray::from_iter_primitive::(vec![Some(vec![])]); + let sv = ScalarValue::List(Arc::new(arr)); + let result = sv.to_array_of_size(3).unwrap(); + let result_list = result.as_list::(); + assert_eq!(result_list.len(), 3); + for i in 0..3 { + assert_eq!(result_list.value(i).len(), 0); + } + } + + #[test] + fn test_to_array_of_size_large_list() { + let arr = + LargeListArray::from_iter_primitive::(vec![Some(vec![ + Some(100), + Some(200), + ])]); + let sv = ScalarValue::LargeList(Arc::new(arr)); + let result = sv.to_array_of_size(3).unwrap(); + let expected = LargeListArray::from_iter_primitive::(vec![ + Some(vec![Some(100), Some(200)]), + Some(vec![Some(100), Some(200)]), + Some(vec![Some(100), Some(200)]), + ]); + assert_eq!(result.as_list::(), &expected); + } + + #[test] + fn test_list_to_array_of_size_multi_row() { + // Call list_to_array_of_size directly with arr.len() > 1 + let arr = Int32Array::from(vec![Some(10), None, Some(30)]); + let result = ScalarValue::list_to_array_of_size(&arr, 3).unwrap(); + let result = result.as_primitive::(); + assert_eq!( + result.iter().collect::>(), + vec![ + Some(10), + None, + Some(30), + Some(10), + None, + Some(30), + Some(10), + None, + Some(30), + ] + ); + } + + #[test] + fn test_to_array_of_size_null_list() { + let dt = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); + let sv = ScalarValue::try_from(&dt).unwrap(); + let result = sv.to_array_of_size(3).unwrap(); + assert_eq!(result.len(), 3); + assert_eq!(result.null_count(), 3); + } + /// See https://github.com/apache/datafusion/issues/18870 #[test] fn test_to_array_of_size_for_none_fsb() { @@ -7256,6 +7576,31 @@ mod tests { } } + #[test] + fn roundtrip_run_array() { + // Comparison logic in round_trip_through_scalar doesn't work for RunArrays + // so we have a custom test for them + // TODO: https://github.com/apache/arrow-rs/pull/9213 might fix this ^ + let run_ends = Int16Array::from(vec![2, 3]); + let values = Int64Array::from(vec![Some(1), None]); + let run_array = RunArray::try_new(&run_ends, &values).unwrap(); + let run_array = run_array.downcast::().unwrap(); + + let expected_values = run_array.into_iter().collect::>(); + + for i in 0..run_array.len() { + let scalar = ScalarValue::try_from_array(&run_array, i).unwrap(); + let array = scalar.to_array_of_size(1).unwrap(); + assert_eq!(array.data_type(), run_array.data_type()); + let array = array.as_run::(); + let array = array.downcast::().unwrap(); + assert_eq!( + array.into_iter().collect::>(), + expected_values[i..i + 1] + ); + } + } + #[test] fn test_scalar_union_sparse() { let field_a = Arc::new(Field::new("A", DataType::Int32, true)); @@ -8868,7 +9213,7 @@ mod tests { .unwrap(), ScalarValue::try_new_null(&DataType::Map(map_field_ref, false)).unwrap(), ScalarValue::try_new_null(&DataType::Union( - UnionFields::new(vec![42], vec![field_ref]), + UnionFields::try_new(vec![42], vec![field_ref]).unwrap(), UnionMode::Dense, )) .unwrap(), @@ -8971,13 +9316,14 @@ mod tests { } // Test union type - let union_fields = UnionFields::new( + let union_fields = UnionFields::try_new( vec![0, 1], vec![ Field::new("i32", DataType::Int32, false), Field::new("f64", DataType::Float64, false), ], - ); + ) + .unwrap(); let union_result = ScalarValue::new_default(&DataType::Union( union_fields.clone(), UnionMode::Sparse, @@ -9227,6 +9573,175 @@ mod tests { assert_eq!(value.len(), buffers[0].len()); } + #[test] + fn test_to_array_of_size_run_end_encoded() { + fn run_test() { + let value = Box::new(ScalarValue::Float32(Some(1.0))); + let size = 5; + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", R::DATA_TYPE, false).into(), + Field::new("values", DataType::Float32, true).into(), + value.clone(), + ); + let array = scalar.to_array_of_size(size).unwrap(); + let array = array.as_run::(); + let array = array.downcast::().unwrap(); + assert_eq!(vec![Some(1.0); size], array.into_iter().collect::>()); + assert_eq!(1, array.values().len()); + } + + run_test::(); + run_test::(); + run_test::(); + + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int16, false).into(), + Field::new("values", DataType::Float32, true).into(), + Box::new(ScalarValue::Float32(Some(1.0))), + ); + let err = scalar.to_array_of_size(i16::MAX as usize + 10).unwrap_err(); + assert_eq!( + "Execution error: Cannot construct RunArray of size 32777: Overflows run-ends type Int16", + err.to_string() + ) + } + + #[test] + fn test_eq_array_run_end_encoded() { + let run_ends = Int16Array::from(vec![1, 3]); + let values = Float32Array::from(vec![None, Some(1.0)]); + let run_array = + Arc::new(RunArray::try_new(&run_ends, &values).unwrap()) as ArrayRef; + + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int16, false).into(), + Field::new("values", DataType::Float32, true).into(), + Box::new(ScalarValue::Float32(None)), + ); + assert!(scalar.eq_array(&run_array, 0).unwrap()); + + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int16, false).into(), + Field::new("values", DataType::Float32, true).into(), + Box::new(ScalarValue::Float32(Some(1.0))), + ); + assert!(scalar.eq_array(&run_array, 1).unwrap()); + assert!(scalar.eq_array(&run_array, 2).unwrap()); + + // value types must match + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int16, false).into(), + Field::new("values", DataType::Float64, true).into(), + Box::new(ScalarValue::Float64(Some(1.0))), + ); + let err = scalar.eq_array(&run_array, 1).unwrap_err(); + let expected = "Internal error: could not cast array of type Float32 to arrow_array::array::primitive_array::PrimitiveArray"; + assert!(err.to_string().starts_with(expected)); + + // run ends type must match + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int32, false).into(), + Field::new("values", DataType::Float32, true).into(), + Box::new(ScalarValue::Float32(None)), + ); + let err = scalar.eq_array(&run_array, 0).unwrap_err(); + let expected = "Internal error: could not cast array of type RunEndEncoded(\"run_ends\": non-null Int16, \"values\": Float32) to arrow_array::array::run_array::RunArray"; + assert!(err.to_string().starts_with(expected)); + } + + #[test] + fn test_iter_to_array_run_end_encoded() { + let run_ends_field = Arc::new(Field::new("run_ends", DataType::Int16, false)); + let values_field = Arc::new(Field::new("values", DataType::Int64, true)); + let scalars = vec![ + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(None)), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(2))), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(2))), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(2))), + ), + ]; + + let run_array = ScalarValue::iter_to_array(scalars).unwrap(); + let expected = RunArray::try_new( + &Int16Array::from(vec![2, 3, 6]), + &Int64Array::from(vec![Some(1), None, Some(2)]), + ) + .unwrap(); + assert_eq!(&expected as &dyn Array, run_array.as_ref()); + + // inconsistent run-ends type + let scalars = vec![ + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int32, false).into(), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ]; + let err = ScalarValue::iter_to_array(scalars).unwrap_err(); + let expected = "Execution error: Expected RunEndEncoded scalar with run-ends field Field { \"run_ends\": Int16 } but got: RunEndEncoded(Field { name: \"run_ends\", data_type: Int32 }, Field { name: \"values\", data_type: Int64, nullable: true }, Int64(1))"; + assert!(err.to_string().starts_with(expected)); + + // inconsistent value type + let scalars = vec![ + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Field::new("values", DataType::Int32, true).into(), + Box::new(ScalarValue::Int32(Some(1))), + ), + ]; + let err = ScalarValue::iter_to_array(scalars).unwrap_err(); + let expected = "Execution error: Expected RunEndEncoded scalar with run-ends field Field { \"run_ends\": Int16 } but got: RunEndEncoded(Field { name: \"run_ends\", data_type: Int16 }, Field { name: \"values\", data_type: Int32, nullable: true }, Int32(1))"; + assert!(err.to_string().starts_with(expected)); + + // inconsistent scalars type + let scalars = vec![ + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ScalarValue::Int64(Some(1)), + ]; + let err = ScalarValue::iter_to_array(scalars).unwrap_err(); + let expected = "Execution error: Expected RunEndEncoded scalar with run-ends field Field { \"run_ends\": Int16 } but got: Int64(1)"; + assert!(err.to_string().starts_with(expected)); + } + #[test] fn test_convert_array_to_scalar_vec() { // 1: Regular ListArray diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index ba13ef392d91..3d4d9b6c6c4a 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -391,8 +391,13 @@ impl Statistics { /// For example, if we had statistics for columns `{"a", "b", "c"}`, /// projecting to `vec![2, 1]` would return statistics for columns `{"c", /// "b"}`. - pub fn project(mut self, projection: Option<&Vec>) -> Self { - let Some(projection) = projection else { + pub fn project(self, projection: Option<&impl AsRef<[usize]>>) -> Self { + let projection = projection.map(AsRef::as_ref); + self.project_impl(projection) + } + + fn project_impl(mut self, projection: Option<&[usize]>) -> Self { + let Some(projection) = projection.map(AsRef::as_ref) else { return self; }; @@ -410,7 +415,7 @@ impl Statistics { .map(Slot::Present) .collect(); - for idx in projection { + for idx in projection.iter() { let next_idx = self.column_statistics.len(); let slot = std::mem::replace( columns.get_mut(*idx).expect("projection out of bounds"), @@ -1066,7 +1071,7 @@ mod tests { #[test] fn test_project_none() { - let projection = None; + let projection: Option> = None; let stats = make_stats(vec![10, 20, 30]).project(projection.as_ref()); assert_eq!(stats, make_stats(vec![10, 20, 30])); } @@ -1260,7 +1265,7 @@ mod tests { col_stats.min_value, Precision::Inexact(ScalarValue::Int32(Some(-10))) ); - assert!(matches!(col_stats.sum_value, Precision::Absent)); + assert_eq!(col_stats.sum_value, Precision::Absent); } #[test] diff --git a/datafusion/common/src/types/native.rs b/datafusion/common/src/types/native.rs index 766c50441613..65b6a5a15fc8 100644 --- a/datafusion/common/src/types/native.rs +++ b/datafusion/common/src/types/native.rs @@ -186,7 +186,57 @@ pub enum NativeType { impl Display for NativeType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{self:?}") // TODO: nicer formatting + // Match the format used by arrow::datatypes::DataType's Display impl + match self { + Self::Null => write!(f, "Null"), + Self::Boolean => write!(f, "Boolean"), + Self::Int8 => write!(f, "Int8"), + Self::Int16 => write!(f, "Int16"), + Self::Int32 => write!(f, "Int32"), + Self::Int64 => write!(f, "Int64"), + Self::UInt8 => write!(f, "UInt8"), + Self::UInt16 => write!(f, "UInt16"), + Self::UInt32 => write!(f, "UInt32"), + Self::UInt64 => write!(f, "UInt64"), + Self::Float16 => write!(f, "Float16"), + Self::Float32 => write!(f, "Float32"), + Self::Float64 => write!(f, "Float64"), + Self::Timestamp(unit, Some(tz)) => write!(f, "Timestamp({unit}, {tz:?})"), + Self::Timestamp(unit, None) => write!(f, "Timestamp({unit})"), + Self::Date => write!(f, "Date"), + Self::Time(unit) => write!(f, "Time({unit})"), + Self::Duration(unit) => write!(f, "Duration({unit})"), + Self::Interval(unit) => write!(f, "Interval({unit:?})"), + Self::Binary => write!(f, "Binary"), + Self::FixedSizeBinary(size) => write!(f, "FixedSizeBinary({size})"), + Self::String => write!(f, "String"), + Self::List(field) => write!(f, "List({})", field.logical_type), + Self::FixedSizeList(field, size) => { + write!(f, "FixedSizeList({size} x {})", field.logical_type) + } + Self::Struct(fields) => { + write!(f, "Struct(")?; + for (i, field) in fields.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{:?}: {}", field.name, field.logical_type)?; + } + write!(f, ")") + } + Self::Union(fields) => { + write!(f, "Union(")?; + for (i, (type_id, field)) in fields.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{type_id}: ({:?}: {})", field.name, field.logical_type)?; + } + write!(f, ")") + } + Self::Decimal(precision, scale) => write!(f, "Decimal({precision}, {scale})"), + Self::Map(field) => write!(f, "Map({})", field.logical_type), + } } } @@ -449,7 +499,7 @@ impl NativeType { #[inline] pub fn is_date(&self) -> bool { - matches!(self, NativeType::Date) + *self == NativeType::Date } #[inline] @@ -474,7 +524,7 @@ impl NativeType { #[inline] pub fn is_null(&self) -> bool { - matches!(self, NativeType::Null) + *self == NativeType::Null } #[inline] diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 03310a7bde19..73e8ba6c7003 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -39,7 +39,7 @@ use std::cmp::{Ordering, min}; use std::collections::HashSet; use std::num::NonZero; use std::ops::Range; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use std::thread::available_parallelism; /// Applies an optional projection to a [`SchemaRef`], returning the @@ -70,10 +70,10 @@ use std::thread::available_parallelism; /// ``` pub fn project_schema( schema: &SchemaRef, - projection: Option<&Vec>, + projection: Option<&impl AsRef<[usize]>>, ) -> Result { let schema = match projection { - Some(columns) => Arc::new(schema.project(columns)?), + Some(columns) => Arc::new(schema.project(columns.as_ref())?), None => Arc::clone(schema), }; Ok(schema) @@ -516,6 +516,7 @@ impl SingleRowListArrayBuilder { /// ); /// /// assert_eq!(list_arr, expected); +/// ``` pub fn arrays_into_list_array( arr: impl IntoIterator, ) -> Result { @@ -587,6 +588,7 @@ pub enum ListCoercion { /// let base_type = DataType::Float64; /// let coerced_type = coerced_type_with_base_type_only(&data_type, &base_type, None); /// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new_list_field(DataType::Float64, true)))); +/// ``` pub fn coerced_type_with_base_type_only( data_type: &DataType, base_type: &DataType, @@ -920,10 +922,15 @@ pub fn combine_limit( /// /// This is a wrapper around `std::thread::available_parallelism`, providing a default value /// of `1` if the system's parallelism cannot be determined. +/// +/// The result is cached after the first call. pub fn get_available_parallelism() -> usize { - available_parallelism() - .unwrap_or(NonZero::new(1).expect("literal value `1` shouldn't be zero")) - .get() + static PARALLELISM: LazyLock = LazyLock::new(|| { + available_parallelism() + .unwrap_or(NonZero::new(1).expect("literal value `1` shouldn't be zero")) + .get() + }); + *PARALLELISM } /// Converts a collection of function arguments into a fixed-size array of length N diff --git a/datafusion/common/src/utils/proxy.rs b/datafusion/common/src/utils/proxy.rs index fddf83491254..846c928515d6 100644 --- a/datafusion/common/src/utils/proxy.rs +++ b/datafusion/common/src/utils/proxy.rs @@ -121,6 +121,8 @@ pub trait HashTableAllocExt { /// /// Returns the bucket where the element was inserted. /// Note that allocation counts capacity, not size. + /// Panics: + /// Assumes the element is not already present, and may panic if it does /// /// # Example: /// ``` @@ -134,7 +136,7 @@ pub trait HashTableAllocExt { /// assert_eq!(allocated, 64); /// /// // insert more values - /// for i in 0..100 { + /// for i in 2..100 { /// table.insert_accounted(i, hash_fn, &mut allocated); /// } /// assert_eq!(allocated, 400); @@ -161,22 +163,24 @@ where ) { let hash = hasher(&x); - // NOTE: `find_entry` does NOT grow! - match self.find_entry(hash, |y| y == &x) { - Ok(_occupied) => {} - Err(_absent) => { - if self.len() == self.capacity() { - // need to request more memory - let bump_elements = self.capacity().max(16); - let bump_size = bump_elements * size_of::(); - *accounting = (*accounting).checked_add(bump_size).expect("overflow"); + if cfg!(debug_assertions) { + // In debug mode, check that the element is not already present + debug_assert!( + self.find_entry(hash, |y| y == &x).is_err(), + "attempted to insert duplicate element into HashTableAllocExt::insert_accounted" + ); + } - self.reserve(bump_elements, &hasher); - } + if self.len() == self.capacity() { + // need to request more memory + let bump_elements = self.capacity().max(16); + let bump_size = bump_elements * size_of::(); + *accounting = (*accounting).checked_add(bump_size).expect("overflow"); - // still need to insert the element since first try failed - self.entry(hash, |y| y == &x, hasher).insert(x); - } + self.reserve(bump_elements, &hasher); } + + // We assume the element is not already present + self.insert_unique(hash, x, hasher); } } diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index bd88ed3b9ca1..8965948a0f4e 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -88,8 +88,8 @@ recursive_protection = [ "datafusion-optimizer/recursive_protection", "datafusion-physical-optimizer/recursive_protection", "datafusion-physical-expr/recursive_protection", - "datafusion-sql/recursive_protection", - "sqlparser/recursive-protection", + "datafusion-sql?/recursive_protection", + "sqlparser?/recursive-protection", ] serde = [ "dep:serde", @@ -158,7 +158,7 @@ sqlparser = { workspace = true, optional = true } tempfile = { workspace = true } tokio = { workspace = true } url = { workspace = true } -uuid = { version = "1.19", features = ["v4", "js"] } +uuid = { workspace = true, features = ["v4", "js"] } zstd = { workspace = true, optional = true } [dev-dependencies] @@ -175,12 +175,14 @@ env_logger = { workspace = true } glob = { workspace = true } insta = { workspace = true } paste = { workspace = true } +pretty_assertions = "1.0" rand = { workspace = true, features = ["small_rng"] } rand_distr = "0.5" +recursive = { workspace = true } regex = { workspace = true } rstest = { workspace = true } serde_json = { workspace = true } -sysinfo = "0.37.2" +sysinfo = "0.38.2" test-utils = { path = "../../test-utils" } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot", "fs"] } @@ -188,7 +190,7 @@ tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot", "fs"] ignored = ["datafusion-doc", "datafusion-macros", "dashmap"] [target.'cfg(not(target_os = "windows"))'.dev-dependencies] -nix = { version = "0.30.1", features = ["fs"] } +nix = { version = "0.31.1", features = ["fs"] } [[bench]] harness = false @@ -239,6 +241,11 @@ harness = false name = "parquet_query_sql" required-features = ["parquet"] +[[bench]] +harness = false +name = "parquet_struct_query" +required-features = ["parquet"] + [[bench]] harness = false name = "range_and_generate_series" @@ -280,3 +287,7 @@ name = "spm" harness = false name = "preserve_file_partitioning" required-features = ["parquet"] + +[[bench]] +harness = false +name = "reset_plan_states" diff --git a/datafusion/core/benches/aggregate_query_sql.rs b/datafusion/core/benches/aggregate_query_sql.rs index 4aa667504e45..402ac9c7176b 100644 --- a/datafusion/core/benches/aggregate_query_sql.rs +++ b/datafusion/core/benches/aggregate_query_sql.rs @@ -15,14 +15,9 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -extern crate arrow; -extern crate datafusion; - mod data_utils; -use crate::criterion::Criterion; +use criterion::{Criterion, criterion_group, criterion_main}; use data_utils::create_table_provider; use datafusion::error::Result; use datafusion::execution::context::SessionContext; @@ -256,6 +251,50 @@ fn criterion_benchmark(c: &mut Criterion) { ) }) }); + + c.bench_function("array_agg_query_group_by_few_groups", |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + "SELECT u64_narrow, array_agg(f64) \ + FROM t GROUP BY u64_narrow", + ) + }) + }); + + c.bench_function("array_agg_query_group_by_mid_groups", |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + "SELECT u64_mid, array_agg(f64) \ + FROM t GROUP BY u64_mid", + ) + }) + }); + + c.bench_function("array_agg_query_group_by_many_groups", |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + "SELECT u64_wide, array_agg(f64) \ + FROM t GROUP BY u64_wide", + ) + }) + }); + + c.bench_function("array_agg_struct_query_group_by_mid_groups", |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + "SELECT u64_mid, array_agg(named_struct('market', dict10, 'price', f64)) \ + FROM t GROUP BY u64_mid", + ) + }) + }); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/core/benches/csv_load.rs b/datafusion/core/benches/csv_load.rs index 228457947fd5..13843dadddd0 100644 --- a/datafusion/core/benches/csv_load.rs +++ b/datafusion/core/benches/csv_load.rs @@ -15,14 +15,9 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -extern crate arrow; -extern crate datafusion; - mod data_utils; -use crate::criterion::Criterion; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::error::Result; use datafusion::execution::context::SessionContext; use datafusion::prelude::CsvReadOptions; diff --git a/datafusion/core/benches/data_utils/mod.rs b/datafusion/core/benches/data_utils/mod.rs index 630bc056600b..728c6490c72b 100644 --- a/datafusion/core/benches/data_utils/mod.rs +++ b/datafusion/core/benches/data_utils/mod.rs @@ -20,8 +20,9 @@ use arrow::array::{ ArrayRef, Float32Array, Float64Array, RecordBatch, StringArray, StringViewBuilder, UInt64Array, - builder::{Int64Builder, StringBuilder}, + builder::{Int64Builder, StringBuilder, StringDictionaryBuilder}, }; +use arrow::datatypes::Int32Type; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::datasource::MemTable; use datafusion::error::Result; @@ -36,6 +37,7 @@ use std::sync::Arc; /// create an in-memory table given the partition len, array len, and batch size, /// and the result table will be of array_len in total, and then partitioned, and batched. +#[expect(clippy::allow_attributes)] // some issue where expect(dead_code) doesn't fire properly #[allow(dead_code)] pub fn create_table_provider( partitions_len: usize, @@ -44,7 +46,7 @@ pub fn create_table_provider( ) -> Result> { let schema = Arc::new(create_schema()); let partitions = - create_record_batches(schema.clone(), array_len, partitions_len, batch_size); + create_record_batches(&schema, array_len, partitions_len, batch_size); // declare a table in memory. In spark API, this corresponds to createDataFrame(...). MemTable::try_new(schema, partitions).map(Arc::new) } @@ -55,21 +57,24 @@ pub fn create_schema() -> Schema { Field::new("utf8", DataType::Utf8, false), Field::new("f32", DataType::Float32, false), Field::new("f64", DataType::Float64, true), - // This field will contain integers randomly selected from a large - // range of values, i.e. [0, u64::MAX], such that there are none (or - // very few) repeated values. - Field::new("u64_wide", DataType::UInt64, true), - // This field will contain integers randomly selected from a narrow - // range of values such that there are a few distinct values, but they - // are repeated often. + // Integers randomly selected from a wide range of values, i.e. [0, + // u64::MAX], such that there are ~no repeated values. + Field::new("u64_wide", DataType::UInt64, false), + // Integers randomly selected from a mid-range of values [0, 1000), + // providing ~1000 distinct groups. + Field::new("u64_mid", DataType::UInt64, false), + // Integers randomly selected from a narrow range of values such that + // there are a few distinct values, but they are repeated often. Field::new("u64_narrow", DataType::UInt64, false), + Field::new( + "dict10", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + ), ]) } -fn create_data(size: usize, null_density: f64) -> Vec> { - // use random numbers to avoid spurious compiler optimizations wrt to branching - let mut rng = StdRng::seed_from_u64(42); - +fn create_data(rng: &mut StdRng, size: usize, null_density: f64) -> Vec> { (0..size) .map(|_| { if rng.random::() > null_density { @@ -81,57 +86,54 @@ fn create_data(size: usize, null_density: f64) -> Vec> { .collect() } -fn create_integer_data( - rng: &mut StdRng, - size: usize, - value_density: f64, -) -> Vec> { - (0..size) - .map(|_| { - if rng.random::() > value_density { - None - } else { - Some(rng.random::()) - } - }) - .collect() -} - fn create_record_batch( schema: SchemaRef, rng: &mut StdRng, batch_size: usize, - i: usize, + batch_index: usize, ) -> RecordBatch { - // the 4 here is the number of different keys. - // a higher number increase sparseness - let vs = [0, 1, 2, 3]; - let keys: Vec = (0..batch_size) - .map( - // use random numbers to avoid spurious compiler optimizations wrt to branching - |_| format!("hi{:?}", vs.choose(rng)), - ) - .collect(); - let keys: Vec<&str> = keys.iter().map(|e| &**e).collect(); + // Randomly choose from 4 distinct key values; a higher number increases sparseness. + let key_suffixes = [0, 1, 2, 3]; + let keys = StringArray::from_iter_values( + (0..batch_size).map(|_| format!("hi{}", key_suffixes.choose(rng).unwrap())), + ); - let values = create_data(batch_size, 0.5); + let values = create_data(rng, batch_size, 0.5); // Integer values between [0, u64::MAX]. - let integer_values_wide = create_integer_data(rng, batch_size, 9.0); + let integer_values_wide = (0..batch_size) + .map(|_| rng.random::()) + .collect::>(); - // Integer values between [0, 9]. + // Integer values between [0, 1000). + let integer_values_mid = (0..batch_size) + .map(|_| rng.random_range(0..1000)) + .collect::>(); + + // Integer values between [0, 10). let integer_values_narrow = (0..batch_size) - .map(|_| rng.random_range(0_u64..10)) + .map(|_| rng.random_range(0..10)) .collect::>(); + let mut dict_builder = StringDictionaryBuilder::::new(); + for _ in 0..batch_size { + if rng.random::() > 0.9 { + dict_builder.append_null(); + } else { + dict_builder.append_value(format!("market_{}", rng.random_range(0..10))); + } + } + RecordBatch::try_new( schema, vec![ - Arc::new(StringArray::from(keys)), - Arc::new(Float32Array::from(vec![i as f32; batch_size])), + Arc::new(keys), + Arc::new(Float32Array::from(vec![batch_index as f32; batch_size])), Arc::new(Float64Array::from(values)), Arc::new(UInt64Array::from(integer_values_wide)), + Arc::new(UInt64Array::from(integer_values_mid)), Arc::new(UInt64Array::from(integer_values_narrow)), + Arc::new(dict_builder.finish()), ], ) .unwrap() @@ -139,21 +141,29 @@ fn create_record_batch( /// Create record batches of `partitions_len` partitions and `batch_size` for each batch, /// with a total number of `array_len` records -#[expect(clippy::needless_pass_by_value)] pub fn create_record_batches( - schema: SchemaRef, + schema: &SchemaRef, array_len: usize, partitions_len: usize, batch_size: usize, ) -> Vec> { let mut rng = StdRng::seed_from_u64(42); - (0..partitions_len) - .map(|_| { - (0..array_len / batch_size / partitions_len) - .map(|i| create_record_batch(schema.clone(), &mut rng, batch_size, i)) - .collect::>() - }) - .collect::>() + let mut partitions = Vec::with_capacity(partitions_len); + let batches_per_partition = array_len / batch_size / partitions_len; + + for _ in 0..partitions_len { + let mut batches = Vec::with_capacity(batches_per_partition); + for batch_index in 0..batches_per_partition { + batches.push(create_record_batch( + schema.clone(), + &mut rng, + batch_size, + batch_index, + )); + } + partitions.push(batches); + } + partitions } /// An enum that wraps either a regular StringBuilder or a GenericByteViewBuilder @@ -183,6 +193,7 @@ impl TraceIdBuilder { /// Create time series data with `partition_cnt` partitions and `sample_cnt` rows per partition /// in ascending order, if `asc` is true, otherwise randomly sampled using a Pareto distribution +#[expect(clippy::allow_attributes)] // some issue where expect(dead_code) doesn't fire properly #[allow(dead_code)] pub(crate) fn make_data( partition_cnt: i32, diff --git a/datafusion/core/benches/dataframe.rs b/datafusion/core/benches/dataframe.rs index 726187ab5e92..5aeade315cc7 100644 --- a/datafusion/core/benches/dataframe.rs +++ b/datafusion/core/benches/dataframe.rs @@ -15,13 +15,8 @@ // specific language governing permissions and limitations // under the License. -extern crate arrow; -#[macro_use] -extern crate criterion; -extern crate datafusion; - use arrow_schema::{DataType, Field, Schema}; -use criterion::Criterion; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::datasource::MemTable; use datafusion::prelude::SessionContext; use datafusion_expr::col; diff --git a/datafusion/core/benches/distinct_query_sql.rs b/datafusion/core/benches/distinct_query_sql.rs index 0e638e293d8c..d389b1b3d6a2 100644 --- a/datafusion/core/benches/distinct_query_sql.rs +++ b/datafusion/core/benches/distinct_query_sql.rs @@ -15,13 +15,9 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -extern crate arrow; -extern crate datafusion; - mod data_utils; -use crate::criterion::Criterion; + +use criterion::{Criterion, criterion_group, criterion_main}; use data_utils::{create_table_provider, make_data}; use datafusion::execution::context::SessionContext; use datafusion::physical_plan::{ExecutionPlan, collect}; diff --git a/datafusion/core/benches/math_query_sql.rs b/datafusion/core/benches/math_query_sql.rs index 4d1d4abb6783..f5df56e95a2d 100644 --- a/datafusion/core/benches/math_query_sql.rs +++ b/datafusion/core/benches/math_query_sql.rs @@ -15,18 +15,13 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -use criterion::Criterion; +use criterion::{Criterion, criterion_group, criterion_main}; use parking_lot::Mutex; use std::sync::Arc; use tokio::runtime::Runtime; -extern crate arrow; -extern crate datafusion; - use arrow::{ array::{Float32Array, Float64Array}, datatypes::{DataType, Field, Schema}, diff --git a/datafusion/core/benches/parquet_query_sql.rs b/datafusion/core/benches/parquet_query_sql.rs index e44524127bf1..f09913797359 100644 --- a/datafusion/core/benches/parquet_query_sql.rs +++ b/datafusion/core/benches/parquet_query_sql.rs @@ -45,7 +45,7 @@ const NUM_BATCHES: usize = 2048; /// The number of rows in each record batch to write const WRITE_RECORD_BATCH_SIZE: usize = 1024; /// The number of rows in a row group -const ROW_GROUP_SIZE: usize = 1024 * 1024; +const ROW_GROUP_ROW_COUNT: usize = 1024 * 1024; /// The number of row groups expected const EXPECTED_ROW_GROUPS: usize = 2; @@ -154,7 +154,7 @@ fn generate_file() -> NamedTempFile { let properties = WriterProperties::builder() .set_writer_version(WriterVersion::PARQUET_2_0) - .set_max_row_group_size(ROW_GROUP_SIZE) + .set_max_row_group_row_count(Some(ROW_GROUP_ROW_COUNT)) .build(); let mut writer = diff --git a/datafusion/core/benches/parquet_struct_query.rs b/datafusion/core/benches/parquet_struct_query.rs new file mode 100644 index 000000000000..e7e91f0dd0e1 --- /dev/null +++ b/datafusion/core/benches/parquet_struct_query.rs @@ -0,0 +1,312 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmarks of SQL queries on struct columns in parquet data + +use arrow::array::{ArrayRef, Int32Array, StringArray, StructArray}; +use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion::prelude::SessionContext; +use datafusion_common::instant::Instant; +use parquet::arrow::ArrowWriter; +use parquet::file::properties::{WriterProperties, WriterVersion}; +use rand::distr::Alphanumeric; +use rand::prelude::*; +use rand::rng; +use std::hint::black_box; +use std::ops::Range; +use std::path::Path; +use std::sync::Arc; +use tempfile::NamedTempFile; +use tokio::runtime::Runtime; + +/// The number of batches to write +const NUM_BATCHES: usize = 128; +/// The number of rows in each record batch to write +const WRITE_RECORD_BATCH_SIZE: usize = 4096; +/// The number of rows in a row group +const ROW_GROUP_ROW_COUNT: usize = 65536; +/// The number of row groups expected +const EXPECTED_ROW_GROUPS: usize = 8; +/// The range for random string lengths +const STRING_LENGTH_RANGE: Range = 50..200; + +fn schema() -> SchemaRef { + let struct_fields = Fields::from(vec![ + Field::new("id", DataType::Int32, false), + Field::new("value", DataType::Utf8, false), + ]); + let struct_type = DataType::Struct(struct_fields); + + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("s", struct_type, false), + ])) +} + +fn generate_strings(len: usize) -> ArrayRef { + let mut rng = rng(); + Arc::new(StringArray::from_iter((0..len).map(|_| { + let string_len = rng.random_range(STRING_LENGTH_RANGE.clone()); + Some( + (0..string_len) + .map(|_| char::from(rng.sample(Alphanumeric))) + .collect::(), + ) + }))) +} + +fn generate_batch(batch_id: usize) -> RecordBatch { + let schema = schema(); + let len = WRITE_RECORD_BATCH_SIZE; + + // Generate sequential IDs based on batch_id for uniqueness + let base_id = (batch_id * len) as i32; + let id_values: Vec = (0..len).map(|i| base_id + i as i32).collect(); + let id_array = Arc::new(Int32Array::from(id_values.clone())); + + // Create struct id array (matching top-level id) + let struct_id_array = Arc::new(Int32Array::from(id_values)); + + // Generate random strings for struct value field + let value_array = generate_strings(len); + + // Construct StructArray + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("id", DataType::Int32, false)), + struct_id_array as ArrayRef, + ), + ( + Arc::new(Field::new("value", DataType::Utf8, false)), + value_array, + ), + ]); + + RecordBatch::try_new(schema, vec![id_array, Arc::new(struct_array)]).unwrap() +} + +fn generate_file() -> NamedTempFile { + let now = Instant::now(); + let mut named_file = tempfile::Builder::new() + .prefix("parquet_struct_query") + .suffix(".parquet") + .tempfile() + .unwrap(); + + println!("Generating parquet file - {}", named_file.path().display()); + let schema = schema(); + + let properties = WriterProperties::builder() + .set_writer_version(WriterVersion::PARQUET_2_0) + .set_max_row_group_row_count(Some(ROW_GROUP_ROW_COUNT)) + .build(); + + let mut writer = + ArrowWriter::try_new(&mut named_file, schema, Some(properties)).unwrap(); + + for batch_id in 0..NUM_BATCHES { + let batch = generate_batch(batch_id); + writer.write(&batch).unwrap(); + } + + let metadata = writer.close().unwrap(); + let file_metadata = metadata.file_metadata(); + let expected_rows = WRITE_RECORD_BATCH_SIZE * NUM_BATCHES; + assert_eq!( + file_metadata.num_rows() as usize, + expected_rows, + "Expected {} rows but got {}", + expected_rows, + file_metadata.num_rows() + ); + assert_eq!( + metadata.row_groups().len(), + EXPECTED_ROW_GROUPS, + "Expected {} row groups but got {}", + EXPECTED_ROW_GROUPS, + metadata.row_groups().len() + ); + + println!( + "Generated parquet file with {} rows and {} row groups in {} seconds", + file_metadata.num_rows(), + metadata.row_groups().len(), + now.elapsed().as_secs_f32() + ); + + named_file +} + +fn create_context(file_path: &str) -> SessionContext { + let ctx = SessionContext::new(); + let rt = Runtime::new().unwrap(); + rt.block_on(ctx.register_parquet("t", file_path, Default::default())) + .unwrap(); + ctx +} + +fn query(ctx: &SessionContext, rt: &Runtime, sql: &str) { + let ctx = ctx.clone(); + let sql = sql.to_string(); + let df = rt.block_on(ctx.sql(&sql)).unwrap(); + black_box(rt.block_on(df.collect()).unwrap()); +} + +fn criterion_benchmark(c: &mut Criterion) { + let (file_path, temp_file) = match std::env::var("PARQUET_FILE") { + Ok(file) => (file, None), + Err(_) => { + let temp_file = generate_file(); + (temp_file.path().display().to_string(), Some(temp_file)) + } + }; + + assert!(Path::new(&file_path).exists(), "path not found"); + println!("Using parquet file {file_path}"); + + let ctx = create_context(&file_path); + let rt = Runtime::new().unwrap(); + + // Basic struct access + c.bench_function("struct_access", |b| { + b.iter(|| query(&ctx, &rt, "select id, s['id'] from t")) + }); + + // Filter queries + c.bench_function("filter_struct_field_eq", |b| { + b.iter(|| query(&ctx, &rt, "select id from t where s['id'] = 5")) + }); + + c.bench_function("filter_struct_field_with_select", |b| { + b.iter(|| query(&ctx, &rt, "select id, s['id'] from t where s['id'] = 5")) + }); + + c.bench_function("filter_top_level_with_struct_select", |b| { + b.iter(|| query(&ctx, &rt, "select s['id'] from t where id = 5")) + }); + + c.bench_function("filter_struct_string_length", |b| { + b.iter(|| query(&ctx, &rt, "select id from t where length(s['value']) > 100")) + }); + + c.bench_function("filter_struct_range", |b| { + b.iter(|| { + query( + &ctx, + &rt, + "select id from t where s['id'] > 100 and s['id'] < 200", + ) + }) + }); + + // Join queries (limited with WHERE id < 1000 for performance) + c.bench_function("join_struct_to_struct", |b| { + b.iter(|| query( + &ctx, + &rt, + "select t1.id from t t1 join t t2 on t1.s['id'] = t2.s['id'] where t1.id < 1000" + )) + }); + + c.bench_function("join_struct_to_toplevel", |b| { + b.iter(|| query( + &ctx, + &rt, + "select t1.id from t t1 join t t2 on t1.s['id'] = t2.id where t1.id < 1000" + )) + }); + + c.bench_function("join_toplevel_to_struct", |b| { + b.iter(|| query( + &ctx, + &rt, + "select t1.id from t t1 join t t2 on t1.id = t2.s['id'] where t1.id < 1000" + )) + }); + + c.bench_function("join_struct_to_struct_with_top_level", |b| { + b.iter(|| query( + &ctx, + &rt, + "select t1.id from t t1 join t t2 on t1.s['id'] = t2.s['id'] and t1.id = t2.id where t1.id < 1000" + )) + }); + + c.bench_function("join_struct_and_struct_value", |b| { + b.iter(|| query( + &ctx, + &rt, + "select t1.s['id'], t2.s['value'] from t t1 join t t2 on t1.id = t2.id where t1.id < 1000" + )) + }); + + // Group by queries + c.bench_function("group_by_struct_field", |b| { + b.iter(|| query(&ctx, &rt, "select s['id'] from t group by s['id']")) + }); + + c.bench_function("group_by_struct_select_toplevel", |b| { + b.iter(|| query(&ctx, &rt, "select max(id) from t group by s['id']")) + }); + + c.bench_function("group_by_toplevel_select_struct", |b| { + b.iter(|| query(&ctx, &rt, "select max(s['id']) from t group by id")) + }); + + c.bench_function("group_by_struct_with_count", |b| { + b.iter(|| { + query( + &ctx, + &rt, + "select s['id'], count(*) from t group by s['id']", + ) + }) + }); + + c.bench_function("group_by_multiple_with_count", |b| { + b.iter(|| { + query( + &ctx, + &rt, + "select id, s['id'], count(*) from t group by id, s['id']", + ) + }) + }); + + // Additional queries + c.bench_function("order_by_struct_limit", |b| { + b.iter(|| { + query( + &ctx, + &rt, + "select id, s['id'] from t order by s['id'] limit 1000", + ) + }) + }); + + c.bench_function("distinct_struct_field", |b| { + b.iter(|| query(&ctx, &rt, "select distinct s['id'] from t")) + }); + + // Temporary file must outlive the benchmarks, it is deleted when dropped + drop(temp_file); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/core/benches/physical_plan.rs b/datafusion/core/benches/physical_plan.rs index e6763b4761c2..7b66996b0592 100644 --- a/datafusion/core/benches/physical_plan.rs +++ b/datafusion/core/benches/physical_plan.rs @@ -15,11 +15,7 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -use criterion::{BatchSize, Criterion}; -extern crate arrow; -extern crate datafusion; +use criterion::{BatchSize, Criterion, criterion_group, criterion_main}; use std::sync::Arc; diff --git a/datafusion/core/benches/preserve_file_partitioning.rs b/datafusion/core/benches/preserve_file_partitioning.rs index 17ebca52cd1d..9b1f59adc682 100644 --- a/datafusion/core/benches/preserve_file_partitioning.rs +++ b/datafusion/core/benches/preserve_file_partitioning.rs @@ -322,7 +322,7 @@ async fn save_plans( } } -#[allow(clippy::too_many_arguments)] +#[expect(clippy::too_many_arguments)] fn run_benchmark( c: &mut Criterion, rt: &Runtime, diff --git a/datafusion/core/benches/push_down_filter.rs b/datafusion/core/benches/push_down_filter.rs index 3c2199c708de..d41085907dbc 100644 --- a/datafusion/core/benches/push_down_filter.rs +++ b/datafusion/core/benches/push_down_filter.rs @@ -25,9 +25,9 @@ use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_optimizer::filter_pushdown::FilterPushdown; use datafusion_physical_plan::ExecutionPlan; -use object_store::ObjectStore; use object_store::memory::InMemory; use object_store::path::Path; +use object_store::{ObjectStore, ObjectStoreExt}; use parquet::arrow::ArrowWriter; use std::sync::Arc; diff --git a/datafusion/core/benches/range_and_generate_series.rs b/datafusion/core/benches/range_and_generate_series.rs index 2b1463a21062..10d560df0813 100644 --- a/datafusion/core/benches/range_and_generate_series.rs +++ b/datafusion/core/benches/range_and_generate_series.rs @@ -15,13 +15,9 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -extern crate datafusion; - mod data_utils; -use crate::criterion::Criterion; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::execution::context::SessionContext; use parking_lot::Mutex; use std::hint::black_box; diff --git a/datafusion/core/benches/reset_plan_states.rs b/datafusion/core/benches/reset_plan_states.rs new file mode 100644 index 000000000000..5afae7f43242 --- /dev/null +++ b/datafusion/core/benches/reset_plan_states.rs @@ -0,0 +1,200 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::{Arc, LazyLock}; + +use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion::prelude::SessionContext; +use datafusion_catalog::MemTable; +use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::displayable; +use datafusion_physical_plan::execution_plan::reset_plan_states; +use tokio::runtime::Runtime; + +const NUM_FIELDS: usize = 1000; +const PREDICATE_LEN: usize = 50; + +static SCHEMA: LazyLock = LazyLock::new(|| { + Arc::new(Schema::new( + (0..NUM_FIELDS) + .map(|i| Arc::new(Field::new(format!("x_{i}"), DataType::Int64, false))) + .collect::(), + )) +}); + +fn col_name(i: usize) -> String { + format!("x_{i}") +} + +fn aggr_name(i: usize) -> String { + format!("aggr_{i}") +} + +fn physical_plan( + ctx: &SessionContext, + rt: &Runtime, + sql: &str, +) -> Arc { + rt.block_on(async { + ctx.sql(sql) + .await + .unwrap() + .create_physical_plan() + .await + .unwrap() + }) +} + +fn predicate(col_name: impl Fn(usize) -> String, len: usize) -> String { + let mut predicate = String::new(); + for i in 0..len { + if i > 0 { + predicate.push_str(" AND "); + } + predicate.push_str(&col_name(i)); + predicate.push_str(" = "); + predicate.push_str(&i.to_string()); + } + predicate +} + +/// Returns a typical plan for the query like: +/// +/// ```sql +/// SELECT aggr1(col1) as aggr1, aggr2(col2) as aggr2 FROM t +/// WHERE p1 +/// HAVING p2 +/// ``` +/// +/// Where `p1` and `p2` some long predicates. +/// +fn query1() -> String { + let mut query = String::new(); + query.push_str("SELECT "); + for i in 0..NUM_FIELDS { + if i > 0 { + query.push_str(", "); + } + query.push_str("AVG("); + query.push_str(&col_name(i)); + query.push_str(") AS "); + query.push_str(&aggr_name(i)); + } + query.push_str(" FROM t WHERE "); + query.push_str(&predicate(col_name, PREDICATE_LEN)); + query.push_str(" HAVING "); + query.push_str(&predicate(aggr_name, PREDICATE_LEN)); + query +} + +/// Returns a typical plan for the query like: +/// +/// ```sql +/// SELECT projection FROM t JOIN v ON t.a = v.a +/// WHERE p1 +/// ``` +/// +fn query2() -> String { + let mut query = String::new(); + query.push_str("SELECT "); + for i in (0..NUM_FIELDS).step_by(2) { + if i > 0 { + query.push_str(", "); + } + if (i / 2) % 2 == 0 { + query.push_str(&format!("t.{}", col_name(i))); + } else { + query.push_str(&format!("v.{}", col_name(i))); + } + } + query.push_str(" FROM t JOIN v ON t.x_0 = v.x_0 WHERE "); + + fn qualified_name(i: usize) -> String { + format!("t.{}", col_name(i)) + } + + query.push_str(&predicate(qualified_name, PREDICATE_LEN)); + query +} + +/// Returns a typical plan for the query like: +/// +/// ```sql +/// SELECT projection FROM t +/// WHERE p +/// ``` +/// +fn query3() -> String { + let mut query = String::new(); + query.push_str("SELECT "); + + // Create non-trivial projection. + for i in 0..NUM_FIELDS / 2 { + if i > 0 { + query.push_str(", "); + } + query.push_str(&col_name(i * 2)); + query.push_str(" + "); + query.push_str(&col_name(i * 2 + 1)); + } + + query.push_str(" FROM t WHERE "); + query.push_str(&predicate(col_name, PREDICATE_LEN)); + query +} + +fn run_reset_states(b: &mut criterion::Bencher, plan: &Arc) { + b.iter(|| std::hint::black_box(reset_plan_states(Arc::clone(plan)).unwrap())); +} + +/// Benchmark is intended to measure overhead of actions, required to perform +/// making an independent instance of the execution plan to re-execute it, avoiding +/// re-planning stage. +fn bench_reset_plan_states(c: &mut Criterion) { + env_logger::init(); + + let rt = Runtime::new().unwrap(); + let ctx = SessionContext::new(); + ctx.register_table( + "t", + Arc::new(MemTable::try_new(Arc::clone(&SCHEMA), vec![vec![], vec![]]).unwrap()), + ) + .unwrap(); + + ctx.register_table( + "v", + Arc::new(MemTable::try_new(Arc::clone(&SCHEMA), vec![vec![], vec![]]).unwrap()), + ) + .unwrap(); + + macro_rules! bench_query { + ($query_producer: expr) => {{ + let sql = $query_producer(); + let plan = physical_plan(&ctx, &rt, &sql); + log::debug!("plan:\n{}", displayable(plan.as_ref()).indent(true)); + move |b| run_reset_states(b, &plan) + }}; + } + + c.bench_function("query1", bench_query!(query1)); + c.bench_function("query2", bench_query!(query2)); + c.bench_function("query3", bench_query!(query3)); +} + +criterion_group!(benches, bench_reset_plan_states); +criterion_main!(benches); diff --git a/datafusion/core/benches/sort_limit_query_sql.rs b/datafusion/core/benches/sort_limit_query_sql.rs index c18070fb7725..54cd9a0bcd54 100644 --- a/datafusion/core/benches/sort_limit_query_sql.rs +++ b/datafusion/core/benches/sort_limit_query_sql.rs @@ -15,9 +15,7 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -use criterion::Criterion; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::datasource::file_format::csv::CsvFormat; use datafusion::datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, @@ -27,9 +25,6 @@ use datafusion::prelude::SessionConfig; use parking_lot::Mutex; use std::sync::Arc; -extern crate arrow; -extern crate datafusion; - use arrow::datatypes::{DataType, Field, Schema}; use datafusion::datasource::MemTable; diff --git a/datafusion/core/benches/spm.rs b/datafusion/core/benches/spm.rs index 9db1306d2bd1..afd384f7b170 100644 --- a/datafusion/core/benches/spm.rs +++ b/datafusion/core/benches/spm.rs @@ -66,10 +66,9 @@ fn generate_spm_for_round_robin_tie_breaker( RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap() }; - let rbs = (0..batch_count).map(|_| rb.clone()).collect::>(); - let partitions = vec![rbs.clone(); partition_count]; - let schema = rb.schema(); + let rbs = std::iter::repeat_n(rb, batch_count).collect::>(); + let partitions = vec![rbs.clone(); partition_count]; let sort = [ PhysicalSortExpr { expr: col("b", &schema).unwrap(), diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index 7cce7e0bd7db..59502da98790 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -15,20 +15,15 @@ // specific language governing permissions and limitations // under the License. -extern crate arrow; -#[macro_use] -extern crate criterion; -extern crate datafusion; - mod data_utils; -use crate::criterion::Criterion; use arrow::array::PrimitiveArray; use arrow::array::{ArrayRef, RecordBatch}; use arrow::datatypes::ArrowNativeTypeOp; use arrow::datatypes::ArrowPrimitiveType; use arrow::datatypes::{DataType, Field, Fields, Schema}; use criterion::Bencher; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::datasource::MemTable; use datafusion::execution::context::SessionContext; use datafusion_common::{ScalarValue, config::Dialect}; @@ -78,6 +73,21 @@ fn create_table_provider(column_prefix: &str, num_columns: usize) -> Arc Arc { + let struct_fields = Fields::from(vec![ + Field::new("value", DataType::Int32, true), + Field::new("label", DataType::Utf8, true), + ]); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("props", DataType::Struct(struct_fields), true), + ])); + MemTable::try_new(schema, vec![vec![]]) + .map(Arc::new) + .unwrap() +} + fn create_context() -> SessionContext { let ctx = SessionContext::new(); ctx.register_table("t1", create_table_provider("a", 200)) @@ -88,6 +98,10 @@ fn create_context() -> SessionContext { .unwrap(); ctx.register_table("t1000", create_table_provider("d", 1000)) .unwrap(); + ctx.register_table("struct_t1", create_struct_table_provider()) + .unwrap(); + ctx.register_table("struct_t2", create_struct_table_provider()) + .unwrap(); ctx } @@ -118,6 +132,11 @@ fn register_clickbench_hits_table(rt: &Runtime) -> SessionContext { let sql = format!("CREATE EXTERNAL TABLE hits STORED AS PARQUET LOCATION '{path}'"); + // ClickBench partitioned dataset was written by an ancient version of pyarrow that + // that wrote strings with the wrong logical type. To read it correctly, we must + // automatically convert binary to string. + rt.block_on(ctx.sql("SET datafusion.execution.parquet.binary_as_string = true;")) + .unwrap(); rt.block_on(ctx.sql(&sql)).unwrap(); let count = @@ -419,6 +438,25 @@ fn criterion_benchmark(c: &mut Criterion) { }); }); + let struct_agg_sort_query = "SELECT \ + struct_t1.props['label'], \ + SUM(struct_t1.props['value']), \ + MAX(struct_t2.props['value']), \ + COUNT(*) \ + FROM struct_t1 \ + JOIN struct_t2 ON struct_t1.id = struct_t2.id \ + WHERE struct_t1.props['value'] > 50 \ + GROUP BY struct_t1.props['label'] \ + ORDER BY SUM(struct_t1.props['value']) DESC"; + + // -- Struct column benchmarks -- + c.bench_function("logical_plan_struct_join_agg_sort", |b| { + b.iter(|| logical_plan(&ctx, &rt, struct_agg_sort_query)) + }); + c.bench_function("physical_plan_struct_join_agg_sort", |b| { + b.iter(|| physical_plan(&ctx, &rt, struct_agg_sort_query)) + }); + // -- Sorted Queries -- // 100, 200 && 300 is taking too long - https://github.com/apache/datafusion/issues/18366 // Logical Plan for datatype Int64 and UInt64 differs, UInt64 Logical Plan's Union are wrapped diff --git a/datafusion/core/benches/sql_query_with_io.rs b/datafusion/core/benches/sql_query_with_io.rs index 0c188f7ba104..fc8caf31acd1 100644 --- a/datafusion/core/benches/sql_query_with_io.rs +++ b/datafusion/core/benches/sql_query_with_io.rs @@ -31,7 +31,7 @@ use datafusion::{ use datafusion_execution::runtime_env::RuntimeEnv; use itertools::Itertools; use object_store::{ - ObjectStore, + ObjectStore, ObjectStoreExt, memory::InMemory, path::Path, throttle::{ThrottleConfig, ThrottledStore}, diff --git a/datafusion/core/benches/topk_aggregate.rs b/datafusion/core/benches/topk_aggregate.rs index be193f873713..f71cf1087be7 100644 --- a/datafusion/core/benches/topk_aggregate.rs +++ b/datafusion/core/benches/topk_aggregate.rs @@ -17,6 +17,9 @@ mod data_utils; +use arrow::array::Int64Builder; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; use criterion::{Criterion, criterion_group, criterion_main}; use data_utils::make_data; @@ -24,10 +27,53 @@ use datafusion::physical_plan::{collect, displayable}; use datafusion::prelude::SessionContext; use datafusion::{datasource::MemTable, error::Result}; use datafusion_execution::config::SessionConfig; +use rand::SeedableRng; +use rand::seq::SliceRandom; use std::hint::black_box; use std::sync::Arc; use tokio::runtime::Runtime; +const LIMIT: usize = 10; + +/// Create deterministic data for DISTINCT benchmarks with predictable trace_ids +/// This ensures consistent results across benchmark runs +fn make_distinct_data( + partition_cnt: i32, + sample_cnt: i32, +) -> Result<(Arc, Vec>)> { + let mut rng = rand::rngs::SmallRng::from_seed([42; 32]); + let total_samples = partition_cnt as usize * sample_cnt as usize; + let mut ids = Vec::new(); + for i in 0..total_samples { + ids.push(i as i64); + } + ids.shuffle(&mut rng); + + let mut global_idx = 0; + let schema = test_distinct_schema(); + let mut partitions = vec![]; + for _ in 0..partition_cnt { + let mut id_builder = Int64Builder::new(); + + for _ in 0..sample_cnt { + let id = ids[global_idx]; + id_builder.append_value(id); + global_idx += 1; + } + + let id_col = Arc::new(id_builder.finish()); + let batch = RecordBatch::try_new(schema.clone(), vec![id_col])?; + partitions.push(vec![batch]); + } + + Ok((schema, partitions)) +} + +/// Returns a Schema for distinct benchmarks with i64 trace_id +fn test_distinct_schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)])) +} + async fn create_context( partition_cnt: i32, sample_cnt: i32, @@ -48,10 +94,45 @@ async fn create_context( Ok(ctx) } +async fn create_context_distinct( + partition_cnt: i32, + sample_cnt: i32, + use_topk: bool, +) -> Result { + // Use deterministic data generation for DISTINCT queries to ensure consistent results + let (schema, parts) = make_distinct_data(partition_cnt, sample_cnt).unwrap(); + let mem_table = Arc::new(MemTable::try_new(schema, parts).unwrap()); + + // Create the DataFrame + let mut cfg = SessionConfig::new(); + let opts = cfg.options_mut(); + opts.optimizer.enable_topk_aggregation = use_topk; + let ctx = SessionContext::new_with_config(cfg); + let _ = ctx.register_table("traces", mem_table)?; + + Ok(ctx) +} + fn run(rt: &Runtime, ctx: SessionContext, limit: usize, use_topk: bool, asc: bool) { black_box(rt.block_on(async { aggregate(ctx, limit, use_topk, asc).await })).unwrap(); } +fn run_string(rt: &Runtime, ctx: SessionContext, limit: usize, use_topk: bool) { + black_box(rt.block_on(async { aggregate_string(ctx, limit, use_topk).await })) + .unwrap(); +} + +fn run_distinct( + rt: &Runtime, + ctx: SessionContext, + limit: usize, + use_topk: bool, + asc: bool, +) { + black_box(rt.block_on(async { aggregate_distinct(ctx, limit, use_topk, asc).await })) + .unwrap(); +} + async fn aggregate( ctx: SessionContext, limit: usize, @@ -72,7 +153,7 @@ async fn aggregate( let batches = collect(plan, ctx.task_ctx()).await?; assert_eq!(batches.len(), 1); let batch = batches.first().unwrap(); - assert_eq!(batch.num_rows(), 10); + assert_eq!(batch.num_rows(), LIMIT); let actual = format!("{}", pretty_format_batches(&batches)?).to_lowercase(); let expected_asc = r#" @@ -99,9 +180,114 @@ async fn aggregate( Ok(()) } +/// Benchmark for string aggregate functions with topk optimization. +/// This tests grouping by a numeric column (timestamp_ms) and aggregating +/// a string column (trace_id) with Utf8 or Utf8View data types. +async fn aggregate_string( + ctx: SessionContext, + limit: usize, + use_topk: bool, +) -> Result<()> { + let sql = format!( + "select max(trace_id) from traces group by timestamp_ms order by max(trace_id) desc limit {limit};" + ); + let df = ctx.sql(sql.as_str()).await?; + let plan = df.create_physical_plan().await?; + let actual_phys_plan = displayable(plan.as_ref()).indent(true).to_string(); + assert_eq!( + actual_phys_plan.contains(&format!("lim=[{limit}]")), + use_topk + ); + + let batches = collect(plan, ctx.task_ctx()).await?; + assert_eq!(batches.len(), 1); + let batch = batches.first().unwrap(); + assert_eq!(batch.num_rows(), LIMIT); + + Ok(()) +} + +async fn aggregate_distinct( + ctx: SessionContext, + limit: usize, + use_topk: bool, + asc: bool, +) -> Result<()> { + let order_direction = if asc { "asc" } else { "desc" }; + let sql = format!( + "select id from traces group by id order by id {order_direction} limit {limit};" + ); + let df = ctx.sql(sql.as_str()).await?; + let plan = df.create_physical_plan().await?; + let actual_phys_plan = displayable(plan.as_ref()).indent(true).to_string(); + assert_eq!( + actual_phys_plan.contains(&format!("lim=[{limit}]")), + use_topk + ); + let batches = collect(plan, ctx.task_ctx()).await?; + assert_eq!(batches.len(), 1); + let batch = batches.first().unwrap(); + assert_eq!(batch.num_rows(), LIMIT); + + let actual = format!("{}", pretty_format_batches(&batches)?).to_lowercase(); + + let expected_asc = r#" ++----+ +| id | ++----+ +| 0 | +| 1 | +| 2 | +| 3 | +| 4 | +| 5 | +| 6 | +| 7 | +| 8 | +| 9 | ++----+ +"# + .trim(); + + let expected_desc = r#" ++---------+ +| id | ++---------+ +| 9999999 | +| 9999998 | +| 9999997 | +| 9999996 | +| 9999995 | +| 9999994 | +| 9999993 | +| 9999992 | +| 9999991 | +| 9999990 | ++---------+ +"# + .trim(); + + // Verify exact results match expected values + if asc { + assert_eq!( + actual.trim(), + expected_asc, + "Ascending DISTINCT results do not match expected values" + ); + } else { + assert_eq!( + actual.trim(), + expected_desc, + "Descending DISTINCT results do not match expected values" + ); + } + + Ok(()) +} + fn criterion_benchmark(c: &mut Criterion) { let rt = Runtime::new().unwrap(); - let limit = 10; + let limit = LIMIT; let partitions = 10; let samples = 1_000_000; @@ -170,6 +356,86 @@ fn criterion_benchmark(c: &mut Criterion) { .as_str(), |b| b.iter(|| run(&rt, ctx.clone(), limit, true, true)), ); + + // String aggregate benchmarks - grouping by timestamp, aggregating string column + let ctx = rt + .block_on(create_context(partitions, samples, false, true, false)) + .unwrap(); + c.bench_function( + format!( + "top k={limit} string aggregate {} time-series rows [Utf8]", + partitions * samples + ) + .as_str(), + |b| b.iter(|| run_string(&rt, ctx.clone(), limit, true)), + ); + + let ctx = rt + .block_on(create_context(partitions, samples, true, true, false)) + .unwrap(); + c.bench_function( + format!( + "top k={limit} string aggregate {} worst-case rows [Utf8]", + partitions * samples + ) + .as_str(), + |b| b.iter(|| run_string(&rt, ctx.clone(), limit, true)), + ); + + let ctx = rt + .block_on(create_context(partitions, samples, false, true, true)) + .unwrap(); + c.bench_function( + format!( + "top k={limit} string aggregate {} time-series rows [Utf8View]", + partitions * samples + ) + .as_str(), + |b| b.iter(|| run_string(&rt, ctx.clone(), limit, true)), + ); + + let ctx = rt + .block_on(create_context(partitions, samples, true, true, true)) + .unwrap(); + c.bench_function( + format!( + "top k={limit} string aggregate {} worst-case rows [Utf8View]", + partitions * samples + ) + .as_str(), + |b| b.iter(|| run_string(&rt, ctx.clone(), limit, true)), + ); + + // DISTINCT benchmarks + let ctx = rt.block_on(async { + create_context_distinct(partitions, samples, false) + .await + .unwrap() + }); + c.bench_function( + format!("distinct {} rows desc [no TopK]", partitions * samples).as_str(), + |b| b.iter(|| run_distinct(&rt, ctx.clone(), limit, false, false)), + ); + + c.bench_function( + format!("distinct {} rows asc [no TopK]", partitions * samples).as_str(), + |b| b.iter(|| run_distinct(&rt, ctx.clone(), limit, false, true)), + ); + + let ctx_topk = rt.block_on(async { + create_context_distinct(partitions, samples, true) + .await + .unwrap() + }); + c.bench_function( + format!("distinct {} rows desc [TopK]", partitions * samples).as_str(), + |b| b.iter(|| run_distinct(&rt, ctx_topk.clone(), limit, true, false)), + ); + + c.bench_function( + format!("distinct {} rows asc [TopK]", partitions * samples).as_str(), + |b| b.iter(|| run_distinct(&rt, ctx_topk.clone(), limit, true, true)), + ); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/core/benches/window_query_sql.rs b/datafusion/core/benches/window_query_sql.rs index e4643567a0f0..1657cae913fe 100644 --- a/datafusion/core/benches/window_query_sql.rs +++ b/datafusion/core/benches/window_query_sql.rs @@ -15,14 +15,9 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -extern crate arrow; -extern crate datafusion; - mod data_utils; -use crate::criterion::Criterion; +use criterion::{Criterion, criterion_group, criterion_main}; use data_utils::create_table_provider; use datafusion::error::Result; use datafusion::execution::context::SessionContext; diff --git a/datafusion/core/src/bin/print_functions_docs.rs b/datafusion/core/src/bin/print_functions_docs.rs index 74a10bf079e6..2466d4269219 100644 --- a/datafusion/core/src/bin/print_functions_docs.rs +++ b/datafusion/core/src/bin/print_functions_docs.rs @@ -84,30 +84,6 @@ fn print_window_docs() -> Result { print_docs(providers, window_doc_sections::doc_sections()) } -// Temporary method useful to semi automate -// the migration of UDF documentation generation from code based -// to attribute based -// To be removed -#[allow(dead_code)] -fn save_doc_code_text(documentation: &Documentation, name: &str) { - let attr_text = documentation.to_doc_attribute(); - - let file_path = format!("{name}.txt"); - if std::path::Path::new(&file_path).exists() { - std::fs::remove_file(&file_path).unwrap(); - } - - // Open the file in append mode, create it if it doesn't exist - let mut file = std::fs::OpenOptions::new() - .append(true) // Open in append mode - .create(true) // Create the file if it doesn't exist - .open(file_path) - .unwrap(); - - use std::io::Write; - file.write_all(attr_text.as_bytes()).unwrap(); -} - #[expect(clippy::needless_pass_by_value)] fn print_docs( providers: Vec>, @@ -306,8 +282,7 @@ impl DocProvider for WindowUDF { } } -#[allow(clippy::borrowed_box)] -#[allow(clippy::ptr_arg)] +#[expect(clippy::borrowed_box)] fn get_names_and_aliases(functions: &Vec<&Box>) -> Vec { functions .iter() diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index fe760760eef3..2292f5855bfd 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -78,9 +78,11 @@ pub struct DataFrameWriteOptions { /// Controls how new data should be written to the table, determining whether /// to append, overwrite, or replace existing data. insert_op: InsertOp, - /// Controls if all partitions should be coalesced into a single output file - /// Generally will have slower performance when set to true. - single_file_output: bool, + /// Controls if all partitions should be coalesced into a single output file. + /// - `None`: Use automatic mode (extension-based heuristic) + /// - `Some(true)`: Force single file output at exact path + /// - `Some(false)`: Force directory output with generated filenames + single_file_output: Option, /// Sets which columns should be used for hive-style partitioned writes by name. /// Can be set to empty vec![] for non-partitioned writes. partition_by: Vec, @@ -94,7 +96,7 @@ impl DataFrameWriteOptions { pub fn new() -> Self { DataFrameWriteOptions { insert_op: InsertOp::Append, - single_file_output: false, + single_file_output: None, partition_by: vec![], sort_by: vec![], } @@ -108,9 +110,13 @@ impl DataFrameWriteOptions { /// Set the single_file_output value to true or false /// - /// When set to true, an output file will always be created even if the DataFrame is empty + /// - `true`: Force single file output at the exact path specified + /// - `false`: Force directory output with generated filenames + /// + /// When not called, automatic mode is used (extension-based heuristic). + /// When set to true, an output file will always be created even if the DataFrame is empty. pub fn with_single_file_output(mut self, single_file_output: bool) -> Self { - self.single_file_output = single_file_output; + self.single_file_output = Some(single_file_output); self } @@ -125,6 +131,15 @@ impl DataFrameWriteOptions { self.sort_by = sort_by; self } + + /// Build the options HashMap to pass to CopyTo for sink configuration. + fn build_sink_options(&self) -> HashMap { + let mut options = HashMap::new(); + if let Some(single_file) = self.single_file_output { + options.insert("single_file_output".to_string(), single_file.to_string()); + } + options + } } impl Default for DataFrameWriteOptions { @@ -447,15 +462,31 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` - pub fn drop_columns(self, columns: &[&str]) -> Result { + pub fn drop_columns(self, columns: &[T]) -> Result + where + T: Into + Clone, + { let fields_to_drop = columns .iter() - .flat_map(|name| { - self.plan - .schema() - .qualified_fields_with_unqualified_name(name) + .flat_map(|col| { + let column: Column = col.clone().into(); + match column.relation.as_ref() { + Some(_) => { + // qualified_field_from_column returns Result<(Option<&TableReference>, &FieldRef)> + vec![self.plan.schema().qualified_field_from_column(&column)] + } + None => { + // qualified_fields_with_unqualified_name returns Vec<(Option<&TableReference>, &FieldRef)> + self.plan + .schema() + .qualified_fields_with_unqualified_name(&column.name) + .into_iter() + .map(Ok) + .collect::>() + } + } }) - .collect::>(); + .collect::, _>>()?; let expr: Vec = self .plan .schema() @@ -481,7 +512,7 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_json("tests/data/unnest.json", NdJsonReadOptions::default()).await?; + /// let df = ctx.read_json("tests/data/unnest.json", JsonReadOptions::default()).await?; /// // expand into multiple columns if it's json array, flatten field name if it's nested structure /// let df = df.unnest_columns(&["b","c","d"])?; /// let expected = vec![ @@ -2024,6 +2055,8 @@ impl DataFrame { let file_type = format_as_file_type(format); + let copy_options = options.build_sink_options(); + let plan = if options.sort_by.is_empty() { self.plan } else { @@ -2036,7 +2069,7 @@ impl DataFrame { plan, path.into(), file_type, - HashMap::new(), + copy_options, options.partition_by, )? .build()?; @@ -2092,6 +2125,8 @@ impl DataFrame { let file_type = format_as_file_type(format); + let copy_options = options.build_sink_options(); + let plan = if options.sort_by.is_empty() { self.plan } else { @@ -2104,7 +2139,7 @@ impl DataFrame { plan, path.into(), file_type, - Default::default(), + copy_options, options.partition_by, )? .build()?; @@ -2465,6 +2500,48 @@ impl DataFrame { .collect() } + /// Find qualified columns for this dataframe from names + /// + /// # Arguments + /// * `names` - Unqualified names to find. + /// + /// # Example + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # use datafusion_common::ScalarValue; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let ctx = SessionContext::new(); + /// ctx.register_csv("first_table", "tests/data/example.csv", CsvReadOptions::new()) + /// .await?; + /// let df = ctx.table("first_table").await?; + /// ctx.register_csv("second_table", "tests/data/example.csv", CsvReadOptions::new()) + /// .await?; + /// let df2 = ctx.table("second_table").await?; + /// let join_expr = df.find_qualified_columns(&["a"])?.iter() + /// .zip(df2.find_qualified_columns(&["a"])?.iter()) + /// .map(|(col1, col2)| col(*col1).eq(col(*col2))) + /// .collect::>(); + /// let df3 = df.join_on(df2, JoinType::Inner, join_expr)?; + /// # Ok(()) + /// # } + /// ``` + pub fn find_qualified_columns( + &self, + names: &[&str], + ) -> Result, &FieldRef)>> { + let schema = self.logical_plan().schema(); + names + .iter() + .map(|name| { + schema + .qualified_field_from_column(&Column::from_name(*name)) + .map_err(|_| plan_datafusion_err!("Column '{}' not found", name)) + }) + .collect() + } + /// Helper for creating DataFrame. /// # Example /// ``` diff --git a/datafusion/core/src/dataframe/parquet.rs b/datafusion/core/src/dataframe/parquet.rs index 6edf628e2d6d..54dadfd78cbc 100644 --- a/datafusion/core/src/dataframe/parquet.rs +++ b/datafusion/core/src/dataframe/parquet.rs @@ -76,6 +76,8 @@ impl DataFrame { let file_type = format_as_file_type(format); + let copy_options = options.build_sink_options(); + let plan = if options.sort_by.is_empty() { self.plan } else { @@ -88,7 +90,7 @@ impl DataFrame { plan, path.into(), file_type, - Default::default(), + copy_options, options.partition_by, )? .build()?; @@ -324,4 +326,156 @@ mod tests { Ok(()) } + + /// Test FileOutputMode::SingleFile - explicitly request single file output + /// for paths WITHOUT file extensions. This verifies the fix for the regression + /// where extension heuristics ignored the explicit with_single_file_output(true). + #[tokio::test] + async fn test_file_output_mode_single_file() -> Result<()> { + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + + let ctx = SessionContext::new(); + let tmp_dir = TempDir::new()?; + + // Path WITHOUT .parquet extension - this is the key scenario + let output_path = tmp_dir.path().join("data_no_ext"); + let output_path_str = output_path.to_str().unwrap(); + + let df = ctx.read_batch(RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + )?)?; + + // Explicitly request single file output + df.write_parquet( + output_path_str, + DataFrameWriteOptions::new().with_single_file_output(true), + None, + ) + .await?; + + // Verify: output should be a FILE, not a directory + assert!( + output_path.is_file(), + "Expected single file at {:?}, but got is_file={}, is_dir={}", + output_path, + output_path.is_file(), + output_path.is_dir() + ); + + // Verify the file is readable as parquet + let file = std::fs::File::open(&output_path)?; + let reader = parquet::file::reader::SerializedFileReader::new(file)?; + let metadata = reader.metadata(); + assert_eq!(metadata.num_row_groups(), 1); + assert_eq!(metadata.file_metadata().num_rows(), 3); + + Ok(()) + } + + /// Test FileOutputMode::Automatic - uses extension heuristic. + /// Path WITH extension -> single file; path WITHOUT extension -> directory. + #[tokio::test] + async fn test_file_output_mode_automatic() -> Result<()> { + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + + let ctx = SessionContext::new(); + let tmp_dir = TempDir::new()?; + + let schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + )?; + + // Case 1: Path WITH extension -> should create single file (Automatic mode) + let output_with_ext = tmp_dir.path().join("data.parquet"); + let df = ctx.read_batch(batch.clone())?; + df.write_parquet( + output_with_ext.to_str().unwrap(), + DataFrameWriteOptions::new(), // Automatic mode (default) + None, + ) + .await?; + + assert!( + output_with_ext.is_file(), + "Path with extension should be a single file, got is_file={}, is_dir={}", + output_with_ext.is_file(), + output_with_ext.is_dir() + ); + + // Case 2: Path WITHOUT extension -> should create directory (Automatic mode) + let output_no_ext = tmp_dir.path().join("data_dir"); + let df = ctx.read_batch(batch)?; + df.write_parquet( + output_no_ext.to_str().unwrap(), + DataFrameWriteOptions::new(), // Automatic mode (default) + None, + ) + .await?; + + assert!( + output_no_ext.is_dir(), + "Path without extension should be a directory, got is_file={}, is_dir={}", + output_no_ext.is_file(), + output_no_ext.is_dir() + ); + + Ok(()) + } + + /// Test FileOutputMode::Directory - explicitly request directory output + /// even for paths WITH file extensions. + #[tokio::test] + async fn test_file_output_mode_directory() -> Result<()> { + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + + let ctx = SessionContext::new(); + let tmp_dir = TempDir::new()?; + + // Path WITH .parquet extension but explicitly requesting directory output + let output_path = tmp_dir.path().join("output.parquet"); + let output_path_str = output_path.to_str().unwrap(); + + let df = ctx.read_batch(RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + )?)?; + + // Explicitly request directory output (single_file_output = false) + df.write_parquet( + output_path_str, + DataFrameWriteOptions::new().with_single_file_output(false), + None, + ) + .await?; + + // Verify: output should be a DIRECTORY, not a single file + assert!( + output_path.is_dir(), + "Expected directory at {:?}, but got is_file={}, is_dir={}", + output_path, + output_path.is_file(), + output_path.is_dir() + ); + + // Verify the directory contains parquet file(s) + let entries: Vec<_> = std::fs::read_dir(&output_path)? + .filter_map(|e| e.ok()) + .collect(); + assert!( + !entries.is_empty(), + "Directory should contain at least one file" + ); + + Ok(()) + } } diff --git a/datafusion/core/src/datasource/file_format/avro.rs b/datafusion/core/src/datasource/file_format/avro.rs index cad35d43db48..7cf23ee294d8 100644 --- a/datafusion/core/src/datasource/file_format/avro.rs +++ b/datafusion/core/src/datasource/file_format/avro.rs @@ -95,7 +95,7 @@ mod tests { .schema() .fields() .iter() - .map(|f| format!("{}: {:?}", f.name(), f.data_type())) + .map(|f| format!("{}: {}", f.name(), f.data_type())) .collect(); assert_eq!( vec![ @@ -109,7 +109,7 @@ mod tests { "double_col: Float64", "date_string_col: Binary", "string_col: Binary", - "timestamp_col: Timestamp(Microsecond, None)", + "timestamp_col: Timestamp(µs)", ], x ); diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index aa226144a4af..51d799a5b65c 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -65,7 +65,8 @@ mod tests { use object_store::path::Path; use object_store::{ Attributes, GetOptions, GetResult, GetResultPayload, ListResult, MultipartUpload, - ObjectMeta, ObjectStore, PutMultipartOptions, PutOptions, PutPayload, PutResult, + ObjectMeta, ObjectStore, ObjectStoreExt, PutMultipartOptions, PutOptions, + PutPayload, PutResult, }; use regex::Regex; use rstest::*; @@ -104,10 +105,6 @@ mod tests { unimplemented!() } - async fn get(&self, location: &Path) -> object_store::Result { - self.get_opts(location, GetOptions::default()).await - } - async fn get_opts( &self, location: &Path, @@ -147,14 +144,6 @@ mod tests { unimplemented!() } - async fn head(&self, _location: &Path) -> object_store::Result { - unimplemented!() - } - - async fn delete(&self, _location: &Path) -> object_store::Result<()> { - unimplemented!() - } - fn list( &self, _prefix: Option<&Path>, @@ -169,17 +158,21 @@ mod tests { unimplemented!() } - async fn copy(&self, _from: &Path, _to: &Path) -> object_store::Result<()> { - unimplemented!() - } - - async fn copy_if_not_exists( + async fn copy_opts( &self, _from: &Path, _to: &Path, + _options: object_store::CopyOptions, ) -> object_store::Result<()> { unimplemented!() } + + fn delete_stream( + &self, + _locations: BoxStream<'static, object_store::Result>, + ) -> BoxStream<'static, object_store::Result> { + unimplemented!() + } } impl VariableStream { diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index cb2e9d787ee9..5b3e22705620 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -25,7 +25,7 @@ mod tests { use super::*; use crate::datasource::file_format::test_util::scan_format; - use crate::prelude::{NdJsonReadOptions, SessionConfig, SessionContext}; + use crate::prelude::{SessionConfig, SessionContext}; use crate::test::object_store::local_unpartitioned_file; use arrow::array::RecordBatch; use arrow_schema::Schema; @@ -46,12 +46,54 @@ mod tests { use datafusion_common::internal_err; use datafusion_common::stats::Precision; + use crate::execution::options::JsonReadOptions; use datafusion_common::Result; + use datafusion_datasource::file_compression_type::FileCompressionType; use futures::StreamExt; use insta::assert_snapshot; use object_store::local::LocalFileSystem; use regex::Regex; use rstest::rstest; + // ==================== Test Helpers ==================== + + /// Create a temporary JSON file and return (TempDir, path) + fn create_temp_json(content: &str) -> (tempfile::TempDir, String) { + let tmp_dir = tempfile::TempDir::new().unwrap(); + let path = tmp_dir.path().join("test.json"); + std::fs::write(&path, content).unwrap(); + (tmp_dir, path.to_string_lossy().to_string()) + } + + /// Infer schema from JSON array format file + async fn infer_json_array_schema( + content: &str, + ) -> Result { + let (_tmp_dir, path) = create_temp_json(content); + let session = SessionContext::new(); + let ctx = session.state(); + let store = Arc::new(LocalFileSystem::new()) as _; + let format = JsonFormat::default().with_newline_delimited(false); + format + .infer_schema(&ctx, &store, &[local_unpartitioned_file(&path)]) + .await + } + + /// Register a JSON array table and run a query + async fn query_json_array(content: &str, query: &str) -> Result> { + let (_tmp_dir, path) = create_temp_json(content); + let ctx = SessionContext::new(); + let options = JsonReadOptions::default().newline_delimited(false); + ctx.register_json("test_table", &path, options).await?; + ctx.sql(query).await?.collect().await + } + + /// Register a JSON array table and run a query, return formatted string + async fn query_json_array_str(content: &str, query: &str) -> Result { + let result = query_json_array(content, query).await?; + Ok(batches_to_string(&result)) + } + + // ==================== Existing Tests ==================== #[tokio::test] async fn read_small_batches() -> Result<()> { @@ -208,7 +250,7 @@ mod tests { let ctx = SessionContext::new_with_config(config); let table_path = "tests/data/1.json"; - let options = NdJsonReadOptions::default(); + let options = JsonReadOptions::default(); ctx.register_json("json_parallel", table_path, options) .await?; @@ -240,7 +282,7 @@ mod tests { let ctx = SessionContext::new_with_config(config); let table_path = "tests/data/empty.json"; - let options = NdJsonReadOptions::default(); + let options = JsonReadOptions::default(); ctx.register_json("json_parallel_empty", table_path, options) .await?; @@ -314,7 +356,6 @@ mod tests { .digest(r#"{ "c1": 11, "c2": 12, "c3": 13, "c4": 14, "c5": 15 }"#.into()); let mut all_batches = RecordBatch::new_empty(schema.clone()); - // We get RequiresMoreData after 2 batches because of how json::Decoder works for _ in 0..2 { let output = deserializer.next()?; let DeserializerOutput::RecordBatch(batch) = output else { @@ -354,11 +395,11 @@ mod tests { async fn test_write_empty_json_from_sql() -> Result<()> { let ctx = SessionContext::new(); let tmp_dir = tempfile::TempDir::new()?; - let path = format!("{}/empty_sql.json", tmp_dir.path().to_string_lossy()); + let path = tmp_dir.path().join("empty_sql.json"); + let path = path.to_string_lossy().to_string(); let df = ctx.sql("SELECT CAST(1 AS BIGINT) AS id LIMIT 0").await?; df.write_json(&path, crate::dataframe::DataFrameWriteOptions::new(), None) .await?; - // Expected the file to exist and be empty assert!(std::path::Path::new(&path).exists()); let metadata = std::fs::metadata(&path)?; assert_eq!(metadata.len(), 0); @@ -381,14 +422,216 @@ mod tests { )?; let tmp_dir = tempfile::TempDir::new()?; - let path = format!("{}/empty_batch.json", tmp_dir.path().to_string_lossy()); + let path = tmp_dir.path().join("empty_batch.json"); + let path = path.to_string_lossy().to_string(); let df = ctx.read_batch(empty_batch.clone())?; df.write_json(&path, crate::dataframe::DataFrameWriteOptions::new(), None) .await?; - // Expected the file to exist and be empty assert!(std::path::Path::new(&path).exists()); let metadata = std::fs::metadata(&path)?; assert_eq!(metadata.len(), 0); Ok(()) } + + // ==================== JSON Array Format Tests ==================== + + #[tokio::test] + async fn test_json_array_schema_inference() -> Result<()> { + let schema = infer_json_array_schema( + r#"[{"a": 1, "b": 2.0, "c": true}, {"a": 2, "b": 3.5, "c": false}]"#, + ) + .await?; + + let fields: Vec<_> = schema + .fields() + .iter() + .map(|f| format!("{}: {:?}", f.name(), f.data_type())) + .collect(); + assert_eq!(vec!["a: Int64", "b: Float64", "c: Boolean"], fields); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_empty() -> Result<()> { + let schema = infer_json_array_schema("[]").await?; + assert_eq!(schema.fields().len(), 0); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_nested_struct() -> Result<()> { + let schema = infer_json_array_schema( + r#"[{"id": 1, "info": {"name": "Alice", "age": 30}}]"#, + ) + .await?; + + let info_field = schema.field_with_name("info").unwrap(); + assert!(matches!(info_field.data_type(), DataType::Struct(_))); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_list_type() -> Result<()> { + let schema = + infer_json_array_schema(r#"[{"id": 1, "tags": ["a", "b", "c"]}]"#).await?; + + let tags_field = schema.field_with_name("tags").unwrap(); + assert!(matches!(tags_field.data_type(), DataType::List(_))); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_basic_query() -> Result<()> { + let result = query_json_array_str( + r#"[{"a": 1, "b": "hello"}, {"a": 2, "b": "world"}, {"a": 3, "b": "test"}]"#, + "SELECT a, b FROM test_table ORDER BY a", + ) + .await?; + + assert_snapshot!(result, @r" + +---+-------+ + | a | b | + +---+-------+ + | 1 | hello | + | 2 | world | + | 3 | test | + +---+-------+ + "); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_with_nulls() -> Result<()> { + let result = query_json_array_str( + r#"[{"id": 1, "name": "Alice"}, {"id": 2, "name": null}, {"id": 3, "name": "Charlie"}]"#, + "SELECT id, name FROM test_table ORDER BY id", + ) + .await?; + + assert_snapshot!(result, @r" + +----+---------+ + | id | name | + +----+---------+ + | 1 | Alice | + | 2 | | + | 3 | Charlie | + +----+---------+ + "); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_unnest() -> Result<()> { + let result = query_json_array_str( + r#"[{"id": 1, "values": [10, 20, 30]}, {"id": 2, "values": [40, 50]}]"#, + "SELECT id, unnest(values) as value FROM test_table ORDER BY id, value", + ) + .await?; + + assert_snapshot!(result, @r" + +----+-------+ + | id | value | + +----+-------+ + | 1 | 10 | + | 1 | 20 | + | 1 | 30 | + | 2 | 40 | + | 2 | 50 | + +----+-------+ + "); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_unnest_struct() -> Result<()> { + let result = query_json_array_str( + r#"[{"id": 1, "orders": [{"product": "A", "qty": 2}, {"product": "B", "qty": 3}]}, {"id": 2, "orders": [{"product": "C", "qty": 1}]}]"#, + "SELECT id, unnest(orders)['product'] as product, unnest(orders)['qty'] as qty FROM test_table ORDER BY id, product", + ) + .await?; + + assert_snapshot!(result, @r" + +----+---------+-----+ + | id | product | qty | + +----+---------+-----+ + | 1 | A | 2 | + | 1 | B | 3 | + | 2 | C | 1 | + +----+---------+-----+ + "); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_nested_struct_access() -> Result<()> { + let result = query_json_array_str( + r#"[{"id": 1, "dept": {"name": "Engineering", "head": "Alice"}}, {"id": 2, "dept": {"name": "Sales", "head": "Bob"}}]"#, + "SELECT id, dept['name'] as dept_name, dept['head'] as head FROM test_table ORDER BY id", + ) + .await?; + + assert_snapshot!(result, @r" + +----+-------------+-------+ + | id | dept_name | head | + +----+-------------+-------+ + | 1 | Engineering | Alice | + | 2 | Sales | Bob | + +----+-------------+-------+ + "); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_with_compression() -> Result<()> { + use flate2::Compression; + use flate2::write::GzEncoder; + use std::io::Write; + + let tmp_dir = tempfile::TempDir::new()?; + let path = tmp_dir.path().join("array.json.gz"); + let path = path.to_string_lossy().to_string(); + + let file = std::fs::File::create(&path)?; + let mut encoder = GzEncoder::new(file, Compression::default()); + encoder.write_all( + r#"[{"a": 1, "b": "hello"}, {"a": 2, "b": "world"}]"#.as_bytes(), + )?; + encoder.finish()?; + + let ctx = SessionContext::new(); + let options = JsonReadOptions::default() + .newline_delimited(false) + .file_compression_type(FileCompressionType::GZIP) + .file_extension(".json.gz"); + + ctx.register_json("test_table", &path, options).await?; + let result = ctx + .sql("SELECT a, b FROM test_table ORDER BY a") + .await? + .collect() + .await?; + + assert_snapshot!(batches_to_string(&result), @r" + +---+-------+ + | a | b | + +---+-------+ + | 1 | hello | + | 2 | world | + +---+-------+ + "); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_list_of_structs() -> Result<()> { + let batches = query_json_array( + r#"[{"id": 1, "items": [{"name": "x", "price": 10.5}]}, {"id": 2, "items": []}]"#, + "SELECT id, items FROM test_table ORDER BY id", + ) + .await?; + + assert_eq!(1, batches.len()); + assert_eq!(2, batches[0].num_rows()); + Ok(()) + } } diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index 6bbb63f6a17a..b04238ebc9b3 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -73,17 +73,7 @@ pub(crate) mod test_util { .infer_stats(state, &store, file_schema.clone(), &meta) .await?; - let file_groups = vec![ - vec![PartitionedFile { - object_meta: meta, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }] - .into(), - ]; + let file_groups = vec![vec![PartitionedFile::new_from_meta(meta)].into()]; let exec = format .create_physical_plan( diff --git a/datafusion/core/src/datasource/file_format/options.rs b/datafusion/core/src/datasource/file_format/options.rs index 146c5f6f5fd0..bd0ac3608738 100644 --- a/datafusion/core/src/datasource/file_format/options.rs +++ b/datafusion/core/src/datasource/file_format/options.rs @@ -442,14 +442,23 @@ impl<'a> AvroReadOptions<'a> { } } -/// Options that control the reading of Line-delimited JSON files (NDJson) +#[deprecated( + since = "53.0.0", + note = "Use `JsonReadOptions` instead. This alias will be removed in a future version." +)] +#[doc = "Deprecated: Use [`JsonReadOptions`] instead."] +pub type NdJsonReadOptions<'a> = JsonReadOptions<'a>; + +/// Options that control the reading of JSON files. +/// +/// Supports both newline-delimited JSON (NDJSON) and JSON array formats. /// /// Note this structure is supplied when a datasource is created and -/// can not not vary from statement to statement. For settings that +/// can not vary from statement to statement. For settings that /// can vary statement to statement see /// [`ConfigOptions`](crate::config::ConfigOptions). #[derive(Clone)] -pub struct NdJsonReadOptions<'a> { +pub struct JsonReadOptions<'a> { /// The data source schema. pub schema: Option<&'a Schema>, /// Max number of rows to read from JSON files for schema inference if needed. Defaults to `DEFAULT_SCHEMA_INFER_MAX_RECORD`. @@ -465,9 +474,25 @@ pub struct NdJsonReadOptions<'a> { pub infinite: bool, /// Indicates how the file is sorted pub file_sort_order: Vec>, + /// Whether to read as newline-delimited JSON (default: true). + /// + /// When `true` (default), expects newline-delimited JSON (NDJSON): + /// ```text + /// {"key1": 1, "key2": "val"} + /// {"key1": 2, "key2": "vals"} + /// ``` + /// + /// When `false`, expects JSON array format: + /// ```text + /// [ + /// {"key1": 1, "key2": "val"}, + /// {"key1": 2, "key2": "vals"} + /// ] + /// ``` + pub newline_delimited: bool, } -impl Default for NdJsonReadOptions<'_> { +impl Default for JsonReadOptions<'_> { fn default() -> Self { Self { schema: None, @@ -477,11 +502,12 @@ impl Default for NdJsonReadOptions<'_> { file_compression_type: FileCompressionType::UNCOMPRESSED, infinite: false, file_sort_order: vec![], + newline_delimited: true, } } } -impl<'a> NdJsonReadOptions<'a> { +impl<'a> JsonReadOptions<'a> { /// Specify table_partition_cols for partition pruning pub fn table_partition_cols( mut self, @@ -529,6 +555,26 @@ impl<'a> NdJsonReadOptions<'a> { self.schema_infer_max_records = schema_infer_max_records; self } + + /// Set whether to read as newline-delimited JSON. + /// + /// When `true` (default), expects newline-delimited JSON (NDJSON): + /// ```text + /// {"key1": 1, "key2": "val"} + /// {"key1": 2, "key2": "vals"} + /// ``` + /// + /// When `false`, expects JSON array format: + /// ```text + /// [ + /// {"key1": 1, "key2": "val"}, + /// {"key1": 2, "key2": "vals"} + /// ] + /// ``` + pub fn newline_delimited(mut self, newline_delimited: bool) -> Self { + self.newline_delimited = newline_delimited; + self + } } #[async_trait] @@ -654,7 +700,7 @@ impl ReadOptions<'_> for ParquetReadOptions<'_> { } #[async_trait] -impl ReadOptions<'_> for NdJsonReadOptions<'_> { +impl ReadOptions<'_> for JsonReadOptions<'_> { fn to_listing_options( &self, config: &SessionConfig, @@ -663,7 +709,8 @@ impl ReadOptions<'_> for NdJsonReadOptions<'_> { let file_format = JsonFormat::default() .with_options(table_options.json) .with_schema_infer_max_rec(self.schema_infer_max_records) - .with_file_compression_type(self.file_compression_type.to_owned()); + .with_file_compression_type(self.file_compression_type.to_owned()) + .with_newline_delimited(self.newline_delimited); ListingOptions::new(Arc::new(file_format)) .with_file_extension(self.file_extension) diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 47ce519f0128..6a8f7ab99975 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -130,7 +130,9 @@ mod tests { use datafusion_common::test_util::batches_to_string; use datafusion_common::{Result, ScalarValue}; use datafusion_datasource::file_format::FileFormat; - use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig}; + use datafusion_datasource::file_sink_config::{ + FileOutputMode, FileSink, FileSinkConfig, + }; use datafusion_datasource::{ListingTableUrl, PartitionedFile}; use datafusion_datasource_parquet::{ ParquetFormat, ParquetFormatFactory, ParquetSink, @@ -154,8 +156,8 @@ mod tests { use futures::StreamExt; use futures::stream::BoxStream; use insta::assert_snapshot; - use object_store::ObjectMeta; use object_store::local::LocalFileSystem; + use object_store::{CopyOptions, ObjectMeta}; use object_store::{ GetOptions, GetResult, ListResult, MultipartUpload, ObjectStore, PutMultipartOptions, PutOptions, PutPayload, PutResult, path::Path, @@ -163,7 +165,8 @@ mod tests { use parquet::arrow::ParquetRecordBatchStreamBuilder; use parquet::arrow::arrow_reader::ArrowReaderOptions; use parquet::file::metadata::{ - KeyValue, ParquetColumnIndex, ParquetMetaData, ParquetOffsetIndex, + KeyValue, PageIndexPolicy, ParquetColumnIndex, ParquetMetaData, + ParquetOffsetIndex, }; use parquet::file::page_index::column_index::ColumnIndexMetaData; use tokio::fs::File; @@ -308,7 +311,7 @@ mod tests { _payload: PutPayload, _opts: PutOptions, ) -> object_store::Result { - Err(object_store::Error::NotImplemented) + unimplemented!() } async fn put_multipart_opts( @@ -316,7 +319,7 @@ mod tests { _location: &Path, _opts: PutMultipartOptions, ) -> object_store::Result> { - Err(object_store::Error::NotImplemented) + unimplemented!() } async fn get_opts( @@ -328,40 +331,34 @@ mod tests { self.inner.get_opts(location, options).await } - async fn head(&self, _location: &Path) -> object_store::Result { - Err(object_store::Error::NotImplemented) - } - - async fn delete(&self, _location: &Path) -> object_store::Result<()> { - Err(object_store::Error::NotImplemented) + fn delete_stream( + &self, + _locations: BoxStream<'static, object_store::Result>, + ) -> BoxStream<'static, object_store::Result> { + unimplemented!() } fn list( &self, _prefix: Option<&Path>, ) -> BoxStream<'static, object_store::Result> { - Box::pin(futures::stream::once(async { - Err(object_store::Error::NotImplemented) - })) + unimplemented!() } async fn list_with_delimiter( &self, _prefix: Option<&Path>, ) -> object_store::Result { - Err(object_store::Error::NotImplemented) - } - - async fn copy(&self, _from: &Path, _to: &Path) -> object_store::Result<()> { - Err(object_store::Error::NotImplemented) + unimplemented!() } - async fn copy_if_not_exists( + async fn copy_opts( &self, _from: &Path, _to: &Path, + _options: CopyOptions, ) -> object_store::Result<()> { - Err(object_store::Error::NotImplemented) + unimplemented!() } } @@ -815,7 +812,7 @@ mod tests { .schema() .fields() .iter() - .map(|f| format!("{}: {:?}", f.name(), f.data_type())) + .map(|f| format!("{}: {}", f.name(), f.data_type())) .collect(); let y = x.join("\n"); assert_eq!(expected, y); @@ -841,7 +838,7 @@ mod tests { double_col: Float64\n\ date_string_col: Binary\n\ string_col: Binary\n\ - timestamp_col: Timestamp(Nanosecond, None)"; + timestamp_col: Timestamp(ns)"; _run_read_alltypes_plain_parquet(ForceViews::No, no_views).await?; let with_views = "id: Int32\n\ @@ -854,7 +851,7 @@ mod tests { double_col: Float64\n\ date_string_col: BinaryView\n\ string_col: BinaryView\n\ - timestamp_col: Timestamp(Nanosecond, None)"; + timestamp_col: Timestamp(ns)"; _run_read_alltypes_plain_parquet(ForceViews::Yes, with_views).await?; Ok(()) @@ -1103,7 +1100,8 @@ mod tests { let testdata = datafusion_common::test_util::parquet_test_data(); let path = format!("{testdata}/alltypes_tiny_pages.parquet"); let file = File::open(path).await?; - let options = ArrowReaderOptions::new().with_page_index(true); + let options = + ArrowReaderOptions::new().with_page_index_policy(PageIndexPolicy::Required); let builder = ParquetRecordBatchStreamBuilder::new_with_options(file, options.clone()) .await? @@ -1547,6 +1545,7 @@ mod tests { insert_op: InsertOp::Overwrite, keep_partition_by_columns: false, file_extension: "parquet".into(), + file_output_mode: FileOutputMode::Automatic, }; let parquet_sink = Arc::new(ParquetSink::new( file_sink_config, @@ -1638,6 +1637,7 @@ mod tests { insert_op: InsertOp::Overwrite, keep_partition_by_columns: false, file_extension: "parquet".into(), + file_output_mode: FileOutputMode::Automatic, }; let parquet_sink = Arc::new(ParquetSink::new( file_sink_config, @@ -1728,6 +1728,7 @@ mod tests { insert_op: InsertOp::Overwrite, keep_partition_by_columns: false, file_extension: "parquet".into(), + file_output_mode: FileOutputMode::Automatic, }; let parquet_sink = Arc::new(ParquetSink::new( file_sink_config, diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 93d77e10ba23..5dd11739c1f5 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -110,6 +110,7 @@ mod tests { #[cfg(feature = "parquet")] use crate::datasource::file_format::parquet::ParquetFormat; use crate::datasource::listing::table::ListingTableConfigExt; + use crate::execution::options::JsonReadOptions; use crate::prelude::*; use crate::{ datasource::{ @@ -347,7 +348,7 @@ mod tests { let table = ListingTable::try_new(config.clone()).expect("Creating the table"); let ordering_result = - table.try_create_output_ordering(state.execution_props()); + table.try_create_output_ordering(state.execution_props(), &[]); match (expected_result, ordering_result) { (Ok(expected), Ok(result)) => { @@ -808,7 +809,7 @@ mod tests { .register_json( "t", tmp_dir.path().to_str().unwrap(), - NdJsonReadOptions::default() + JsonReadOptions::default() .schema(schema.as_ref()) .file_compression_type(file_compression_type), ) diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index 3ca388af0c4c..f85f15a6d8c6 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -54,7 +54,15 @@ impl TableProviderFactory for ListingTableFactory { cmd: &CreateExternalTable, ) -> Result> { // TODO (https://github.com/apache/datafusion/issues/11600) remove downcast_ref from here. Should file format factory be an extension to session state? - let session_state = state.as_any().downcast_ref::().unwrap(); + let session_state = + state + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion_common::internal_datafusion_err!( + "ListingTableFactory requires SessionState" + ) + })?; let file_format = session_state .get_file_format_factory(cmd.file_type.as_str()) .ok_or(config_datafusion_err!( @@ -63,7 +71,8 @@ impl TableProviderFactory for ListingTableFactory { ))? .create(session_state, &cmd.options)?; - let mut table_path = ListingTableUrl::parse(&cmd.location)?; + let mut table_path = + ListingTableUrl::parse(&cmd.location)?.with_table_ref(cmd.name.clone()); let file_extension = match table_path.is_collection() { // Setting the extension to be empty instead of allowing the default extension seems // odd, but was done to ensure existing behavior isn't modified. It seems like this @@ -545,4 +554,103 @@ mod tests { "Statistics cache should not be pre-warmed when collect_statistics is disabled" ); } + + #[tokio::test] + async fn test_create_with_invalid_session() { + use async_trait::async_trait; + use datafusion_catalog::Session; + use datafusion_common::Result; + use datafusion_common::config::TableOptions; + use datafusion_execution::TaskContext; + use datafusion_execution::config::SessionConfig; + use datafusion_physical_expr::PhysicalExpr; + use datafusion_physical_plan::ExecutionPlan; + use std::any::Any; + use std::collections::HashMap; + use std::sync::Arc; + + // A mock Session that is NOT SessionState + #[derive(Debug)] + struct MockSession; + + #[async_trait] + impl Session for MockSession { + fn session_id(&self) -> &str { + "mock_session" + } + fn config(&self) -> &SessionConfig { + unimplemented!() + } + async fn create_physical_plan( + &self, + _logical_plan: &datafusion_expr::LogicalPlan, + ) -> Result> { + unimplemented!() + } + fn create_physical_expr( + &self, + _expr: datafusion_expr::Expr, + _df_schema: &DFSchema, + ) -> Result> { + unimplemented!() + } + fn scalar_functions( + &self, + ) -> &HashMap> { + unimplemented!() + } + fn aggregate_functions( + &self, + ) -> &HashMap> { + unimplemented!() + } + fn window_functions( + &self, + ) -> &HashMap> { + unimplemented!() + } + fn runtime_env(&self) -> &Arc { + unimplemented!() + } + fn execution_props( + &self, + ) -> &datafusion_expr::execution_props::ExecutionProps { + unimplemented!() + } + fn as_any(&self) -> &dyn Any { + self + } + fn table_options(&self) -> &TableOptions { + unimplemented!() + } + fn table_options_mut(&mut self) -> &mut TableOptions { + unimplemented!() + } + fn task_ctx(&self) -> Arc { + unimplemented!() + } + } + + let factory = ListingTableFactory::new(); + let mock_session = MockSession; + + let name = TableReference::bare("foo"); + let cmd = CreateExternalTable::builder( + name, + "foo.csv".to_string(), + "csv", + Arc::new(DFSchema::empty()), + ) + .build(); + + // This should return an error, not panic + let result = factory.create(&mock_session, &cmd).await; + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .strip_backtrace() + .contains("Internal error: ListingTableFactory requires SessionState") + ); + } } diff --git a/datafusion/core/src/datasource/mod.rs b/datafusion/core/src/datasource/mod.rs index aefda64d3936..32b3b0799dd8 100644 --- a/datafusion/core/src/datasource/mod.rs +++ b/datafusion/core/src/datasource/mod.rs @@ -113,14 +113,7 @@ mod tests { version: None, }; - let partitioned_file = PartitionedFile { - object_meta: meta, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }; + let partitioned_file = PartitionedFile::new_from_meta(meta); let f1 = Field::new("id", DataType::Int32, true); let f2 = Field::new("extra_column", DataType::Utf8, true); @@ -156,10 +149,10 @@ mod tests { &self, _logical_file_schema: SchemaRef, physical_file_schema: SchemaRef, - ) -> Arc { - Arc::new(TestPhysicalExprAdapter { + ) -> Result> { + Ok(Arc::new(TestPhysicalExprAdapter { physical_file_schema, - }) + })) } } diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 0e40ed2df206..82c47b6c7281 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -31,7 +31,7 @@ mod tests { use datafusion_datasource::TableSchema; use datafusion_datasource_csv::CsvFormat; - use object_store::ObjectStore; + use object_store::{ObjectStore, ObjectStoreExt}; use crate::datasource::file_format::FileFormat; use crate::prelude::CsvReadOptions; diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index 8de6a60258f0..b70791c7b239 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -32,7 +32,7 @@ mod tests { use crate::dataframe::DataFrameWriteOptions; use crate::execution::SessionState; - use crate::prelude::{CsvReadOptions, NdJsonReadOptions, SessionContext}; + use crate::prelude::{CsvReadOptions, JsonReadOptions, SessionContext}; use crate::test::partitioned_file_groups; use datafusion_common::Result; use datafusion_common::cast::{as_int32_array, as_int64_array, as_string_array}; @@ -136,7 +136,7 @@ mod tests { .get_ext_with_compression(&file_compression_type) .unwrap(); - let read_options = NdJsonReadOptions::default() + let read_options = JsonReadOptions::default() .file_extension(ext.as_str()) .file_compression_type(file_compression_type.to_owned()); let frame = ctx.read_json(path, read_options).await.unwrap(); @@ -389,7 +389,7 @@ mod tests { let path = format!("{TEST_DATA_BASE}/1.json"); // register json file with the execution context - ctx.register_json("test", path.as_str(), NdJsonReadOptions::default()) + ctx.register_json("test", path.as_str(), JsonReadOptions::default()) .await?; // register a local file system object store for /tmp directory @@ -431,7 +431,7 @@ mod tests { } // register each partition as well as the top level dir - let json_read_option = NdJsonReadOptions::default(); + let json_read_option = JsonReadOptions::default(); ctx.register_json( "part0", &format!("{out_dir}/{part_0_name}"), @@ -511,7 +511,7 @@ mod tests { async fn read_test_data(schema_infer_max_records: usize) -> Result { let ctx = SessionContext::new(); - let options = NdJsonReadOptions { + let options = JsonReadOptions { schema_infer_max_records, ..Default::default() }; @@ -587,7 +587,7 @@ mod tests { .get_ext_with_compression(&file_compression_type) .unwrap(); - let read_option = NdJsonReadOptions::default() + let read_option = JsonReadOptions::default() .file_compression_type(file_compression_type) .file_extension(ext.as_str()); diff --git a/datafusion/core/src/datasource/physical_plan/parquet.rs b/datafusion/core/src/datasource/physical_plan/parquet.rs index 4703b55ecc0d..4c6d915d5bca 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet.rs @@ -38,10 +38,10 @@ mod tests { use crate::prelude::{ParquetReadOptions, SessionConfig, SessionContext}; use crate::test::object_store::local_unpartitioned_file; use arrow::array::{ - ArrayRef, AsArray, Date64Array, Int8Array, Int32Array, Int64Array, StringArray, - StringViewArray, StructArray, TimestampNanosecondArray, + ArrayRef, AsArray, Date64Array, DictionaryArray, Int8Array, Int32Array, + Int64Array, StringArray, StringViewArray, StructArray, TimestampNanosecondArray, }; - use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaBuilder}; + use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaBuilder, UInt16Type}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; use arrow_schema::{SchemaRef, TimeUnit}; @@ -54,7 +54,7 @@ mod tests { use datafusion_datasource::source::DataSourceExec; use datafusion_datasource::file::FileSource; - use datafusion_datasource::{FileRange, PartitionedFile, TableSchema}; + use datafusion_datasource::{PartitionedFile, TableSchema}; use datafusion_datasource_parquet::source::ParquetSource; use datafusion_datasource_parquet::{ DefaultParquetFileReaderFactory, ParquetFileReaderFactory, ParquetFormat, @@ -995,6 +995,7 @@ mod tests { assert_eq!(read, 1, "Expected 1 rows to match the predicate"); assert_eq!(get_value(&metrics, "row_groups_pruned_statistics"), 0); assert_eq!(get_value(&metrics, "page_index_rows_pruned"), 2); + assert_eq!(get_value(&metrics, "page_index_pages_pruned"), 1); assert_eq!(get_value(&metrics, "pushdown_rows_pruned"), 1); // If we filter with a value that is completely out of the range of the data // we prune at the row group level. @@ -1168,10 +1169,16 @@ mod tests { // There are 4 rows pruned in each of batch2, batch3, and // batch4 for a total of 12. batch1 had no pruning as c2 was // filled in as null - let (page_index_pruned, page_index_matched) = + let (page_index_rows_pruned, page_index_rows_matched) = get_pruning_metric(&metrics, "page_index_rows_pruned"); - assert_eq!(page_index_pruned, 12); - assert_eq!(page_index_matched, 6); + assert_eq!(page_index_rows_pruned, 12); + assert_eq!(page_index_rows_matched, 6); + + // each page has 2 rows, so the num of pages is 1/2 the number of rows + let (page_index_pages_pruned, page_index_pages_matched) = + get_pruning_metric(&metrics, "page_index_pages_pruned"); + assert_eq!(page_index_pages_pruned, 6); + assert_eq!(page_index_pages_matched, 3); } #[tokio::test] @@ -1527,14 +1534,7 @@ mod tests { #[tokio::test] async fn parquet_exec_with_range() -> Result<()> { fn file_range(meta: &ObjectMeta, start: i64, end: i64) -> PartitionedFile { - PartitionedFile { - object_meta: meta.clone(), - partition_values: vec![], - range: Some(FileRange { start, end }), - statistics: None, - extensions: None, - metadata_size_hint: None, - } + PartitionedFile::new_from_meta(meta.clone()).with_range(start, end) } async fn assert_parquet_read( @@ -1616,21 +1616,15 @@ mod tests { .await .unwrap(); - let partitioned_file = PartitionedFile { - object_meta: meta, - partition_values: vec![ + let partitioned_file = PartitionedFile::new_from_meta(meta) + .with_partition_values(vec![ ScalarValue::from("2021"), ScalarValue::UInt8(Some(10)), ScalarValue::Dictionary( Box::new(DataType::UInt16), Box::new(ScalarValue::from("26")), ), - ], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }; + ]); let expected_schema = Schema::new(vec![ Field::new("id", DataType::Int32, true), @@ -1711,20 +1705,13 @@ mod tests { .unwrap() .child("invalid.parquet"); - let partitioned_file = PartitionedFile { - object_meta: ObjectMeta { - location, - last_modified: Utc.timestamp_nanos(0), - size: 1337, - e_tag: None, - version: None, - }, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }; + let partitioned_file = PartitionedFile::new_from_meta(ObjectMeta { + location, + last_modified: Utc.timestamp_nanos(0), + size: 1337, + e_tag: None, + version: None, + }); let file_schema = Arc::new(Schema::empty()); let config = FileScanConfigBuilder::new( @@ -1754,6 +1741,7 @@ mod tests { Some(3), Some(4), Some(5), + Some(6), // last page with only one row ])); let batch1 = create_batch(vec![("int", c1.clone())]); @@ -1762,7 +1750,7 @@ mod tests { let rt = RoundTrip::new() .with_predicate(filter) .with_page_index_predicate() - .round_trip(vec![batch1]) + .round_trip(vec![batch1.clone()]) .await; let metrics = rt.parquet_exec.metrics().unwrap(); @@ -1775,14 +1763,40 @@ mod tests { | 5 | +-----+ "); - let (page_index_pruned, page_index_matched) = + let (page_index_rows_pruned, page_index_rows_matched) = get_pruning_metric(&metrics, "page_index_rows_pruned"); - assert_eq!(page_index_pruned, 4); - assert_eq!(page_index_matched, 2); + assert_eq!(page_index_rows_pruned, 5); + assert_eq!(page_index_rows_matched, 2); assert!( get_value(&metrics, "page_index_eval_time") > 0, "no eval time in metrics: {metrics:#?}" ); + + // each page has 2 rows, so the num of pages is 1/2 the number of rows + let (page_index_pages_pruned, page_index_pages_matched) = + get_pruning_metric(&metrics, "page_index_pages_pruned"); + assert_eq!(page_index_pages_pruned, 3); + assert_eq!(page_index_pages_matched, 1); + + // test with a filter that matches the page with one row + let filter = col("int").eq(lit(6_i32)); + let rt = RoundTrip::new() + .with_predicate(filter) + .with_page_index_predicate() + .round_trip(vec![batch1]) + .await; + + let metrics = rt.parquet_exec.metrics().unwrap(); + + let (page_index_rows_pruned, page_index_rows_matched) = + get_pruning_metric(&metrics, "page_index_rows_pruned"); + assert_eq!(page_index_rows_pruned, 6); + assert_eq!(page_index_rows_matched, 1); + + let (page_index_pages_pruned, page_index_pages_matched) = + get_pruning_metric(&metrics, "page_index_pages_pruned"); + assert_eq!(page_index_pages_pruned, 3); + assert_eq!(page_index_pages_matched, 1); } /// Returns a string array with contents: @@ -2249,6 +2263,48 @@ mod tests { Ok(()) } + /// Tests that constant dictionary columns (where min == max in statistics) + /// are correctly handled. This reproduced a bug where the constant value + /// from statistics had type Utf8 but the schema expected Dictionary. + #[tokio::test] + async fn test_constant_dictionary_column_parquet() -> Result<()> { + let tmp_dir = TempDir::new()?; + let path = tmp_dir.path().to_str().unwrap().to_string() + "/test.parquet"; + + // Write parquet with dictionary column where all values are the same + let schema = Arc::new(Schema::new(vec![Field::new( + "status", + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)), + false, + )])); + let status: DictionaryArray = + vec!["active", "active"].into_iter().collect(); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(status)])?; + let file = File::create(&path)?; + let props = WriterProperties::builder() + .set_statistics_enabled(parquet::file::properties::EnabledStatistics::Page) + .build(); + let mut writer = ArrowWriter::try_new(file, schema, Some(props))?; + writer.write(&batch)?; + writer.close()?; + + // Query the constant dictionary column + let ctx = SessionContext::new(); + ctx.register_parquet("t", &path, ParquetReadOptions::default()) + .await?; + let result = ctx.sql("SELECT status FROM t").await?.collect().await?; + + insta::assert_snapshot!(batches_to_string(&result),@r" + +--------+ + | status | + +--------+ + | active | + | active | + +--------+ + "); + Ok(()) + } + fn write_file(file: &String) { let struct_fields = Fields::from(vec![ Field::new("id", DataType::Int64, false), @@ -2376,36 +2432,22 @@ mod tests { ); let config = FileScanConfigBuilder::new(store_url, source) .with_file( - PartitionedFile { - object_meta: ObjectMeta { - location: Path::from(name_1), - last_modified: Utc::now(), - size: total_size_1, - e_tag: None, - version: None, - }, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - } - .with_metadata_size_hint(123), - ) - .with_file(PartitionedFile { - object_meta: ObjectMeta { - location: Path::from(name_2), + PartitionedFile::new_from_meta(ObjectMeta { + location: Path::from(name_1), last_modified: Utc::now(), - size: total_size_2, + size: total_size_1, e_tag: None, version: None, - }, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }) + }) + .with_metadata_size_hint(123), + ) + .with_file(PartitionedFile::new_from_meta(ObjectMeta { + location: Path::from(name_2), + last_modified: Utc::now(), + size: total_size_2, + e_tag: None, + version: None, + })) .build(); let exec = DataSourceExec::from_data_source(config); diff --git a/datafusion/core/src/execution/context/json.rs b/datafusion/core/src/execution/context/json.rs index e9d799400863..f7df2ad7a1cd 100644 --- a/datafusion/core/src/execution/context/json.rs +++ b/datafusion/core/src/execution/context/json.rs @@ -15,13 +15,13 @@ // specific language governing permissions and limitations // under the License. +use super::super::options::ReadOptions; +use super::{DataFilePaths, DataFrame, ExecutionPlan, Result, SessionContext}; +use crate::execution::options::JsonReadOptions; use datafusion_common::TableReference; use datafusion_datasource_json::source::plan_to_json; use std::sync::Arc; -use super::super::options::{NdJsonReadOptions, ReadOptions}; -use super::{DataFilePaths, DataFrame, ExecutionPlan, Result, SessionContext}; - impl SessionContext { /// Creates a [`DataFrame`] for reading an JSON data source. /// @@ -32,7 +32,7 @@ impl SessionContext { pub async fn read_json( &self, table_paths: P, - options: NdJsonReadOptions<'_>, + options: JsonReadOptions<'_>, ) -> Result { self._read_type(table_paths, options).await } @@ -43,7 +43,7 @@ impl SessionContext { &self, table_ref: impl Into, table_path: impl AsRef, - options: NdJsonReadOptions<'_>, + options: JsonReadOptions<'_>, ) -> Result<()> { let listing_options = options .to_listing_options(&self.copied_config(), self.copied_table_options()); diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index a769bb01b435..cdc50167d16c 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -93,9 +93,9 @@ use datafusion_expr::{ logical_plan::{DdlStatement, Statement}, planner::ExprPlanner, }; -use datafusion_optimizer::Analyzer; use datafusion_optimizer::analyzer::type_coercion::TypeCoercion; use datafusion_optimizer::simplify_expressions::ExprSimplifier; +use datafusion_optimizer::{Analyzer, OptimizerContext}; use datafusion_optimizer::{AnalyzerRule, OptimizerRule}; use datafusion_session::SessionStore; @@ -749,12 +749,19 @@ impl SessionContext { ); } } - // Store the unoptimized plan into the session state. Although storing the - // optimized plan or the physical plan would be more efficient, doing so is - // not currently feasible. This is because `now()` would be optimized to a - // constant value, causing each EXECUTE to yield the same result, which is - // incorrect behavior. - self.state.write().store_prepared(name, fields, input)?; + // Optimize the plan without evaluating expressions like now() + let optimizer_context = OptimizerContext::new_with_config_options( + Arc::clone(self.state().config().options()), + ) + .without_query_execution_start_time(); + let plan = self.state().optimizer().optimize( + Arc::unwrap_or_clone(input), + &optimizer_context, + |_1, _2| {}, + )?; + self.state + .write() + .store_prepared(name, fields, Arc::new(plan))?; self.return_empty_dataframe() } LogicalPlan::Statement(Statement::Execute(execute)) => { @@ -1160,20 +1167,20 @@ impl SessionContext { let mut builder = RuntimeEnvBuilder::from_runtime_env(state.runtime_env()); builder = match key { "memory_limit" => { - let memory_limit = Self::parse_memory_limit(value)?; + let memory_limit = Self::parse_capacity_limit(variable, value)?; builder.with_memory_limit(memory_limit, 1.0) } "max_temp_directory_size" => { - let directory_size = Self::parse_memory_limit(value)?; + let directory_size = Self::parse_capacity_limit(variable, value)?; builder.with_max_temp_directory_size(directory_size as u64) } "temp_directory" => builder.with_temp_file_path(value), "metadata_cache_limit" => { - let limit = Self::parse_memory_limit(value)?; + let limit = Self::parse_capacity_limit(variable, value)?; builder.with_metadata_cache_limit(limit) } "list_files_cache_limit" => { - let limit = Self::parse_memory_limit(value)?; + let limit = Self::parse_capacity_limit(variable, value)?; builder.with_object_list_cache_limit(limit) } "list_files_cache_ttl" => { @@ -1245,11 +1252,23 @@ impl SessionContext { /// (1.5 * 1024.0 * 1024.0 * 1024.0) as usize /// ); /// ``` + #[deprecated( + since = "53.0.0", + note = "please use `parse_capacity_limit` function instead." + )] pub fn parse_memory_limit(limit: &str) -> Result { + if limit.trim().is_empty() { + return Err(plan_datafusion_err!("Empty limit value found!")); + } let (number, unit) = limit.split_at(limit.len() - 1); let number: f64 = number.parse().map_err(|_| { plan_datafusion_err!("Failed to parse number from memory limit '{limit}'") })?; + if number.is_sign_negative() || number.is_infinite() { + return Err(plan_datafusion_err!( + "Limit value should be positive finite number" + )); + } match unit { "K" => Ok((number * 1024.0) as usize), @@ -1259,6 +1278,51 @@ impl SessionContext { } } + /// Parse capacity limit from string to number of bytes by allowing units: K, M and G. + /// Supports formats like '1.5G', '100M', '512K' + /// + /// # Examples + /// ``` + /// use datafusion::execution::context::SessionContext; + /// + /// assert_eq!( + /// SessionContext::parse_capacity_limit("datafusion.runtime.memory_limit", "1M").unwrap(), + /// 1024 * 1024 + /// ); + /// assert_eq!( + /// SessionContext::parse_capacity_limit("datafusion.runtime.memory_limit", "1.5G").unwrap(), + /// (1.5 * 1024.0 * 1024.0 * 1024.0) as usize + /// ); + /// ``` + pub fn parse_capacity_limit(config_name: &str, limit: &str) -> Result { + if limit.trim().is_empty() { + return Err(plan_datafusion_err!( + "Empty limit value found for '{config_name}'" + )); + } + let (number, unit) = limit.split_at(limit.len() - 1); + let number: f64 = number.parse().map_err(|_| { + plan_datafusion_err!( + "Failed to parse number from '{config_name}', limit '{limit}'" + ) + })?; + if number.is_sign_negative() || number.is_infinite() { + return Err(plan_datafusion_err!( + "Limit value should be positive finite number for '{config_name}'" + )); + } + + match unit { + "K" => Ok((number * 1024.0) as usize), + "M" => Ok((number * 1024.0 * 1024.0) as usize), + "G" => Ok((number * 1024.0 * 1024.0 * 1024.0) as usize), + _ => plan_err!( + "Unsupported unit '{unit}' in '{config_name}', limit '{limit}'. \ + Unit must be one of: 'K', 'M', 'G'" + ), + } + } + fn parse_duration(duration: &str) -> Result { let mut minutes = None; let mut seconds = None; @@ -1315,7 +1379,7 @@ impl SessionContext { let table = table_ref.table().to_owned(); let maybe_schema = { let state = self.state.read(); - let resolved = state.resolve_table_ref(table_ref); + let resolved = state.resolve_table_ref(table_ref.clone()); state .catalog_list() .catalog(&resolved.catalog) @@ -1327,6 +1391,11 @@ impl SessionContext { && table_provider.table_type() == table_type { schema.deregister_table(&table)?; + if table_type == TableType::Base + && let Some(lfc) = self.runtime_env().cache_manager.get_list_files_cache() + { + lfc.drop_table_entries(&Some(table_ref))?; + } return Ok(true); } @@ -1394,7 +1463,12 @@ impl SessionContext { })?; let state = self.state.read(); - let context = SimplifyContext::new(state.execution_props()); + let context = SimplifyContext::default() + .with_schema(Arc::clone(prepared.plan.schema())) + .with_config_options(Arc::clone(state.config_options())) + .with_query_execution_start_time( + state.execution_props().query_execution_start_time, + ); let simplifier = ExprSimplifier::new(context); // Only allow literals as parameters for now. @@ -2169,7 +2243,7 @@ mod tests { // configure with same memory / disk manager let memory_pool = ctx1.runtime_env().memory_pool.clone(); - let mut reservation = MemoryConsumer::new("test").register(&memory_pool); + let reservation = MemoryConsumer::new("test").register(&memory_pool); reservation.grow(100); let disk_manager = ctx1.runtime_env().disk_manager.clone(); @@ -2742,4 +2816,71 @@ mod tests { assert!(have.is_err()); } } + + #[test] + fn test_parse_memory_limit() { + // Valid memory_limit + for (limit, want) in [ + ("1.5K", (1.5 * 1024.0) as usize), + ("2M", (2f64 * 1024.0 * 1024.0) as usize), + ("1G", (1f64 * 1024.0 * 1024.0 * 1024.0) as usize), + ] { + #[expect(deprecated)] + let have = SessionContext::parse_memory_limit(limit).unwrap(); + assert_eq!(want, have); + } + + // Invalid memory_limit + for limit in [ + "1B", + "1T", + "", + " ", + "XYZG", + "-1G", + "infG", + "-infG", + "G", + "1024B", + "invalid_size", + ] { + #[expect(deprecated)] + let have = SessionContext::parse_memory_limit(limit); + assert!(have.is_err()); + } + } + + #[test] + fn test_parse_capacity_limit() { + const MEMORY_LIMIT: &str = "datafusion.runtime.memory_limit"; + + // Valid capacity_limit + for (limit, want) in [ + ("1.5K", (1.5 * 1024.0) as usize), + ("2M", (2f64 * 1024.0 * 1024.0) as usize), + ("1G", (1f64 * 1024.0 * 1024.0 * 1024.0) as usize), + ] { + let have = SessionContext::parse_capacity_limit(MEMORY_LIMIT, limit).unwrap(); + assert_eq!(want, have); + } + + // Invalid capacity_limit + for limit in [ + "1B", + "1T", + "", + " ", + "XYZG", + "-1G", + "infG", + "-infG", + "G", + "1024B", + "invalid_size", + ] { + let have = SessionContext::parse_capacity_limit(MEMORY_LIMIT, limit); + assert!(have.is_err()); + assert!(have.unwrap_err().to_string().contains(MEMORY_LIMIT)); + } + } } diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 6a9ebcdf5125..9560616c1b6d 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -57,10 +57,8 @@ use datafusion_expr::planner::ExprPlanner; #[cfg(feature = "sql")] use datafusion_expr::planner::{RelationPlanner, TypePlanner}; use datafusion_expr::registry::{FunctionRegistry, SerializerRegistry}; -use datafusion_expr::simplify::SimplifyInfo; -use datafusion_expr::{ - AggregateUDF, Explain, Expr, ExprSchemable, LogicalPlan, ScalarUDF, WindowUDF, -}; +use datafusion_expr::simplify::SimplifyContext; +use datafusion_expr::{AggregateUDF, Explain, Expr, LogicalPlan, ScalarUDF, WindowUDF}; use datafusion_optimizer::simplify_expressions::ExprSimplifier; use datafusion_optimizer::{ Analyzer, AnalyzerRule, Optimizer, OptimizerConfig, OptimizerRule, @@ -744,13 +742,18 @@ impl SessionState { expr: Expr, df_schema: &DFSchema, ) -> datafusion_common::Result> { - let simplifier = - ExprSimplifier::new(SessionSimplifyProvider::new(self, df_schema)); + let config_options = self.config_options(); + let simplify_context = SimplifyContext::default() + .with_schema(Arc::new(df_schema.clone())) + .with_config_options(Arc::clone(config_options)) + .with_query_execution_start_time( + self.execution_props().query_execution_start_time, + ); + let simplifier = ExprSimplifier::new(simplify_context); // apply type coercion here to ensure types match let mut expr = simplifier.coerce(expr, df_schema)?; // rewrite Exprs to functions if necessary - let config_options = self.config_options(); for rewrite in self.analyzer.function_rewrites() { expr = expr .transform_up(|expr| rewrite.rewrite(expr, df_schema, config_options))? @@ -875,11 +878,8 @@ impl SessionState { &self.catalog_list } - /// set the catalog list - pub(crate) fn register_catalog_list( - &mut self, - catalog_list: Arc, - ) { + /// Set the catalog list + pub fn register_catalog_list(&mut self, catalog_list: Arc) { self.catalog_list = catalog_list; } @@ -969,6 +969,7 @@ impl SessionState { /// be used for all values unless explicitly provided. /// /// See example on [`SessionState`] +#[derive(Clone)] pub struct SessionStateBuilder { session_id: Option, analyzer: Option, @@ -1834,12 +1835,20 @@ impl ContextProvider for SessionContextProvider<'_> { .get(name) .cloned() .ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?; - let dummy_schema = DFSchema::empty(); - let simplifier = - ExprSimplifier::new(SessionSimplifyProvider::new(self.state, &dummy_schema)); + let simplify_context = SimplifyContext::default() + .with_config_options(Arc::clone(self.state.config_options())) + .with_query_execution_start_time( + self.state.execution_props().query_execution_start_time, + ); + let simplifier = ExprSimplifier::new(simplify_context); + let schema = DFSchema::empty(); let args = args .into_iter() - .map(|arg| simplifier.simplify(arg)) + .map(|arg| { + simplifier + .coerce(arg, &schema) + .and_then(|e| simplifier.simplify(e)) + }) .collect::>>()?; let provider = tbl_func.create_table_provider(&args)?; @@ -2063,7 +2072,7 @@ impl datafusion_execution::TaskContextProvider for SessionState { } impl OptimizerConfig for SessionState { - fn query_execution_start_time(&self) -> DateTime { + fn query_execution_start_time(&self) -> Option> { self.execution_props.query_execution_start_time } @@ -2115,35 +2124,6 @@ impl QueryPlanner for DefaultQueryPlanner { } } -struct SessionSimplifyProvider<'a> { - state: &'a SessionState, - df_schema: &'a DFSchema, -} - -impl<'a> SessionSimplifyProvider<'a> { - fn new(state: &'a SessionState, df_schema: &'a DFSchema) -> Self { - Self { state, df_schema } - } -} - -impl SimplifyInfo for SessionSimplifyProvider<'_> { - fn is_boolean_type(&self, expr: &Expr) -> datafusion_common::Result { - Ok(expr.get_type(self.df_schema)? == DataType::Boolean) - } - - fn nullable(&self, expr: &Expr) -> datafusion_common::Result { - expr.nullable(self.df_schema) - } - - fn execution_props(&self) -> &ExecutionProps { - self.state.execution_props() - } - - fn get_data_type(&self, expr: &Expr) -> datafusion_common::Result { - expr.get_type(self.df_schema) - } -} - #[derive(Debug)] pub(crate) struct PreparedPlan { /// Data types of the parameters diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index e83934a8e281..349eee5592ab 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -#![deny(clippy::allow_attributes)] #![doc( html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" @@ -1181,8 +1180,56 @@ doc_comment::doctest!( #[cfg(doctest)] doc_comment::doctest!( - "../../../docs/source/library-user-guide/upgrading.md", - library_user_guide_upgrading + "../../../docs/source/library-user-guide/upgrading/46.0.0.md", + library_user_guide_upgrading_46_0_0 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/47.0.0.md", + library_user_guide_upgrading_47_0_0 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/48.0.0.md", + library_user_guide_upgrading_48_0_0 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/48.0.1.md", + library_user_guide_upgrading_48_0_1 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/49.0.0.md", + library_user_guide_upgrading_49_0_0 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/50.0.0.md", + library_user_guide_upgrading_50_0_0 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/51.0.0.md", + library_user_guide_upgrading_51_0_0 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/52.0.0.md", + library_user_guide_upgrading_52_0_0 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/53.0.0.md", + library_user_guide_upgrading_53_0_0 ); #[cfg(doctest)] diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index cc7d534776d7..12406b6c29dd 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -18,12 +18,12 @@ //! Planner for [`LogicalPlan`] to [`ExecutionPlan`] use std::borrow::Cow; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use crate::datasource::file_format::file_type_to_format; use crate::datasource::listing::ListingTableUrl; -use crate::datasource::physical_plan::FileSinkConfig; +use crate::datasource::physical_plan::{FileOutputMode, FileSinkConfig}; use crate::datasource::{DefaultTableSource, source_as_provider}; use crate::error::{DataFusionError, Result}; use crate::execution::context::{ExecutionProps, SessionState}; @@ -39,7 +39,7 @@ use crate::physical_expr::{create_physical_expr, create_physical_exprs}; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::physical_plan::analyze::AnalyzeExec; use crate::physical_plan::explain::ExplainExec; -use crate::physical_plan::filter::FilterExec; +use crate::physical_plan::filter::FilterExecBuilder; use crate::physical_plan::joins::utils as join_utils; use crate::physical_plan::joins::{ CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec, @@ -69,8 +69,8 @@ use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor, }; use datafusion_common::{ - DFSchema, ScalarValue, exec_err, internal_datafusion_err, internal_err, not_impl_err, - plan_err, + DFSchema, DFSchemaRef, ScalarValue, exec_err, internal_datafusion_err, internal_err, + not_impl_err, plan_err, }; use datafusion_common::{ TableReference, assert_eq_or_internal_err, assert_or_internal_err, @@ -84,7 +84,7 @@ use datafusion_expr::expr::{ }; use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; -use datafusion_expr::utils::split_conjunction; +use datafusion_expr::utils::{expr_to_columns, split_conjunction}; use datafusion_expr::{ Analyze, BinaryExpr, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension, FetchType, Filter, JoinType, Operator, RecursiveQuery, SkipType, StringifiedPlan, @@ -157,6 +157,80 @@ pub trait ExtensionPlanner { physical_inputs: &[Arc], session_state: &SessionState, ) -> Result>>; + + /// Create a physical plan for a [`LogicalPlan::TableScan`]. + /// + /// This is useful for planning valid [`TableSource`]s that are not [`TableProvider`]s. + /// + /// Returns: + /// * `Ok(Some(plan))` if the planner knows how to plan the `scan` + /// * `Ok(None)` if the planner does not know how to plan the `scan` and wants to delegate the planning to another [`ExtensionPlanner`] + /// * `Err` if the planner knows how to plan the `scan` but errors while doing so + /// + /// # Example + /// + /// ```rust,ignore + /// use std::sync::Arc; + /// use datafusion::physical_plan::ExecutionPlan; + /// use datafusion::logical_expr::TableScan; + /// use datafusion::execution::context::SessionState; + /// use datafusion::error::Result; + /// use datafusion_physical_planner::{ExtensionPlanner, PhysicalPlanner}; + /// use async_trait::async_trait; + /// + /// // Your custom table source type + /// struct MyCustomTableSource { /* ... */ } + /// + /// // Your custom execution plan + /// struct MyCustomExec { /* ... */ } + /// + /// struct MyExtensionPlanner; + /// + /// #[async_trait] + /// impl ExtensionPlanner for MyExtensionPlanner { + /// async fn plan_extension( + /// &self, + /// _planner: &dyn PhysicalPlanner, + /// _node: &dyn UserDefinedLogicalNode, + /// _logical_inputs: &[&LogicalPlan], + /// _physical_inputs: &[Arc], + /// _session_state: &SessionState, + /// ) -> Result>> { + /// Ok(None) + /// } + /// + /// async fn plan_table_scan( + /// &self, + /// _planner: &dyn PhysicalPlanner, + /// scan: &TableScan, + /// _session_state: &SessionState, + /// ) -> Result>> { + /// // Check if this is your custom table source + /// if scan.source.as_any().is::() { + /// // Create a custom execution plan for your table source + /// let exec = MyCustomExec::new( + /// scan.table_name.clone(), + /// Arc::clone(scan.projected_schema.inner()), + /// ); + /// Ok(Some(Arc::new(exec))) + /// } else { + /// // Return None to let other extension planners handle it + /// Ok(None) + /// } + /// } + /// } + /// ``` + /// + /// [`TableSource`]: datafusion_expr::TableSource + /// [`TableProvider`]: datafusion_catalog::TableProvider + async fn plan_table_scan( + &self, + _planner: &dyn PhysicalPlanner, + _scan: &TableScan, + _session_state: &SessionState, + ) -> Result>> { + Ok(None) + } } /// Default single node physical query planner that converts a @@ -278,7 +352,8 @@ struct LogicalNode<'a> { impl DefaultPhysicalPlanner { /// Create a physical planner that uses `extension_planners` to - /// plan user-defined logical nodes [`LogicalPlan::Extension`]. + /// plan user-defined logical nodes [`LogicalPlan::Extension`] + /// or user-defined table sources in [`LogicalPlan::TableScan`]. /// The planner uses the first [`ExtensionPlanner`] to return a non-`None` /// plan. pub fn with_extension_planners( @@ -287,6 +362,24 @@ impl DefaultPhysicalPlanner { Self { extension_planners } } + fn ensure_schema_matches( + &self, + logical_schema: &DFSchemaRef, + physical_plan: &Arc, + context: &str, + ) -> Result<()> { + if !logical_schema.matches_arrow_schema(&physical_plan.schema()) { + return plan_err!( + "{} created an ExecutionPlan with mismatched schema. \ + LogicalPlan schema: {:?}, ExecutionPlan schema: {:?}", + context, + logical_schema, + physical_plan.schema() + ); + } + Ok(()) + } + /// Create a physical plan from a logical plan async fn create_initial_plan( &self, @@ -455,25 +548,53 @@ impl DefaultPhysicalPlanner { ) -> Result> { let exec_node: Arc = match node { // Leaves (no children) - LogicalPlan::TableScan(TableScan { - source, - projection, - filters, - fetch, - .. - }) => { - let source = source_as_provider(source)?; - // Remove all qualifiers from the scan as the provider - // doesn't know (nor should care) how the relation was - // referred to in the query - let filters = unnormalize_cols(filters.iter().cloned()); - let filters_vec = filters.into_iter().collect::>(); - let opts = ScanArgs::default() - .with_projection(projection.as_deref()) - .with_filters(Some(&filters_vec)) - .with_limit(*fetch); - let res = source.scan_with_args(session_state, opts).await?; - Arc::clone(res.plan()) + LogicalPlan::TableScan(scan) => { + let TableScan { + source, + projection, + filters, + fetch, + projected_schema, + .. + } = scan; + + if let Ok(source) = source_as_provider(source) { + // Remove all qualifiers from the scan as the provider + // doesn't know (nor should care) how the relation was + // referred to in the query + let filters = unnormalize_cols(filters.iter().cloned()); + let filters_vec = filters.into_iter().collect::>(); + let opts = ScanArgs::default() + .with_projection(projection.as_deref()) + .with_filters(Some(&filters_vec)) + .with_limit(*fetch); + let res = source.scan_with_args(session_state, opts).await?; + Arc::clone(res.plan()) + } else { + let mut maybe_plan = None; + for planner in &self.extension_planners { + if maybe_plan.is_some() { + break; + } + + maybe_plan = + planner.plan_table_scan(self, scan, session_state).await?; + } + + let plan = match maybe_plan { + Some(plan) => plan, + None => { + return plan_err!( + "No installed planner was able to plan TableScan for custom TableSource: {:?}", + scan.table_name + ); + } + }; + let context = + format!("Extension planner for table scan {}", scan.table_name); + self.ensure_schema_matches(projected_schema, &plan, &context)?; + plan + } } LogicalPlan::Values(Values { values, schema }) => { let exprs = values @@ -549,8 +670,30 @@ impl DefaultPhysicalPlanner { } }; + // Parse single_file_output option if explicitly set + let file_output_mode = match source_option_tuples + .get("single_file_output") + .map(|v| v.trim()) + { + None => FileOutputMode::Automatic, + Some("true") => FileOutputMode::SingleFile, + Some("false") => FileOutputMode::Directory, + Some(value) => { + return Err(DataFusionError::Configuration(format!( + "provided value for 'single_file_output' was not recognized: \"{value}\"" + ))); + } + }; + + // Filter out sink-related options that are not format options + let format_options: HashMap = source_option_tuples + .iter() + .filter(|(k, _)| k.as_str() != "single_file_output") + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + let sink_format = file_type_to_format(file_type)? - .create(session_state, source_option_tuples)?; + .create(session_state, &format_options)?; // Determine extension based on format extension and compression let file_extension = match sink_format.compression_type() { @@ -571,6 +714,7 @@ impl DefaultPhysicalPlanner { insert_op: InsertOp::Append, keep_partition_by_columns, file_extension, + file_output_mode, }; let ordering = input_exec.properties().output_ordering().cloned(); @@ -613,7 +757,7 @@ impl DefaultPhysicalPlanner { if let Some(provider) = target.as_any().downcast_ref::() { - let filters = extract_dml_filters(input)?; + let filters = extract_dml_filters(input, table_name)?; provider .table_provider .delete_from(session_state, filters) @@ -639,7 +783,7 @@ impl DefaultPhysicalPlanner { { // For UPDATE, the assignments are encoded in the projection of input // We pass the filters and let the provider handle the projection - let filters = extract_dml_filters(input)?; + let filters = extract_dml_filters(input, table_name)?; // Extract assignments from the projection in input plan let assignments = extract_update_assignments(input)?; provider @@ -655,6 +799,30 @@ impl DefaultPhysicalPlanner { ); } } + LogicalPlan::Dml(DmlStatement { + table_name, + target, + op: WriteOp::Truncate, + .. + }) => { + if let Some(provider) = + target.as_any().downcast_ref::() + { + provider + .table_provider + .truncate(session_state) + .await + .map_err(|e| { + e.context(format!( + "TRUNCATE operation on table '{table_name}'" + )) + })? + } else { + return exec_err!( + "Table source can't be downcasted to DefaultTableSource" + ); + } + } LogicalPlan::Window(Window { window_expr, .. }) => { assert_or_internal_err!( !window_expr.is_empty(), @@ -938,8 +1106,12 @@ impl DefaultPhysicalPlanner { input_schema.as_arrow(), )? { PlanAsyncExpr::Sync(PlannedExprResult::Expr(runtime_expr)) => { - FilterExec::try_new(Arc::clone(&runtime_expr[0]), physical_input)? - .with_batch_size(session_state.config().batch_size())? + FilterExecBuilder::new( + Arc::clone(&runtime_expr[0]), + physical_input, + ) + .with_batch_size(session_state.config().batch_size()) + .build()? } PlanAsyncExpr::Async( async_map, @@ -949,16 +1121,17 @@ impl DefaultPhysicalPlanner { async_map.async_exprs, physical_input, )?; - FilterExec::try_new( + FilterExecBuilder::new( Arc::clone(&runtime_expr[0]), Arc::new(async_exec), - )? + ) // project the output columns excluding the async functions // The async functions are always appended to the end of the schema. - .with_projection(Some( - (0..input.schema().fields().len()).collect(), + .apply_projection(Some( + (0..input.schema().fields().len()).collect::>(), ))? - .with_batch_size(session_state.config().batch_size())? + .with_batch_size(session_state.config().batch_size()) + .build()? } _ => { return internal_err!( @@ -1091,6 +1264,7 @@ impl DefaultPhysicalPlanner { filter, join_type, null_equality, + null_aware, schema: join_schema, .. }) => { @@ -1342,7 +1516,7 @@ impl DefaultPhysicalPlanner { // TODO: Allow PWMJ to deal with residual equijoin conditions let join: Arc = if join_on.is_empty() { - if join_filter.is_none() && matches!(join_type, JoinType::Inner) { + if join_filter.is_none() && *join_type == JoinType::Inner { // cross join if there is no join conditions and no join filter set Arc::new(CrossJoinExec::new(physical_left, physical_right)) } else if num_range_filters == 1 @@ -1417,9 +1591,7 @@ impl DefaultPhysicalPlanner { let left_side = side_of(lhs_logical)?; let right_side = side_of(rhs_logical)?; - if matches!(left_side, Side::Both) - || matches!(right_side, Side::Both) - { + if left_side == Side::Both || right_side == Side::Both { return Ok(Arc::new(NestedLoopJoinExec::try_new( physical_left, physical_right, @@ -1487,6 +1659,8 @@ impl DefaultPhysicalPlanner { } else if session_state.config().target_partitions() > 1 && session_state.config().repartition_joins() && prefer_hash_join + && !*null_aware + // Null-aware joins must use CollectLeft { Arc::new(HashJoinExec::try_new( physical_left, @@ -1497,6 +1671,7 @@ impl DefaultPhysicalPlanner { None, PartitionMode::Auto, *null_equality, + *null_aware, )?) } else { Arc::new(HashJoinExec::try_new( @@ -1508,6 +1683,7 @@ impl DefaultPhysicalPlanner { None, PartitionMode::CollectLeft, *null_equality, + *null_aware, )?) }; @@ -1561,20 +1737,9 @@ impl DefaultPhysicalPlanner { ), }?; - // Ensure the ExecutionPlan's schema matches the - // declared logical schema to catch and warn about - // logic errors when creating user defined plans. - if !node.schema().matches_arrow_schema(&plan.schema()) { - return plan_err!( - "Extension planner for {:?} created an ExecutionPlan with mismatched schema. \ - LogicalPlan schema: {:?}, ExecutionPlan schema: {:?}", - node, - node.schema(), - plan.schema() - ); - } else { - plan - } + let context = format!("Extension planner for {node:?}"); + self.ensure_schema_matches(node.schema(), &plan, &context)?; + plan } // Other @@ -1902,24 +2067,149 @@ fn get_physical_expr_pair( } /// Extract filter predicates from a DML input plan (DELETE/UPDATE). -/// Walks the logical plan tree and collects Filter predicates, -/// splitting AND conjunctions into individual expressions. -/// Column qualifiers are stripped so expressions can be evaluated against -/// the TableProvider's schema. /// -fn extract_dml_filters(input: &Arc) -> Result> { +/// Walks the logical plan tree and collects Filter predicates and any filters +/// pushed down into TableScan nodes, splitting AND conjunctions into individual expressions. +/// +/// For UPDATE...FROM queries involving multiple tables, this function only extracts predicates +/// that reference the target table. Filters from source table scans are excluded to prevent +/// incorrect filter semantics. +/// +/// Column qualifiers are stripped so expressions can be evaluated against the TableProvider's +/// schema. Deduplication is performed because filters may appear in both Filter nodes and +/// TableScan.filters when the optimizer performs partial (Inexact) filter pushdown. +/// +/// # Parameters +/// - `input`: The logical plan tree to extract filters from (typically a DELETE or UPDATE plan) +/// - `target`: The target table reference to scope filter extraction (prevents multi-table filter leakage) +/// +/// # Returns +/// A vector of unqualified filter expressions that can be passed to the TableProvider for execution. +/// Returns an empty vector if no applicable filters are found. +/// +fn extract_dml_filters( + input: &Arc, + target: &TableReference, +) -> Result> { let mut filters = Vec::new(); + let mut allowed_refs = vec![target.clone()]; + + // First pass: collect any alias references to the target table + input.apply(|node| { + if let LogicalPlan::SubqueryAlias(alias) = node + // Check if this alias points to the target table + && let LogicalPlan::TableScan(scan) = alias.input.as_ref() + && scan.table_name.resolved_eq(target) + { + allowed_refs.push(TableReference::bare(alias.alias.to_string())); + } + Ok(TreeNodeRecursion::Continue) + })?; input.apply(|node| { - if let LogicalPlan::Filter(filter) = node { - // Split AND predicates into individual expressions - filters.extend(split_conjunction(&filter.predicate).into_iter().cloned()); + match node { + LogicalPlan::Filter(filter) => { + // Split AND predicates into individual expressions + for predicate in split_conjunction(&filter.predicate) { + if predicate_is_on_target_multi(predicate, &allowed_refs)? { + filters.push(predicate.clone()); + } + } + } + LogicalPlan::TableScan(TableScan { + table_name, + filters: scan_filters, + .. + }) => { + // Only extract filters from the target table scan. + // This prevents incorrect filter extraction in UPDATE...FROM scenarios + // where multiple table scans may have filters. + if table_name.resolved_eq(target) { + for filter in scan_filters { + filters.extend(split_conjunction(filter).into_iter().cloned()); + } + } + } + // Plans without filter information + LogicalPlan::EmptyRelation(_) + | LogicalPlan::Values(_) + | LogicalPlan::DescribeTable(_) + | LogicalPlan::Explain(_) + | LogicalPlan::Analyze(_) + | LogicalPlan::Distinct(_) + | LogicalPlan::Extension(_) + | LogicalPlan::Statement(_) + | LogicalPlan::Dml(_) + | LogicalPlan::Ddl(_) + | LogicalPlan::Copy(_) + | LogicalPlan::Unnest(_) + | LogicalPlan::RecursiveQuery(_) => { + // No filters to extract from leaf/meta plans + } + // Plans with inputs (may contain filters in children) + LogicalPlan::Projection(_) + | LogicalPlan::SubqueryAlias(_) + | LogicalPlan::Limit(_) + | LogicalPlan::Sort(_) + | LogicalPlan::Union(_) + | LogicalPlan::Join(_) + | LogicalPlan::Repartition(_) + | LogicalPlan::Aggregate(_) + | LogicalPlan::Window(_) + | LogicalPlan::Subquery(_) => { + // Filter information may appear in child nodes; continue traversal + // to extract filters from Filter/TableScan nodes deeper in the plan + } } Ok(TreeNodeRecursion::Continue) })?; - // Strip table qualifiers from column references - filters.into_iter().map(strip_column_qualifiers).collect() + // Strip qualifiers and deduplicate. This ensures: + // 1. Only target-table predicates are retained from Filter nodes + // 2. Qualifiers stripped for TableProvider compatibility + // 3. Duplicates removed (from Filter nodes + TableScan.filters) + // + // Deduplication is necessary because filters may appear in both Filter nodes + // and TableScan.filters when the optimizer performs partial (Inexact) pushdown. + let mut seen_filters = HashSet::new(); + filters + .into_iter() + .try_fold(Vec::new(), |mut deduped, filter| { + let unqualified = strip_column_qualifiers(filter).map_err(|e| { + e.context(format!( + "Failed to strip column qualifiers for DML filter on table '{target}'" + )) + })?; + if seen_filters.insert(unqualified.clone()) { + deduped.push(unqualified); + } + Ok(deduped) + }) +} + +/// Determine whether a predicate references only columns from the target table +/// or its aliases. +/// +/// Columns may be qualified with the target table name or any of its aliases. +/// Unqualified columns are also accepted as they implicitly belong to the target table. +fn predicate_is_on_target_multi( + expr: &Expr, + allowed_refs: &[TableReference], +) -> Result { + let mut columns = HashSet::new(); + expr_to_columns(expr, &mut columns)?; + + // Short-circuit on first mismatch: returns false if any column references a table not in allowed_refs. + // Columns are accepted if: + // 1. They are unqualified (no relation specified), OR + // 2. Their relation matches one of the allowed table references using resolved equality + Ok(!columns.iter().any(|column| { + column.relation.as_ref().is_some_and(|relation| { + !allowed_refs + .iter() + .any(|allowed| relation.resolved_eq(allowed)) + }) + })) } /// Strip table qualifiers from column references in an expression. @@ -2719,7 +3009,7 @@ impl<'a> OptimizationInvariantChecker<'a> { && !is_allowed_schema_change(previous_schema.as_ref(), plan.schema().as_ref()) { internal_err!( - "PhysicalOptimizer rule '{}' failed. Schema mismatch. Expected original schema: {:?}, got new schema: {:?}", + "PhysicalOptimizer rule '{}' failed. Schema mismatch. Expected original schema: {}, got new schema: {}", self.rule.name(), previous_schema, plan.schema() @@ -2834,7 +3124,9 @@ mod tests { use datafusion_execution::TaskContext; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::builder::subquery_alias; - use datafusion_expr::{LogicalPlanBuilder, UserDefinedLogicalNodeCore, col, lit}; + use datafusion_expr::{ + LogicalPlanBuilder, TableSource, UserDefinedLogicalNodeCore, col, lit, + }; use datafusion_functions_aggregate::count::count_all; use datafusion_functions_aggregate::expr_fn::sum; use datafusion_physical_expr::EquivalenceProperties; @@ -3496,12 +3788,12 @@ mod tests { assert!( stringified_plans .iter() - .any(|p| matches!(p.plan_type, PlanType::FinalLogicalPlan)) + .any(|p| p.plan_type == PlanType::FinalLogicalPlan) ); assert!( stringified_plans .iter() - .any(|p| matches!(p.plan_type, PlanType::InitialPhysicalPlan)) + .any(|p| p.plan_type == PlanType::InitialPhysicalPlan) ); assert!( stringified_plans.iter().any(|p| matches!( @@ -3512,7 +3804,7 @@ mod tests { assert!( stringified_plans .iter() - .any(|p| matches!(p.plan_type, PlanType::FinalPhysicalPlan)) + .any(|p| p.plan_type == PlanType::FinalPhysicalPlan) ); } else { panic!( @@ -3656,13 +3948,15 @@ mod tests { #[derive(Debug)] struct NoOpExecutionPlan { - cache: PlanProperties, + cache: Arc, } impl NoOpExecutionPlan { fn new(schema: SchemaRef) -> Self { let cache = Self::compute_properties(schema); - Self { cache } + Self { + cache: Arc::new(cache), + } } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -3700,7 +3994,7 @@ mod tests { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -3854,7 +4148,7 @@ digraph { fn children(&self) -> Vec<&Arc> { self.0.iter().collect::>() } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { unimplemented!() } fn execute( @@ -3903,7 +4197,7 @@ digraph { fn children(&self) -> Vec<&Arc> { unimplemented!() } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { unimplemented!() } fn execute( @@ -4024,7 +4318,7 @@ digraph { fn children(&self) -> Vec<&Arc> { vec![] } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { unimplemented!() } fn execute( @@ -4356,4 +4650,76 @@ digraph { assert_contains!(&err_str, "field nullability at index"); assert_contains!(&err_str, "field metadata at index"); } + + #[derive(Debug)] + struct MockTableSource { + schema: SchemaRef, + } + + impl TableSource for MockTableSource { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + } + + struct MockTableScanExtensionPlanner; + + #[async_trait] + impl ExtensionPlanner for MockTableScanExtensionPlanner { + async fn plan_extension( + &self, + _planner: &dyn PhysicalPlanner, + _node: &dyn UserDefinedLogicalNode, + _logical_inputs: &[&LogicalPlan], + _physical_inputs: &[Arc], + _session_state: &SessionState, + ) -> Result>> { + Ok(None) + } + + async fn plan_table_scan( + &self, + _planner: &dyn PhysicalPlanner, + scan: &TableScan, + _session_state: &SessionState, + ) -> Result>> { + if scan.source.as_any().is::() { + Ok(Some(Arc::new(EmptyExec::new(Arc::clone( + scan.projected_schema.inner(), + ))))) + } else { + Ok(None) + } + } + } + + #[tokio::test] + async fn test_table_scan_extension_planner() { + let session_state = make_session_state(); + let planner = Arc::new(MockTableScanExtensionPlanner); + let physical_planner = + DefaultPhysicalPlanner::with_extension_planners(vec![planner]); + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + let table_source = Arc::new(MockTableSource { + schema: Arc::clone(&schema), + }); + let logical_plan = LogicalPlanBuilder::scan("test", table_source, None) + .unwrap() + .build() + .unwrap(); + + let plan = physical_planner + .create_physical_plan(&logical_plan, &session_state) + .await + .unwrap(); + + assert_eq!(plan.schema(), schema); + assert!(plan.as_any().is::()); + } } diff --git a/datafusion/core/src/prelude.rs b/datafusion/core/src/prelude.rs index 50e4a2649c92..31d9d7eb471f 100644 --- a/datafusion/core/src/prelude.rs +++ b/datafusion/core/src/prelude.rs @@ -29,7 +29,7 @@ pub use crate::dataframe; pub use crate::dataframe::DataFrame; pub use crate::execution::context::{SQLOptions, SessionConfig, SessionContext}; pub use crate::execution::options::{ - AvroReadOptions, CsvReadOptions, NdJsonReadOptions, ParquetReadOptions, + AvroReadOptions, CsvReadOptions, JsonReadOptions, ParquetReadOptions, }; pub use datafusion_common::Column; diff --git a/datafusion/core/src/test/object_store.rs b/datafusion/core/src/test/object_store.rs index a0438e3d74ab..62c6699f8fcd 100644 --- a/datafusion/core/src/test/object_store.rs +++ b/datafusion/core/src/test/object_store.rs @@ -27,6 +27,7 @@ use crate::{ prelude::SessionContext, }; use futures::{FutureExt, stream::BoxStream}; +use object_store::{CopyOptions, ObjectStoreExt}; use std::{ fmt::{Debug, Display, Formatter}, sync::Arc, @@ -130,39 +131,40 @@ impl ObjectStore for BlockingObjectStore { location: &Path, options: GetOptions, ) -> object_store::Result { - self.inner.get_opts(location, options).await - } - - async fn head(&self, location: &Path) -> object_store::Result { - println!( - "{} received head call for {location}", - BlockingObjectStore::NAME - ); - // Wait until the expected number of concurrent calls is reached, but timeout after 1 second to avoid hanging failing tests. - let wait_result = timeout(Duration::from_secs(1), self.barrier.wait()).await; - match wait_result { - Ok(_) => println!( - "{} barrier reached for {location}", + if options.head { + println!( + "{} received head call for {location}", BlockingObjectStore::NAME - ), - Err(_) => { - let error_message = format!( - "{} barrier wait timed out for {location}", + ); + // Wait until the expected number of concurrent calls is reached, but timeout after 1 second to avoid hanging failing tests. + let wait_result = timeout(Duration::from_secs(1), self.barrier.wait()).await; + match wait_result { + Ok(_) => println!( + "{} barrier reached for {location}", BlockingObjectStore::NAME - ); - log::error!("{error_message}"); - return Err(Error::Generic { - store: BlockingObjectStore::NAME, - source: error_message.into(), - }); + ), + Err(_) => { + let error_message = format!( + "{} barrier wait timed out for {location}", + BlockingObjectStore::NAME + ); + log::error!("{error_message}"); + return Err(Error::Generic { + store: BlockingObjectStore::NAME, + source: error_message.into(), + }); + } } } + // Forward the call to the inner object store. - self.inner.head(location).await + self.inner.get_opts(location, options).await } - - async fn delete(&self, location: &Path) -> object_store::Result<()> { - self.inner.delete(location).await + fn delete_stream( + &self, + locations: BoxStream<'static, object_store::Result>, + ) -> BoxStream<'static, object_store::Result> { + self.inner.delete_stream(locations) } fn list( @@ -179,15 +181,12 @@ impl ObjectStore for BlockingObjectStore { self.inner.list_with_delimiter(prefix).await } - async fn copy(&self, from: &Path, to: &Path) -> object_store::Result<()> { - self.inner.copy(from, to).await - } - - async fn copy_if_not_exists( + async fn copy_opts( &self, from: &Path, to: &Path, + options: CopyOptions, ) -> object_store::Result<()> { - self.inner.copy_if_not_exists(from, to).await + self.inner.copy_opts(from, to, options).await } } diff --git a/datafusion/core/src/test_util/parquet.rs b/datafusion/core/src/test_util/parquet.rs index 44e884c23a68..dba017f83ba1 100644 --- a/datafusion/core/src/test_util/parquet.rs +++ b/datafusion/core/src/test_util/parquet.rs @@ -160,22 +160,13 @@ impl TestParquetFile { .with_table_parquet_options(parquet_options.clone()), ); let scan_config_builder = - FileScanConfigBuilder::new(self.object_store_url.clone(), source).with_file( - PartitionedFile { - object_meta: self.object_meta.clone(), - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }, - ); + FileScanConfigBuilder::new(self.object_store_url.clone(), source) + .with_file(PartitionedFile::new_from_meta(self.object_meta.clone())); let df_schema = Arc::clone(&self.schema).to_dfschema_ref()?; // run coercion on the filters to coerce types etc. - let props = ExecutionProps::new(); - let context = SimplifyContext::new(&props).with_schema(Arc::clone(&df_schema)); + let context = SimplifyContext::default().with_schema(Arc::clone(&df_schema)); if let Some(filter) = maybe_filter { let simplifier = ExprSimplifier::new(context); let filter = simplifier.coerce(filter, &df_schema).unwrap(); diff --git a/datafusion/core/tests/catalog_listing/pruned_partition_list.rs b/datafusion/core/tests/catalog_listing/pruned_partition_list.rs index f4782ee13c24..8f93dc17dbad 100644 --- a/datafusion/core/tests/catalog_listing/pruned_partition_list.rs +++ b/datafusion/core/tests/catalog_listing/pruned_partition_list.rs @@ -19,7 +19,7 @@ use std::sync::Arc; use arrow_schema::DataType; use futures::{FutureExt, StreamExt as _, TryStreamExt as _}; -use object_store::{ObjectStore as _, memory::InMemory, path::Path}; +use object_store::{ObjectStoreExt, memory::InMemory, path::Path}; use datafusion::execution::SessionStateBuilder; use datafusion_catalog_listing::helpers::{ diff --git a/datafusion/core/tests/custom_sources_cases/dml_planning.rs b/datafusion/core/tests/custom_sources_cases/dml_planning.rs index 84cf97710a90..8c4bae5e98b3 100644 --- a/datafusion/core/tests/custom_sources_cases/dml_planning.rs +++ b/datafusion/core/tests/custom_sources_cases/dml_planning.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Tests for DELETE and UPDATE planning to verify filter and assignment extraction. +//! Tests for DELETE, UPDATE, and TRUNCATE planning to verify filter and assignment extraction. use std::any::Any; use std::sync::{Arc, Mutex}; @@ -24,9 +24,13 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; use datafusion::datasource::{TableProvider, TableType}; use datafusion::error::Result; -use datafusion::execution::context::SessionContext; -use datafusion::logical_expr::Expr; +use datafusion::execution::context::{SessionConfig, SessionContext}; +use datafusion::logical_expr::{ + Expr, LogicalPlan, TableProviderFilterPushDown, TableScan, +}; use datafusion_catalog::Session; +use datafusion_common::ScalarValue; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_physical_plan::ExecutionPlan; use datafusion_physical_plan::empty::EmptyExec; @@ -34,6 +38,8 @@ use datafusion_physical_plan::empty::EmptyExec; struct CaptureDeleteProvider { schema: SchemaRef, received_filters: Arc>>>, + filter_pushdown: TableProviderFilterPushDown, + per_filter_pushdown: Option>, } impl CaptureDeleteProvider { @@ -41,6 +47,32 @@ impl CaptureDeleteProvider { Self { schema, received_filters: Arc::new(Mutex::new(None)), + filter_pushdown: TableProviderFilterPushDown::Unsupported, + per_filter_pushdown: None, + } + } + + fn new_with_filter_pushdown( + schema: SchemaRef, + filter_pushdown: TableProviderFilterPushDown, + ) -> Self { + Self { + schema, + received_filters: Arc::new(Mutex::new(None)), + filter_pushdown, + per_filter_pushdown: None, + } + } + + fn new_with_per_filter_pushdown( + schema: SchemaRef, + per_filter_pushdown: Vec, + ) -> Self { + Self { + schema, + received_filters: Arc::new(Mutex::new(None)), + filter_pushdown: TableProviderFilterPushDown::Unsupported, + per_filter_pushdown: Some(per_filter_pushdown), } } @@ -91,14 +123,29 @@ impl TableProvider for CaptureDeleteProvider { Field::new("count", DataType::UInt64, false), ]))))) } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + if let Some(per_filter) = &self.per_filter_pushdown + && per_filter.len() == filters.len() + { + return Ok(per_filter.clone()); + } + + Ok(vec![self.filter_pushdown.clone(); filters.len()]) + } } /// A TableProvider that captures filters and assignments passed to update(). -#[allow(clippy::type_complexity)] +#[expect(clippy::type_complexity)] struct CaptureUpdateProvider { schema: SchemaRef, received_filters: Arc>>>, received_assignments: Arc>>>, + filter_pushdown: TableProviderFilterPushDown, + per_filter_pushdown: Option>, } impl CaptureUpdateProvider { @@ -107,6 +154,21 @@ impl CaptureUpdateProvider { schema, received_filters: Arc::new(Mutex::new(None)), received_assignments: Arc::new(Mutex::new(None)), + filter_pushdown: TableProviderFilterPushDown::Unsupported, + per_filter_pushdown: None, + } + } + + fn new_with_filter_pushdown( + schema: SchemaRef, + filter_pushdown: TableProviderFilterPushDown, + ) -> Self { + Self { + schema, + received_filters: Arc::new(Mutex::new(None)), + received_assignments: Arc::new(Mutex::new(None)), + filter_pushdown, + per_filter_pushdown: None, } } @@ -163,6 +225,79 @@ impl TableProvider for CaptureUpdateProvider { Field::new("count", DataType::UInt64, false), ]))))) } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + if let Some(per_filter) = &self.per_filter_pushdown + && per_filter.len() == filters.len() + { + return Ok(per_filter.clone()); + } + + Ok(vec![self.filter_pushdown.clone(); filters.len()]) + } +} + +/// A TableProvider that captures whether truncate() was called. +struct CaptureTruncateProvider { + schema: SchemaRef, + truncate_called: Arc>, +} + +impl CaptureTruncateProvider { + fn new(schema: SchemaRef) -> Self { + Self { + schema, + truncate_called: Arc::new(Mutex::new(false)), + } + } + + fn was_truncated(&self) -> bool { + *self.truncate_called.lock().unwrap() + } +} + +impl std::fmt::Debug for CaptureTruncateProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CaptureTruncateProvider") + .field("schema", &self.schema) + .finish() + } +} + +#[async_trait] +impl TableProvider for CaptureTruncateProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + _projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + Ok(Arc::new(EmptyExec::new(Arc::clone(&self.schema)))) + } + + async fn truncate(&self, _state: &dyn Session) -> Result> { + *self.truncate_called.lock().unwrap() = true; + + Ok(Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![ + Field::new("count", DataType::UInt64, false), + ]))))) + } } fn test_schema() -> SchemaRef { @@ -246,6 +381,168 @@ async fn test_delete_complex_expr() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_delete_filter_pushdown_extracts_table_scan_filters() -> Result<()> { + let provider = Arc::new(CaptureDeleteProvider::new_with_filter_pushdown( + test_schema(), + TableProviderFilterPushDown::Exact, + )); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + let df = ctx.sql("DELETE FROM t WHERE id = 1").await?; + let optimized_plan = df.clone().into_optimized_plan()?; + + let mut scan_filters = Vec::new(); + optimized_plan.apply(|node| { + if let LogicalPlan::TableScan(TableScan { filters, .. }) = node { + scan_filters.extend(filters.clone()); + } + Ok(TreeNodeRecursion::Continue) + })?; + + assert_eq!(scan_filters.len(), 1); + assert!(scan_filters[0].to_string().contains("id")); + + df.collect().await?; + + let filters = provider + .captured_filters() + .expect("filters should be captured"); + assert_eq!(filters.len(), 1); + assert!(filters[0].to_string().contains("id")); + Ok(()) +} + +#[tokio::test] +async fn test_delete_compound_filters_with_pushdown() -> Result<()> { + let provider = Arc::new(CaptureDeleteProvider::new_with_filter_pushdown( + test_schema(), + TableProviderFilterPushDown::Exact, + )); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + ctx.sql("DELETE FROM t WHERE id = 1 AND status = 'active'") + .await? + .collect() + .await?; + + let filters = provider + .captured_filters() + .expect("filters should be captured"); + // Should receive both filters, not deduplicate valid separate predicates + assert_eq!( + filters.len(), + 2, + "compound filters should not be over-suppressed" + ); + + let filter_strs: Vec = filters.iter().map(|f| f.to_string()).collect(); + assert!( + filter_strs.iter().any(|s| s.contains("id")), + "should contain id filter" + ); + assert!( + filter_strs.iter().any(|s| s.contains("status")), + "should contain status filter" + ); + Ok(()) +} + +#[tokio::test] +async fn test_delete_mixed_filter_locations() -> Result<()> { + // Test mixed-location filters: some in Filter node, some in TableScan.filters + // This happens when provider uses TableProviderFilterPushDown::Inexact, + // meaning it can push down some predicates but not others. + let provider = Arc::new(CaptureDeleteProvider::new_with_filter_pushdown( + test_schema(), + TableProviderFilterPushDown::Inexact, + )); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + // Execute DELETE with compound WHERE clause + ctx.sql("DELETE FROM t WHERE id = 1 AND status = 'active'") + .await? + .collect() + .await?; + + // Verify that both predicates are extracted and passed to delete_from(), + // even though they may be split between Filter node and TableScan.filters + let filters = provider + .captured_filters() + .expect("filters should be captured"); + assert_eq!( + filters.len(), + 2, + "should extract both predicates (union of Filter and TableScan.filters)" + ); + + let filter_strs: Vec = filters.iter().map(|f| f.to_string()).collect(); + assert!( + filter_strs.iter().any(|s| s.contains("id")), + "should contain id filter" + ); + assert!( + filter_strs.iter().any(|s| s.contains("status")), + "should contain status filter" + ); + Ok(()) +} + +#[tokio::test] +async fn test_delete_per_filter_pushdown_mixed_locations() -> Result<()> { + // Force per-filter pushdown decisions to exercise mixed locations in one query. + // First predicate is pushed down (Exact), second stays as residual (Unsupported). + let provider = Arc::new(CaptureDeleteProvider::new_with_per_filter_pushdown( + test_schema(), + vec![ + TableProviderFilterPushDown::Exact, + TableProviderFilterPushDown::Unsupported, + ], + )); + + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + let df = ctx + .sql("DELETE FROM t WHERE id = 1 AND status = 'active'") + .await?; + let optimized_plan = df.clone().into_optimized_plan()?; + + // Only the first predicate should be pushed to TableScan.filters. + let mut scan_filters = Vec::new(); + optimized_plan.apply(|node| { + if let LogicalPlan::TableScan(TableScan { filters, .. }) = node { + scan_filters.extend(filters.clone()); + } + Ok(TreeNodeRecursion::Continue) + })?; + assert_eq!(scan_filters.len(), 1); + assert!(scan_filters[0].to_string().contains("id")); + + // Both predicates should still reach the provider (union + dedup behavior). + df.collect().await?; + + let filters = provider + .captured_filters() + .expect("filters should be captured"); + assert_eq!(filters.len(), 2); + + let filter_strs: Vec = filters.iter().map(|f| f.to_string()).collect(); + assert!( + filter_strs.iter().any(|s| s.contains("id")), + "should contain pushed-down id filter" + ); + assert!( + filter_strs.iter().any(|s| s.contains("status")), + "should contain residual status filter" + ); + + Ok(()) +} + #[tokio::test] async fn test_update_assignments() -> Result<()> { let provider = Arc::new(CaptureUpdateProvider::new(test_schema())); @@ -269,6 +566,102 @@ async fn test_update_assignments() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_update_filter_pushdown_extracts_table_scan_filters() -> Result<()> { + let provider = Arc::new(CaptureUpdateProvider::new_with_filter_pushdown( + test_schema(), + TableProviderFilterPushDown::Exact, + )); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + let df = ctx.sql("UPDATE t SET value = 100 WHERE id = 1").await?; + let optimized_plan = df.clone().into_optimized_plan()?; + + // Verify that the optimizer pushed down the filter into TableScan + let mut scan_filters = Vec::new(); + optimized_plan.apply(|node| { + if let LogicalPlan::TableScan(TableScan { filters, .. }) = node { + scan_filters.extend(filters.clone()); + } + Ok(TreeNodeRecursion::Continue) + })?; + + assert_eq!(scan_filters.len(), 1); + assert!(scan_filters[0].to_string().contains("id")); + + // Execute the UPDATE and verify filters were extracted and passed to update() + df.collect().await?; + + let filters = provider + .captured_filters() + .expect("filters should be captured"); + assert_eq!(filters.len(), 1); + assert!(filters[0].to_string().contains("id")); + Ok(()) +} + +#[tokio::test] +async fn test_update_filter_pushdown_passes_table_scan_filters() -> Result<()> { + let provider = Arc::new(CaptureUpdateProvider::new_with_filter_pushdown( + test_schema(), + TableProviderFilterPushDown::Exact, + )); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + let df = ctx + .sql("UPDATE t SET value = 42 WHERE status = 'ready'") + .await?; + let optimized_plan = df.clone().into_optimized_plan()?; + + let mut scan_filters = Vec::new(); + optimized_plan.apply(|node| { + if let LogicalPlan::TableScan(TableScan { filters, .. }) = node { + scan_filters.extend(filters.clone()); + } + Ok(TreeNodeRecursion::Continue) + })?; + + assert!( + !scan_filters.is_empty(), + "expected filter pushdown to populate TableScan filters" + ); + + df.collect().await?; + + let filters = provider + .captured_filters() + .expect("filters should be captured"); + assert!( + !filters.is_empty(), + "expected filters extracted from TableScan during UPDATE" + ); + Ok(()) +} + +#[tokio::test] +async fn test_truncate_calls_provider() -> Result<()> { + let provider = Arc::new(CaptureTruncateProvider::new(test_schema())); + let config = SessionConfig::new().set( + "datafusion.optimizer.max_passes", + &ScalarValue::UInt64(Some(0)), + ); + + let ctx = SessionContext::new_with_config(config); + + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + ctx.sql("TRUNCATE TABLE t").await?.collect().await?; + + assert!( + provider.was_truncated(), + "truncate() should be called on the TableProvider" + ); + + Ok(()) +} + #[tokio::test] async fn test_unsupported_table_delete() -> Result<()> { let schema = test_schema(); @@ -295,3 +688,132 @@ async fn test_unsupported_table_update() -> Result<()> { assert!(result.is_err() || result.unwrap().collect().await.is_err()); Ok(()) } + +#[tokio::test] +async fn test_delete_target_table_scoping() -> Result<()> { + // Test that DELETE only extracts filters from the target table, + // not from other tables (important for DELETE...FROM safety) + let target_provider = Arc::new(CaptureDeleteProvider::new_with_filter_pushdown( + test_schema(), + TableProviderFilterPushDown::Exact, + )); + let ctx = SessionContext::new(); + ctx.register_table( + "target_t", + Arc::clone(&target_provider) as Arc, + )?; + + // For now, we test single-table DELETE + // and validate that the scoping logic is correct + let df = ctx.sql("DELETE FROM target_t WHERE id > 5").await?; + df.collect().await?; + + let filters = target_provider + .captured_filters() + .expect("filters should be captured"); + assert_eq!(filters.len(), 1); + assert!( + filters[0].to_string().contains("id"), + "Filter should be for id column" + ); + assert!( + filters[0].to_string().contains("5"), + "Filter should contain the value 5" + ); + Ok(()) +} + +#[tokio::test] +async fn test_update_from_drops_non_target_predicates() -> Result<()> { + // UPDATE ... FROM is currently not working + // TODO fix https://github.com/apache/datafusion/issues/19950 + let target_provider = Arc::new(CaptureUpdateProvider::new_with_filter_pushdown( + test_schema(), + TableProviderFilterPushDown::Exact, + )); + let ctx = SessionContext::new(); + ctx.register_table("t1", Arc::clone(&target_provider) as Arc)?; + + let source_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("status", DataType::Utf8, true), + // t2-only column to avoid false negatives after qualifier stripping + Field::new("src_only", DataType::Utf8, true), + ])); + let source_table = datafusion::datasource::empty::EmptyTable::new(source_schema); + ctx.register_table("t2", Arc::new(source_table))?; + + let result = ctx + .sql( + "UPDATE t1 SET value = 1 FROM t2 \ + WHERE t1.id = t2.id AND t2.src_only = 'active' AND t1.value > 10", + ) + .await; + + // Verify UPDATE ... FROM is rejected with appropriate error + // TODO fix https://github.com/apache/datafusion/issues/19950 + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + err.to_string().contains("UPDATE ... FROM is not supported"), + "Expected 'UPDATE ... FROM is not supported' error, got: {err}" + ); + Ok(()) +} + +#[tokio::test] +async fn test_delete_qualifier_stripping_and_validation() -> Result<()> { + // Test that filter qualifiers are properly stripped and validated + // Unqualified predicates should work fine + let provider = Arc::new(CaptureDeleteProvider::new_with_filter_pushdown( + test_schema(), + TableProviderFilterPushDown::Exact, + )); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + // Execute DELETE with unqualified column reference + // (After parsing, the planner adds qualifiers, but our validation should accept them) + let df = ctx.sql("DELETE FROM t WHERE id = 1").await?; + df.collect().await?; + + let filters = provider + .captured_filters() + .expect("filters should be captured"); + assert!(!filters.is_empty(), "Should have extracted filter"); + + // Verify qualifiers are stripped: check that Column expressions have no qualifier + let has_qualified_column = filters[0] + .exists(|expr| Ok(matches!(expr, Expr::Column(col) if col.relation.is_some())))?; + assert!( + !has_qualified_column, + "Filter should have unqualified columns after stripping" + ); + + // Also verify the string representation doesn't contain table qualifiers + let filter_str = filters[0].to_string(); + assert!( + !filter_str.contains("t.id"), + "Filter should not contain qualified column reference, got: {filter_str}" + ); + assert!( + filter_str.contains("id") || filter_str.contains("1"), + "Filter should reference id column or the value 1, got: {filter_str}" + ); + Ok(()) +} + +#[tokio::test] +async fn test_unsupported_table_truncate() -> Result<()> { + let schema = test_schema(); + let ctx = SessionContext::new(); + + let empty_table = datafusion::datasource::empty::EmptyTable::new(schema); + ctx.register_table("empty_t", Arc::new(empty_table))?; + + let result = ctx.sql("TRUNCATE TABLE empty_t").await; + + assert!(result.is_err() || result.unwrap().collect().await.is_err()); + + Ok(()) +} diff --git a/datafusion/core/tests/custom_sources_cases/mod.rs b/datafusion/core/tests/custom_sources_cases/mod.rs index 8453615c2886..f51d0a1e3653 100644 --- a/datafusion/core/tests/custom_sources_cases/mod.rs +++ b/datafusion/core/tests/custom_sources_cases/mod.rs @@ -79,7 +79,7 @@ struct CustomTableProvider; #[derive(Debug, Clone)] struct CustomExecutionPlan { projection: Option>, - cache: PlanProperties, + cache: Arc, } impl CustomExecutionPlan { @@ -88,7 +88,10 @@ impl CustomExecutionPlan { let schema = project_schema(&schema, projection.as_ref()).expect("projected schema"); let cache = Self::compute_properties(schema); - Self { projection, cache } + Self { + projection, + cache: Arc::new(cache), + } } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -157,7 +160,7 @@ impl ExecutionPlan for CustomExecutionPlan { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -180,10 +183,6 @@ impl ExecutionPlan for CustomExecutionPlan { Ok(Box::pin(TestCustomRecordBatchStream { nb_batch: 1 })) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { if partition.is_some() { return Ok(Statistics::new_unknown(&self.schema())); diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index ca1eaa1f958e..96357d310312 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -29,7 +29,7 @@ use datafusion::logical_expr::TableProviderFilterPushDown; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, - SendableRecordBatchStream, Statistics, + SendableRecordBatchStream, }; use datafusion::prelude::*; use datafusion::scalar::ScalarValue; @@ -62,13 +62,16 @@ fn create_batch(value: i32, num_rows: usize) -> Result { #[derive(Debug)] struct CustomPlan { batches: Vec, - cache: PlanProperties, + cache: Arc, } impl CustomPlan { fn new(schema: SchemaRef, batches: Vec) -> Self { let cache = Self::compute_properties(schema); - Self { batches, cache } + Self { + batches, + cache: Arc::new(cache), + } } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -109,7 +112,7 @@ impl ExecutionPlan for CustomPlan { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -149,12 +152,6 @@ impl ExecutionPlan for CustomPlan { })), ))) } - - fn statistics(&self) -> Result { - // here we could provide more accurate statistics - // but we want to test the filter pushdown not the CBOs - Ok(Statistics::new_unknown(&self.schema())) - } } #[derive(Clone, Debug)] diff --git a/datafusion/core/tests/custom_sources_cases/statistics.rs b/datafusion/core/tests/custom_sources_cases/statistics.rs index 820c2a470b37..03513ec730de 100644 --- a/datafusion/core/tests/custom_sources_cases/statistics.rs +++ b/datafusion/core/tests/custom_sources_cases/statistics.rs @@ -45,7 +45,7 @@ use async_trait::async_trait; struct StatisticsValidation { stats: Statistics, schema: Arc, - cache: PlanProperties, + cache: Arc, } impl StatisticsValidation { @@ -59,7 +59,7 @@ impl StatisticsValidation { Self { stats, schema, - cache, + cache: Arc::new(cache), } } @@ -158,7 +158,7 @@ impl ExecutionPlan for StatisticsValidation { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -181,10 +181,6 @@ impl ExecutionPlan for StatisticsValidation { unimplemented!("This plan only serves for testing statistics") } - fn statistics(&self) -> Result { - Ok(self.stats.clone()) - } - fn partition_statistics(&self, partition: Option) -> Result { if partition.is_some() { Ok(Statistics::new_unknown(&self.schema)) diff --git a/datafusion/core/tests/data/json_array.json b/datafusion/core/tests/data/json_array.json new file mode 100644 index 000000000000..1a8716dbf4be --- /dev/null +++ b/datafusion/core/tests/data/json_array.json @@ -0,0 +1,5 @@ +[ + {"a": 1, "b": "hello"}, + {"a": 2, "b": "world"}, + {"a": 3, "b": "test"} +] diff --git a/datafusion/core/tests/data/json_empty_array.json b/datafusion/core/tests/data/json_empty_array.json new file mode 100644 index 000000000000..fe51488c7066 --- /dev/null +++ b/datafusion/core/tests/data/json_empty_array.json @@ -0,0 +1 @@ +[] diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index c09db371912b..c94ab10a9e72 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -43,6 +43,7 @@ use datafusion_functions_nested::make_array::make_array_udf; use datafusion_functions_window::expr_fn::{first_value, lead, row_number}; use insta::assert_snapshot; use object_store::local::LocalFileSystem; +use rstest::rstest; use std::collections::HashMap; use std::fs; use std::path::Path; @@ -56,9 +57,7 @@ use datafusion::error::Result; use datafusion::execution::context::SessionContext; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::logical_expr::{ColumnarValue, Volatility}; -use datafusion::prelude::{ - CsvReadOptions, JoinType, NdJsonReadOptions, ParquetReadOptions, -}; +use datafusion::prelude::{CsvReadOptions, JoinType, ParquetReadOptions}; use datafusion::test_util::{ parquet_test_data, populate_csv_partitions, register_aggregate_csv, test_table, test_table_with_cache_factory, test_table_with_name, @@ -93,6 +92,7 @@ use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties, displayable}; use datafusion::error::Result as DataFusionResult; +use datafusion::execution::options::JsonReadOptions; use datafusion_functions_window::expr_fn::lag; // Get string representation of the plan @@ -534,7 +534,8 @@ async fn drop_columns_with_nonexistent_columns() -> Result<()> { async fn drop_columns_with_empty_array() -> Result<()> { // build plan using Table API let t = test_table().await?; - let t2 = t.drop_columns(&[])?; + let drop_columns = vec![] as Vec<&str>; + let t2 = t.drop_columns(&drop_columns)?; let plan = t2.logical_plan().clone(); // build query using SQL @@ -549,6 +550,107 @@ async fn drop_columns_with_empty_array() -> Result<()> { Ok(()) } +#[tokio::test] +async fn drop_columns_qualified() -> Result<()> { + // build plan using Table API + let mut t = test_table().await?; + t = t.select_columns(&["c1", "c2", "c11"])?; + let mut t2 = test_table_with_name("another_table").await?; + t2 = t2.select_columns(&["c1", "c2", "c11"])?; + let mut t3 = t.join_on( + t2, + JoinType::Inner, + [col("aggregate_test_100.c1").eq(col("another_table.c1"))], + )?; + t3 = t3.drop_columns(&["another_table.c2", "another_table.c11"])?; + + let plan = t3.logical_plan().clone(); + + let sql = "SELECT aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c11, another_table.c1 FROM (SELECT c1, c2, c11 FROM aggregate_test_100) INNER JOIN (SELECT c1, c2, c11 FROM another_table) ON aggregate_test_100.c1 = another_table.c1"; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx, "aggregate_test_100").await?; + register_aggregate_csv(&ctx, "another_table").await?; + let sql_plan = ctx.sql(sql).await?.into_unoptimized_plan(); + + // the two plans should be identical + assert_same_plan(&plan, &sql_plan); + + Ok(()) +} + +#[tokio::test] +async fn drop_columns_qualified_find_qualified() -> Result<()> { + // build plan using Table API + let mut t = test_table().await?; + t = t.select_columns(&["c1", "c2", "c11"])?; + let mut t2 = test_table_with_name("another_table").await?; + t2 = t2.select_columns(&["c1", "c2", "c11"])?; + let mut t3 = t.join_on( + t2.clone(), + JoinType::Inner, + [col("aggregate_test_100.c1").eq(col("another_table.c1"))], + )?; + t3 = t3.drop_columns(&t2.find_qualified_columns(&["c2", "c11"])?)?; + + let plan = t3.logical_plan().clone(); + + let sql = "SELECT aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c11, another_table.c1 FROM (SELECT c1, c2, c11 FROM aggregate_test_100) INNER JOIN (SELECT c1, c2, c11 FROM another_table) ON aggregate_test_100.c1 = another_table.c1"; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx, "aggregate_test_100").await?; + register_aggregate_csv(&ctx, "another_table").await?; + let sql_plan = ctx.sql(sql).await?.into_unoptimized_plan(); + + // the two plans should be identical + assert_same_plan(&plan, &sql_plan); + + Ok(()) +} + +#[tokio::test] +async fn test_find_qualified_names() -> Result<()> { + let t = test_table().await?; + let column_names = ["c1", "c2", "c3"]; + let columns = t.find_qualified_columns(&column_names)?; + + // Expected results for each column + let binding = TableReference::bare("aggregate_test_100"); + let expected = [ + (Some(&binding), "c1"), + (Some(&binding), "c2"), + (Some(&binding), "c3"), + ]; + + // Verify we got the expected number of results + assert_eq!( + columns.len(), + expected.len(), + "Expected {} columns, got {}", + expected.len(), + columns.len() + ); + + // Iterate over the results and check each one individually + for (i, (actual, expected)) in columns.iter().zip(expected.iter()).enumerate() { + let (actual_table_ref, actual_field_ref) = actual; + let (expected_table_ref, expected_field_name) = expected; + + // Check table reference + assert_eq!( + actual_table_ref, expected_table_ref, + "Column {i}: expected table reference {expected_table_ref:?}, got {actual_table_ref:?}" + ); + + // Check field name + assert_eq!( + actual_field_ref.name(), + *expected_field_name, + "Column {i}: expected field name '{expected_field_name}', got '{actual_field_ref}'" + ); + } + + Ok(()) +} + #[tokio::test] async fn drop_with_quotes() -> Result<()> { // define data with a column name that has a "." in it: @@ -594,7 +696,7 @@ async fn drop_with_periods() -> Result<()> { let ctx = SessionContext::new(); ctx.register_batch("t", batch)?; - let df = ctx.table("t").await?.drop_columns(&["f.c1"])?; + let df = ctx.table("t").await?.drop_columns(&["\"f.c1\""])?; let df_results = df.collect().await?; @@ -2793,7 +2895,7 @@ async fn write_json_with_order() -> Result<()> { ctx.register_json( "data", test_path.to_str().unwrap(), - NdJsonReadOptions::default().schema(&schema), + JsonReadOptions::default().schema(&schema), ) .await?; @@ -4699,7 +4801,7 @@ async fn unnest_with_redundant_columns() -> Result<()> { @r" Projection: shapes.shape_id [shape_id:UInt32] Unnest: lists[shape_id2|depth=1] structs[] [shape_id:UInt32, shape_id2:UInt32;N] - Aggregate: groupBy=[[shapes.shape_id]], aggr=[[array_agg(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { data_type: UInt32, nullable: true });N] + Aggregate: groupBy=[[shapes.shape_id]], aggr=[[array_agg(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(UInt32);N] TableScan: shapes projection=[shape_id] [shape_id:UInt32] " ); @@ -5513,30 +5615,33 @@ async fn test_dataframe_placeholder_like_expression() -> Result<()> { Ok(()) } +#[rstest] +#[case(DataType::Utf8)] +#[case(DataType::LargeUtf8)] +#[case(DataType::Utf8View)] #[tokio::test] -async fn write_partitioned_parquet_results() -> Result<()> { - // create partitioned input file and context - let tmp_dir = TempDir::new()?; - - let ctx = SessionContext::new(); - +async fn write_partitioned_parquet_results(#[case] string_type: DataType) -> Result<()> { // Create an in memory table with schema C1 and C2, both strings let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Utf8, false), - Field::new("c2", DataType::Utf8, false), + Field::new("c1", string_type.clone(), false), + Field::new("c2", string_type.clone(), false), ])); - let record_batch = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(StringArray::from(vec!["abc", "def"])), - Arc::new(StringArray::from(vec!["123", "456"])), - ], - )?; + let columns = [ + Arc::new(StringArray::from(vec!["abc", "def"])) as ArrayRef, + Arc::new(StringArray::from(vec!["123", "456"])) as ArrayRef, + ] + .map(|col| arrow::compute::cast(&col, &string_type).unwrap()) + .to_vec(); + + let record_batch = RecordBatch::try_new(schema.clone(), columns)?; let mem_table = Arc::new(MemTable::try_new(schema, vec![vec![record_batch]])?); // Register the table in the context + // create partitioned input file and context + let tmp_dir = TempDir::new()?; + let ctx = SessionContext::new(); ctx.register_table("test", mem_table)?; let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?); @@ -5563,6 +5668,7 @@ async fn write_partitioned_parquet_results() -> Result<()> { // Check that the c2 column is gone and that c1 is abc. let results = filter_df.collect().await?; + insta::allow_duplicates! { assert_snapshot!( batches_to_string(&results), @r" @@ -5572,7 +5678,7 @@ async fn write_partitioned_parquet_results() -> Result<()> { | abc | +-----+ " - ); + )}; // Read the entire set of parquet files let df = ctx @@ -5585,9 +5691,10 @@ async fn write_partitioned_parquet_results() -> Result<()> { // Check that the df has the entire set of data let results = df.collect().await?; - assert_snapshot!( - batches_to_sort_string(&results), - @r" + insta::allow_duplicates! { + assert_snapshot!( + batches_to_sort_string(&results), + @r" +-----+-----+ | c1 | c2 | +-----+-----+ @@ -5595,7 +5702,8 @@ async fn write_partitioned_parquet_results() -> Result<()> { | def | 456 | +-----+-----+ " - ); + ) + }; Ok(()) } @@ -6213,7 +6321,7 @@ async fn register_non_json_file() { .register_json( "data", "tests/data/test_binary.parquet", - NdJsonReadOptions::default(), + JsonReadOptions::default(), ) .await; assert_contains!( @@ -6426,7 +6534,7 @@ async fn test_fill_null_all_columns() -> Result<()> { async fn test_insert_into_casting_support() -> Result<()> { // Testing case1: // Inserting query schema mismatch: Expected table field 'a' with type Float16, but got 'a' with type Utf8. - // And the cast is not supported from Utf8 to Float16. + // And the cast is not supported from Binary to Float16. // Create a new schema with one field called "a" of type Float16, and setting nullable to false let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float16, false)])); @@ -6437,7 +6545,10 @@ async fn test_insert_into_casting_support() -> Result<()> { let initial_table = Arc::new(MemTable::try_new(schema.clone(), vec![vec![]])?); session_ctx.register_table("t", initial_table.clone())?; - let mut write_df = session_ctx.sql("values ('a123'), ('b456')").await.unwrap(); + let mut write_df = session_ctx + .sql("values (x'a123'), (x'b456')") + .await + .unwrap(); write_df = write_df .clone() @@ -6451,7 +6562,7 @@ async fn test_insert_into_casting_support() -> Result<()> { assert_contains!( e.to_string(), - "Inserting query schema mismatch: Expected table field 'a' with type Float16, but got 'a' with type Utf8." + "Inserting query schema mismatch: Expected table field 'a' with type Float16, but got 'a' with type Binary." ); // Testing case2: diff --git a/datafusion/core/tests/datasource/object_store_access.rs b/datafusion/core/tests/datasource/object_store_access.rs index 561de2152039..30654c687f8d 100644 --- a/datafusion/core/tests/datasource/object_store_access.rs +++ b/datafusion/core/tests/datasource/object_store_access.rs @@ -36,8 +36,9 @@ use insta::assert_snapshot; use object_store::memory::InMemory; use object_store::path::Path; use object_store::{ - GetOptions, GetRange, GetResult, ListResult, MultipartUpload, ObjectMeta, - ObjectStore, PutMultipartOptions, PutOptions, PutPayload, PutResult, + CopyOptions, GetOptions, GetRange, GetResult, ListResult, MultipartUpload, + ObjectMeta, ObjectStore, ObjectStoreExt, PutMultipartOptions, PutOptions, PutPayload, + PutResult, }; use parking_lot::Mutex; use std::fmt; @@ -54,8 +55,8 @@ async fn create_single_csv_file() { @r" RequestCountingObjectStore() Total Requests: 2 - - HEAD path=csv_table.csv - - GET path=csv_table.csv + - GET (opts) path=csv_table.csv head=true + - GET (opts) path=csv_table.csv " ); } @@ -76,7 +77,7 @@ async fn query_single_csv_file() { ------- Object Store Request Summary ------- RequestCountingObjectStore() Total Requests: 2 - - HEAD path=csv_table.csv + - GET (opts) path=csv_table.csv head=true - GET (opts) path=csv_table.csv " ); @@ -91,9 +92,9 @@ async fn create_multi_file_csv_file() { RequestCountingObjectStore() Total Requests: 4 - LIST prefix=data - - GET path=data/file_0.csv - - GET path=data/file_1.csv - - GET path=data/file_2.csv + - GET (opts) path=data/file_0.csv + - GET (opts) path=data/file_1.csv + - GET (opts) path=data/file_2.csv " ); } @@ -351,8 +352,8 @@ async fn create_single_parquet_file_default() { @r" RequestCountingObjectStore() Total Requests: 2 - - HEAD path=parquet_table.parquet - - GET (range) range=0-2994 path=parquet_table.parquet + - GET (opts) path=parquet_table.parquet head=true + - GET (opts) path=parquet_table.parquet range=0-2994 " ); } @@ -370,8 +371,8 @@ async fn create_single_parquet_file_prefetch() { @r" RequestCountingObjectStore() Total Requests: 2 - - HEAD path=parquet_table.parquet - - GET (range) range=1994-2994 path=parquet_table.parquet + - GET (opts) path=parquet_table.parquet head=true + - GET (opts) path=parquet_table.parquet range=1994-2994 " ); } @@ -399,10 +400,10 @@ async fn create_single_parquet_file_too_small_prefetch() { @r" RequestCountingObjectStore() Total Requests: 4 - - HEAD path=parquet_table.parquet - - GET (range) range=2494-2994 path=parquet_table.parquet - - GET (range) range=2264-2986 path=parquet_table.parquet - - GET (range) range=2124-2264 path=parquet_table.parquet + - GET (opts) path=parquet_table.parquet head=true + - GET (opts) path=parquet_table.parquet range=2494-2994 + - GET (opts) path=parquet_table.parquet range=2264-2986 + - GET (opts) path=parquet_table.parquet range=2124-2264 " ); } @@ -431,9 +432,9 @@ async fn create_single_parquet_file_small_prefetch() { @r" RequestCountingObjectStore() Total Requests: 3 - - HEAD path=parquet_table.parquet - - GET (range) range=2254-2994 path=parquet_table.parquet - - GET (range) range=2124-2264 path=parquet_table.parquet + - GET (opts) path=parquet_table.parquet head=true + - GET (opts) path=parquet_table.parquet range=2254-2994 + - GET (opts) path=parquet_table.parquet range=2124-2264 " ); } @@ -455,8 +456,8 @@ async fn create_single_parquet_file_no_prefetch() { @r" RequestCountingObjectStore() Total Requests: 2 - - HEAD path=parquet_table.parquet - - GET (range) range=0-2994 path=parquet_table.parquet + - GET (opts) path=parquet_table.parquet head=true + - GET (opts) path=parquet_table.parquet range=0-2994 " ); } @@ -476,7 +477,7 @@ async fn query_single_parquet_file() { ------- Object Store Request Summary ------- RequestCountingObjectStore() Total Requests: 3 - - HEAD path=parquet_table.parquet + - GET (opts) path=parquet_table.parquet head=true - GET (ranges) path=parquet_table.parquet ranges=4-534,534-1064 - GET (ranges) path=parquet_table.parquet ranges=1064-1594,1594-2124 " @@ -500,7 +501,7 @@ async fn query_single_parquet_file_with_single_predicate() { ------- Object Store Request Summary ------- RequestCountingObjectStore() Total Requests: 2 - - HEAD path=parquet_table.parquet + - GET (opts) path=parquet_table.parquet head=true - GET (ranges) path=parquet_table.parquet ranges=1064-1481,1481-1594,1594-2011,2011-2124 " ); @@ -524,7 +525,7 @@ async fn query_single_parquet_file_multi_row_groups_multiple_predicates() { ------- Object Store Request Summary ------- RequestCountingObjectStore() Total Requests: 3 - - HEAD path=parquet_table.parquet + - GET (opts) path=parquet_table.parquet head=true - GET (ranges) path=parquet_table.parquet ranges=4-421,421-534,534-951,951-1064 - GET (ranges) path=parquet_table.parquet ranges=1064-1481,1481-1594,1594-2011,2011-2124 " @@ -701,7 +702,7 @@ impl Test { let mut buffer = vec![]; let props = parquet::file::properties::WriterProperties::builder() - .set_max_row_group_size(100) + .set_max_row_group_row_count(Some(100)) .build(); let mut writer = parquet::arrow::ArrowWriter::try_new( &mut buffer, @@ -752,11 +753,8 @@ impl Test { /// Details of individual requests made through the [`RequestCountingObjectStore`] #[derive(Clone, Debug)] enum RequestDetails { - Get { path: Path }, GetOpts { path: Path, get_options: GetOptions }, GetRanges { path: Path, ranges: Vec> }, - GetRange { path: Path, range: Range }, - Head { path: Path }, List { prefix: Option }, ListWithDelimiter { prefix: Option }, ListWithOffset { prefix: Option, offset: Path }, @@ -774,9 +772,6 @@ fn display_range(range: &Range) -> impl Display + '_ { impl Display for RequestDetails { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self { - RequestDetails::Get { path } => { - write!(f, "GET path={path}") - } RequestDetails::GetOpts { path, get_options } => { write!(f, "GET (opts) path={path}")?; if let Some(range) = &get_options.range { @@ -814,13 +809,6 @@ impl Display for RequestDetails { } Ok(()) } - RequestDetails::GetRange { path, range } => { - let range = display_range(range); - write!(f, "GET (range) range={range} path={path}") - } - RequestDetails::Head { path } => { - write!(f, "HEAD path={path}") - } RequestDetails::List { prefix } => { write!(f, "LIST")?; if let Some(prefix) = prefix { @@ -893,7 +881,7 @@ impl ObjectStore for RequestCountingObjectStore { _payload: PutPayload, _opts: PutOptions, ) -> object_store::Result { - Err(object_store::Error::NotImplemented) + unimplemented!() } async fn put_multipart_opts( @@ -901,15 +889,7 @@ impl ObjectStore for RequestCountingObjectStore { _location: &Path, _opts: PutMultipartOptions, ) -> object_store::Result> { - Err(object_store::Error::NotImplemented) - } - - async fn get(&self, location: &Path) -> object_store::Result { - let result = self.inner.get(location).await?; - self.requests.lock().push(RequestDetails::Get { - path: location.to_owned(), - }); - Ok(result) + unimplemented!() } async fn get_opts( @@ -925,19 +905,6 @@ impl ObjectStore for RequestCountingObjectStore { Ok(result) } - async fn get_range( - &self, - location: &Path, - range: Range, - ) -> object_store::Result { - let result = self.inner.get_range(location, range.clone()).await?; - self.requests.lock().push(RequestDetails::GetRange { - path: location.to_owned(), - range: range.clone(), - }); - Ok(result) - } - async fn get_ranges( &self, location: &Path, @@ -951,18 +918,6 @@ impl ObjectStore for RequestCountingObjectStore { Ok(result) } - async fn head(&self, location: &Path) -> object_store::Result { - let result = self.inner.head(location).await?; - self.requests.lock().push(RequestDetails::Head { - path: location.to_owned(), - }); - Ok(result) - } - - async fn delete(&self, _location: &Path) -> object_store::Result<()> { - Err(object_store::Error::NotImplemented) - } - fn list( &self, prefix: Option<&Path>, @@ -998,15 +953,19 @@ impl ObjectStore for RequestCountingObjectStore { self.inner.list_with_delimiter(prefix).await } - async fn copy(&self, _from: &Path, _to: &Path) -> object_store::Result<()> { - Err(object_store::Error::NotImplemented) + fn delete_stream( + &self, + _locations: BoxStream<'static, object_store::Result>, + ) -> BoxStream<'static, object_store::Result> { + unimplemented!() } - async fn copy_if_not_exists( + async fn copy_opts( &self, _from: &Path, _to: &Path, + _options: CopyOptions, ) -> object_store::Result<()> { - Err(object_store::Error::NotImplemented) + unimplemented!() } } diff --git a/datafusion/core/tests/execution/coop.rs b/datafusion/core/tests/execution/coop.rs index 27dacf598c2c..e02364a0530c 100644 --- a/datafusion/core/tests/execution/coop.rs +++ b/datafusion/core/tests/execution/coop.rs @@ -24,7 +24,7 @@ use datafusion::physical_expr::aggregate::AggregateExprBuilder; use datafusion::physical_plan; use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_plan::aggregates::{ - AggregateExec, AggregateMode, PhysicalGroupBy, + AggregateExec, AggregateMode, LimitOptions, PhysicalGroupBy, }; use datafusion::physical_plan::execution_plan::Boundedness; use datafusion::prelude::SessionContext; @@ -41,7 +41,6 @@ use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_optimizer::ensure_coop::EnsureCooperative; -use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion_physical_plan::coop::make_cooperative; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec}; @@ -234,6 +233,7 @@ async fn agg_grouped_topk_yields( #[values(false, true)] pretend_infinite: bool, ) -> Result<(), Box> { // build session + let session_ctx = SessionContext::new(); // set up a top-k aggregation @@ -261,7 +261,7 @@ async fn agg_grouped_topk_yields( inf.clone(), inf.schema(), )? - .with_limit(Some(100)), + .with_limit_options(Some(LimitOptions::new(100))), ); query_yields(aggr, session_ctx.task_ctx()).await @@ -425,10 +425,7 @@ async fn filter_reject_all_batches_yields( )); let filtered = Arc::new(FilterExec::try_new(false_predicate, Arc::new(infinite))?); - // Use CoalesceBatchesExec to guarantee each Filter pull always yields an 8192-row batch - let coalesced = Arc::new(CoalesceBatchesExec::new(filtered, 8_192)); - - query_yields(coalesced, session_ctx.task_ctx()).await + query_yields(filtered, session_ctx.task_ctx()).await } #[rstest] @@ -584,17 +581,18 @@ async fn join_yields( let left_keys: Vec> = vec![Arc::new(Column::new("value", 0))]; let right_keys: Vec> = vec![Arc::new(Column::new("value", 0))]; - // Wrap each side in CoalesceBatches + Repartition so they are both hashed into 1 partition - let coalesced_left = - Arc::new(CoalesceBatchesExec::new(Arc::new(infinite_left), 8_192)); - let coalesced_right = - Arc::new(CoalesceBatchesExec::new(Arc::new(infinite_right), 8_192)); - let part_left = Partitioning::Hash(left_keys, 1); let part_right = Partitioning::Hash(right_keys, 1); - let hashed_left = Arc::new(RepartitionExec::try_new(coalesced_left, part_left)?); - let hashed_right = Arc::new(RepartitionExec::try_new(coalesced_right, part_right)?); + // Wrap each side in Repartition so they are both hashed into 1 partition + let hashed_left = Arc::new(RepartitionExec::try_new( + Arc::new(infinite_left), + part_left, + )?); + let hashed_right = Arc::new(RepartitionExec::try_new( + Arc::new(infinite_right), + part_right, + )?); // Build an Inner HashJoinExec → left.value = right.value let join = Arc::new(HashJoinExec::try_new( @@ -609,6 +607,7 @@ async fn join_yields( None, PartitionMode::CollectLeft, NullEquality::NullEqualsNull, + false, )?); query_yields(join, session_ctx.task_ctx()).await @@ -632,17 +631,18 @@ async fn join_agg_yields( let left_keys: Vec> = vec![Arc::new(Column::new("value", 0))]; let right_keys: Vec> = vec![Arc::new(Column::new("value", 0))]; - // Wrap each side in CoalesceBatches + Repartition so they are both hashed into 1 partition - let coalesced_left = - Arc::new(CoalesceBatchesExec::new(Arc::new(infinite_left), 8_192)); - let coalesced_right = - Arc::new(CoalesceBatchesExec::new(Arc::new(infinite_right), 8_192)); - let part_left = Partitioning::Hash(left_keys, 1); let part_right = Partitioning::Hash(right_keys, 1); - let hashed_left = Arc::new(RepartitionExec::try_new(coalesced_left, part_left)?); - let hashed_right = Arc::new(RepartitionExec::try_new(coalesced_right, part_right)?); + // Wrap each side in Repartition so they are both hashed into 1 partition + let hashed_left = Arc::new(RepartitionExec::try_new( + Arc::new(infinite_left), + part_left, + )?); + let hashed_right = Arc::new(RepartitionExec::try_new( + Arc::new(infinite_right), + part_right, + )?); // Build an Inner HashJoinExec → left.value = right.value let join = Arc::new(HashJoinExec::try_new( @@ -657,6 +657,7 @@ async fn join_agg_yields( None, PartitionMode::CollectLeft, NullEquality::NullEqualsNull, + false, )?); // Project only one column (“value” from the left side) because we just want to sum that @@ -722,6 +723,7 @@ async fn hash_join_yields( None, PartitionMode::CollectLeft, NullEquality::NullEqualsNull, + false, )?); query_yields(join, session_ctx.task_ctx()).await @@ -753,9 +755,10 @@ async fn hash_join_without_repartition_and_no_agg( /* filter */ None, &JoinType::Inner, /* output64 */ None, - // Using CollectLeft is fine—just avoid RepartitionExec’s partitioned channels. + // Using CollectLeft is fine—just avoid RepartitionExec's partitioned channels. PartitionMode::CollectLeft, NullEquality::NullEqualsNull, + false, )?); query_yields(join, session_ctx.task_ctx()).await @@ -764,7 +767,7 @@ async fn hash_join_without_repartition_and_no_agg( #[derive(Debug)] enum Yielded { ReadyOrPending, - Err(#[allow(dead_code)] DataFusionError), + Err(#[expect(dead_code)] DataFusionError), Timeout, } @@ -791,9 +794,9 @@ async fn stream_yields( let yielded = select! { result = join_handle => { match result { - Ok(Pending) => Yielded::ReadyOrPending, - Ok(Ready(Ok(_))) => Yielded::ReadyOrPending, - Ok(Ready(Err(e))) => Yielded::Err(e), + Ok(Poll::Pending) => Yielded::ReadyOrPending, + Ok(Poll::Ready(Ok(_))) => Yielded::ReadyOrPending, + Ok(Poll::Ready(Err(e))) => Yielded::Err(e), Err(_) => Yielded::Err(exec_datafusion_err!("join error")), } }, diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index 90c1b96749b3..91dd5de7fcd6 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -24,7 +24,6 @@ use arrow::util::pretty::{pretty_format_batches, pretty_format_columns}; use datafusion::prelude::*; use datafusion_common::{DFSchema, ScalarValue}; use datafusion_expr::ExprFunctionExt; -use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::NullTreatment; use datafusion_expr::simplify::SimplifyContext; use datafusion_functions::core::expr_ext::FieldAccessor; @@ -422,9 +421,7 @@ fn create_simplified_expr_test(expr: Expr, expected_expr: &str) { let df_schema = DFSchema::try_from(batch.schema()).unwrap(); // Simplify the expression first - let props = ExecutionProps::new(); - let simplify_context = - SimplifyContext::new(&props).with_schema(df_schema.clone().into()); + let simplify_context = SimplifyContext::default().with_schema(Arc::new(df_schema)); let simplifier = ExprSimplifier::new(simplify_context).with_max_cycles(10); let simplified = simplifier.simplify(expr).unwrap(); create_expr_test(simplified, expected_expr); diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index a42dfc951da0..02f2503faf22 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -23,16 +23,16 @@ use arrow::array::types::IntervalDayTime; use arrow::array::{ArrayRef, Int32Array}; use arrow::datatypes::{DataType, Field, Schema}; use chrono::{DateTime, TimeZone, Utc}; -use datafusion::{error::Result, execution::context::ExecutionProps, prelude::*}; +use datafusion::{error::Result, prelude::*}; use datafusion_common::ScalarValue; use datafusion_common::cast::as_int32_array; use datafusion_common::{DFSchemaRef, ToDFSchema}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::logical_plan::builder::table_scan_with_filters; -use datafusion_expr::simplify::SimplifyInfo; +use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::{ - Cast, ColumnarValue, ExprSchemable, LogicalPlan, LogicalPlanBuilder, ScalarUDF, - Volatility, table_scan, + Cast, ColumnarValue, ExprSchemable, LogicalPlan, LogicalPlanBuilder, Projection, + ScalarUDF, Volatility, table_scan, }; use datafusion_functions::math; use datafusion_optimizer::optimizer::Optimizer; @@ -40,50 +40,6 @@ use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyExpress use datafusion_optimizer::{OptimizerContext, OptimizerRule}; use std::sync::Arc; -/// In order to simplify expressions, DataFusion must have information -/// about the expressions. -/// -/// You can provide that information using DataFusion [DFSchema] -/// objects or from some other implementation -struct MyInfo { - /// The input schema - schema: DFSchemaRef, - - /// Execution specific details needed for constant evaluation such - /// as the current time for `now()` and [VariableProviders] - execution_props: ExecutionProps, -} - -impl SimplifyInfo for MyInfo { - fn is_boolean_type(&self, expr: &Expr) -> Result { - Ok(matches!( - expr.get_type(self.schema.as_ref())?, - DataType::Boolean - )) - } - - fn nullable(&self, expr: &Expr) -> Result { - expr.nullable(self.schema.as_ref()) - } - - fn execution_props(&self) -> &ExecutionProps { - &self.execution_props - } - - fn get_data_type(&self, expr: &Expr) -> Result { - expr.get_type(self.schema.as_ref()) - } -} - -impl From for MyInfo { - fn from(schema: DFSchemaRef) -> Self { - Self { - schema, - execution_props: ExecutionProps::new(), - } - } -} - /// A schema like: /// /// a: Int32 (possibly with nulls) @@ -132,14 +88,10 @@ fn test_evaluate_with_start_time( expected_expr: Expr, date_time: &DateTime, ) { - let execution_props = - ExecutionProps::new().with_query_execution_start_time(*date_time); - - let info: MyInfo = MyInfo { - schema: schema(), - execution_props, - }; - let simplifier = ExprSimplifier::new(info); + let context = SimplifyContext::default() + .with_schema(schema()) + .with_query_execution_start_time(Some(*date_time)); + let simplifier = ExprSimplifier::new(context); let simplified_expr = simplifier .simplify(input_expr.clone()) .expect("successfully evaluated"); @@ -201,7 +153,9 @@ fn to_timestamp_expr(arg: impl Into) -> Expr { #[test] fn basic() { - let info: MyInfo = schema().into(); + let context = SimplifyContext::default() + .with_schema(schema()) + .with_query_execution_start_time(Some(Utc::now())); // The `Expr` is a core concept in DataFusion, and DataFusion can // help simplify it. @@ -210,21 +164,21 @@ fn basic() { // optimize form `a < 5` automatically let expr = col("a").lt(lit(2i32) + lit(3i32)); - let simplifier = ExprSimplifier::new(info); + let simplifier = ExprSimplifier::new(context); let simplified = simplifier.simplify(expr).unwrap(); assert_eq!(simplified, col("a").lt(lit(5i32))); } #[test] fn fold_and_simplify() { - let info: MyInfo = schema().into(); + let context = SimplifyContext::default().with_schema(schema()); // What will it do with the expression `concat('foo', 'bar') == 'foobar')`? let expr = concat(vec![lit("foo"), lit("bar")]).eq(lit("foobar")); // Since datafusion applies both simplification *and* rewriting // some expressions can be entirely simplified - let simplifier = ExprSimplifier::new(info); + let simplifier = ExprSimplifier::new(context); let simplified = simplifier.simplify(expr).unwrap(); assert_eq!(simplified, lit(true)) } @@ -523,6 +477,72 @@ fn multiple_now() -> Result<()> { Ok(()) } +/// Unwraps an alias expression to get the inner expression +fn unrwap_aliases(expr: &Expr) -> &Expr { + match expr { + Expr::Alias(alias) => unrwap_aliases(&alias.expr), + expr => expr, + } +} + +/// Test that `now()` is simplified to a literal when execution start time is set, +/// but remains as an expression when no execution start time is available. +#[test] +fn now_simplification_with_and_without_start_time() { + let plan = LogicalPlanBuilder::empty(false) + .project(vec![now()]) + .unwrap() + .build() + .unwrap(); + + // Case 1: With execution start time set, now() should be simplified to a literal + { + let time = DateTime::::from_timestamp_nanos(123); + let ctx: OptimizerContext = + OptimizerContext::new().with_query_execution_start_time(time); + let optimizer = SimplifyExpressions {}; + let simplified = optimizer + .rewrite(plan.clone(), &ctx) + .expect("rewrite should succeed") + .data; + let LogicalPlan::Projection(Projection { expr, .. }) = simplified else { + panic!("Expected Projection plan"); + }; + assert_eq!(expr.len(), 1); + let simplified = unrwap_aliases(expr.first().unwrap()); + // Should be a literal timestamp + match simplified { + Expr::Literal(ScalarValue::TimestampNanosecond(Some(ts), _), _) => { + assert_eq!(*ts, time.timestamp_nanos_opt().unwrap()); + } + other => panic!("Expected timestamp literal, got: {other:?}"), + } + } + + // Case 2: Without execution start time, now() should remain as a function call + { + let ctx: OptimizerContext = + OptimizerContext::new().without_query_execution_start_time(); + let optimizer = SimplifyExpressions {}; + let simplified = optimizer + .rewrite(plan, &ctx) + .expect("rewrite should succeed") + .data; + let LogicalPlan::Projection(Projection { expr, .. }) = simplified else { + panic!("Expected Projection plan"); + }; + assert_eq!(expr.len(), 1); + let simplified = unrwap_aliases(expr.first().unwrap()); + // Should still be a now() function call + match simplified { + Expr::ScalarFunction(ScalarFunction { func, .. }) => { + assert_eq!(func.name(), "now"); + } + other => panic!("Expected now() function call, got: {other:?}"), + } + } +} + // ------------------------------ // --- Simplifier tests ----- // ------------------------------ @@ -545,11 +565,8 @@ fn expr_test_schema() -> DFSchemaRef { } fn test_simplify(input_expr: Expr, expected_expr: Expr) { - let info: MyInfo = MyInfo { - schema: expr_test_schema(), - execution_props: ExecutionProps::new(), - }; - let simplifier = ExprSimplifier::new(info); + let context = SimplifyContext::default().with_schema(expr_test_schema()); + let simplifier = ExprSimplifier::new(context); let simplified_expr = simplifier .simplify(input_expr.clone()) .expect("successfully evaluated"); @@ -564,11 +581,10 @@ fn test_simplify_with_cycle_count( expected_expr: Expr, expected_count: u32, ) { - let info: MyInfo = MyInfo { - schema: expr_test_schema(), - execution_props: ExecutionProps::new(), - }; - let simplifier = ExprSimplifier::new(info); + let context = SimplifyContext::default() + .with_schema(expr_test_schema()) + .with_query_execution_start_time(Some(Utc::now())); + let simplifier = ExprSimplifier::new(context); let (simplified_expr, count) = simplifier .simplify_with_cycle_count_transformed(input_expr.clone()) .expect("successfully evaluated"); diff --git a/datafusion/core/tests/fifo/mod.rs b/datafusion/core/tests/fifo/mod.rs index 36cc769417db..3d99cc72fa59 100644 --- a/datafusion/core/tests/fifo/mod.rs +++ b/datafusion/core/tests/fifo/mod.rs @@ -94,7 +94,6 @@ mod unix_test { /// This function creates a writing task for the FIFO file. To verify /// incremental processing, it waits for a signal to continue writing after /// a certain number of lines are written. - #[allow(clippy::disallowed_methods)] fn create_writing_task( file_path: PathBuf, header: String, @@ -105,6 +104,7 @@ mod unix_test { // Timeout for a long period of BrokenPipe error let broken_pipe_timeout = Duration::from_secs(10); // Spawn a new task to write to the FIFO file + #[expect(clippy::disallowed_methods)] tokio::spawn(async move { let mut file = tokio::fs::OpenOptions::new() .write(true) @@ -357,7 +357,7 @@ mod unix_test { (sink_fifo_path.clone(), sink_fifo_path.display()); // Spawn a new thread to read sink EXTERNAL TABLE. - #[allow(clippy::disallowed_methods)] // spawn allowed only in tests + #[expect(clippy::disallowed_methods)] // spawn allowed only in tests tasks.push(spawn_blocking(move || { let file = File::open(sink_fifo_path_thread).unwrap(); let schema = Arc::new(Schema::new(vec![ diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 97d1db5728cf..d64223abdb76 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -554,7 +554,7 @@ async fn verify_ordered_aggregate(frame: &DataFrame, expected_sort: bool) { InputOrderMode::PartiallySorted(_) | InputOrderMode::Sorted )); } else { - assert!(matches!(exec.input_order_mode(), InputOrderMode::Linear)); + assert_eq!(*exec.input_order_mode(), InputOrderMode::Linear); } } Ok(TreeNodeRecursion::Continue) diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs index bf71053d6c85..fe31098622c5 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs @@ -214,7 +214,7 @@ impl GeneratedSessionContextBuilder { /// The generated params for [`SessionContext`] #[derive(Debug)] -#[allow(dead_code)] +#[expect(dead_code)] pub struct SessionContextParams { batch_size: usize, target_partitions: usize, diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs index 0d04e98536f2..7bb6177c3101 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs @@ -182,13 +182,13 @@ impl QueryBuilder { /// Add max columns num in group by(default: 3), for example if it is set to 1, /// the generated sql will group by at most 1 column - #[allow(dead_code)] + #[expect(dead_code)] pub fn with_max_group_by_columns(mut self, max_group_by_columns: usize) -> Self { self.max_group_by_columns = max_group_by_columns; self } - #[allow(dead_code)] + #[expect(dead_code)] pub fn with_min_group_by_columns(mut self, min_group_by_columns: usize) -> Self { self.min_group_by_columns = min_group_by_columns; self @@ -202,7 +202,7 @@ impl QueryBuilder { } /// Add if also test the no grouping aggregation case(default: true) - #[allow(dead_code)] + #[expect(dead_code)] pub fn with_no_grouping(mut self, no_grouping: bool) -> Self { self.no_grouping = no_grouping; self diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index ce422494db10..669b98e39fec 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -849,6 +849,7 @@ impl JoinFuzzTestCase { None, PartitionMode::Partitioned, NullEquality::NullEqualsNothing, + false, ) .unwrap(), ) @@ -1086,7 +1087,7 @@ impl JoinFuzzTestCase { /// Files can be of different sizes /// The method can be useful to read partitions have been saved by `save_partitioned_batches_as_parquet` /// for test debugging purposes - #[allow(dead_code)] + #[expect(dead_code)] async fn load_partitioned_batches_from_parquet( dir: &str, ) -> std::io::Result> { diff --git a/datafusion/core/tests/fuzz_cases/once_exec.rs b/datafusion/core/tests/fuzz_cases/once_exec.rs index 49e2caaa7417..69edf9be1d82 100644 --- a/datafusion/core/tests/fuzz_cases/once_exec.rs +++ b/datafusion/core/tests/fuzz_cases/once_exec.rs @@ -32,7 +32,7 @@ use std::sync::{Arc, Mutex}; pub struct OnceExec { /// the results to send back stream: Mutex>, - cache: PlanProperties, + cache: Arc, } impl Debug for OnceExec { @@ -46,7 +46,7 @@ impl OnceExec { let cache = Self::compute_properties(stream.schema()); Self { stream: Mutex::new(Some(stream)), - cache, + cache: Arc::new(cache), } } @@ -83,7 +83,7 @@ impl ExecutionPlan for OnceExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } diff --git a/datafusion/core/tests/fuzz_cases/pruning.rs b/datafusion/core/tests/fuzz_cases/pruning.rs index 8a84e4c5d181..8ce5207f9119 100644 --- a/datafusion/core/tests/fuzz_cases/pruning.rs +++ b/datafusion/core/tests/fuzz_cases/pruning.rs @@ -31,7 +31,9 @@ use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_plan::{ExecutionPlan, collect, filter::FilterExec}; use itertools::Itertools; -use object_store::{ObjectStore, PutPayload, memory::InMemory, path::Path}; +use object_store::{ + ObjectStore, ObjectStoreExt, PutPayload, memory::InMemory, path::Path, +}; use parquet::{ arrow::ArrowWriter, file::properties::{EnabledStatistics, WriterProperties}, diff --git a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs index c424a314270c..8f3b8ea05324 100644 --- a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs @@ -301,7 +301,7 @@ mod sp_repartition_fuzz_tests { let mut handles = Vec::new(); for seed in seed_start..seed_end { - #[allow(clippy::disallowed_methods)] // spawn allowed only in tests + #[expect(clippy::disallowed_methods)] // spawn allowed only in tests let job = tokio::spawn(run_sort_preserving_repartition_test( make_staggered_batches::(n_row, n_distinct, seed as u64), is_first_roundrobin, diff --git a/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs b/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs index 16481516e0be..d401557e966d 100644 --- a/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs +++ b/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs @@ -278,9 +278,11 @@ async fn run_sort_test_with_limited_memory( let string_item_size = record_batch_memory_size / record_batch_size as usize; - let string_array = Arc::new(StringArray::from_iter_values( - (0..record_batch_size).map(|_| "a".repeat(string_item_size)), - )); + let string_array = + Arc::new(StringArray::from_iter_values(std::iter::repeat_n( + "a".repeat(string_item_size), + record_batch_size as usize, + ))); RecordBatch::try_new( Arc::clone(&schema), @@ -536,9 +538,11 @@ async fn run_test_aggregate_with_high_cardinality( let string_item_size = record_batch_memory_size / record_batch_size as usize; - let string_array = Arc::new(StringArray::from_iter_values( - (0..record_batch_size).map(|_| "a".repeat(string_item_size)), - )); + let string_array = + Arc::new(StringArray::from_iter_values(std::iter::repeat_n( + "a".repeat(string_item_size), + record_batch_size as usize, + ))); RecordBatch::try_new( Arc::clone(&schema), diff --git a/datafusion/core/tests/fuzz_cases/topk_filter_pushdown.rs b/datafusion/core/tests/fuzz_cases/topk_filter_pushdown.rs index 7f994daeaa58..d14afaf1b326 100644 --- a/datafusion/core/tests/fuzz_cases/topk_filter_pushdown.rs +++ b/datafusion/core/tests/fuzz_cases/topk_filter_pushdown.rs @@ -31,7 +31,7 @@ use datafusion_execution::object_store::ObjectStoreUrl; use itertools::Itertools; use object_store::memory::InMemory; use object_store::path::Path; -use object_store::{ObjectStore, PutPayload}; +use object_store::{ObjectStore, ObjectStoreExt, PutPayload}; use parquet::arrow::ArrowWriter; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 1212c081ebe0..82b6d0e4e9d8 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -589,7 +589,7 @@ async fn run_window_test( orderby_columns: Vec<&str>, search_mode: InputOrderMode, ) -> Result<()> { - let is_linear = !matches!(search_mode, Sorted); + let is_linear = search_mode != Sorted; let mut rng = StdRng::seed_from_u64(random_seed); let schema = input1[0].schema(); let session_config = SessionConfig::new().with_batch_size(50); diff --git a/datafusion/core/tests/macro_hygiene/mod.rs b/datafusion/core/tests/macro_hygiene/mod.rs index 48f0103113cf..9fd60cd1f06f 100644 --- a/datafusion/core/tests/macro_hygiene/mod.rs +++ b/datafusion/core/tests/macro_hygiene/mod.rs @@ -73,7 +73,7 @@ mod config_field { #[test] fn test_macro() { #[derive(Debug)] - #[allow(dead_code)] + #[expect(dead_code)] struct E; impl std::fmt::Display for E { @@ -84,7 +84,7 @@ mod config_field { impl std::error::Error for E {} - #[allow(dead_code)] + #[expect(dead_code)] #[derive(Default)] struct S; diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index c28d23ba0602..0076a762106e 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -24,6 +24,7 @@ use std::sync::{Arc, LazyLock}; #[cfg(feature = "extended_tests")] mod memory_limit_validation; mod repartition_mem_limit; +mod union_nullable_spill; use arrow::array::{ArrayRef, DictionaryArray, Int32Array, RecordBatch, StringViewArray}; use arrow::compute::SortOptions; use arrow::datatypes::{Int32Type, SchemaRef}; diff --git a/datafusion/core/tests/memory_limit/union_nullable_spill.rs b/datafusion/core/tests/memory_limit/union_nullable_spill.rs new file mode 100644 index 000000000000..c5ef2387d3cd --- /dev/null +++ b/datafusion/core/tests/memory_limit/union_nullable_spill.rs @@ -0,0 +1,162 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{Array, Int64Array, RecordBatch}; +use arrow::compute::SortOptions; +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::datasource::memory::MemorySourceConfig; +use datafusion_execution::config::SessionConfig; +use datafusion_execution::memory_pool::FairSpillPool; +use datafusion_execution::runtime_env::RuntimeEnvBuilder; +use datafusion_physical_expr::expressions::col; +use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::sorts::sort::sort_batch; +use datafusion_physical_plan::union::UnionExec; +use datafusion_physical_plan::{ExecutionPlan, Partitioning}; +use futures::StreamExt; + +const NUM_BATCHES: usize = 200; +const ROWS_PER_BATCH: usize = 10; + +fn non_nullable_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("key", DataType::Int64, false), + Field::new("val", DataType::Int64, false), + ])) +} + +fn nullable_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("key", DataType::Int64, false), + Field::new("val", DataType::Int64, true), + ])) +} + +fn non_nullable_batches() -> Vec { + (0..NUM_BATCHES) + .map(|i| { + let start = (i * ROWS_PER_BATCH) as i64; + let keys: Vec = (start..start + ROWS_PER_BATCH as i64).collect(); + RecordBatch::try_new( + non_nullable_schema(), + vec![ + Arc::new(Int64Array::from(keys)), + Arc::new(Int64Array::from(vec![0i64; ROWS_PER_BATCH])), + ], + ) + .unwrap() + }) + .collect() +} + +fn nullable_batches() -> Vec { + (0..NUM_BATCHES) + .map(|i| { + let start = (i * ROWS_PER_BATCH) as i64; + let keys: Vec = (start..start + ROWS_PER_BATCH as i64).collect(); + let vals: Vec> = (0..ROWS_PER_BATCH) + .map(|j| if j % 3 == 1 { None } else { Some(j as i64) }) + .collect(); + RecordBatch::try_new( + nullable_schema(), + vec![ + Arc::new(Int64Array::from(keys)), + Arc::new(Int64Array::from(vals)), + ], + ) + .unwrap() + }) + .collect() +} + +fn build_task_ctx(pool_size: usize) -> Arc { + let session_config = SessionConfig::new().with_batch_size(2); + let runtime = RuntimeEnvBuilder::new() + .with_memory_pool(Arc::new(FairSpillPool::new(pool_size))) + .build_arc() + .unwrap(); + Arc::new( + datafusion_execution::TaskContext::default() + .with_session_config(session_config) + .with_runtime(runtime), + ) +} + +/// Exercises spilling through UnionExec -> RepartitionExec where union children +/// have mismatched nullability (one child's `val` is non-nullable, the other's +/// is nullable with NULLs). A tiny FairSpillPool forces all batches to spill. +/// +/// UnionExec returns child streams without schema coercion, so batches from +/// different children carry different per-field nullability into the shared +/// SpillPool. The IPC writer must use the SpillManager's canonical (nullable) +/// schema — not the first batch's schema — so readback batches are valid. +/// +/// Otherwise, sort_batch will panic with +/// `Column 'val' is declared as non-nullable but contains null values` +#[tokio::test] +async fn test_sort_union_repartition_spill_mixed_nullability() { + let non_nullable_exec = MemorySourceConfig::try_new_exec( + &[non_nullable_batches()], + non_nullable_schema(), + None, + ) + .unwrap(); + + let nullable_exec = + MemorySourceConfig::try_new_exec(&[nullable_batches()], nullable_schema(), None) + .unwrap(); + + let union_exec = UnionExec::try_new(vec![non_nullable_exec, nullable_exec]).unwrap(); + assert!(union_exec.schema().field(1).is_nullable()); + + let repartition = Arc::new( + RepartitionExec::try_new(union_exec, Partitioning::RoundRobinBatch(1)).unwrap(), + ); + + let task_ctx = build_task_ctx(200); + let mut stream = repartition.execute(0, task_ctx).unwrap(); + + let sort_expr = LexOrdering::new(vec![PhysicalSortExpr { + expr: col("key", &nullable_schema()).unwrap(), + options: SortOptions::default(), + }]) + .unwrap(); + + let mut total_rows = 0usize; + let mut total_nulls = 0usize; + while let Some(result) = stream.next().await { + let batch = result.unwrap(); + + let batch = sort_batch(&batch, &sort_expr, None).unwrap(); + + total_rows += batch.num_rows(); + total_nulls += batch.column(1).null_count(); + } + + assert_eq!( + total_rows, + NUM_BATCHES * ROWS_PER_BATCH * 2, + "All rows from both UNION branches should be present" + ); + assert!( + total_nulls > 0, + "Expected some null values in output (i.e. nullable batches were processed)" + ); +} diff --git a/datafusion/core/tests/parquet/custom_reader.rs b/datafusion/core/tests/parquet/custom_reader.rs index 31ec6efd1951..ae11fa9a1133 100644 --- a/datafusion/core/tests/parquet/custom_reader.rs +++ b/datafusion/core/tests/parquet/custom_reader.rs @@ -43,7 +43,7 @@ use futures::{FutureExt, TryFutureExt}; use insta::assert_snapshot; use object_store::memory::InMemory; use object_store::path::Path; -use object_store::{ObjectMeta, ObjectStore}; +use object_store::{ObjectMeta, ObjectStore, ObjectStoreExt}; use parquet::arrow::ArrowWriter; use parquet::arrow::arrow_reader::ArrowReaderOptions; use parquet::arrow::async_reader::AsyncFileReader; @@ -69,13 +69,9 @@ async fn route_data_access_ops_to_parquet_file_reader_factory() { store_parquet_in_memory(vec![batch]).await; let file_group = parquet_files_meta .into_iter() - .map(|meta| PartitionedFile { - object_meta: meta, - partition_values: vec![], - range: None, - statistics: None, - extensions: Some(Arc::new(String::from(EXPECTED_USER_DEFINED_METADATA))), - metadata_size_hint: None, + .map(|meta| { + PartitionedFile::new_from_meta(meta) + .with_extensions(Arc::new(String::from(EXPECTED_USER_DEFINED_METADATA))) }) .collect(); diff --git a/datafusion/core/tests/parquet/expr_adapter.rs b/datafusion/core/tests/parquet/expr_adapter.rs index 515422ed750e..efd492ed2780 100644 --- a/datafusion/core/tests/parquet/expr_adapter.rs +++ b/datafusion/core/tests/parquet/expr_adapter.rs @@ -37,7 +37,7 @@ use datafusion_physical_expr_adapter::{ DefaultPhysicalExprAdapter, DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, }; -use object_store::{ObjectStore, memory::InMemory, path::Path}; +use object_store::{ObjectStore, ObjectStoreExt, memory::InMemory, path::Path}; use parquet::arrow::ArrowWriter; async fn write_parquet(batch: RecordBatch, store: Arc, path: &str) { @@ -63,15 +63,15 @@ impl PhysicalExprAdapterFactory for CustomPhysicalExprAdapterFactory { &self, logical_file_schema: SchemaRef, physical_file_schema: SchemaRef, - ) -> Arc { - Arc::new(CustomPhysicalExprAdapter { + ) -> Result> { + Ok(Arc::new(CustomPhysicalExprAdapter { logical_file_schema: Arc::clone(&logical_file_schema), physical_file_schema: Arc::clone(&physical_file_schema), inner: Arc::new(DefaultPhysicalExprAdapter::new( logical_file_schema, physical_file_schema, )), - }) + })) } } diff --git a/datafusion/core/tests/parquet/external_access_plan.rs b/datafusion/core/tests/parquet/external_access_plan.rs index 0c02c8fe523d..9ff8137687c9 100644 --- a/datafusion/core/tests/parquet/external_access_plan.rs +++ b/datafusion/core/tests/parquet/external_access_plan.rs @@ -409,7 +409,7 @@ fn get_test_data() -> TestData { .expect("tempfile creation"); let props = WriterProperties::builder() - .set_max_row_group_size(row_per_group) + .set_max_row_group_row_count(Some(row_per_group)) .build(); let batches = create_data_batch(scenario); diff --git a/datafusion/core/tests/parquet/filter_pushdown.rs b/datafusion/core/tests/parquet/filter_pushdown.rs index e3a191ee9ade..e6266b2c088d 100644 --- a/datafusion/core/tests/parquet/filter_pushdown.rs +++ b/datafusion/core/tests/parquet/filter_pushdown.rs @@ -63,7 +63,7 @@ async fn single_file() { // Set the row group size smaller so can test with fewer rows let props = WriterProperties::builder() - .set_max_row_group_size(1024) + .set_max_row_group_row_count(Some(1024)) .build(); // Only create the parquet file once as it is fairly large @@ -220,7 +220,6 @@ async fn single_file() { } #[tokio::test] -#[allow(dead_code)] async fn single_file_small_data_pages() { let batches = read_parquet_test_data( "tests/data/filter_pushdown/single_file_small_pages.gz.parquet", @@ -231,7 +230,7 @@ async fn single_file_small_data_pages() { // Set a low row count limit to improve page filtering let props = WriterProperties::builder() - .set_max_row_group_size(2048) + .set_max_row_group_row_count(Some(2048)) .set_data_page_row_count_limit(512) .set_write_batch_size(512) .build(); @@ -644,6 +643,22 @@ async fn predicate_cache_pushdown_default() -> datafusion_common::Result<()> { .await } +#[tokio::test] +async fn predicate_cache_stats_issue_19561() -> datafusion_common::Result<()> { + let mut config = SessionConfig::new(); + config.options_mut().execution.parquet.pushdown_filters = true; + // force to get multiple batches to trigger repeated metric compound bug + config.options_mut().execution.batch_size = 1; + let ctx = SessionContext::new_with_config(config); + // The cache is on by default, and used when filter pushdown is enabled + PredicateCacheTest { + expected_inner_records: 8, + expected_records: 4, + } + .run(&ctx) + .await +} + #[tokio::test] async fn predicate_cache_pushdown_default_selections_only() -> datafusion_common::Result<()> { diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 35b5918d9e8b..0535ddd9247d 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -30,6 +30,7 @@ use arrow::{ record_batch::RecordBatch, util::pretty::pretty_format_batches, }; +use arrow_schema::SchemaRef; use chrono::{Datelike, Duration, TimeDelta}; use datafusion::{ datasource::{TableProvider, provider_as_source}, @@ -50,6 +51,7 @@ mod expr_adapter; mod external_access_plan; mod file_statistics; mod filter_pushdown; +mod ordering; mod page_pruning; mod row_group_pruning; mod schema; @@ -109,6 +111,26 @@ struct ContextWithParquet { ctx: SessionContext, } +struct PruningMetric { + total_pruned: usize, + total_matched: usize, + total_fully_matched: usize, +} + +impl PruningMetric { + pub fn total_pruned(&self) -> usize { + self.total_pruned + } + + pub fn total_matched(&self) -> usize { + self.total_matched + } + + pub fn total_fully_matched(&self) -> usize { + self.total_fully_matched + } +} + /// The output of running one of the test cases struct TestOutput { /// The input query SQL @@ -126,8 +148,8 @@ struct TestOutput { impl TestOutput { /// retrieve the value of the named metric, if any fn metric_value(&self, metric_name: &str) -> Option { - if let Some((pruned, _matched)) = self.pruning_metric(metric_name) { - return Some(pruned); + if let Some(pm) = self.pruning_metric(metric_name) { + return Some(pm.total_pruned()); } self.parquet_metrics @@ -140,9 +162,10 @@ impl TestOutput { }) } - fn pruning_metric(&self, metric_name: &str) -> Option<(usize, usize)> { + fn pruning_metric(&self, metric_name: &str) -> Option { let mut total_pruned = 0; let mut total_matched = 0; + let mut total_fully_matched = 0; let mut found = false; for metric in self.parquet_metrics.iter() { @@ -154,12 +177,18 @@ impl TestOutput { { total_pruned += pruning_metrics.pruned(); total_matched += pruning_metrics.matched(); + total_fully_matched += pruning_metrics.fully_matched(); + found = true; } } if found { - Some((total_pruned, total_matched)) + Some(PruningMetric { + total_pruned, + total_matched, + total_fully_matched, + }) } else { None } @@ -171,27 +200,33 @@ impl TestOutput { } /// The number of row_groups pruned / matched by bloom filter - fn row_groups_bloom_filter(&self) -> Option<(usize, usize)> { + fn row_groups_bloom_filter(&self) -> Option { self.pruning_metric("row_groups_pruned_bloom_filter") } /// The number of row_groups matched by statistics fn row_groups_matched_statistics(&self) -> Option { self.pruning_metric("row_groups_pruned_statistics") - .map(|(_pruned, matched)| matched) + .map(|pm| pm.total_matched()) + } + + /// The number of row_groups fully matched by statistics + fn row_groups_fully_matched_statistics(&self) -> Option { + self.pruning_metric("row_groups_pruned_statistics") + .map(|pm| pm.total_fully_matched()) } /// The number of row_groups pruned by statistics fn row_groups_pruned_statistics(&self) -> Option { self.pruning_metric("row_groups_pruned_statistics") - .map(|(pruned, _matched)| pruned) + .map(|pm| pm.total_pruned()) } /// Metric `files_ranges_pruned_statistics` tracks both pruned and matched count, /// for testing purpose, here it only aggregate the `pruned` count. fn files_ranges_pruned_statistics(&self) -> Option { self.pruning_metric("files_ranges_pruned_statistics") - .map(|(pruned, _matched)| pruned) + .map(|pm| pm.total_pruned()) } /// The number of row_groups matched by bloom filter or statistics @@ -200,14 +235,13 @@ impl TestOutput { /// filter: 7 total -> 3 matched, this function returns 3 for the final matched /// count. fn row_groups_matched(&self) -> Option { - self.row_groups_bloom_filter() - .map(|(_pruned, matched)| matched) + self.row_groups_bloom_filter().map(|pm| pm.total_matched()) } /// The number of row_groups pruned fn row_groups_pruned(&self) -> Option { self.row_groups_bloom_filter() - .map(|(pruned, _matched)| pruned) + .map(|pm| pm.total_pruned()) .zip(self.row_groups_pruned_statistics()) .map(|(a, b)| a + b) } @@ -215,7 +249,13 @@ impl TestOutput { /// The number of row pages pruned fn row_pages_pruned(&self) -> Option { self.pruning_metric("page_index_rows_pruned") - .map(|(pruned, _matched)| pruned) + .map(|pm| pm.total_pruned()) + } + + /// The number of row groups pruned by limit pruning + fn limit_pruned_row_groups(&self) -> Option { + self.pruning_metric("limit_pruned_row_groups") + .map(|pm| pm.total_pruned()) } fn description(&self) -> String { @@ -231,20 +271,41 @@ impl TestOutput { /// and the appropriate scenario impl ContextWithParquet { async fn new(scenario: Scenario, unit: Unit) -> Self { - Self::with_config(scenario, unit, SessionConfig::new()).await + Self::with_config(scenario, unit, SessionConfig::new(), None, None).await + } + + /// Set custom schema and batches for the test + pub async fn with_custom_data( + scenario: Scenario, + unit: Unit, + schema: Arc, + batches: Vec, + ) -> Self { + Self::with_config( + scenario, + unit, + SessionConfig::new(), + Some(schema), + Some(batches), + ) + .await } async fn with_config( scenario: Scenario, unit: Unit, mut config: SessionConfig, + custom_schema: Option, + custom_batches: Option>, ) -> Self { // Use a single partition for deterministic results no matter how many CPUs the host has config = config.with_target_partitions(1); let file = match unit { Unit::RowGroup(row_per_group) => { config = config.with_parquet_bloom_filter_pruning(true); - make_test_file_rg(scenario, row_per_group).await + config.options_mut().execution.parquet.pushdown_filters = true; + make_test_file_rg(scenario, row_per_group, custom_schema, custom_batches) + .await } Unit::Page(row_per_page) => { config = config.with_parquet_page_index_pruning(true); @@ -515,9 +576,9 @@ fn make_uint_batches(start: u8, end: u8) -> RecordBatch { Field::new("u64", DataType::UInt64, true), ])); let v8: Vec = (start..end).collect(); - let v16: Vec = (start as _..end as _).collect(); - let v32: Vec = (start as _..end as _).collect(); - let v64: Vec = (start as _..end as _).collect(); + let v16: Vec = (start as u16..end as u16).collect(); + let v32: Vec = (start as u32..end as u32).collect(); + let v64: Vec = (start as u64..end as u64).collect(); RecordBatch::try_new( schema, vec![ @@ -1074,7 +1135,12 @@ fn create_data_batch(scenario: Scenario) -> Vec { } /// Create a test parquet file with various data types -async fn make_test_file_rg(scenario: Scenario, row_per_group: usize) -> NamedTempFile { +async fn make_test_file_rg( + scenario: Scenario, + row_per_group: usize, + custom_schema: Option, + custom_batches: Option>, +) -> NamedTempFile { let mut output_file = tempfile::Builder::new() .prefix("parquet_pruning") .suffix(".parquet") @@ -1082,13 +1148,19 @@ async fn make_test_file_rg(scenario: Scenario, row_per_group: usize) -> NamedTem .expect("tempfile creation"); let props = WriterProperties::builder() - .set_max_row_group_size(row_per_group) + .set_max_row_group_row_count(Some(row_per_group)) .set_bloom_filter_enabled(true) .set_statistics_enabled(EnabledStatistics::Page) .build(); - let batches = create_data_batch(scenario); - let schema = batches[0].schema(); + let (batches, schema) = + if let (Some(schema), Some(batches)) = (custom_schema, custom_batches) { + (batches, schema) + } else { + let batches = create_data_batch(scenario); + let schema = batches[0].schema(); + (batches, schema) + }; let mut writer = ArrowWriter::try_new(&mut output_file, schema, Some(props)).unwrap(); diff --git a/datafusion/core/tests/parquet/ordering.rs b/datafusion/core/tests/parquet/ordering.rs new file mode 100644 index 000000000000..faecb4ca6a86 --- /dev/null +++ b/datafusion/core/tests/parquet/ordering.rs @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for ordering in Parquet sorting_columns metadata + +use datafusion::prelude::SessionContext; +use datafusion_common::Result; +use tempfile::tempdir; + +/// Test that CREATE TABLE ... WITH ORDER writes sorting_columns to Parquet metadata +#[tokio::test] +async fn test_create_table_with_order_writes_sorting_columns() -> Result<()> { + use parquet::file::reader::FileReader; + use parquet::file::serialized_reader::SerializedFileReader; + use std::fs::File; + + let ctx = SessionContext::new(); + let tmp_dir = tempdir()?; + let table_path = tmp_dir.path().join("sorted_table"); + std::fs::create_dir_all(&table_path)?; + + // Create external table with ordering + let create_table_sql = format!( + "CREATE EXTERNAL TABLE sorted_data (a INT, b VARCHAR) \ + STORED AS PARQUET \ + LOCATION '{}' \ + WITH ORDER (a ASC NULLS FIRST, b DESC NULLS LAST)", + table_path.display() + ); + ctx.sql(&create_table_sql).await?; + + // Insert sorted data + ctx.sql("INSERT INTO sorted_data VALUES (1, 'x'), (2, 'y'), (3, 'z')") + .await? + .collect() + .await?; + + // Find the parquet file that was written + let parquet_files: Vec<_> = std::fs::read_dir(&table_path)? + .filter_map(|e| e.ok()) + .filter(|e| e.path().extension().is_some_and(|ext| ext == "parquet")) + .collect(); + + assert!( + !parquet_files.is_empty(), + "Expected at least one parquet file in {}", + table_path.display() + ); + + // Read the parquet file and verify sorting_columns metadata + let file = File::open(parquet_files[0].path())?; + let reader = SerializedFileReader::new(file)?; + let metadata = reader.metadata(); + + // Check that row group has sorting_columns + let row_group = metadata.row_group(0); + let sorting_columns = row_group.sorting_columns(); + + assert!( + sorting_columns.is_some(), + "Expected sorting_columns in row group metadata" + ); + let sorting = sorting_columns.unwrap(); + assert_eq!(sorting.len(), 2, "Expected 2 sorting columns"); + + // First column: a ASC NULLS FIRST (column_idx = 0) + assert_eq!(sorting[0].column_idx, 0, "First sort column should be 'a'"); + assert!( + !sorting[0].descending, + "First column should be ASC (descending=false)" + ); + assert!( + sorting[0].nulls_first, + "First column should have NULLS FIRST" + ); + + // Second column: b DESC NULLS LAST (column_idx = 1) + assert_eq!(sorting[1].column_idx, 1, "Second sort column should be 'b'"); + assert!( + sorting[1].descending, + "Second column should be DESC (descending=true)" + ); + assert!( + !sorting[1].nulls_first, + "Second column should have NULLS LAST" + ); + + Ok(()) +} diff --git a/datafusion/core/tests/parquet/page_pruning.rs b/datafusion/core/tests/parquet/page_pruning.rs index 17392974b63a..6d49e0bcc676 100644 --- a/datafusion/core/tests/parquet/page_pruning.rs +++ b/datafusion/core/tests/parquet/page_pruning.rs @@ -20,7 +20,8 @@ use std::sync::Arc; use crate::parquet::Unit::Page; use crate::parquet::{ContextWithParquet, Scenario}; -use arrow::array::RecordBatch; +use arrow::array::{Int32Array, RecordBatch}; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion::datasource::file_format::FileFormat; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::PartitionedFile; @@ -30,7 +31,7 @@ use datafusion::datasource::source::DataSourceExec; use datafusion::execution::context::SessionState; use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_plan::metrics::MetricValue; -use datafusion::prelude::SessionContext; +use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::{ScalarValue, ToDFSchema}; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::{Expr, col, lit}; @@ -40,6 +41,8 @@ use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use futures::StreamExt; use object_store::ObjectMeta; use object_store::path::Path; +use parquet::arrow::ArrowWriter; +use parquet::file::properties::WriterProperties; async fn get_parquet_exec( state: &SessionState, @@ -67,14 +70,7 @@ async fn get_parquet_exec( .await .unwrap(); - let partitioned_file = PartitionedFile { - object_meta: meta, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }; + let partitioned_file = PartitionedFile::new_from_meta(meta); let df_schema = schema.clone().to_dfschema().unwrap(); let execution_props = ExecutionProps::new(); @@ -968,3 +964,56 @@ fn cast_count_metric(metric: MetricValue) -> Option { _ => None, } } + +#[tokio::test] +async fn test_parquet_opener_without_page_index() { + // Defines a simple schema and batch + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + + // Create a temp file + let file = tempfile::Builder::new() + .suffix(".parquet") + .tempfile() + .unwrap(); + let path = file.path().to_str().unwrap().to_string(); + + // Write parquet WITHOUT page index + // The default WriterProperties does not write page index, but we set it explicitly + // to be robust against future changes in defaults as requested by reviewers. + let props = WriterProperties::builder() + .set_statistics_enabled(parquet::file::properties::EnabledStatistics::None) + .build(); + + let file_fs = std::fs::File::create(&path).unwrap(); + let mut writer = ArrowWriter::try_new(file_fs, batch.schema(), Some(props)).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + + // Setup SessionContext with PageIndex enabled + // This triggers the ParquetOpener to try and load page index if available + let config = SessionConfig::new().with_parquet_page_index_pruning(true); + + let ctx = SessionContext::new_with_config(config); + + // Register the table + ctx.register_parquet("t", &path, Default::default()) + .await + .unwrap(); + + // Query the table + // If the bug exists, this might fail because Opener tries to load PageIndex forcefully + let df = ctx.sql("SELECT * FROM t").await.unwrap(); + let batches = df + .collect() + .await + .expect("Failed to read parquet file without page index"); + + // We expect this to succeed, but currently it might fail + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].num_rows(), 3); +} diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index 0411298055f2..445ae7e97f22 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -18,8 +18,12 @@ //! This file contains an end to end test of parquet pruning. It writes //! data into a parquet file and then verifies row groups are pruned as //! expected. +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int32Array, RecordBatch}; +use arrow_schema::{DataType, Field, Schema}; use datafusion::prelude::SessionConfig; -use datafusion_common::ScalarValue; +use datafusion_common::{DataFusionError, ScalarValue}; use itertools::Itertools; use crate::parquet::Unit::RowGroup; @@ -30,10 +34,12 @@ struct RowGroupPruningTest { query: String, expected_errors: Option, expected_row_group_matched_by_statistics: Option, + expected_row_group_fully_matched_by_statistics: Option, expected_row_group_pruned_by_statistics: Option, expected_files_pruned_by_statistics: Option, expected_row_group_matched_by_bloom_filter: Option, expected_row_group_pruned_by_bloom_filter: Option, + expected_limit_pruned_row_groups: Option, expected_rows: usize, } impl RowGroupPruningTest { @@ -45,9 +51,11 @@ impl RowGroupPruningTest { expected_errors: None, expected_row_group_matched_by_statistics: None, expected_row_group_pruned_by_statistics: None, + expected_row_group_fully_matched_by_statistics: None, expected_files_pruned_by_statistics: None, expected_row_group_matched_by_bloom_filter: None, expected_row_group_pruned_by_bloom_filter: None, + expected_limit_pruned_row_groups: None, expected_rows: 0, } } @@ -76,6 +84,15 @@ impl RowGroupPruningTest { self } + // Set the expected fully matched row groups by statistics + fn with_fully_matched_by_stats( + mut self, + fully_matched_by_stats: Option, + ) -> Self { + self.expected_row_group_fully_matched_by_statistics = fully_matched_by_stats; + self + } + // Set the expected pruned row groups by statistics fn with_pruned_by_stats(mut self, pruned_by_stats: Option) -> Self { self.expected_row_group_pruned_by_statistics = pruned_by_stats; @@ -99,6 +116,11 @@ impl RowGroupPruningTest { self } + fn with_limit_pruned_row_groups(mut self, pruned_by_limit: Option) -> Self { + self.expected_limit_pruned_row_groups = pruned_by_limit; + self + } + /// Set the number of expected rows from the output of this test fn with_expected_rows(mut self, rows: usize) -> Self { self.expected_rows = rows; @@ -135,15 +157,74 @@ impl RowGroupPruningTest { ); let bloom_filter_metrics = output.row_groups_bloom_filter(); assert_eq!( - bloom_filter_metrics.map(|(_pruned, matched)| matched), + bloom_filter_metrics.as_ref().map(|pm| pm.total_matched()), self.expected_row_group_matched_by_bloom_filter, "mismatched row_groups_matched_bloom_filter", ); assert_eq!( - bloom_filter_metrics.map(|(pruned, _matched)| pruned), + bloom_filter_metrics.map(|pm| pm.total_pruned()), self.expected_row_group_pruned_by_bloom_filter, "mismatched row_groups_pruned_bloom_filter", ); + + assert_eq!( + output.result_rows, + self.expected_rows, + "Expected {} rows, got {}: {}", + output.result_rows, + self.expected_rows, + output.description(), + ); + } + + // Execute the test with the current configuration + async fn test_row_group_prune_with_custom_data( + self, + schema: Arc, + batches: Vec, + max_row_per_group: usize, + ) { + let output = ContextWithParquet::with_custom_data( + self.scenario, + RowGroup(max_row_per_group), + schema, + batches, + ) + .await + .query(&self.query) + .await; + + println!("{}", output.description()); + assert_eq!( + output.predicate_evaluation_errors(), + self.expected_errors, + "mismatched predicate_evaluation error" + ); + assert_eq!( + output.row_groups_matched_statistics(), + self.expected_row_group_matched_by_statistics, + "mismatched row_groups_matched_statistics", + ); + assert_eq!( + output.row_groups_fully_matched_statistics(), + self.expected_row_group_fully_matched_by_statistics, + "mismatched row_groups_fully_matched_statistics", + ); + assert_eq!( + output.row_groups_pruned_statistics(), + self.expected_row_group_pruned_by_statistics, + "mismatched row_groups_pruned_statistics", + ); + assert_eq!( + output.files_ranges_pruned_statistics(), + self.expected_files_pruned_by_statistics, + "mismatched files_ranges_pruned_statistics", + ); + assert_eq!( + output.limit_pruned_row_groups(), + self.expected_limit_pruned_row_groups, + "mismatched limit_pruned_row_groups", + ); assert_eq!( output.result_rows, self.expected_rows, @@ -289,11 +370,16 @@ async fn prune_disabled() { let expected_rows = 10; let config = SessionConfig::new().with_parquet_pruning(false); - let output = - ContextWithParquet::with_config(Scenario::Timestamps, RowGroup(5), config) - .await - .query(query) - .await; + let output = ContextWithParquet::with_config( + Scenario::Timestamps, + RowGroup(5), + config, + None, + None, + ) + .await + .query(query) + .await; println!("{}", output.description()); // This should not prune any @@ -1636,3 +1722,240 @@ async fn test_bloom_filter_decimal_dict() { .test_row_group_prune() .await; } + +// Helper function to create a batch with a single Int32 column. +fn make_i32_batch( + name: &str, + values: Vec, +) -> datafusion_common::error::Result { + let schema = Arc::new(Schema::new(vec![Field::new(name, DataType::Int32, false)])); + let array: ArrayRef = Arc::new(Int32Array::from(values)); + RecordBatch::try_new(schema, vec![array]).map_err(DataFusionError::from) +} + +// Helper function to create a batch with two Int32 columns +fn make_two_col_i32_batch( + name_a: &str, + name_b: &str, + values_a: Vec, + values_b: Vec, +) -> datafusion_common::error::Result { + let schema = Arc::new(Schema::new(vec![ + Field::new(name_a, DataType::Int32, false), + Field::new(name_b, DataType::Int32, false), + ])); + let array_a: ArrayRef = Arc::new(Int32Array::from(values_a)); + let array_b: ArrayRef = Arc::new(Int32Array::from(values_b)); + RecordBatch::try_new(schema, vec![array_a, array_b]).map_err(DataFusionError::from) +} + +#[tokio::test] +async fn test_limit_pruning_basic() -> datafusion_common::error::Result<()> { + // Scenario: Simple integer column, multiple row groups + // Query: SELECT c1 FROM t WHERE c1 = 0 LIMIT 2 + // We expect 2 rows in total. + + // Row Group 0: c1 = [0, -2] -> Partially matched, 1 row + // Row Group 1: c1 = [1, 2] -> Fully matched, 2 rows + // Row Group 2: c1 = [3, 4] -> Fully matched, 2 rows + // Row Group 3: c1 = [5, 6] -> Fully matched, 2 rows + // Row Group 4: c1 = [-1, -2] -> Not matched + + // If limit = 2, and RG1 is fully matched and has 2 rows, we should + // only scan RG1 and prune other row groups + // RG4 is pruned by statistics. RG2 and RG3 are pruned by limit. + // So 2 row groups are effectively pruned due to limit pruning. + + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); + let query = "SELECT c1 FROM t WHERE c1 >= 0 LIMIT 2"; + + let batches = vec![ + make_i32_batch("c1", vec![0, -2])?, + make_i32_batch("c1", vec![0, 0])?, + make_i32_batch("c1", vec![0, 0])?, + make_i32_batch("c1", vec![0, 0])?, + make_i32_batch("c1", vec![-1, -2])?, + ]; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) // Assuming Scenario::Int can handle this data + .with_query(query) + .with_expected_errors(Some(0)) + .with_expected_rows(2) + .with_pruned_files(Some(0)) + .with_matched_by_stats(Some(4)) + .with_fully_matched_by_stats(Some(3)) + .with_pruned_by_stats(Some(1)) + .with_limit_pruned_row_groups(Some(3)) + .test_row_group_prune_with_custom_data(schema, batches, 2) + .await; + + Ok(()) +} + +#[tokio::test] +async fn test_limit_pruning_complex_filter() -> datafusion_common::error::Result<()> { + // Test Case 1: Complex filter with two columns (a = 1 AND b > 1 AND b < 4) + // Row Group 0: a=[1,1,1], b=[0,2,3] -> Partially matched, 2 rows match (b=2,3) + // Row Group 1: a=[1,1,1], b=[2,2,2] -> Fully matched, 3 rows + // Row Group 2: a=[1,1,1], b=[2,3,3] -> Fully matched, 3 rows + // Row Group 3: a=[1,1,1], b=[2,2,3] -> Fully matched, 3 rows + // Row Group 4: a=[2,2,2], b=[2,2,2] -> Not matched (a != 1) + // Row Group 5: a=[1,1,1], b=[5,6,7] -> Not matched (b >= 4) + + // With LIMIT 5, we need RG1 (3 rows) + RG2 (2 rows from 3) = 5 rows + // RG4 and RG5 should be pruned by statistics + // RG3 should be pruned by limit + // RG0 is partially matched, so it depends on the order + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + let query = "SELECT a, b FROM t WHERE a = 1 AND b > 1 AND b < 4 LIMIT 5"; + + let batches = vec![ + make_two_col_i32_batch("a", "b", vec![1, 1, 1], vec![0, 2, 3])?, + make_two_col_i32_batch("a", "b", vec![1, 1, 1], vec![2, 2, 2])?, + make_two_col_i32_batch("a", "b", vec![1, 1, 1], vec![2, 3, 3])?, + make_two_col_i32_batch("a", "b", vec![1, 1, 1], vec![2, 2, 3])?, + make_two_col_i32_batch("a", "b", vec![2, 2, 2], vec![2, 2, 2])?, + make_two_col_i32_batch("a", "b", vec![1, 1, 1], vec![5, 6, 7])?, + ]; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(query) + .with_expected_errors(Some(0)) + .with_expected_rows(5) + .with_pruned_files(Some(0)) + .with_matched_by_stats(Some(4)) // RG0,1,2,3 are matched + .with_fully_matched_by_stats(Some(3)) + .with_pruned_by_stats(Some(2)) // RG4,5 are pruned + .with_limit_pruned_row_groups(Some(2)) // RG0, RG3 is pruned by limit + .test_row_group_prune_with_custom_data(schema, batches, 3) + .await; + + Ok(()) +} + +#[tokio::test] +async fn test_limit_pruning_multiple_fully_matched() +-> datafusion_common::error::Result<()> { + // Test Case 2: Limit requires multiple fully matched row groups + // Row Group 0: a=[5,5,5,5] -> Fully matched, 4 rows + // Row Group 1: a=[5,5,5,5] -> Fully matched, 4 rows + // Row Group 2: a=[5,5,5,5] -> Fully matched, 4 rows + // Row Group 3: a=[5,5,5,5] -> Fully matched, 4 rows + // Row Group 4: a=[1,2,3,4] -> Not matched + + // With LIMIT 8, we need RG0 (4 rows) + RG1 (4 rows) 8 rows + // RG2,3 should be pruned by limit + // RG4 should be pruned by statistics + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let query = "SELECT a FROM t WHERE a = 5 LIMIT 8"; + + let batches = vec![ + make_i32_batch("a", vec![5, 5, 5, 5])?, + make_i32_batch("a", vec![5, 5, 5, 5])?, + make_i32_batch("a", vec![5, 5, 5, 5])?, + make_i32_batch("a", vec![5, 5, 5, 5])?, + make_i32_batch("a", vec![1, 2, 3, 4])?, + ]; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(query) + .with_expected_errors(Some(0)) + .with_expected_rows(8) + .with_pruned_files(Some(0)) + .with_matched_by_stats(Some(4)) // RG0,1,2,3 matched + .with_fully_matched_by_stats(Some(4)) + .with_pruned_by_stats(Some(1)) // RG4 pruned + .with_limit_pruned_row_groups(Some(2)) // RG2,3 pruned by limit + .test_row_group_prune_with_custom_data(schema, batches, 4) + .await; + + Ok(()) +} + +#[tokio::test] +async fn test_limit_pruning_no_fully_matched() -> datafusion_common::error::Result<()> { + // Test Case 3: No fully matched row groups - all are partially matched + // Row Group 0: a=[1,2,3] -> Partially matched, 1 row (a=2) + // Row Group 1: a=[2,3,4] -> Partially matched, 1 row (a=2) + // Row Group 2: a=[2,5,6] -> Partially matched, 1 row (a=2) + // Row Group 3: a=[2,7,8] -> Partially matched, 1 row (a=2) + // Row Group 4: a=[9,10,11] -> Not matched + + // With LIMIT 3, we need to scan RG0,1,2 to get 3 matching rows + // Cannot prune much by limit since all matching RGs are partial + // RG4 should be pruned by statistics + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let query = "SELECT a FROM t WHERE a = 2 LIMIT 3"; + + let batches = vec![ + make_i32_batch("a", vec![1, 2, 3])?, + make_i32_batch("a", vec![2, 3, 4])?, + make_i32_batch("a", vec![2, 5, 6])?, + make_i32_batch("a", vec![2, 7, 8])?, + make_i32_batch("a", vec![9, 10, 11])?, + ]; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(query) + .with_expected_errors(Some(0)) + .with_expected_rows(3) + .with_pruned_files(Some(0)) + .with_matched_by_stats(Some(4)) // RG0,1,2,3 matched + .with_fully_matched_by_stats(Some(0)) + .with_pruned_by_stats(Some(1)) // RG4 pruned + .with_limit_pruned_row_groups(Some(0)) // RG3 pruned by limit + .test_row_group_prune_with_custom_data(schema, batches, 3) + .await; + + Ok(()) +} + +#[tokio::test] +async fn test_limit_pruning_exceeds_fully_matched() -> datafusion_common::error::Result<()> +{ + // Test Case 4: Limit exceeds all fully matched rows, need partially matched + // Row Group 0: a=[10,11,12,12] -> Partially matched, 1 row (a=10) + // Row Group 1: a=[10,10,10,10] -> Fully matched, 4 rows + // Row Group 2: a=[10,10,10,10] -> Fully matched, 4 rows + // Row Group 3: a=[10,13,14,11] -> Partially matched, 1 row (a=10) + // Row Group 4: a=[20,21,22,22] -> Not matched + + // With LIMIT 10, we need RG1 (4) + RG2 (4) = 8 from fully matched + // Still need 2 more, so we need to scan partially matched RG0 and RG3 + // All matching row groups should be scanned, only RG4 pruned by statistics + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let query = "SELECT a FROM t WHERE a = 10 LIMIT 10"; + + let batches = vec![ + make_i32_batch("a", vec![10, 11, 12, 12])?, + make_i32_batch("a", vec![10, 10, 10, 10])?, + make_i32_batch("a", vec![10, 10, 10, 10])?, + make_i32_batch("a", vec![10, 13, 14, 11])?, + make_i32_batch("a", vec![20, 21, 22, 22])?, + ]; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(query) + .with_expected_errors(Some(0)) + .with_expected_rows(10) // Total: 1 + 4 + 4 + 1 = 10 + .with_pruned_files(Some(0)) + .with_matched_by_stats(Some(4)) // RG0,1,2,3 matched + .with_fully_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) // RG4 pruned + .with_limit_pruned_row_groups(Some(0)) // No limit pruning since we need all RGs + .test_row_group_prune_with_custom_data(schema, batches, 4) + .await; + Ok(()) +} diff --git a/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs b/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs index 1fdc0ae6c7f6..4218f76fa135 100644 --- a/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs @@ -20,11 +20,15 @@ use std::sync::Arc; use crate::physical_optimizer::test_utils::TestAggregate; use arrow::array::Int32Array; +use arrow::array::{Int64Array, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; +use datafusion::datasource::memory::MemTable; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::source::DataSourceExec; +use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::Result; +use datafusion_common::assert_batches_eq; use datafusion_common::cast::as_int64_array; use datafusion_common::config::ConfigOptions; use datafusion_execution::TaskContext; @@ -38,6 +42,7 @@ use datafusion_physical_plan::aggregates::AggregateMode; use datafusion_physical_plan::aggregates::PhysicalGroupBy; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::common; +use datafusion_physical_plan::displayable; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::projection::ProjectionExec; @@ -316,3 +321,84 @@ async fn test_count_with_nulls_inexact_stat() -> Result<()> { Ok(()) } + +/// Tests that TopK aggregation correctly handles UTF-8 (string) types in both grouping keys and aggregate values. +/// +/// The TopK optimization is designed to efficiently handle `GROUP BY ... ORDER BY aggregate LIMIT n` queries +/// by maintaining only the top K groups during aggregation. However, not all type combinations are supported. +/// +/// This test verifies two scenarios: +/// 1. **Supported case**: UTF-8 grouping key with numeric aggregate (max/min) - should use TopK optimization +/// 2. **Unsupported case**: UTF-8 grouping key with UTF-8 aggregate value - must gracefully fall back to +/// standard aggregation without panicking +/// +/// The fallback behavior is critical because attempting to use TopK with unsupported types could cause +/// runtime panics. This test ensures the optimizer correctly detects incompatible types and chooses +/// the appropriate execution path. +#[tokio::test] +async fn utf8_grouping_min_max_limit_fallbacks() -> Result<()> { + let mut config = SessionConfig::new(); + config.options_mut().optimizer.enable_topk_aggregation = true; + let ctx = SessionContext::new_with_config(config); + + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("g", DataType::Utf8, false), + Field::new("val_str", DataType::Utf8, false), + Field::new("val_num", DataType::Int64, false), + ])), + vec![ + Arc::new(StringArray::from(vec!["a", "b", "a"])), + Arc::new(StringArray::from(vec!["alpha", "bravo", "charlie"])), + Arc::new(Int64Array::from(vec![1, 2, 3])), + ], + )?; + let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; + ctx.register_table("t", Arc::new(table))?; + + // Supported path: numeric min/max with UTF-8 grouping should still use TopK aggregation + // and return correct results. + let supported_df = ctx + .sql("SELECT g, max(val_num) AS m FROM t GROUP BY g ORDER BY m DESC LIMIT 1") + .await?; + let supported_batches = supported_df.collect().await?; + assert_batches_eq!( + &[ + "+---+---+", + "| g | m |", + "+---+---+", + "| a | 3 |", + "+---+---+" + ], + &supported_batches + ); + + // Unsupported TopK value type: string min/max should fall back without panicking. + let unsupported_df = ctx + .sql("SELECT g, max(val_str) AS s FROM t GROUP BY g ORDER BY s DESC LIMIT 1") + .await?; + let unsupported_plan = unsupported_df.clone().create_physical_plan().await?; + let unsupported_batches = unsupported_df.collect().await?; + + // Ensure the plan avoided the TopK-specific stream implementation. + let plan_display = displayable(unsupported_plan.as_ref()) + .indent(true) + .to_string(); + assert!( + !plan_display.contains("GroupedTopKAggregateStream"), + "Unsupported UTF-8 aggregate value should not use TopK: {plan_display}" + ); + + assert_batches_eq!( + &[ + "+---+---------+", + "| g | s |", + "+---+---------+", + "| a | charlie |", + "+---+---------+" + ], + &unsupported_batches + ); + + Ok(()) +} diff --git a/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs index 2fdfece2a86e..9e63c341c92d 100644 --- a/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs @@ -37,7 +37,7 @@ use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_optimizer::combine_partial_final_agg::CombinePartialFinalAggregate; use datafusion_physical_plan::ExecutionPlan; use datafusion_physical_plan::aggregates::{ - AggregateExec, AggregateMode, PhysicalGroupBy, + AggregateExec, AggregateMode, LimitOptions, PhysicalGroupBy, }; use datafusion_physical_plan::displayable; use datafusion_physical_plan::repartition::RepartitionExec; @@ -260,7 +260,7 @@ fn aggregations_with_limit_combined() -> datafusion_common::Result<()> { schema, ) .unwrap() - .with_limit(Some(5)), + .with_limit_options(Some(LimitOptions::new(5))), ); let plan: Arc = final_agg; // should combine the Partial/Final AggregateExecs to a Single AggregateExec diff --git a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs index 7cedaf86cb52..5df634c70bcb 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs @@ -56,7 +56,7 @@ use datafusion_physical_optimizer::output_requirements::OutputRequirements; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; -use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; + use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::execution_plan::ExecutionPlan; use datafusion_physical_plan::expressions::col; @@ -67,8 +67,7 @@ use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::union::UnionExec; use datafusion_physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlanProperties, PlanProperties, Statistics, - displayable, + DisplayAs, DisplayFormatType, ExecutionPlanProperties, PlanProperties, displayable, }; use insta::Settings; @@ -120,7 +119,7 @@ macro_rules! assert_plan { struct SortRequiredExec { input: Arc, expr: LexOrdering, - cache: PlanProperties, + cache: Arc, } impl SortRequiredExec { @@ -132,7 +131,7 @@ impl SortRequiredExec { Self { input, expr: requirement, - cache, + cache: Arc::new(cache), } } @@ -174,7 +173,7 @@ impl ExecutionPlan for SortRequiredExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -210,10 +209,6 @@ impl ExecutionPlan for SortRequiredExec { ) -> Result { unreachable!(); } - - fn statistics(&self) -> Result { - self.input.partition_statistics(None) - } } fn parquet_exec() -> Arc { @@ -1741,9 +1736,6 @@ fn merge_does_not_need_sort() -> Result<()> { // Scan some sorted parquet files let exec = parquet_exec_multiple_sorted(vec![sort_key.clone()]); - // CoalesceBatchesExec to mimic behavior after a filter - let exec = Arc::new(CoalesceBatchesExec::new(exec, 4096)); - // Merge from multiple parquet files and keep the data sorted let exec: Arc = Arc::new(SortPreservingMergeExec::new(sort_key, exec)); @@ -1757,8 +1749,7 @@ fn merge_does_not_need_sort() -> Result<()> { assert_plan!(plan_distrib, @r" SortPreservingMergeExec: [a@0 ASC] - CoalesceBatchesExec: target_batch_size=4096 - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet "); // Test: result IS DIFFERENT, if EnforceSorting is run first: @@ -1772,8 +1763,7 @@ fn merge_does_not_need_sort() -> Result<()> { @r" SortExec: expr=[a@0 ASC], preserve_partitioning=[false] CoalescePartitionsExec - CoalesceBatchesExec: target_batch_size=4096 - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet "); Ok(()) diff --git a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs index 47e3adb45511..6349ff1cd109 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs @@ -20,21 +20,20 @@ use std::sync::Arc; use crate::memory_limit::DummyStreamPartition; use crate::physical_optimizer::test_utils::{ RequirementsTestExec, aggregate_exec, bounded_window_exec, - bounded_window_exec_with_partition, check_integrity, coalesce_batches_exec, - coalesce_partitions_exec, create_test_schema, create_test_schema2, - create_test_schema3, filter_exec, global_limit_exec, hash_join_exec, - local_limit_exec, memory_exec, parquet_exec, parquet_exec_with_sort, projection_exec, - repartition_exec, sort_exec, sort_exec_with_fetch, sort_expr, sort_expr_options, - sort_merge_join_exec, sort_preserving_merge_exec, - sort_preserving_merge_exec_with_fetch, spr_repartition_exec, stream_exec_ordered, - union_exec, + bounded_window_exec_with_partition, check_integrity, coalesce_partitions_exec, + create_test_schema, create_test_schema2, create_test_schema3, filter_exec, + global_limit_exec, hash_join_exec, local_limit_exec, memory_exec, parquet_exec, + parquet_exec_with_sort, projection_exec, repartition_exec, sort_exec, + sort_exec_with_fetch, sort_expr, sort_expr_options, sort_merge_join_exec, + sort_preserving_merge_exec, sort_preserving_merge_exec_with_fetch, + spr_repartition_exec, stream_exec_ordered, union_exec, }; -use arrow::compute::SortOptions; +use arrow::compute::{SortOptions}; use arrow::datatypes::{DataType, SchemaRef}; use datafusion_common::config::{ConfigOptions, CsvOptions}; use datafusion_common::tree_node::{TreeNode, TransformedResult}; -use datafusion_common::{Result, TableReference}; +use datafusion_common::{create_array, Result, TableReference}; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; use datafusion_expr_common::operator::Operator; @@ -59,7 +58,7 @@ use datafusion_physical_optimizer::enforce_distribution::EnforceDistribution; use datafusion_physical_optimizer::output_requirements::OutputRequirementExec; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion::prelude::*; -use arrow::array::{Int32Array, RecordBatch}; +use arrow::array::{record_batch, ArrayRef, Int32Array, RecordBatch}; use arrow::datatypes::{Field}; use arrow_schema::Schema; use datafusion_execution::TaskContext; @@ -1845,9 +1844,7 @@ async fn test_remove_unnecessary_sort_window_multilayer() -> Result<()> { )] .into(); let sort = sort_exec(ordering.clone(), source); - // Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before - let coalesce_batches = coalesce_batches_exec(sort, 128); - let window_agg = bounded_window_exec("non_nullable_col", ordering, coalesce_batches); + let window_agg = bounded_window_exec("non_nullable_col", ordering, sort); let ordering2: LexOrdering = [sort_expr_options( "non_nullable_col", &window_agg.schema(), @@ -1873,17 +1870,15 @@ async fn test_remove_unnecessary_sort_window_multilayer() -> Result<()> { FilterExec: NOT non_nullable_col@1 SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false] BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - CoalesceBatchesExec: target_batch_size=128 - SortExec: expr=[non_nullable_col@1 DESC], preserve_partitioning=[false] - DataSourceExec: partitions=1, partition_sizes=[0] + SortExec: expr=[non_nullable_col@1 DESC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] Optimized Plan: WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] FilterExec: NOT non_nullable_col@1 BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - CoalesceBatchesExec: target_batch_size=128 - SortExec: expr=[non_nullable_col@1 DESC], preserve_partitioning=[false] - DataSourceExec: partitions=1, partition_sizes=[0] + SortExec: expr=[non_nullable_col@1 DESC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] "#); Ok(()) @@ -2810,3 +2805,47 @@ async fn test_partial_sort_with_homogeneous_batches() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_sort_with_streaming_table() -> Result<()> { + let batch = record_batch!(("a", Int32, [1, 2, 3]), ("b", Int32, [1, 2, 3]))?; + + let ctx = SessionContext::new(); + + let sort_order = vec![ + SortExpr::new( + Expr::Column(datafusion_common::Column::new( + Option::::None, + "a", + )), + true, + false, + ), + SortExpr::new( + Expr::Column(datafusion_common::Column::new( + Option::::None, + "b", + )), + true, + false, + ), + ]; + let schema = batch.schema(); + let batches = Arc::new(DummyStreamPartition { + schema: schema.clone(), + batches: vec![batch], + }) as _; + let provider = StreamingTable::try_new(schema.clone(), vec![batches])? + .with_sort_order(sort_order); + ctx.register_table("test_table", Arc::new(provider))?; + + let sql = "SELECT a FROM test_table GROUP BY a ORDER BY a"; + let results = ctx.sql(sql).await?.collect().await?; + + assert_eq!(results.len(), 1); + assert_eq!(results[0].num_columns(), 1); + let expected = create_array!(Int32, vec![1, 2, 3]) as ArrayRef; + assert_eq!(results[0].column(0), &expected); + + Ok(()) +} diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown.rs similarity index 79% rename from datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs rename to datafusion/core/tests/physical_optimizer/filter_pushdown.rs index d6357fdf6bc7..99db81d34d8f 100644 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown.rs @@ -56,23 +56,21 @@ use datafusion_physical_optimizer::{ use datafusion_physical_plan::{ ExecutionPlan, aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}, - coalesce_batches::CoalesceBatchesExec, coalesce_partitions::CoalescePartitionsExec, collect, - filter::FilterExec, + filter::{FilterExec, FilterExecBuilder}, + projection::ProjectionExec, repartition::RepartitionExec, sorts::sort::SortExec, }; +use super::pushdown_utils::{ + OptimizationTest, TestNode, TestScanBuilder, TestSource, format_plan_for_test, +}; use datafusion_physical_plan::union::UnionExec; use futures::StreamExt; use object_store::{ObjectStore, memory::InMemory}; use regex::Regex; -use util::{OptimizationTest, TestNode, TestScanBuilder, format_plan_for_test}; - -use crate::physical_optimizer::filter_pushdown::util::TestSource; - -mod util; #[test] fn test_pushdown_into_scan() { @@ -234,6 +232,7 @@ async fn test_dynamic_filter_pushdown_through_hash_join_with_topk() { None, PartitionMode::Partitioned, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); @@ -355,6 +354,7 @@ async fn test_static_filter_pushdown_through_hash_join() { None, PartitionMode::Partitioned, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); @@ -401,7 +401,8 @@ async fn test_static_filter_pushdown_through_hash_join() { " ); - // Test left join - filters should NOT be pushed down + // Test left join: filter on preserved (build) side is pushed down, + // filter on non-preserved (probe) side is NOT pushed down. let join = Arc::new( HashJoinExec::try_new( TestScanBuilder::new(Arc::clone(&build_side_schema)) @@ -419,30 +420,36 @@ async fn test_static_filter_pushdown_through_hash_join() { None, PartitionMode::Partitioned, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); let join_schema = join.schema(); - let filter = col_lit_predicate("a", "aa", &join_schema); - let plan = - Arc::new(FilterExec::try_new(filter, join).unwrap()) as Arc; + // Filter on build side column (preserved): should be pushed down + let left_filter = col_lit_predicate("a", "aa", &join_schema); + // Filter on probe side column (not preserved): should NOT be pushed down + let right_filter = col_lit_predicate("e", "ba", &join_schema); + let filter = + Arc::new(FilterExec::try_new(left_filter, Arc::clone(&join) as _).unwrap()); + let plan = Arc::new(FilterExec::try_new(right_filter, filter).unwrap()) + as Arc; - // Test that filters are NOT pushed down for left join insta::assert_snapshot!( OptimizationTest::new(plan, FilterPushdown::new(), true), @r" OptimizationTest: input: - - FilterExec: a@0 = aa - - HashJoinExec: mode=Partitioned, join_type=Left, on=[(a@0, d@0)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true + - FilterExec: e@4 = ba + - FilterExec: a@0 = aa + - HashJoinExec: mode=Partitioned, join_type=Left, on=[(a@0, d@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true output: Ok: - - FilterExec: a@0 = aa + - FilterExec: e@4 = ba - HashJoinExec: mode=Partitioned, join_type=Left, on=[(a@0, d@0)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = aa - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true " ); @@ -478,9 +485,10 @@ fn test_filter_with_projection() { let projection = vec![1, 0]; let predicate = col_lit_predicate("a", "foo", &schema()); let plan = Arc::new( - FilterExec::try_new(predicate, Arc::clone(&scan)) + FilterExecBuilder::new(predicate, Arc::clone(&scan)) + .apply_projection(Some(projection)) .unwrap() - .with_projection(Some(projection)) + .build() .unwrap(), ); @@ -503,9 +511,10 @@ fn test_filter_with_projection() { let projection = vec![1]; let predicate = col_lit_predicate("a", "foo", &schema()); let plan = Arc::new( - FilterExec::try_new(predicate, scan) + FilterExecBuilder::new(predicate, scan) + .apply_projection(Some(projection)) .unwrap() - .with_projection(Some(projection)) + .build() .unwrap(), ); insta::assert_snapshot!( @@ -527,9 +536,8 @@ fn test_filter_with_projection() { fn test_push_down_through_transparent_nodes() { // expect the predicate to be pushed down into the DataSource let scan = TestScanBuilder::new(schema()).with_support(true).build(); - let coalesce = Arc::new(CoalesceBatchesExec::new(scan, 1)); let predicate = col_lit_predicate("a", "foo", &schema()); - let filter = Arc::new(FilterExec::try_new(predicate, coalesce).unwrap()); + let filter = Arc::new(FilterExec::try_new(predicate, scan).unwrap()); let repartition = Arc::new( RepartitionExec::try_new(filter, Partitioning::RoundRobinBatch(1)).unwrap(), ); @@ -545,13 +553,11 @@ fn test_push_down_through_transparent_nodes() { - FilterExec: b@1 = bar - RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=1 - FilterExec: a@0 = foo - - CoalesceBatchesExec: target_batch_size=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true output: Ok: - RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=1 - - CoalesceBatchesExec: target_batch_size=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo AND b@1 = bar + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo AND b@1 = bar " ); } @@ -564,10 +570,11 @@ fn test_pushdown_through_aggregates_on_grouping_columns() { // 2. An outer filter (b@1 = bar) above the aggregate - also gets pushed through because 'b' is a grouping column let scan = TestScanBuilder::new(schema()).with_support(true).build(); - let coalesce = Arc::new(CoalesceBatchesExec::new(scan, 10)); - let filter = Arc::new( - FilterExec::try_new(col_lit_predicate("a", "foo", &schema()), coalesce).unwrap(), + FilterExecBuilder::new(col_lit_predicate("a", "foo", &schema()), scan) + .with_batch_size(10) + .build() + .unwrap(), ); let aggregate_expr = vec![ @@ -594,10 +601,13 @@ fn test_pushdown_through_aggregates_on_grouping_columns() { .unwrap(), ); - let coalesce = Arc::new(CoalesceBatchesExec::new(aggregate, 100)); - let predicate = col_lit_predicate("b", "bar", &schema()); - let plan = Arc::new(FilterExec::try_new(predicate, coalesce).unwrap()); + let plan = Arc::new( + FilterExecBuilder::new(predicate, aggregate) + .with_batch_size(100) + .build() + .unwrap(), + ); // Both filters should be pushed down to the DataSource since both reference grouping columns insta::assert_snapshot!( @@ -606,17 +616,13 @@ fn test_pushdown_through_aggregates_on_grouping_columns() { OptimizationTest: input: - FilterExec: b@1 = bar - - CoalesceBatchesExec: target_batch_size=100 - - AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b], aggr=[cnt], ordering_mode=PartiallySorted([0]) - - FilterExec: a@0 = foo - - CoalesceBatchesExec: target_batch_size=10 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b], aggr=[cnt], ordering_mode=PartiallySorted([0]) + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true output: Ok: - - CoalesceBatchesExec: target_batch_size=100 - - AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b], aggr=[cnt], ordering_mode=Sorted - - CoalesceBatchesExec: target_batch_size=10 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo AND b@1 = bar + - AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b], aggr=[cnt], ordering_mode=Sorted + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo AND b@1 = bar " ); } @@ -921,61 +927,6 @@ async fn test_topk_filter_passes_through_coalesce_partitions() { ); } -#[tokio::test] -async fn test_topk_filter_passes_through_coalesce_batches() { - let batches = vec![ - record_batch!( - ("a", Utf8, ["aa", "ab"]), - ("b", Utf8, ["bd", "bc"]), - ("c", Float64, [1.0, 2.0]) - ) - .unwrap(), - record_batch!( - ("a", Utf8, ["ac", "ad"]), - ("b", Utf8, ["bb", "ba"]), - ("c", Float64, [2.0, 1.0]) - ) - .unwrap(), - ]; - - let scan = TestScanBuilder::new(schema()) - .with_support(true) - .with_batches(batches) - .build(); - - let coalesce_batches = - Arc::new(CoalesceBatchesExec::new(scan, 1024)) as Arc; - - // Add SortExec with TopK - let plan = Arc::new( - SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr::new( - col("b", &schema()).unwrap(), - SortOptions::new(true, false), - )]) - .unwrap(), - coalesce_batches, - ) - .with_fetch(Some(1)), - ) as Arc; - - insta::assert_snapshot!( - OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true), - @r" - OptimizationTest: - input: - - SortExec: TopK(fetch=1), expr=[b@1 DESC NULLS LAST], preserve_partitioning=[false] - - CoalesceBatchesExec: target_batch_size=1024 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - SortExec: TopK(fetch=1), expr=[b@1 DESC NULLS LAST], preserve_partitioning=[false] - - CoalesceBatchesExec: target_batch_size=1024 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] - " - ); -} - #[tokio::test] async fn test_hashjoin_dynamic_filter_pushdown() { use datafusion_common::JoinType; @@ -1040,6 +991,7 @@ async fn test_hashjoin_dynamic_filter_pushdown() { None, PartitionMode::CollectLeft, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ) as Arc; @@ -1118,23 +1070,11 @@ async fn test_hashjoin_dynamic_filter_pushdown_partitioned() { // | plan_type | plan | // +---------------+------------------------------------------------------------+ // | physical_plan | ┌───────────────────────────┐ | - // | | │ CoalesceBatchesExec │ | - // | | │ -------------------- │ | - // | | │ target_batch_size: │ | - // | | │ 8192 │ | - // | | └─────────────┬─────────────┘ | - // | | ┌─────────────┴─────────────┐ | // | | │ HashJoinExec │ | // | | │ -------------------- ├──────────────┐ | // | | │ on: (k = k) │ │ | // | | └─────────────┬─────────────┘ │ | // | | ┌─────────────┴─────────────┐┌─────────────┴─────────────┐ | - // | | │ CoalesceBatchesExec ││ CoalesceBatchesExec │ | - // | | │ -------------------- ││ -------------------- │ | - // | | │ target_batch_size: ││ target_batch_size: │ | - // | | │ 8192 ││ 8192 │ | - // | | └─────────────┬─────────────┘└─────────────┬─────────────┘ | - // | | ┌─────────────┴─────────────┐┌─────────────┴─────────────┐ | // | | │ RepartitionExec ││ RepartitionExec │ | // | | │ -------------------- ││ -------------------- │ | // | | │ partition_count(in->out): ││ partition_count(in->out): │ | @@ -1194,7 +1134,7 @@ async fn test_hashjoin_dynamic_filter_pushdown_partitioned() { // Create RepartitionExec nodes for both sides with hash partitioning on join keys let partition_count = 12; - // Build side: DataSource -> RepartitionExec (Hash) -> CoalesceBatchesExec + // Build side: DataSource -> RepartitionExec (Hash) let build_hash_exprs = vec![ col("a", &build_side_schema).unwrap(), col("b", &build_side_schema).unwrap(), @@ -1206,9 +1146,8 @@ async fn test_hashjoin_dynamic_filter_pushdown_partitioned() { ) .unwrap(), ); - let build_coalesce = Arc::new(CoalesceBatchesExec::new(build_repartition, 8192)); - // Probe side: DataSource -> RepartitionExec (Hash) -> CoalesceBatchesExec + // Probe side: DataSource -> RepartitionExec (Hash) let probe_hash_exprs = vec![ col("a", &probe_side_schema).unwrap(), col("b", &probe_side_schema).unwrap(), @@ -1220,7 +1159,6 @@ async fn test_hashjoin_dynamic_filter_pushdown_partitioned() { ) .unwrap(), ); - let probe_coalesce = Arc::new(CoalesceBatchesExec::new(probe_repartition, 8192)); // Create HashJoinExec with partitioned inputs let on = vec![ @@ -1235,23 +1173,21 @@ async fn test_hashjoin_dynamic_filter_pushdown_partitioned() { ]; let hash_join = Arc::new( HashJoinExec::try_new( - build_coalesce, - probe_coalesce, + build_repartition, + probe_repartition, on, None, &JoinType::Inner, None, PartitionMode::Partitioned, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); - // Top-level CoalesceBatchesExec - let cb = - Arc::new(CoalesceBatchesExec::new(hash_join, 8192)) as Arc; // Top-level CoalescePartitionsExec - let cp = Arc::new(CoalescePartitionsExec::new(cb)) as Arc; + let cp = Arc::new(CoalescePartitionsExec::new(hash_join)) as Arc; // Add a sort for deterministic output let plan = Arc::new(SortExec::new( LexOrdering::new(vec![PhysicalSortExpr::new( @@ -1270,26 +1206,20 @@ async fn test_hashjoin_dynamic_filter_pushdown_partitioned() { input: - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] - CoalescePartitionsExec - - CoalesceBatchesExec: target_batch_size=8192 - - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] - - CoalesceBatchesExec: target_batch_size=8192 - - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - CoalesceBatchesExec: target_batch_size=8192 - - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true output: Ok: - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] - CoalescePartitionsExec - - CoalesceBatchesExec: target_batch_size=8192 - - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] - - CoalesceBatchesExec: target_batch_size=8192 - - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - CoalesceBatchesExec: target_batch_size=8192 - - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] " ); @@ -1319,14 +1249,11 @@ async fn test_hashjoin_dynamic_filter_pushdown_partitioned() { @r" - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] - CoalescePartitionsExec - - CoalesceBatchesExec: target_batch_size=8192 - - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] - - CoalesceBatchesExec: target_batch_size=8192 - - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - CoalesceBatchesExec: target_batch_size=8192 - - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ CASE hash_repartition % 12 WHEN 2 THEN a@0 >= ab AND a@0 <= ab AND b@1 >= bb AND b@1 <= bb AND struct(a@0, b@1) IN (SET) ([{c0:ab,c1:bb}]) WHEN 4 THEN a@0 >= aa AND a@0 <= aa AND b@1 >= ba AND b@1 <= ba AND struct(a@0, b@1) IN (SET) ([{c0:aa,c1:ba}]) ELSE false END ] + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ CASE hash_repartition % 12 WHEN 2 THEN a@0 >= ab AND a@0 <= ab AND b@1 >= bb AND b@1 <= bb AND struct(a@0, b@1) IN (SET) ([{c0:ab,c1:bb}]) WHEN 4 THEN a@0 >= aa AND a@0 <= aa AND b@1 >= ba AND b@1 <= ba AND struct(a@0, b@1) IN (SET) ([{c0:aa,c1:ba}]) ELSE false END ] " ); @@ -1340,14 +1267,11 @@ async fn test_hashjoin_dynamic_filter_pushdown_partitioned() { @r" - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] - CoalescePartitionsExec - - CoalesceBatchesExec: target_batch_size=8192 - - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] - - CoalesceBatchesExec: target_batch_size=8192 - - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - CoalesceBatchesExec: target_batch_size=8192 - - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb AND struct(a@0, b@1) IN (SET) ([{c0:aa,c1:ba}, {c0:ab,c1:bb}]) ] + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb AND struct(a@0, b@1) IN (SET) ([{c0:aa,c1:ba}, {c0:ab,c1:bb}]) ] " ); @@ -1418,7 +1342,7 @@ async fn test_hashjoin_dynamic_filter_pushdown_collect_left() { // Create RepartitionExec nodes for both sides with hash partitioning on join keys let partition_count = 12; - // Probe side: DataSource -> RepartitionExec(Hash) -> CoalesceBatchesExec + // Probe side: DataSource -> RepartitionExec(Hash) let probe_hash_exprs = vec![ col("a", &probe_side_schema).unwrap(), col("b", &probe_side_schema).unwrap(), @@ -1430,7 +1354,6 @@ async fn test_hashjoin_dynamic_filter_pushdown_collect_left() { ) .unwrap(), ); - let probe_coalesce = Arc::new(CoalesceBatchesExec::new(probe_repartition, 8192)); let on = vec![ ( @@ -1445,22 +1368,20 @@ async fn test_hashjoin_dynamic_filter_pushdown_collect_left() { let hash_join = Arc::new( HashJoinExec::try_new( build_scan, - probe_coalesce, + probe_repartition, on, None, &JoinType::Inner, None, PartitionMode::CollectLeft, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); - // Top-level CoalesceBatchesExec - let cb = - Arc::new(CoalesceBatchesExec::new(hash_join, 8192)) as Arc; // Top-level CoalescePartitionsExec - let cp = Arc::new(CoalescePartitionsExec::new(cb)) as Arc; + let cp = Arc::new(CoalescePartitionsExec::new(hash_join)) as Arc; // Add a sort for deterministic output let plan = Arc::new(SortExec::new( LexOrdering::new(vec![PhysicalSortExpr::new( @@ -1479,22 +1400,18 @@ async fn test_hashjoin_dynamic_filter_pushdown_collect_left() { input: - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] - CoalescePartitionsExec - - CoalesceBatchesExec: target_batch_size=8192 - - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - CoalesceBatchesExec: target_batch_size=8192 - - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true + - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true output: Ok: - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] - CoalescePartitionsExec - - CoalesceBatchesExec: target_batch_size=8192 - - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - CoalesceBatchesExec: target_batch_size=8192 - - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] " ); @@ -1523,12 +1440,10 @@ async fn test_hashjoin_dynamic_filter_pushdown_collect_left() { @r" - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] - CoalescePartitionsExec - - CoalesceBatchesExec: target_batch_size=8192 - - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - CoalesceBatchesExec: target_batch_size=8192 - - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb AND struct(a@0, b@1) IN (SET) ([{c0:aa,c1:ba}, {c0:ab,c1:bb}]) ] + - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb AND struct(a@0, b@1) IN (SET) ([{c0:aa,c1:ba}, {c0:ab,c1:bb}]) ] " ); @@ -1629,6 +1544,7 @@ async fn test_nested_hashjoin_dynamic_filter_pushdown() { None, PartitionMode::Partitioned, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); @@ -1648,6 +1564,7 @@ async fn test_nested_hashjoin_dynamic_filter_pushdown() { None, PartitionMode::Partitioned, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ) as Arc; @@ -1763,6 +1680,7 @@ async fn test_hashjoin_parent_filter_pushdown() { None, PartitionMode::Partitioned, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); @@ -1810,6 +1728,218 @@ async fn test_hashjoin_parent_filter_pushdown() { ); } +#[test] +fn test_hashjoin_parent_filter_pushdown_same_column_names() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("build_val", DataType::Utf8, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .build(); + + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("probe_val", DataType::Utf8, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .build(); + + let on = vec![( + col("id", &build_side_schema).unwrap(), + col("id", &probe_side_schema).unwrap(), + )]; + let join = Arc::new( + HashJoinExec::try_new( + build_scan, + probe_scan, + on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + let join_schema = join.schema(); + + let build_id_filter = col_lit_predicate("id", "aa", &join_schema); + let probe_val_filter = col_lit_predicate("probe_val", "x", &join_schema); + + let filter = + Arc::new(FilterExec::try_new(build_id_filter, Arc::clone(&join) as _).unwrap()); + let plan = Arc::new(FilterExec::try_new(probe_val_filter, filter).unwrap()) + as Arc; + + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: probe_val@3 = x + - FilterExec: id@0 = aa + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(id@0, id@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, build_val], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, probe_val], file_type=test, pushdown_supported=true + output: + Ok: + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(id@0, id@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, build_val], file_type=test, pushdown_supported=true, predicate=id@0 = aa + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, probe_val], file_type=test, pushdown_supported=true, predicate=probe_val@1 = x + " + ); +} + +#[test] +fn test_hashjoin_parent_filter_pushdown_mark_join() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + let left_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("val", DataType::Utf8, false), + ])); + let left_scan = TestScanBuilder::new(Arc::clone(&left_schema)) + .with_support(true) + .build(); + + let right_schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Utf8, false)])); + let right_scan = TestScanBuilder::new(Arc::clone(&right_schema)) + .with_support(true) + .build(); + + let on = vec![( + col("id", &left_schema).unwrap(), + col("id", &right_schema).unwrap(), + )]; + let join = Arc::new( + HashJoinExec::try_new( + left_scan, + right_scan, + on, + None, + &JoinType::LeftMark, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + let join_schema = join.schema(); + + let left_filter = col_lit_predicate("val", "x", &join_schema); + let mark_filter = col_lit_predicate("mark", true, &join_schema); + + let filter = + Arc::new(FilterExec::try_new(left_filter, Arc::clone(&join) as _).unwrap()); + let plan = Arc::new(FilterExec::try_new(mark_filter, filter).unwrap()) + as Arc; + + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: mark@2 = true + - FilterExec: val@1 = x + - HashJoinExec: mode=Partitioned, join_type=LeftMark, on=[(id@0, id@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, val], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: mark@2 = true + - HashJoinExec: mode=Partitioned, join_type=LeftMark, on=[(id@0, id@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, val], file_type=test, pushdown_supported=true, predicate=val@1 = x + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id], file_type=test, pushdown_supported=true + " + ); +} + +/// Test that filters on join key columns are pushed to both sides of semi/anti joins. +/// For LeftSemi/LeftAnti, the output only contains left columns, but filters on +/// join key columns can also be pushed to the right (non-preserved) side because +/// the equijoin condition guarantees the key values match. +#[test] +fn test_hashjoin_parent_filter_pushdown_semi_anti_join() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + let left_schema = Arc::new(Schema::new(vec![ + Field::new("k", DataType::Utf8, false), + Field::new("v", DataType::Utf8, false), + ])); + let left_scan = TestScanBuilder::new(Arc::clone(&left_schema)) + .with_support(true) + .build(); + + let right_schema = Arc::new(Schema::new(vec![ + Field::new("k", DataType::Utf8, false), + Field::new("w", DataType::Utf8, false), + ])); + let right_scan = TestScanBuilder::new(Arc::clone(&right_schema)) + .with_support(true) + .build(); + + let on = vec![( + col("k", &left_schema).unwrap(), + col("k", &right_schema).unwrap(), + )]; + + let join = Arc::new( + HashJoinExec::try_new( + left_scan, + right_scan, + on, + None, + &JoinType::LeftSemi, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + let join_schema = join.schema(); + // Filter on join key column: k = 'x' — should be pushed to BOTH sides + let key_filter = col_lit_predicate("k", "x", &join_schema); + // Filter on non-key column: v = 'y' — should only be pushed to the left side + let val_filter = col_lit_predicate("v", "y", &join_schema); + + let filter = + Arc::new(FilterExec::try_new(key_filter, Arc::clone(&join) as _).unwrap()); + let plan = Arc::new(FilterExec::try_new(val_filter, filter).unwrap()) + as Arc; + + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: v@1 = y + - FilterExec: k@0 = x + - HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(k@0, k@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[k, v], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[k, w], file_type=test, pushdown_supported=true + output: + Ok: + - HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(k@0, k@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[k, v], file_type=test, pushdown_supported=true, predicate=k@0 = x AND v@1 = y + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[k, w], file_type=test, pushdown_supported=true, predicate=k@0 = x + " + ); +} + /// Integration test for dynamic filter pushdown with TopK. /// We use an integration test because there are complex interactions in the optimizer rules /// that the unit tests applying a single optimizer rule do not cover. @@ -1850,15 +1980,17 @@ STORED AS PARQUET; .unwrap(); // Create a TopK query that will use dynamic filter pushdown + // Note that we use t * t as the order by expression to avoid + // the order pushdown optimizer from optimizing away the TopK. let df = ctx - .sql(r"EXPLAIN ANALYZE SELECT t FROM topk_pushdown ORDER BY t LIMIT 10;") + .sql(r"EXPLAIN ANALYZE SELECT t FROM topk_pushdown ORDER BY t * t LIMIT 10;") .await .unwrap(); let batches = df.collect().await.unwrap(); let explain = format!("{}", pretty_format_batches(&batches).unwrap()); assert!(explain.contains("output_rows=128")); // Read 1 row group - assert!(explain.contains("t@0 < 1372708809")); // Dynamic filter was applied + assert!(explain.contains("t@0 < 1884329474306198481")); // Dynamic filter was applied assert!( explain.contains("pushdown_rows_matched=128, pushdown_rows_pruned=99.87 K"), "{explain}" @@ -1894,6 +2026,67 @@ fn test_filter_pushdown_through_union() { ); } +#[test] +fn test_filter_pushdown_through_union_mixed_support() { + // Test case where one child supports filter pushdown and one doesn't + let scan1 = TestScanBuilder::new(schema()).with_support(true).build(); + let scan2 = TestScanBuilder::new(schema()).with_support(false).build(); + + let union = UnionExec::try_new(vec![scan1, scan2]).unwrap(); + + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, union).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - UnionExec + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - UnionExec + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + " + ); +} + +#[test] +fn test_filter_pushdown_through_union_does_not_support() { + // Test case where one child supports filter pushdown and one doesn't + let scan1 = TestScanBuilder::new(schema()).with_support(false).build(); + let scan2 = TestScanBuilder::new(schema()).with_support(false).build(); + + let union = UnionExec::try_new(vec![scan1, scan2]).unwrap(); + + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, union).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - UnionExec + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - UnionExec + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + " + ); +} + /// Schema: /// a: String /// b: String @@ -1911,6 +2104,234 @@ fn schema() -> SchemaRef { Arc::clone(&TEST_SCHEMA) } +struct ProjectionDynFilterTestCase { + schema: SchemaRef, + batches: Vec, + projection: Vec<(Arc, String)>, + sort_expr: PhysicalSortExpr, + expected_plans: Vec, +} + +async fn run_projection_dyn_filter_case(case: ProjectionDynFilterTestCase) { + let ProjectionDynFilterTestCase { + schema, + batches, + projection, + sort_expr, + expected_plans, + } = case; + + let scan = TestScanBuilder::new(Arc::clone(&schema)) + .with_support(true) + .with_batches(batches) + .build(); + + let projection_exec = Arc::new(ProjectionExec::try_new(projection, scan).unwrap()); + + let sort = Arc::new( + SortExec::new(LexOrdering::new(vec![sort_expr]).unwrap(), projection_exec) + .with_fetch(Some(2)), + ) as Arc; + + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + config.optimizer.enable_dynamic_filter_pushdown = true; + + let optimized_plan = FilterPushdown::new_post_optimization() + .optimize(Arc::clone(&sort), &config) + .unwrap(); + + pretty_assertions::assert_eq!( + format_plan_for_test(&optimized_plan).trim(), + expected_plans[0].trim() + ); + + let config = SessionConfig::new().with_batch_size(2); + let session_ctx = SessionContext::new_with_config(config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let mut stream = optimized_plan.execute(0, Arc::clone(&task_ctx)).unwrap(); + for (idx, expected_plan) in expected_plans.iter().enumerate().skip(1) { + stream.next().await.unwrap().unwrap(); + let formatted_plan = format_plan_for_test(&optimized_plan); + pretty_assertions::assert_eq!( + formatted_plan.trim(), + expected_plan.trim(), + "Mismatch at iteration {}", + idx + ); + } +} + +#[tokio::test] +async fn test_topk_with_projection_transformation_on_dyn_filter() { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ])); + let simple_abc = vec![ + record_batch!( + ("a", Int32, [1, 2, 3]), + ("b", Utf8, ["x", "y", "z"]), + ("c", Float64, [1.0, 2.0, 3.0]) + ) + .unwrap(), + ]; + + // Case 1: Reordering [b, a] + run_projection_dyn_filter_case(ProjectionDynFilterTestCase { + schema: Arc::clone(&schema), + batches: simple_abc.clone(), + projection: vec![ + (col("b", &schema).unwrap(), "b".to_string()), + (col("a", &schema).unwrap(), "a".to_string()), + ], + sort_expr: PhysicalSortExpr::new( + Arc::new(Column::new("a", 1)), + SortOptions::default(), + ), + expected_plans: vec![ +r#" - SortExec: TopK(fetch=2), expr=[a@1 ASC], preserve_partitioning=[false] + - ProjectionExec: expr=[b@1 as b, a@0 as a] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ]"#.to_string(), +r#" - SortExec: TopK(fetch=2), expr=[a@1 ASC], preserve_partitioning=[false], filter=[a@1 IS NULL OR a@1 < 2] + - ProjectionExec: expr=[b@1 as b, a@0 as a] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 IS NULL OR a@0 < 2 ]"#.to_string()] + }) + .await; + + // Case 2: Pruning [a] + run_projection_dyn_filter_case(ProjectionDynFilterTestCase { + schema: Arc::clone(&schema), + batches: simple_abc.clone(), + projection: vec![(col("a", &schema).unwrap(), "a".to_string())], + sort_expr: PhysicalSortExpr::new( + Arc::new(Column::new("a", 0)), + SortOptions::default(), + ), + expected_plans: vec![ + r#" - SortExec: TopK(fetch=2), expr=[a@0 ASC], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ]"#.to_string(), + r#" - SortExec: TopK(fetch=2), expr=[a@0 ASC], preserve_partitioning=[false], filter=[a@0 IS NULL OR a@0 < 2] + - ProjectionExec: expr=[a@0 as a] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 IS NULL OR a@0 < 2 ]"#.to_string(), + ], + }) + .await; + + // Case 3: Identity [a, b] + run_projection_dyn_filter_case(ProjectionDynFilterTestCase { + schema: Arc::clone(&schema), + batches: simple_abc.clone(), + projection: vec![ + (col("a", &schema).unwrap(), "a".to_string()), + (col("b", &schema).unwrap(), "b".to_string()), + ], + sort_expr: PhysicalSortExpr::new( + Arc::new(Column::new("a", 0)), + SortOptions::default(), + ), + expected_plans: vec![ + r#" - SortExec: TopK(fetch=2), expr=[a@0 ASC], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ]"#.to_string(), + r#" - SortExec: TopK(fetch=2), expr=[a@0 ASC], preserve_partitioning=[false], filter=[a@0 IS NULL OR a@0 < 2] + - ProjectionExec: expr=[a@0 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 IS NULL OR a@0 < 2 ]"#.to_string(), + ], + }) + .await; + + // Case 4: Expressions [a + 1, b] + run_projection_dyn_filter_case(ProjectionDynFilterTestCase { + schema: Arc::clone(&schema), + batches: simple_abc.clone(), + projection: vec![ + ( + Arc::new(BinaryExpr::new( + col("a", &schema).unwrap(), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + )), + "a_plus_1".to_string(), + ), + (col("b", &schema).unwrap(), "b".to_string()), + ], + sort_expr: PhysicalSortExpr::new( + Arc::new(Column::new("a_plus_1", 0)), + SortOptions::default(), + ), + expected_plans: vec![ + r#" - SortExec: TopK(fetch=2), expr=[a_plus_1@0 ASC], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 + 1 as a_plus_1, b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ]"#.to_string(), + r#" - SortExec: TopK(fetch=2), expr=[a_plus_1@0 ASC], preserve_partitioning=[false], filter=[a_plus_1@0 IS NULL OR a_plus_1@0 < 3] + - ProjectionExec: expr=[a@0 + 1 as a_plus_1, b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 + 1 IS NULL OR a@0 + 1 < 3 ]"#.to_string(), + ], + }) + .await; + + // Case 5: [a as b, b as a] (swapped columns) + run_projection_dyn_filter_case(ProjectionDynFilterTestCase { + schema: Arc::clone(&schema), + batches: simple_abc.clone(), + projection: vec![ + (col("a", &schema).unwrap(), "b".to_string()), + (col("b", &schema).unwrap(), "a".to_string()), + ], + sort_expr: PhysicalSortExpr::new( + Arc::new(Column::new("b", 0)), + SortOptions::default(), + ), + expected_plans: vec![ + r#" - SortExec: TopK(fetch=2), expr=[b@0 ASC], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as b, b@1 as a] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ]"#.to_string(), + r#" - SortExec: TopK(fetch=2), expr=[b@0 ASC], preserve_partitioning=[false], filter=[b@0 IS NULL OR b@0 < 2] + - ProjectionExec: expr=[a@0 as b, b@1 as a] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 IS NULL OR a@0 < 2 ]"#.to_string(), + ], + }) + .await; + + // Case 6: Confusing expr [a + 1 as a, b] + run_projection_dyn_filter_case(ProjectionDynFilterTestCase { + schema: Arc::clone(&schema), + batches: simple_abc.clone(), + projection: vec![ + ( + Arc::new(BinaryExpr::new( + col("a", &schema).unwrap(), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + )), + "a".to_string(), + ), + (col("b", &schema).unwrap(), "b".to_string()), + ], + sort_expr: PhysicalSortExpr::new( + Arc::new(Column::new("a", 0)), + SortOptions::default(), + ), + expected_plans: vec![ + r#" - SortExec: TopK(fetch=2), expr=[a@0 ASC], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 + 1 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ]"#.to_string(), + r#" - SortExec: TopK(fetch=2), expr=[a@0 ASC], preserve_partitioning=[false], filter=[a@0 IS NULL OR a@0 < 3] + - ProjectionExec: expr=[a@0 + 1 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 + 1 IS NULL OR a@0 + 1 < 3 ]"#.to_string(), + ], + }) + .await; +} + /// Returns a predicate that is a binary expression col = lit fn col_lit_predicate( column_name: &str, @@ -2835,7 +3256,6 @@ async fn test_hashjoin_dynamic_filter_all_partitions_empty() { ) .unwrap(), ); - let build_coalesce = Arc::new(CoalesceBatchesExec::new(build_repartition, 8192)); let probe_hash_exprs = vec![ col("a", &probe_side_schema).unwrap(), @@ -2848,7 +3268,6 @@ async fn test_hashjoin_dynamic_filter_all_partitions_empty() { ) .unwrap(), ); - let probe_coalesce = Arc::new(CoalesceBatchesExec::new(probe_repartition, 8192)); // Create HashJoinExec let on = vec![ @@ -2861,23 +3280,21 @@ async fn test_hashjoin_dynamic_filter_all_partitions_empty() { col("b", &probe_side_schema).unwrap(), ), ]; - let hash_join = Arc::new( + let plan = Arc::new( HashJoinExec::try_new( - build_coalesce, - probe_coalesce, + build_repartition, + probe_repartition, on, None, &JoinType::Inner, None, PartitionMode::Partitioned, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); - let plan = - Arc::new(CoalesceBatchesExec::new(hash_join, 8192)) as Arc; - // Apply the filter pushdown optimizer let mut config = SessionConfig::new(); config.options_mut().execution.parquet.pushdown_filters = true; @@ -2887,14 +3304,11 @@ async fn test_hashjoin_dynamic_filter_all_partitions_empty() { insta::assert_snapshot!( format_plan_for_test(&plan), @r" - - CoalesceBatchesExec: target_batch_size=8192 - - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] - - CoalesceBatchesExec: target_batch_size=8192 - - RepartitionExec: partitioning=Hash([a@0, b@1], 4), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b], file_type=test, pushdown_supported=true - - CoalesceBatchesExec: target_batch_size=8192 - - RepartitionExec: partitioning=Hash([a@0, b@1], 4), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - RepartitionExec: partitioning=Hash([a@0, b@1], 4), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([a@0, b@1], 4), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] " ); @@ -2915,14 +3329,11 @@ async fn test_hashjoin_dynamic_filter_all_partitions_empty() { insta::assert_snapshot!( format_plan_for_test(&plan), @r" - - CoalesceBatchesExec: target_batch_size=8192 - - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] - - CoalesceBatchesExec: target_batch_size=8192 - - RepartitionExec: partitioning=Hash([a@0, b@1], 4), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b], file_type=test, pushdown_supported=true - - CoalesceBatchesExec: target_batch_size=8192 - - RepartitionExec: partitioning=Hash([a@0, b@1], 4), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ false ] + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - RepartitionExec: partitioning=Hash([a@0, b@1], 4), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([a@0, b@1], 4), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ false ] " ); } @@ -2998,7 +3409,7 @@ async fn test_hashjoin_dynamic_filter_with_nulls() { col("b", &probe_side_schema).unwrap(), ), ]; - let hash_join = Arc::new( + let plan = Arc::new( HashJoinExec::try_new( build_scan, Arc::clone(&probe_scan), @@ -3008,13 +3419,11 @@ async fn test_hashjoin_dynamic_filter_with_nulls() { None, PartitionMode::CollectLeft, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); - let plan = - Arc::new(CoalesceBatchesExec::new(hash_join, 8192)) as Arc; - // Apply the filter pushdown optimizer let mut config = SessionConfig::new(); config.options_mut().execution.parquet.pushdown_filters = true; @@ -3024,10 +3433,9 @@ async fn test_hashjoin_dynamic_filter_with_nulls() { insta::assert_snapshot!( format_plan_for_test(&plan), @r" - - CoalesceBatchesExec: target_batch_size=8192 - - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b], file_type=test, pushdown_supported=true - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] " ); @@ -3048,10 +3456,9 @@ async fn test_hashjoin_dynamic_filter_with_nulls() { insta::assert_snapshot!( format_plan_for_test(&plan), @r" - - CoalesceBatchesExec: target_batch_size=8192 - - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b], file_type=test, pushdown_supported=true - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= 1 AND b@1 <= 2 AND struct(a@0, b@1) IN (SET) ([{c0:aa,c1:1}, {c0:,c1:2}, {c0:ab,c1:}]) ] + - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= 1 AND b@1 <= 2 AND struct(a@0, b@1) IN (SET) ([{c0:aa,c1:1}, {c0:,c1:2}, {c0:ab,c1:}]) ] " ); @@ -3116,7 +3523,7 @@ async fn test_hashjoin_hash_table_pushdown_partitioned() { // Create RepartitionExec nodes for both sides with hash partitioning on join keys let partition_count = 12; - // Build side: DataSource -> RepartitionExec (Hash) -> CoalesceBatchesExec + // Build side: DataSource -> RepartitionExec (Hash) let build_hash_exprs = vec![ col("a", &build_side_schema).unwrap(), col("b", &build_side_schema).unwrap(), @@ -3128,9 +3535,8 @@ async fn test_hashjoin_hash_table_pushdown_partitioned() { ) .unwrap(), ); - let build_coalesce = Arc::new(CoalesceBatchesExec::new(build_repartition, 8192)); - // Probe side: DataSource -> RepartitionExec (Hash) -> CoalesceBatchesExec + // Probe side: DataSource -> RepartitionExec (Hash) let probe_hash_exprs = vec![ col("a", &probe_side_schema).unwrap(), col("b", &probe_side_schema).unwrap(), @@ -3142,7 +3548,6 @@ async fn test_hashjoin_hash_table_pushdown_partitioned() { ) .unwrap(), ); - let probe_coalesce = Arc::new(CoalesceBatchesExec::new(probe_repartition, 8192)); // Create HashJoinExec with partitioned inputs let on = vec![ @@ -3157,23 +3562,21 @@ async fn test_hashjoin_hash_table_pushdown_partitioned() { ]; let hash_join = Arc::new( HashJoinExec::try_new( - build_coalesce, - probe_coalesce, + build_repartition, + probe_repartition, on, None, &JoinType::Inner, None, PartitionMode::Partitioned, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); - // Top-level CoalesceBatchesExec - let cb = - Arc::new(CoalesceBatchesExec::new(hash_join, 8192)) as Arc; // Top-level CoalescePartitionsExec - let cp = Arc::new(CoalescePartitionsExec::new(cb)) as Arc; + let cp = Arc::new(CoalescePartitionsExec::new(hash_join)) as Arc; // Add a sort for deterministic output let plan = Arc::new(SortExec::new( LexOrdering::new(vec![PhysicalSortExpr::new( @@ -3285,7 +3688,7 @@ async fn test_hashjoin_hash_table_pushdown_collect_left() { // Create RepartitionExec nodes for both sides with hash partitioning on join keys let partition_count = 12; - // Probe side: DataSource -> RepartitionExec(Hash) -> CoalesceBatchesExec + // Probe side: DataSource -> RepartitionExec(Hash) let probe_hash_exprs = vec![ col("a", &probe_side_schema).unwrap(), col("b", &probe_side_schema).unwrap(), @@ -3297,7 +3700,6 @@ async fn test_hashjoin_hash_table_pushdown_collect_left() { ) .unwrap(), ); - let probe_coalesce = Arc::new(CoalesceBatchesExec::new(probe_repartition, 8192)); let on = vec![ ( @@ -3312,22 +3714,20 @@ async fn test_hashjoin_hash_table_pushdown_collect_left() { let hash_join = Arc::new( HashJoinExec::try_new( build_scan, - probe_coalesce, + probe_repartition, on, None, &JoinType::Inner, None, PartitionMode::CollectLeft, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); - // Top-level CoalesceBatchesExec - let cb = - Arc::new(CoalesceBatchesExec::new(hash_join, 8192)) as Arc; // Top-level CoalescePartitionsExec - let cp = Arc::new(CoalescePartitionsExec::new(cb)) as Arc; + let cp = Arc::new(CoalescePartitionsExec::new(hash_join)) as Arc; // Add a sort for deterministic output let plan = Arc::new(SortExec::new( LexOrdering::new(vec![PhysicalSortExpr::new( @@ -3446,7 +3846,7 @@ async fn test_hashjoin_hash_table_pushdown_integer_keys() { col("id2", &probe_side_schema).unwrap(), ), ]; - let hash_join = Arc::new( + let plan = Arc::new( HashJoinExec::try_new( build_scan, Arc::clone(&probe_scan), @@ -3456,13 +3856,11 @@ async fn test_hashjoin_hash_table_pushdown_integer_keys() { None, PartitionMode::CollectLeft, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); - let plan = - Arc::new(CoalesceBatchesExec::new(hash_join, 8192)) as Arc; - // Apply optimization with forced HashTable strategy let session_config = SessionConfig::default() .with_batch_size(10) @@ -3567,6 +3965,7 @@ async fn test_hashjoin_dynamic_filter_pushdown_is_used() { None, PartitionMode::CollectLeft, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ) as Arc; @@ -3600,3 +3999,90 @@ async fn test_hashjoin_dynamic_filter_pushdown_is_used() { ); } } + +/// Regression test for https://github.com/apache/datafusion/issues/20109 +#[tokio::test] +async fn test_filter_with_projection_pushdown() { + use arrow::array::{Int64Array, RecordBatch, StringArray}; + use datafusion_physical_plan::collect; + use datafusion_physical_plan::filter::FilterExecBuilder; + + // Create schema: [time, event, size] + let schema = Arc::new(Schema::new(vec![ + Field::new("time", DataType::Int64, false), + Field::new("event", DataType::Utf8, false), + Field::new("size", DataType::Int64, false), + ])); + + // Create sample data + let timestamps = vec![100i64, 200, 300, 400, 500]; + let events = vec!["Ingestion", "Ingestion", "Query", "Ingestion", "Query"]; + let sizes = vec![10i64, 20, 30, 40, 50]; + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(timestamps)), + Arc::new(StringArray::from(events)), + Arc::new(Int64Array::from(sizes)), + ], + ) + .unwrap(); + + // Create data source + let memory_exec = datafusion_datasource::memory::MemorySourceConfig::try_new_exec( + &[vec![batch]], + schema.clone(), + None, + ) + .unwrap(); + + // First FilterExec: time < 350 with projection=[event@1, size@2] + let time_col = col("time", &memory_exec.schema()).unwrap(); + let time_filter = Arc::new(BinaryExpr::new( + time_col, + Operator::Lt, + Arc::new(Literal::new(ScalarValue::Int64(Some(350)))), + )); + let filter1 = Arc::new( + FilterExecBuilder::new(time_filter, memory_exec) + .apply_projection(Some(vec![1, 2])) + .unwrap() + .build() + .unwrap(), + ); + + // Second FilterExec: event = 'Ingestion' with projection=[size@1] + let event_col = col("event", &filter1.schema()).unwrap(); + let event_filter = Arc::new(BinaryExpr::new( + event_col, + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Utf8(Some( + "Ingestion".to_string(), + )))), + )); + let filter2 = Arc::new( + FilterExecBuilder::new(event_filter, filter1) + .apply_projection(Some(vec![1])) + .unwrap() + .build() + .unwrap(), + ); + + // Apply filter pushdown optimization + let config = ConfigOptions::default(); + let optimized_plan = FilterPushdown::new() + .optimize(Arc::clone(&filter2) as Arc, &config) + .unwrap(); + + // Execute the optimized plan - this should not error + let ctx = SessionContext::new(); + let result = collect(optimized_plan, ctx.task_ctx()).await.unwrap(); + + // Verify results: should return rows where time < 350 AND event = 'Ingestion' + // That's rows with time=100,200 (both have event='Ingestion'), so sizes 10,20 + let expected = [ + "+------+", "| size |", "+------+", "| 10 |", "| 20 |", "+------+", + ]; + assert_batches_eq!(expected, &result); +} diff --git a/datafusion/core/tests/physical_optimizer/join_selection.rs b/datafusion/core/tests/physical_optimizer/join_selection.rs index 37bcefd418bd..ef0bbfc7f422 100644 --- a/datafusion/core/tests/physical_optimizer/join_selection.rs +++ b/datafusion/core/tests/physical_optimizer/join_selection.rs @@ -222,6 +222,7 @@ async fn test_join_with_swap() { None, PartitionMode::CollectLeft, NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); @@ -284,6 +285,7 @@ async fn test_left_join_no_swap() { None, PartitionMode::CollectLeft, NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); @@ -333,6 +335,7 @@ async fn test_join_with_swap_semi() { None, PartitionMode::Partitioned, NullEquality::NullEqualsNothing, + false, ) .unwrap(); @@ -388,6 +391,7 @@ async fn test_join_with_swap_mark() { None, PartitionMode::Partitioned, NullEquality::NullEqualsNothing, + false, ) .unwrap(); @@ -461,6 +465,7 @@ async fn test_nested_join_swap() { None, PartitionMode::CollectLeft, NullEquality::NullEqualsNothing, + false, ) .unwrap(); let child_schema = child_join.schema(); @@ -478,6 +483,7 @@ async fn test_nested_join_swap() { None, PartitionMode::CollectLeft, NullEquality::NullEqualsNothing, + false, ) .unwrap(); @@ -518,6 +524,7 @@ async fn test_join_no_swap() { None, PartitionMode::CollectLeft, NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); @@ -745,6 +752,7 @@ async fn test_hash_join_swap_on_joins_with_projections( Some(projection), PartitionMode::Partitioned, NullEquality::NullEqualsNothing, + false, )?); let swapped = join @@ -754,7 +762,7 @@ async fn test_hash_join_swap_on_joins_with_projections( "ProjectionExec won't be added above if HashJoinExec contains embedded projection", ); - assert_eq!(swapped_join.projection, Some(vec![0_usize])); + assert_eq!(swapped_join.projection.as_deref().unwrap(), &[0_usize]); assert_eq!(swapped.schema().fields.len(), 1); assert_eq!(swapped.schema().fields[0].name(), "small_col"); Ok(()) @@ -906,6 +914,7 @@ fn check_join_partition_mode( None, PartitionMode::Auto, NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); @@ -970,7 +979,7 @@ impl RecordBatchStream for UnboundedStream { pub struct UnboundedExec { batch_produce: Option, batch: RecordBatch, - cache: PlanProperties, + cache: Arc, } impl UnboundedExec { @@ -986,7 +995,7 @@ impl UnboundedExec { Self { batch_produce, batch, - cache, + cache: Arc::new(cache), } } @@ -1043,7 +1052,7 @@ impl ExecutionPlan for UnboundedExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -1082,7 +1091,7 @@ pub enum SourceType { pub struct StatisticsExec { stats: Statistics, schema: Arc, - cache: PlanProperties, + cache: Arc, } impl StatisticsExec { @@ -1096,7 +1105,7 @@ impl StatisticsExec { Self { stats, schema: Arc::new(schema), - cache, + cache: Arc::new(cache), } } @@ -1144,7 +1153,7 @@ impl ExecutionPlan for StatisticsExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -1167,10 +1176,6 @@ impl ExecutionPlan for StatisticsExec { unimplemented!("This plan only serves for testing statistics") } - fn statistics(&self) -> Result { - Ok(self.stats.clone()) - } - fn partition_statistics(&self, partition: Option) -> Result { Ok(if partition.is_some() { Statistics::new_unknown(&self.schema) @@ -1554,6 +1559,7 @@ async fn test_join_with_maybe_swap_unbounded_case(t: TestCase) -> Result<()> { None, t.initial_mode, NullEquality::NullEqualsNothing, + false, )?) as _; let optimized_join_plan = diff --git a/datafusion/core/tests/physical_optimizer/limit_pushdown.rs b/datafusion/core/tests/physical_optimizer/limit_pushdown.rs index b32a9bbd2543..b8c4d6d6f0d7 100644 --- a/datafusion/core/tests/physical_optimizer/limit_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/limit_pushdown.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use crate::physical_optimizer::test_utils::{ - coalesce_batches_exec, coalesce_partitions_exec, global_limit_exec, local_limit_exec, + coalesce_partitions_exec, global_limit_exec, hash_join_exec, local_limit_exec, sort_exec, sort_preserving_merge_exec, stream_exec, }; @@ -26,14 +26,16 @@ use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::config::ConfigOptions; use datafusion_common::error::Result; -use datafusion_expr::Operator; +use datafusion_expr::{JoinType, Operator}; use datafusion_physical_expr::Partitioning; use datafusion_physical_expr::expressions::{BinaryExpr, col, lit}; +use datafusion_physical_expr_common::physical_expr::PhysicalExprRef; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_optimizer::limit_pushdown::LimitPushdown; use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::filter::FilterExec; +use datafusion_physical_plan::joins::NestedLoopJoinExec; use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::{ExecutionPlan, get_plan_string}; @@ -87,6 +89,20 @@ fn empty_exec(schema: SchemaRef) -> Arc { Arc::new(EmptyExec::new(schema)) } +fn nested_loop_join_exec( + left: Arc, + right: Arc, + join_type: JoinType, +) -> Result> { + Ok(Arc::new(NestedLoopJoinExec::try_new( + left, right, None, &join_type, None, + )?)) +} + +fn format_plan(plan: &Arc) -> String { + get_plan_string(plan).join("\n") +} + #[test] fn transforms_streaming_table_exec_into_fetching_version_when_skip_is_zero() -> Result<()> { @@ -94,20 +110,23 @@ fn transforms_streaming_table_exec_into_fetching_version_when_skip_is_zero() -> let streaming_table = stream_exec(&schema); let global_limit = global_limit_exec(streaming_table, 0, Some(5)); - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", - ]; - assert_eq!(initial, expected_initial); + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=0, fetch=5 + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - let expected = [ - "StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true, fetch=5", - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @"StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true, fetch=5" + ); Ok(()) } @@ -119,122 +138,223 @@ fn transforms_streaming_table_exec_into_fetching_version_and_keeps_the_global_li let streaming_table = stream_exec(&schema); let global_limit = global_limit_exec(streaming_table, 2, Some(5)); - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=2, fetch=5", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", - ]; - assert_eq!(initial, expected_initial); + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=2, fetch=5 + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - let expected = [ - "GlobalLimitExec: skip=2, fetch=5", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true, fetch=7", - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=2, fetch=5 + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true, fetch=7 + " + ); Ok(()) } +fn join_on_columns( + left_col: &str, + right_col: &str, +) -> Vec<(PhysicalExprRef, PhysicalExprRef)> { + vec![( + Arc::new(datafusion_physical_expr::expressions::Column::new( + left_col, 0, + )) as _, + Arc::new(datafusion_physical_expr::expressions::Column::new( + right_col, 0, + )) as _, + )] +} + #[test] -fn transforms_coalesce_batches_exec_into_fetching_version_and_removes_local_limit() --> Result<()> { +fn absorbs_limit_into_hash_join_inner() -> Result<()> { + // HashJoinExec with Inner join should absorb limit via with_fetch let schema = create_schema(); - let streaming_table = stream_exec(&schema); - let repartition = repartition_exec(streaming_table)?; - let filter = filter_exec(schema, repartition)?; - let coalesce_batches = coalesce_batches_exec(filter, 8192); - let local_limit = local_limit_exec(coalesce_batches, 5); - let coalesce_partitions = coalesce_partitions_exec(local_limit); - let global_limit = global_limit_exec(coalesce_partitions, 0, Some(5)); - - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " CoalescePartitionsExec", - " LocalLimitExec: fetch=5", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c3@2 > 0", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", - ]; - assert_eq!(initial, expected_initial); + let left = empty_exec(Arc::clone(&schema)); + let right = empty_exec(Arc::clone(&schema)); + let on = join_on_columns("c1", "c1"); + let hash_join = hash_join_exec(left, right, on, None, &JoinType::Inner)?; + let global_limit = global_limit_exec(hash_join, 0, Some(5)); + + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=0, fetch=5 + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c1@0, c1@0)] + EmptyExec + EmptyExec + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - - let expected = [ - "CoalescePartitionsExec: fetch=5", - " CoalesceBatchesExec: target_batch_size=8192, fetch=5", - " FilterExec: c3@2 > 0", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + // The limit should be absorbed by the hash join (not pushed to children) + insta::assert_snapshot!( + optimized, + @r" + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c1@0, c1@0)], fetch=5 + EmptyExec + EmptyExec + " + ); Ok(()) } #[test] -fn pushes_global_limit_exec_through_projection_exec() -> Result<()> { +fn absorbs_limit_into_hash_join_right() -> Result<()> { + // HashJoinExec with Right join should absorb limit via with_fetch let schema = create_schema(); - let streaming_table = stream_exec(&schema); - let filter = filter_exec(Arc::clone(&schema), streaming_table)?; - let projection = projection_exec(schema, filter)?; - let global_limit = global_limit_exec(projection, 0, Some(5)); + let left = empty_exec(Arc::clone(&schema)); + let right = empty_exec(Arc::clone(&schema)); + let on = join_on_columns("c1", "c1"); + let hash_join = hash_join_exec(left, right, on, None, &JoinType::Right)?; + let global_limit = global_limit_exec(hash_join, 0, Some(10)); + + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=0, fetch=10 + HashJoinExec: mode=Partitioned, join_type=Right, on=[(c1@0, c1@0)] + EmptyExec + EmptyExec + " + ); - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " FilterExec: c3@2 > 0", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", - ]; - assert_eq!(initial, expected_initial); + let after_optimize = + LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + let optimized = format_plan(&after_optimize); + // The limit should be absorbed by the hash join + insta::assert_snapshot!( + optimized, + @r" + HashJoinExec: mode=Partitioned, join_type=Right, on=[(c1@0, c1@0)], fetch=10 + EmptyExec + EmptyExec + " + ); + + Ok(()) +} + +#[test] +fn absorbs_limit_into_hash_join_left() -> Result<()> { + // during probing, then unmatched rows at the end, stopping when limit is reached + let schema = create_schema(); + let left = empty_exec(Arc::clone(&schema)); + let right = empty_exec(Arc::clone(&schema)); + let on = join_on_columns("c1", "c1"); + let hash_join = hash_join_exec(left, right, on, None, &JoinType::Left)?; + let global_limit = global_limit_exec(hash_join, 0, Some(5)); + + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=0, fetch=5 + HashJoinExec: mode=Partitioned, join_type=Left, on=[(c1@0, c1@0)] + EmptyExec + EmptyExec + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + let optimized = format_plan(&after_optimize); + // Left join now absorbs the limit + insta::assert_snapshot!( + optimized, + @r" + HashJoinExec: mode=Partitioned, join_type=Left, on=[(c1@0, c1@0)], fetch=5 + EmptyExec + EmptyExec + " + ); + + Ok(()) +} - let expected = [ - "ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " FilterExec: c3@2 > 0, fetch=5", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", - ]; - assert_eq!(get_plan_string(&after_optimize), expected); +#[test] +fn absorbs_limit_with_skip_into_hash_join() -> Result<()> { + let schema = create_schema(); + let left = empty_exec(Arc::clone(&schema)); + let right = empty_exec(Arc::clone(&schema)); + let on = join_on_columns("c1", "c1"); + let hash_join = hash_join_exec(left, right, on, None, &JoinType::Inner)?; + let global_limit = global_limit_exec(hash_join, 3, Some(5)); + + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=3, fetch=5 + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c1@0, c1@0)] + EmptyExec + EmptyExec + " + ); + + let after_optimize = + LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + let optimized = format_plan(&after_optimize); + // With skip, GlobalLimit is kept but fetch (skip + limit = 8) is absorbed by the join + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=3, fetch=5 + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c1@0, c1@0)], fetch=8 + EmptyExec + EmptyExec + " + ); Ok(()) } #[test] -fn pushes_global_limit_exec_through_projection_exec_and_transforms_coalesce_batches_exec_into_fetching_version() --> Result<()> { +fn pushes_global_limit_exec_through_projection_exec() -> Result<()> { let schema = create_schema(); let streaming_table = stream_exec(&schema); - let coalesce_batches = coalesce_batches_exec(streaming_table, 8192); - let projection = projection_exec(schema, coalesce_batches)?; + let filter = filter_exec(Arc::clone(&schema), streaming_table)?; + let projection = projection_exec(schema, filter)?; let global_limit = global_limit_exec(projection, 0, Some(5)); - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " CoalesceBatchesExec: target_batch_size=8192", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", - ]; - - assert_eq!(initial, expected_initial); + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=0, fetch=5 + ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3] + FilterExec: c3@2 > 0 + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - let expected = [ - "ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " CoalesceBatchesExec: target_batch_size=8192, fetch=5", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3] + FilterExec: c3@2 > 0, fetch=5 + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); Ok(()) } @@ -243,8 +363,7 @@ fn pushes_global_limit_exec_through_projection_exec_and_transforms_coalesce_batc fn pushes_global_limit_into_multiple_fetch_plans() -> Result<()> { let schema = create_schema(); let streaming_table = stream_exec(&schema); - let coalesce_batches = coalesce_batches_exec(streaming_table, 8192); - let projection = projection_exec(Arc::clone(&schema), coalesce_batches)?; + let projection = projection_exec(Arc::clone(&schema), streaming_table)?; let repartition = repartition_exec(projection)?; let ordering: LexOrdering = [PhysicalSortExpr { expr: col("c1", &schema)?, @@ -255,31 +374,33 @@ fn pushes_global_limit_into_multiple_fetch_plans() -> Result<()> { let spm = sort_preserving_merge_exec(ordering, sort); let global_limit = global_limit_exec(spm, 0, Some(5)); - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " SortPreservingMergeExec: [c1@0 ASC]", - " SortExec: expr=[c1@0 ASC], preserve_partitioning=[false]", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " CoalesceBatchesExec: target_batch_size=8192", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", - ]; - - assert_eq!(initial, expected_initial); + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=0, fetch=5 + SortPreservingMergeExec: [c1@0 ASC] + SortExec: expr=[c1@0 ASC], preserve_partitioning=[false] + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3] + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - let expected = [ - "SortPreservingMergeExec: [c1@0 ASC], fetch=5", - " SortExec: TopK(fetch=5), expr=[c1@0 ASC], preserve_partitioning=[false]", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " CoalesceBatchesExec: target_batch_size=8192", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + SortPreservingMergeExec: [c1@0 ASC], fetch=5 + SortExec: TopK(fetch=5), expr=[c1@0 ASC], preserve_partitioning=[false] + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3] + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); Ok(()) } @@ -294,26 +415,31 @@ fn keeps_pushed_local_limit_exec_when_there_are_multiple_input_partitions() -> R let coalesce_partitions = coalesce_partitions_exec(filter); let global_limit = global_limit_exec(coalesce_partitions, 0, Some(5)); - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " CoalescePartitionsExec", - " FilterExec: c3@2 > 0", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", - ]; - assert_eq!(initial, expected_initial); + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=0, fetch=5 + CoalescePartitionsExec + FilterExec: c3@2 > 0 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - let expected = [ - "CoalescePartitionsExec: fetch=5", - " FilterExec: c3@2 > 0, fetch=5", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + CoalescePartitionsExec: fetch=5 + FilterExec: c3@2 > 0, fetch=5 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); Ok(()) } @@ -325,20 +451,27 @@ fn merges_local_limit_with_local_limit() -> Result<()> { let child_local_limit = local_limit_exec(empty_exec, 10); let parent_local_limit = local_limit_exec(child_local_limit, 20); - let initial = get_plan_string(&parent_local_limit); - let expected_initial = [ - "LocalLimitExec: fetch=20", - " LocalLimitExec: fetch=10", - " EmptyExec", - ]; - - assert_eq!(initial, expected_initial); + let initial = format_plan(&parent_local_limit); + insta::assert_snapshot!( + initial, + @r" + LocalLimitExec: fetch=20 + LocalLimitExec: fetch=10 + EmptyExec + " + ); let after_optimize = LimitPushdown::new().optimize(parent_local_limit, &ConfigOptions::new())?; - let expected = ["GlobalLimitExec: skip=0, fetch=10", " EmptyExec"]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=0, fetch=10 + EmptyExec + " + ); Ok(()) } @@ -350,20 +483,27 @@ fn merges_global_limit_with_global_limit() -> Result<()> { let child_global_limit = global_limit_exec(empty_exec, 10, Some(30)); let parent_global_limit = global_limit_exec(child_global_limit, 10, Some(20)); - let initial = get_plan_string(&parent_global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=10, fetch=20", - " GlobalLimitExec: skip=10, fetch=30", - " EmptyExec", - ]; - - assert_eq!(initial, expected_initial); + let initial = format_plan(&parent_global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=10, fetch=20 + GlobalLimitExec: skip=10, fetch=30 + EmptyExec + " + ); let after_optimize = LimitPushdown::new().optimize(parent_global_limit, &ConfigOptions::new())?; - let expected = ["GlobalLimitExec: skip=20, fetch=20", " EmptyExec"]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=20, fetch=20 + EmptyExec + " + ); Ok(()) } @@ -375,20 +515,27 @@ fn merges_global_limit_with_local_limit() -> Result<()> { let local_limit = local_limit_exec(empty_exec, 40); let global_limit = global_limit_exec(local_limit, 20, Some(30)); - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=20, fetch=30", - " LocalLimitExec: fetch=40", - " EmptyExec", - ]; - - assert_eq!(initial, expected_initial); + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=20, fetch=30 + LocalLimitExec: fetch=40 + EmptyExec + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - let expected = ["GlobalLimitExec: skip=20, fetch=20", " EmptyExec"]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=20, fetch=20 + EmptyExec + " + ); Ok(()) } @@ -400,20 +547,138 @@ fn merges_local_limit_with_global_limit() -> Result<()> { let global_limit = global_limit_exec(empty_exec, 20, Some(30)); let local_limit = local_limit_exec(global_limit, 20); - let initial = get_plan_string(&local_limit); - let expected_initial = [ - "LocalLimitExec: fetch=20", - " GlobalLimitExec: skip=20, fetch=30", - " EmptyExec", - ]; - - assert_eq!(initial, expected_initial); + let initial = format_plan(&local_limit); + insta::assert_snapshot!( + initial, + @r" + LocalLimitExec: fetch=20 + GlobalLimitExec: skip=20, fetch=30 + EmptyExec + " + ); let after_optimize = LimitPushdown::new().optimize(local_limit, &ConfigOptions::new())?; - let expected = ["GlobalLimitExec: skip=20, fetch=20", " EmptyExec"]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=20, fetch=20 + EmptyExec + " + ); + + Ok(()) +} + +#[test] +fn preserves_nested_global_limit() -> Result<()> { + // If there are multiple limits in an execution plan, they all need to be + // preserved in the optimized plan. + // + // Plan structure: + // GlobalLimitExec: skip=1, fetch=1 + // NestedLoopJoinExec (Left) + // EmptyExec (left side) + // GlobalLimitExec: skip=2, fetch=1 + // NestedLoopJoinExec (Right) + // EmptyExec (left side) + // EmptyExec (right side) + let schema = create_schema(); + + // Build inner join: NestedLoopJoin(Empty, Empty) + let inner_left = empty_exec(Arc::clone(&schema)); + let inner_right = empty_exec(Arc::clone(&schema)); + let inner_join = nested_loop_join_exec(inner_left, inner_right, JoinType::Right)?; + + // Add inner limit: GlobalLimitExec: skip=2, fetch=1 + let inner_limit = global_limit_exec(inner_join, 2, Some(1)); + + // Build outer join: NestedLoopJoin(Empty, GlobalLimit) + let outer_left = empty_exec(Arc::clone(&schema)); + let outer_join = nested_loop_join_exec(outer_left, inner_limit, JoinType::Left)?; + + // Add outer limit: GlobalLimitExec: skip=1, fetch=1 + let outer_limit = global_limit_exec(outer_join, 1, Some(1)); + + let initial = format_plan(&outer_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=1, fetch=1 + NestedLoopJoinExec: join_type=Left + EmptyExec + GlobalLimitExec: skip=2, fetch=1 + NestedLoopJoinExec: join_type=Right + EmptyExec + EmptyExec + " + ); + + let after_optimize = + LimitPushdown::new().optimize(outer_limit, &ConfigOptions::new())?; + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=1, fetch=1 + NestedLoopJoinExec: join_type=Left + EmptyExec + GlobalLimitExec: skip=2, fetch=1 + NestedLoopJoinExec: join_type=Right + EmptyExec + EmptyExec + " + ); + + Ok(()) +} + +#[test] +fn preserves_skip_before_sort() -> Result<()> { + // If there's a limit with skip before a node that (1) supports fetch but + // (2) does not support limit pushdown, that limit should not be removed. + // + // Plan structure: + // GlobalLimitExec: skip=1, fetch=None + // SortExec: TopK(fetch=4) + // EmptyExec + let schema = create_schema(); + + let empty = empty_exec(Arc::clone(&schema)); + + let ordering = [PhysicalSortExpr { + expr: col("c1", &schema)?, + options: SortOptions::default(), + }]; + let sort = sort_exec(ordering.into(), empty) + .with_fetch(Some(4)) + .unwrap(); + + let outer_limit = global_limit_exec(sort, 1, None); + + let initial = format_plan(&outer_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=1, fetch=None + SortExec: TopK(fetch=4), expr=[c1@0 ASC], preserve_partitioning=[false] + EmptyExec + " + ); + + let after_optimize = + LimitPushdown::new().optimize(outer_limit, &ConfigOptions::new())?; + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=1, fetch=3 + SortExec: TopK(fetch=4), expr=[c1@0 ASC], preserve_partitioning=[false] + EmptyExec + " + ); Ok(()) } diff --git a/datafusion/core/tests/physical_optimizer/mod.rs b/datafusion/core/tests/physical_optimizer/mod.rs index d11322cd26be..cf179cb727cf 100644 --- a/datafusion/core/tests/physical_optimizer/mod.rs +++ b/datafusion/core/tests/physical_optimizer/mod.rs @@ -24,7 +24,6 @@ mod combine_partial_final_agg; mod enforce_distribution; mod enforce_sorting; mod enforce_sorting_monotonicity; -#[expect(clippy::needless_pass_by_value)] mod filter_pushdown; mod join_selection; #[expect(clippy::needless_pass_by_value)] @@ -38,3 +37,5 @@ mod sanity_checker; #[expect(clippy::needless_pass_by_value)] mod test_utils; mod window_optimize; + +mod pushdown_utils; diff --git a/datafusion/core/tests/physical_optimizer/partition_statistics.rs b/datafusion/core/tests/physical_optimizer/partition_statistics.rs index ba53d079e305..d73db6fe7480 100644 --- a/datafusion/core/tests/physical_optimizer/partition_statistics.rs +++ b/datafusion/core/tests/physical_optimizer/partition_statistics.rs @@ -41,7 +41,6 @@ mod test { use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; - use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::common::compute_record_batch_statistics; use datafusion_physical_plan::empty::EmptyExec; @@ -387,17 +386,17 @@ mod test { column_statistics: vec![ ColumnStatistics { null_count: Precision::Exact(0), - max_value: Precision::Exact(ScalarValue::Null), - min_value: Precision::Exact(ScalarValue::Null), - sum_value: Precision::Exact(ScalarValue::Null), + max_value: Precision::Exact(ScalarValue::Int32(None)), + min_value: Precision::Exact(ScalarValue::Int32(None)), + sum_value: Precision::Exact(ScalarValue::Int32(None)), distinct_count: Precision::Exact(0), byte_size: Precision::Exact(16), }, ColumnStatistics { null_count: Precision::Exact(0), - max_value: Precision::Exact(ScalarValue::Null), - min_value: Precision::Exact(ScalarValue::Null), - sum_value: Precision::Exact(ScalarValue::Null), + max_value: Precision::Exact(ScalarValue::Date32(None)), + min_value: Precision::Exact(ScalarValue::Date32(None)), + sum_value: Precision::Exact(ScalarValue::Date32(None)), distinct_count: Precision::Exact(0), byte_size: Precision::Exact(16), // 4 rows * 4 bytes (Date32) }, @@ -416,17 +415,17 @@ mod test { column_statistics: vec![ ColumnStatistics { null_count: Precision::Exact(0), - max_value: Precision::Exact(ScalarValue::Null), - min_value: Precision::Exact(ScalarValue::Null), - sum_value: Precision::Exact(ScalarValue::Null), + max_value: Precision::Exact(ScalarValue::Int32(None)), + min_value: Precision::Exact(ScalarValue::Int32(None)), + sum_value: Precision::Exact(ScalarValue::Int32(None)), distinct_count: Precision::Exact(0), byte_size: Precision::Exact(8), }, ColumnStatistics { null_count: Precision::Exact(0), - max_value: Precision::Exact(ScalarValue::Null), - min_value: Precision::Exact(ScalarValue::Null), - sum_value: Precision::Exact(ScalarValue::Null), + max_value: Precision::Exact(ScalarValue::Date32(None)), + min_value: Precision::Exact(ScalarValue::Date32(None)), + sum_value: Precision::Exact(ScalarValue::Date32(None)), distinct_count: Precision::Exact(0), byte_size: Precision::Exact(8), // 2 rows * 4 bytes (Date32) }, @@ -713,43 +712,6 @@ mod test { Ok(()) } - #[tokio::test] - async fn test_statistic_by_partition_of_coalesce_batches() -> Result<()> { - let scan = create_scan_exec_with_statistics(None, Some(2)).await; - let coalesce_batches: Arc = - Arc::new(CoalesceBatchesExec::new(scan, 2)); - // Partition 1: ids [3,4], dates [2025-03-01, 2025-03-02] - let expected_statistic_partition_1 = create_partition_statistics( - 2, - 16, - 3, - 4, - Some((DATE_2025_03_01, DATE_2025_03_02)), - ); - // Partition 2: ids [1,2], dates [2025-03-03, 2025-03-04] - let expected_statistic_partition_2 = create_partition_statistics( - 2, - 16, - 1, - 2, - Some((DATE_2025_03_03, DATE_2025_03_04)), - ); - let statistics = (0..coalesce_batches.output_partitioning().partition_count()) - .map(|idx| coalesce_batches.partition_statistics(Some(idx))) - .collect::>>()?; - assert_eq!(statistics.len(), 2); - assert_eq!(statistics[0], expected_statistic_partition_1); - assert_eq!(statistics[1], expected_statistic_partition_2); - - // Check the statistics_by_partition with real results - let expected_stats = vec![ - ExpectedStatistics::NonEmpty(3, 4, 2), - ExpectedStatistics::NonEmpty(1, 2, 2), - ]; - validate_statistics_with_data(coalesce_batches, expected_stats, 0).await?; - Ok(()) - } - #[tokio::test] async fn test_statistic_by_partition_of_coalesce_partitions() -> Result<()> { let scan = create_scan_exec_with_statistics(None, Some(2)).await; @@ -864,7 +826,7 @@ mod test { let plan_string = get_plan_string(&aggregate_exec_partial).swap_remove(0); assert_snapshot!( plan_string, - @"AggregateExec: mode=Partial, gby=[id@0 as id, 1 + id@0 as expr], aggr=[COUNT(c)], ordering_mode=Sorted" + @"AggregateExec: mode=Partial, gby=[id@0 as id, 1 + id@0 as expr], aggr=[COUNT(c)]" ); let p0_statistics = aggregate_exec_partial.partition_statistics(Some(0))?; @@ -1332,4 +1294,64 @@ mod test { Ok(()) } + + #[tokio::test] + async fn test_statistics_by_partition_of_empty_exec() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + + // Try to test with single partition + let empty_single = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let stats = empty_single.partition_statistics(Some(0))?; + assert_eq!(stats.num_rows, Precision::Exact(0)); + assert_eq!(stats.total_byte_size, Precision::Exact(0)); + assert_eq!(stats.column_statistics.len(), 2); + + for col_stat in &stats.column_statistics { + assert_eq!(col_stat.null_count, Precision::Exact(0)); + assert_eq!(col_stat.distinct_count, Precision::Exact(0)); + assert_eq!(col_stat.byte_size, Precision::Exact(0)); + assert_eq!(col_stat.min_value, Precision::::Absent); + assert_eq!(col_stat.max_value, Precision::::Absent); + assert_eq!(col_stat.sum_value, Precision::::Absent); + assert_eq!(col_stat.byte_size, Precision::Exact(0)); + } + + let overall_stats = empty_single.partition_statistics(None)?; + assert_eq!(stats, overall_stats); + + validate_statistics_with_data(empty_single, vec![ExpectedStatistics::Empty], 0) + .await?; + + // Test with multiple partitions + let empty_multi: Arc = + Arc::new(EmptyExec::new(Arc::clone(&schema)).with_partitions(3)); + + let statistics = (0..empty_multi.output_partitioning().partition_count()) + .map(|idx| empty_multi.partition_statistics(Some(idx))) + .collect::>>()?; + + assert_eq!(statistics.len(), 3); + + for stat in &statistics { + assert_eq!(stat.num_rows, Precision::Exact(0)); + assert_eq!(stat.total_byte_size, Precision::Exact(0)); + } + + validate_statistics_with_data( + empty_multi, + vec![ + ExpectedStatistics::Empty, + ExpectedStatistics::Empty, + ExpectedStatistics::Empty, + ], + 0, + ) + .await?; + + Ok(()) + } } diff --git a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs index 480f5c8cc97b..00e016ae02ca 100644 --- a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs @@ -45,7 +45,6 @@ use datafusion_physical_expr_common::sort_expr::{ use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_optimizer::output_requirements::OutputRequirementExec; use datafusion_physical_optimizer::projection_pushdown::ProjectionPushdown; -use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::coop::CooperativeExec; use datafusion_physical_plan::filter::FilterExec; @@ -1285,6 +1284,7 @@ fn test_hash_join_after_projection() -> Result<()> { None, PartitionMode::Auto, NullEquality::NullEqualsNothing, + false, )?); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ @@ -1681,24 +1681,15 @@ fn test_partition_col_projection_pushdown_expr() -> Result<()> { } #[test] -fn test_coalesce_batches_after_projection() -> Result<()> { +fn test_cooperative_exec_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); - let filter = Arc::new(FilterExec::try_new( - Arc::new(BinaryExpr::new( - Arc::new(Column::new("c", 2)), - Operator::Gt, - Arc::new(Literal::new(ScalarValue::Int32(Some(0)))), - )), - csv, - )?); - let coalesce_batches: Arc = - Arc::new(CoalesceBatchesExec::new(filter, 8192)); + let cooperative: Arc = Arc::new(CooperativeExec::new(csv)); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a"), ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), ], - coalesce_batches, + cooperative, )?); let initial = displayable(projection.as_ref()).indent(true).to_string(); @@ -1708,9 +1699,8 @@ fn test_coalesce_batches_after_projection() -> Result<()> { actual, @r" ProjectionExec: expr=[a@0 as a, b@1 as b] - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@2 > 0 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + CooperativeExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false " ); @@ -1722,13 +1712,12 @@ fn test_coalesce_batches_after_projection() -> Result<()> { .to_string(); let actual = after_optimize_string.trim(); - // Projection should be pushed down through CoalesceBatchesExec + // Projection should be pushed down through CooperativeExec assert_snapshot!( actual, @r" - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@2 > 0, projection=[a@0, b@1] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + CooperativeExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b], file_type=csv, has_header=false " ); @@ -1736,43 +1725,43 @@ fn test_coalesce_batches_after_projection() -> Result<()> { } #[test] -fn test_cooperative_exec_after_projection() -> Result<()> { - let csv = create_simple_csv_exec(); - let cooperative: Arc = Arc::new(CooperativeExec::new(csv)); - let projection: Arc = Arc::new(ProjectionExec::try_new( - vec![ - ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a"), - ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), - ], - cooperative, - )?); +fn test_hash_join_empty_projection_embeds() -> Result<()> { + let left_csv = create_simple_csv_exec(); + let right_csv = create_simple_csv_exec(); - let initial = displayable(projection.as_ref()).indent(true).to_string(); - let actual = initial.trim(); + let join = Arc::new(HashJoinExec::try_new( + left_csv, + right_csv, + vec![(Arc::new(Column::new("a", 0)), Arc::new(Column::new("a", 0)))], + None, + &JoinType::Right, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + false, + )?); - assert_snapshot!( - actual, - @r" - ProjectionExec: expr=[a@0 as a, b@1 as b] - CooperativeExec - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false - " - ); + // Empty projection: no columns needed from the join output + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![] as Vec, + join, + )?); let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let after_optimize_string = displayable(after_optimize.as_ref()) .indent(true) .to_string(); let actual = after_optimize_string.trim(); - // Projection should be pushed down through CooperativeExec + // The empty projection should be embedded into the HashJoinExec, + // resulting in projection=[] on the join and no ProjectionExec wrapper. assert_snapshot!( actual, @r" - CooperativeExec - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b], file_type=csv, has_header=false + HashJoinExec: mode=CollectLeft, join_type=Right, on=[(a@0, a@0)], projection=[] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false " ); diff --git a/datafusion/core/tests/physical_optimizer/pushdown_sort.rs b/datafusion/core/tests/physical_optimizer/pushdown_sort.rs index caef0fba052c..d6fd4d8d00ae 100644 --- a/datafusion/core/tests/physical_optimizer/pushdown_sort.rs +++ b/datafusion/core/tests/physical_optimizer/pushdown_sort.rs @@ -32,10 +32,10 @@ use datafusion_physical_optimizer::pushdown_sort::PushdownSort; use std::sync::Arc; use crate::physical_optimizer::test_utils::{ - OptimizationTest, coalesce_batches_exec, coalesce_partitions_exec, parquet_exec, - parquet_exec_with_sort, projection_exec, projection_exec_with_alias, - repartition_exec, schema, simple_projection_exec, sort_exec, sort_exec_with_fetch, - sort_expr, sort_expr_named, test_scan_with_ordering, + OptimizationTest, coalesce_partitions_exec, parquet_exec, parquet_exec_with_sort, + projection_exec, projection_exec_with_alias, repartition_exec, schema, + simple_projection_exec, sort_exec, sort_exec_with_fetch, sort_expr, sort_expr_named, + test_scan_with_ordering, }; #[test] @@ -231,8 +231,7 @@ fn test_prefix_match_through_transparent_nodes() { let source_ordering = LexOrdering::new(vec![a.clone().reverse(), b, c.reverse()]).unwrap(); let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); - let coalesce = coalesce_batches_exec(source, 1024); - let repartition = repartition_exec(coalesce); + let repartition = repartition_exec(source); // Request only [a ASC NULLS FIRST] - prefix of reversed ordering let prefix_ordering = LexOrdering::new(vec![a.clone().asc().nulls_first()]).unwrap(); @@ -245,14 +244,12 @@ fn test_prefix_match_through_transparent_nodes() { input: - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true - - CoalesceBatchesExec: target_batch_size=1024 - - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 DESC NULLS LAST, b@1 ASC, c@2 DESC NULLS LAST], file_type=parquet + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 DESC NULLS LAST, b@1 ASC, c@2 DESC NULLS LAST], file_type=parquet output: Ok: - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - - CoalesceBatchesExec: target_batch_size=1024 - - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true " ); } @@ -322,35 +319,6 @@ fn test_no_prefix_match_longer_than_source() { // ORIGINAL TESTS // ============================================================================ -#[test] -fn test_sort_through_coalesce_batches() { - // Sort pushes through CoalesceBatchesExec - let schema = schema(); - let a = sort_expr("a", &schema); - let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); - let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); - let coalesce = coalesce_batches_exec(source, 1024); - - let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); - let plan = sort_exec(desc_ordering, coalesce); - - insta::assert_snapshot!( - OptimizationTest::new(plan, PushdownSort::new(), true), - @r" - OptimizationTest: - input: - - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] - - CoalesceBatchesExec: target_batch_size=1024 - - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet - output: - Ok: - - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] - - CoalesceBatchesExec: target_batch_size=1024 - - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true - " - ); -} - #[test] fn test_sort_through_repartition() { // Sort should push through RepartitionExec @@ -416,20 +384,17 @@ fn test_nested_sorts() { fn test_non_sort_plans_unchanged() { // Plans without SortExec should pass through unchanged let schema = schema(); - let source = parquet_exec(schema.clone()); - let plan = coalesce_batches_exec(source, 1024); + let plan = parquet_exec(schema.clone()); insta::assert_snapshot!( OptimizationTest::new(plan, PushdownSort::new(), true), @r" OptimizationTest: input: - - CoalesceBatchesExec: target_batch_size=1024 - - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet output: Ok: - - CoalesceBatchesExec: target_batch_size=1024 - - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet " ); } @@ -482,8 +447,7 @@ fn test_complex_plan_with_multiple_operators() { let a = sort_expr("a", &schema); let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); - let coalesce_batches = coalesce_batches_exec(source, 1024); - let repartition = repartition_exec(coalesce_batches); + let repartition = repartition_exec(source); let coalesce_parts = coalesce_partitions_exec(repartition); let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); @@ -497,15 +461,13 @@ fn test_complex_plan_with_multiple_operators() { - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] - CoalescePartitionsExec - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true - - CoalesceBatchesExec: target_batch_size=1024 - - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet output: Ok: - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] - CoalescePartitionsExec - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - - CoalesceBatchesExec: target_batch_size=1024 - - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true " ); } @@ -870,7 +832,7 @@ fn test_sort_pushdown_projection_with_limit() { } #[test] -fn test_sort_pushdown_through_projection_and_coalesce() { +fn test_sort_pushdown_through_projection() { // Sort pushes through both projection and coalesce batches let schema = schema(); @@ -879,10 +841,8 @@ fn test_sort_pushdown_through_projection_and_coalesce() { let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); - let coalesce = coalesce_batches_exec(source, 1024); - // Projection: SELECT a, b - let projection = simple_projection_exec(coalesce, vec![0, 1]); + let projection = simple_projection_exec(source, vec![0, 1]); // Request [a DESC] let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); @@ -895,14 +855,12 @@ fn test_sort_pushdown_through_projection_and_coalesce() { input: - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] - ProjectionExec: expr=[a@0 as a, b@1 as b] - - CoalesceBatchesExec: target_batch_size=1024 - - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet output: Ok: - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] - ProjectionExec: expr=[a@0 as a, b@1 as b] - - CoalesceBatchesExec: target_batch_size=1024 - - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true " ); } diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs b/datafusion/core/tests/physical_optimizer/pushdown_utils.rs similarity index 92% rename from datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs rename to datafusion/core/tests/physical_optimizer/pushdown_utils.rs index 1afdc4823f0a..91ae6c414e9e 100644 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs +++ b/datafusion/core/tests/physical_optimizer/pushdown_utils.rs @@ -24,6 +24,7 @@ use datafusion_datasource::{ file_scan_config::FileScanConfigBuilder, file_stream::FileOpenFuture, file_stream::FileOpener, source::DataSourceExec, }; +use datafusion_physical_expr::projection::ProjectionExprs; use datafusion_physical_expr_common::physical_expr::fmt_sql; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::filter::batch_filter; @@ -50,7 +51,7 @@ use std::{ pub struct TestOpener { batches: Vec, batch_size: Option, - projection: Option>, + projection: Option, predicate: Option>, } @@ -60,6 +61,7 @@ impl FileOpener for TestOpener { if self.batches.is_empty() { return Ok((async { Ok(TestStream::new(vec![]).boxed()) }).boxed()); } + let schema = self.batches[0].schema(); if let Some(batch_size) = self.batch_size { let batch = concat_batches(&batches[0].schema(), &batches)?; let mut new_batches = Vec::new(); @@ -83,9 +85,10 @@ impl FileOpener for TestOpener { batches = new_batches; if let Some(projection) = &self.projection { + let projector = projection.make_projector(&schema)?; batches = batches .into_iter() - .map(|batch| batch.project(projection).unwrap()) + .map(|batch| projector.project_batch(&batch).unwrap()) .collect(); } @@ -103,14 +106,13 @@ pub struct TestSource { batch_size: Option, batches: Vec, metrics: ExecutionPlanMetricsSet, - projection: Option>, + projection: Option, table_schema: datafusion_datasource::TableSchema, } impl TestSource { pub fn new(schema: SchemaRef, support: bool, batches: Vec) -> Self { - let table_schema = - datafusion_datasource::TableSchema::new(Arc::clone(&schema), vec![]); + let table_schema = datafusion_datasource::TableSchema::new(schema, vec![]); Self { support, metrics: ExecutionPlanMetricsSet::new(), @@ -210,6 +212,30 @@ impl FileSource for TestSource { } } + fn try_pushdown_projection( + &self, + projection: &ProjectionExprs, + ) -> Result>> { + if let Some(existing_projection) = &self.projection { + // Combine existing projection with new projection + let combined_projection = existing_projection.try_merge(projection)?; + Ok(Some(Arc::new(TestSource { + projection: Some(combined_projection), + table_schema: self.table_schema.clone(), + ..self.clone() + }))) + } else { + Ok(Some(Arc::new(TestSource { + projection: Some(projection.clone()), + ..self.clone() + }))) + } + } + + fn projection(&self) -> Option<&ProjectionExprs> { + self.projection.as_ref() + } + fn table_schema(&self) -> &datafusion_datasource::TableSchema { &self.table_schema } @@ -332,6 +358,7 @@ pub struct OptimizationTest { } impl OptimizationTest { + #[expect(clippy::needless_pass_by_value)] pub fn new( input_plan: Arc, opt: O, @@ -447,7 +474,7 @@ impl ExecutionPlan for TestNode { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { self.input.properties() } diff --git a/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs index d93081f5ceb8..cdfed5011696 100644 --- a/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs @@ -18,10 +18,10 @@ use std::sync::Arc; use crate::physical_optimizer::test_utils::{ - check_integrity, coalesce_batches_exec, coalesce_partitions_exec, - create_test_schema3, parquet_exec_with_sort, sort_exec, - sort_exec_with_preserve_partitioning, sort_preserving_merge_exec, - sort_preserving_merge_exec_with_fetch, stream_exec_ordered_with_projection, + check_integrity, coalesce_partitions_exec, create_test_schema3, + parquet_exec_with_sort, sort_exec, sort_exec_with_preserve_partitioning, + sort_preserving_merge_exec, sort_preserving_merge_exec_with_fetch, + stream_exec_ordered_with_projection, }; use datafusion::prelude::SessionContext; @@ -41,7 +41,6 @@ use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use datafusion_physical_optimizer::enforce_sorting::replace_with_order_preserving_variants::{ plan_with_order_breaking_variants, plan_with_order_preserving_variants, replace_with_order_preserving_variants, OrderPreservationContext }; -use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; use datafusion::datasource::memory::MemorySourceConfig; @@ -50,7 +49,7 @@ use datafusion_physical_plan::{ collect, displayable, ExecutionPlan, Partitioning, }; -use object_store::ObjectStore; +use object_store::ObjectStoreExt; use object_store::memory::InMemory; use rstest::rstest; use url::Url; @@ -440,9 +439,7 @@ async fn test_replace_multiple_input_repartition_with_extra_steps( let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); - let coalesce_batches_exec = coalesce_batches_exec(filter, 8192); - let sort = - sort_exec_with_preserve_partitioning(ordering.clone(), coalesce_batches_exec); + let sort = sort_exec_with_preserve_partitioning(ordering.clone(), filter); let physical_plan = sort_preserving_merge_exec(ordering, sort); let run = ReplaceTest::new(physical_plan) @@ -458,19 +455,17 @@ async fn test_replace_multiple_input_repartition_with_extra_steps( Input: SortPreservingMergeExec: [a@0 ASC NULLS LAST] SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true - StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] - - Optimized: - SortPreservingMergeExec: [a@0 ASC NULLS LAST] - CoalesceBatchesExec: target_batch_size=8192 FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + + Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] "); }, (Boundedness::Bounded, SortPreference::MaximizeParallelism) => { @@ -478,11 +473,10 @@ async fn test_replace_multiple_input_repartition_with_extra_steps( Input / Optimized: SortPreservingMergeExec: [a@0 ASC NULLS LAST] SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true - DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); }, (Boundedness::Bounded, SortPreference::PreserveOrder) => { @@ -490,19 +484,17 @@ async fn test_replace_multiple_input_repartition_with_extra_steps( Input: SortPreservingMergeExec: [a@0 ASC NULLS LAST] SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true - DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST - - Optimized: - SortPreservingMergeExec: [a@0 ASC NULLS LAST] - CoalesceBatchesExec: target_batch_size=8192 FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + + Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); } } @@ -527,12 +519,9 @@ async fn test_replace_multiple_input_repartition_with_extra_steps_2( Boundedness::Bounded => memory_exec_sorted(&schema, ordering.clone()), }; let repartition_rr = repartition_exec_round_robin(source); - let coalesce_batches_exec_1 = coalesce_batches_exec(repartition_rr, 8192); - let repartition_hash = repartition_exec_hash(coalesce_batches_exec_1); + let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); - let coalesce_batches_exec_2 = coalesce_batches_exec(filter, 8192); - let sort = - sort_exec_with_preserve_partitioning(ordering.clone(), coalesce_batches_exec_2); + let sort = sort_exec_with_preserve_partitioning(ordering.clone(), filter); let physical_plan = sort_preserving_merge_exec(ordering, sort); let run = ReplaceTest::new(physical_plan) @@ -548,21 +537,17 @@ async fn test_replace_multiple_input_repartition_with_extra_steps_2( Input: SortPreservingMergeExec: [a@0 ASC NULLS LAST] SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true - StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] Optimized: SortPreservingMergeExec: [a@0 ASC NULLS LAST] - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true - StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] "); }, (Boundedness::Bounded, SortPreference::MaximizeParallelism) => { @@ -570,12 +555,10 @@ async fn test_replace_multiple_input_repartition_with_extra_steps_2( Input / Optimized: SortPreservingMergeExec: [a@0 ASC NULLS LAST] SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true - DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); }, (Boundedness::Bounded, SortPreference::PreserveOrder) => { @@ -583,21 +566,17 @@ async fn test_replace_multiple_input_repartition_with_extra_steps_2( Input: SortPreservingMergeExec: [a@0 ASC NULLS LAST] SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true - DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST Optimized: SortPreservingMergeExec: [a@0 ASC NULLS LAST] - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true - DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); } } @@ -622,8 +601,7 @@ async fn test_not_replacing_when_no_need_to_preserve_sorting( let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); - let coalesce_batches_exec = coalesce_batches_exec(filter, 8192); - let physical_plan = coalesce_partitions_exec(coalesce_batches_exec); + let physical_plan = coalesce_partitions_exec(filter); let run = ReplaceTest::new(physical_plan) .with_boundedness(boundedness) @@ -637,22 +615,20 @@ async fn test_not_replacing_when_no_need_to_preserve_sorting( assert_snapshot!(physical_plan, @r" Input / Optimized: CoalescePartitionsExec - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true - StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] "); }, (Boundedness::Bounded, SortPreference::MaximizeParallelism) => { assert_snapshot!(physical_plan, @r" Input / Optimized: CoalescePartitionsExec - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true - DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); // Expected bounded results same with and without flag, because there is no executor with ordering requirement }, @@ -660,11 +636,10 @@ async fn test_not_replacing_when_no_need_to_preserve_sorting( assert_snapshot!(physical_plan, @r" Input / Optimized: CoalescePartitionsExec - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true - DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); } } @@ -691,8 +666,7 @@ async fn test_with_multiple_replaceable_repartitions( let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); - let coalesce_batches = coalesce_batches_exec(filter, 8192); - let repartition_hash_2 = repartition_exec_hash(coalesce_batches); + let repartition_hash_2 = repartition_exec_hash(filter); let sort = sort_exec_with_preserve_partitioning(ordering.clone(), repartition_hash_2); let physical_plan = sort_preserving_merge_exec(ordering, sort); @@ -710,20 +684,18 @@ async fn test_with_multiple_replaceable_repartitions( SortPreservingMergeExec: [a@0 ASC NULLS LAST] SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true - StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] Optimized: SortPreservingMergeExec: [a@0 ASC NULLS LAST] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true - StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] "); }, (Boundedness::Bounded, SortPreference::MaximizeParallelism) => { @@ -732,11 +704,10 @@ async fn test_with_multiple_replaceable_repartitions( SortPreservingMergeExec: [a@0 ASC NULLS LAST] SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true - DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); }, (Boundedness::Bounded, SortPreference::PreserveOrder) => { @@ -745,20 +716,18 @@ async fn test_with_multiple_replaceable_repartitions( SortPreservingMergeExec: [a@0 ASC NULLS LAST] SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true - DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST Optimized: SortPreservingMergeExec: [a@0 ASC NULLS LAST] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true - DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); } } @@ -1041,8 +1010,6 @@ async fn test_with_multiple_child_trees( }; let left_repartition_rr = repartition_exec_round_robin(left_source); let left_repartition_hash = repartition_exec_hash(left_repartition_rr); - let left_coalesce_partitions = - Arc::new(CoalesceBatchesExec::new(left_repartition_hash, 4096)); let right_ordering = [sort_expr("a", &schema)].into(); let right_source = match boundedness { @@ -1053,11 +1020,8 @@ async fn test_with_multiple_child_trees( }; let right_repartition_rr = repartition_exec_round_robin(right_source); let right_repartition_hash = repartition_exec_hash(right_repartition_rr); - let right_coalesce_partitions = - Arc::new(CoalesceBatchesExec::new(right_repartition_hash, 4096)); - let hash_join_exec = - hash_join_exec(left_coalesce_partitions, right_coalesce_partitions); + let hash_join_exec = hash_join_exec(left_repartition_hash, right_repartition_hash); let ordering: LexOrdering = [sort_expr_default("a", &hash_join_exec.schema())].into(); let sort = sort_exec_with_preserve_partitioning(ordering.clone(), hash_join_exec); let physical_plan = sort_preserving_merge_exec(ordering, sort); @@ -1076,14 +1040,12 @@ async fn test_with_multiple_child_trees( SortPreservingMergeExec: [a@0 ASC] SortExec: expr=[a@0 ASC], preserve_partitioning=[true] HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)] - CoalesceBatchesExec: target_batch_size=4096 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true - StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] - CoalesceBatchesExec: target_batch_size=4096 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true - StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] "); }, (Boundedness::Bounded, _) => { @@ -1092,14 +1054,12 @@ async fn test_with_multiple_child_trees( SortPreservingMergeExec: [a@0 ASC] SortExec: expr=[a@0 ASC], preserve_partitioning=[true] HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)] - CoalesceBatchesExec: target_batch_size=4096 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true - DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST - CoalesceBatchesExec: target_batch_size=4096 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true - DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); // Expected bounded results same with and without flag, because ordering get lost during intermediate executor anyway. // Hence, no need to preserve existing ordering. @@ -1179,6 +1139,7 @@ fn hash_join_exec( None, PartitionMode::Partitioned, NullEquality::NullEqualsNothing, + false, ) .unwrap(), ) diff --git a/datafusion/core/tests/physical_optimizer/test_utils.rs b/datafusion/core/tests/physical_optimizer/test_utils.rs index 5b50181d7fd3..f8c91ba272a9 100644 --- a/datafusion/core/tests/physical_optimizer/test_utils.rs +++ b/datafusion/core/tests/physical_optimizer/test_utils.rs @@ -53,7 +53,6 @@ use datafusion_physical_optimizer::limited_distinct_aggregation::LimitedDistinct use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; -use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion_physical_plan::filter::FilterExec; @@ -248,6 +247,7 @@ pub fn hash_join_exec( None, PartitionMode::Partitioned, NullEquality::NullEqualsNothing, + false, )?)) } @@ -360,13 +360,6 @@ pub fn aggregate_exec(input: Arc) -> Arc { ) } -pub fn coalesce_batches_exec( - input: Arc, - batch_size: usize, -) -> Arc { - Arc::new(CoalesceBatchesExec::new(input, batch_size)) -} - pub fn sort_exec( ordering: LexOrdering, input: Arc, @@ -461,7 +454,7 @@ impl ExecutionPlan for RequirementsTestExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { self.input.properties() } @@ -832,7 +825,7 @@ pub fn sort_expr_named(name: &str, index: usize) -> PhysicalSortExpr { pub struct TestScan { schema: SchemaRef, output_ordering: Vec, - plan_properties: PlanProperties, + plan_properties: Arc, // Store the requested ordering for display requested_ordering: Option, } @@ -866,7 +859,7 @@ impl TestScan { Self { schema, output_ordering, - plan_properties, + plan_properties: Arc::new(plan_properties), requested_ordering: None, } } @@ -922,7 +915,7 @@ impl ExecutionPlan for TestScan { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.plan_properties } diff --git a/datafusion/core/tests/set_comparison.rs b/datafusion/core/tests/set_comparison.rs new file mode 100644 index 000000000000..464d6c937b32 --- /dev/null +++ b/datafusion/core/tests/set_comparison.rs @@ -0,0 +1,193 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{Int32Array, StringArray}; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::record_batch::RecordBatch; +use datafusion::prelude::SessionContext; +use datafusion_common::{Result, assert_batches_eq, assert_contains}; + +fn build_table(values: &[i32]) -> Result { + let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)])); + let array = + Arc::new(Int32Array::from(values.to_vec())) as Arc; + RecordBatch::try_new(schema, vec![array]).map_err(Into::into) +} + +#[tokio::test] +async fn set_comparison_any() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1, 6, 10])?)?; + // Include a NULL in the subquery input to ensure we propagate UNKNOWN correctly. + ctx.register_batch("s", { + let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)])); + let array = Arc::new(Int32Array::from(vec![Some(5), None])) + as Arc; + RecordBatch::try_new(schema, vec![array])? + })?; + + let df = ctx + .sql("select v from t where v > any(select v from s)") + .await?; + let results = df.collect().await?; + + assert_batches_eq!( + &["+----+", "| v |", "+----+", "| 6 |", "| 10 |", "+----+",], + &results + ); + Ok(()) +} + +#[tokio::test] +async fn set_comparison_any_aggregate_subquery() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1, 7])?)?; + ctx.register_batch("s", build_table(&[1, 2, 3])?)?; + + let df = ctx + .sql( + "select v from t where v > any(select sum(v) from s group by v % 2) order by v", + ) + .await?; + let results = df.collect().await?; + + assert_batches_eq!(&["+---+", "| v |", "+---+", "| 7 |", "+---+",], &results); + Ok(()) +} + +#[tokio::test] +async fn set_comparison_all_empty() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1, 6, 10])?)?; + ctx.register_batch( + "e", + RecordBatch::new_empty(Arc::new(Schema::new(vec![Field::new( + "v", + DataType::Int32, + true, + )]))), + )?; + + let df = ctx + .sql("select v from t where v < all(select v from e)") + .await?; + let results = df.collect().await?; + + assert_batches_eq!( + &[ + "+----+", "| v |", "+----+", "| 1 |", "| 6 |", "| 10 |", "+----+", + ], + &results + ); + Ok(()) +} + +#[tokio::test] +async fn set_comparison_type_mismatch() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1])?)?; + ctx.register_batch("strings", { + let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)])); + let array = Arc::new(StringArray::from(vec![Some("a"), Some("b")])) + as Arc; + RecordBatch::try_new(schema, vec![array])? + })?; + + let df = ctx + .sql("select v from t where v > any(select s from strings)") + .await?; + let err = df.collect().await.unwrap_err(); + assert_contains!( + err.to_string(), + "expr type Int32 can't cast to Utf8 in SetComparison" + ); + Ok(()) +} + +#[tokio::test] +async fn set_comparison_multiple_operators() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1, 2, 3, 4])?)?; + ctx.register_batch("s", build_table(&[2, 3])?)?; + + let df = ctx + .sql("select v from t where v = any(select v from s) order by v") + .await?; + let results = df.collect().await?; + assert_batches_eq!( + &["+---+", "| v |", "+---+", "| 2 |", "| 3 |", "+---+",], + &results + ); + + let df = ctx + .sql("select v from t where v != all(select v from s) order by v") + .await?; + let results = df.collect().await?; + assert_batches_eq!( + &["+---+", "| v |", "+---+", "| 1 |", "| 4 |", "+---+",], + &results + ); + + let df = ctx + .sql("select v from t where v >= all(select v from s) order by v") + .await?; + let results = df.collect().await?; + assert_batches_eq!( + &["+---+", "| v |", "+---+", "| 3 |", "| 4 |", "+---+",], + &results + ); + + let df = ctx + .sql("select v from t where v <= any(select v from s) order by v") + .await?; + let results = df.collect().await?; + assert_batches_eq!( + &[ + "+---+", "| v |", "+---+", "| 1 |", "| 2 |", "| 3 |", "+---+", + ], + &results + ); + Ok(()) +} + +#[tokio::test] +async fn set_comparison_null_semantics_all() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[5])?)?; + ctx.register_batch("s", { + let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)])); + let array = Arc::new(Int32Array::from(vec![Some(1), None])) + as Arc; + RecordBatch::try_new(schema, vec![array])? + })?; + + let df = ctx + .sql("select v from t where v != all(select v from s)") + .await?; + let results = df.collect().await?; + let row_count: usize = results.iter().map(|batch| batch.num_rows()).sum(); + assert_eq!(0, row_count); + Ok(()) +} diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 75cd78e47aff..5f62f7204eff 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -144,7 +144,6 @@ async fn explain_analyze_baseline_metrics() { || plan.as_any().downcast_ref::().is_some() || plan.as_any().downcast_ref::().is_some() || plan.as_any().downcast_ref::().is_some() - || plan.as_any().downcast_ref::().is_some() || plan.as_any().downcast_ref::().is_some() || plan.as_any().downcast_ref::().is_some() || plan.as_any().downcast_ref::().is_some() @@ -880,12 +879,13 @@ async fn parquet_explain_analyze() { let i_rowgroup_stat = formatted.find("row_groups_pruned_statistics").unwrap(); let i_rowgroup_bloomfilter = formatted.find("row_groups_pruned_bloom_filter").unwrap(); - let i_page = formatted.find("page_index_rows_pruned").unwrap(); + let i_page_rows = formatted.find("page_index_rows_pruned").unwrap(); + let i_page_pages = formatted.find("page_index_pages_pruned").unwrap(); assert!( (i_file < i_rowgroup_stat) && (i_rowgroup_stat < i_rowgroup_bloomfilter) - && (i_rowgroup_bloomfilter < i_page), + && (i_rowgroup_bloomfilter < i_page_pages && i_page_pages < i_page_rows), "The parquet pruning metrics should be displayed in an order of: file range -> row group statistics -> row group bloom filter -> page index." ); } diff --git a/datafusion/core/tests/sql/path_partition.rs b/datafusion/core/tests/sql/path_partition.rs index c6f920584dc2..a9061849795c 100644 --- a/datafusion/core/tests/sql/path_partition.rs +++ b/datafusion/core/tests/sql/path_partition.rs @@ -20,7 +20,6 @@ use std::collections::BTreeSet; use std::fs::File; use std::io::{Read, Seek, SeekFrom}; -use std::ops::Range; use std::sync::Arc; use arrow::datatypes::DataType; @@ -43,9 +42,12 @@ use datafusion_execution::config::SessionConfig; use async_trait::async_trait; use bytes::Bytes; use chrono::{TimeZone, Utc}; +use futures::StreamExt; use futures::stream::{self, BoxStream}; use insta::assert_snapshot; -use object_store::{Attributes, MultipartUpload, PutMultipartOptions, PutPayload}; +use object_store::{ + Attributes, CopyOptions, GetRange, MultipartUpload, PutMultipartOptions, PutPayload, +}; use object_store::{ GetOptions, GetResult, GetResultPayload, ListResult, ObjectMeta, ObjectStore, PutOptions, PutResult, path::Path, @@ -620,7 +622,7 @@ async fn create_partitioned_alltypes_parquet_table( } #[derive(Debug)] -/// An object store implem that is mirrors a given file to multiple paths. +/// An object store implem that mirrors a given file to multiple paths. pub struct MirroringObjectStore { /// The `(path,size)` of the files that "exist" in the store files: Vec, @@ -669,12 +671,13 @@ impl ObjectStore for MirroringObjectStore { async fn get_opts( &self, location: &Path, - _options: GetOptions, + options: GetOptions, ) -> object_store::Result { self.files.iter().find(|x| *x == location).unwrap(); let path = std::path::PathBuf::from(&self.mirrored_file); let file = File::open(&path).unwrap(); let metadata = file.metadata().unwrap(); + let meta = ObjectMeta { location: location.clone(), last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(), @@ -683,37 +686,35 @@ impl ObjectStore for MirroringObjectStore { version: None, }; + let payload = if options.head { + // no content for head requests + GetResultPayload::Stream(stream::empty().boxed()) + } else if let Some(range) = options.range { + let GetRange::Bounded(range) = range else { + unimplemented!("Unbounded range not supported in MirroringObjectStore"); + }; + let mut file = File::open(path).unwrap(); + file.seek(SeekFrom::Start(range.start)).unwrap(); + + let to_read = range.end - range.start; + let to_read: usize = to_read.try_into().unwrap(); + let mut data = Vec::with_capacity(to_read); + let read = file.take(to_read as u64).read_to_end(&mut data).unwrap(); + assert_eq!(read, to_read); + let stream = stream::once(async move { Ok(Bytes::from(data)) }).boxed(); + GetResultPayload::Stream(stream) + } else { + GetResultPayload::File(file, path) + }; + Ok(GetResult { range: 0..meta.size, - payload: GetResultPayload::File(file, path), + payload, meta, attributes: Attributes::default(), }) } - async fn get_range( - &self, - location: &Path, - range: Range, - ) -> object_store::Result { - self.files.iter().find(|x| *x == location).unwrap(); - let path = std::path::PathBuf::from(&self.mirrored_file); - let mut file = File::open(path).unwrap(); - file.seek(SeekFrom::Start(range.start)).unwrap(); - - let to_read = range.end - range.start; - let to_read: usize = to_read.try_into().unwrap(); - let mut data = Vec::with_capacity(to_read); - let read = file.take(to_read as u64).read_to_end(&mut data).unwrap(); - assert_eq!(read, to_read); - - Ok(data.into()) - } - - async fn delete(&self, _location: &Path) -> object_store::Result<()> { - unimplemented!() - } - fn list( &self, prefix: Option<&Path>, @@ -783,14 +784,18 @@ impl ObjectStore for MirroringObjectStore { }) } - async fn copy(&self, _from: &Path, _to: &Path) -> object_store::Result<()> { + fn delete_stream( + &self, + _locations: BoxStream<'static, object_store::Result>, + ) -> BoxStream<'static, object_store::Result> { unimplemented!() } - async fn copy_if_not_exists( + async fn copy_opts( &self, _from: &Path, _to: &Path, + _options: CopyOptions, ) -> object_store::Result<()> { unimplemented!() } diff --git a/datafusion/core/tests/sql/runtime_config.rs b/datafusion/core/tests/sql/runtime_config.rs index d85892c25457..cf5237d72580 100644 --- a/datafusion/core/tests/sql/runtime_config.rs +++ b/datafusion/core/tests/sql/runtime_config.rs @@ -145,7 +145,7 @@ async fn test_memory_limit_enforcement() { } #[tokio::test] -async fn test_invalid_memory_limit() { +async fn test_invalid_memory_limit_when_unit_is_invalid() { let ctx = SessionContext::new(); let result = ctx @@ -154,7 +154,26 @@ async fn test_invalid_memory_limit() { assert!(result.is_err()); let error_message = result.unwrap_err().to_string(); - assert!(error_message.contains("Unsupported unit 'X'")); + assert!( + error_message + .contains("Unsupported unit 'X' in 'datafusion.runtime.memory_limit'") + && error_message.contains("Unit must be one of: 'K', 'M', 'G'") + ); +} + +#[tokio::test] +async fn test_invalid_memory_limit_when_limit_is_not_numeric() { + let ctx = SessionContext::new(); + + let result = ctx + .sql("SET datafusion.runtime.memory_limit = 'invalid_memory_limit'") + .await; + + assert!(result.is_err()); + let error_message = result.unwrap_err().to_string(); + assert!(error_message.contains( + "Failed to parse number from 'datafusion.runtime.memory_limit', limit 'invalid_memory_limit'" + )); } #[tokio::test] diff --git a/datafusion/core/tests/sql/unparser.rs b/datafusion/core/tests/sql/unparser.rs index 8b56bf67a261..ab1015b2d18d 100644 --- a/datafusion/core/tests/sql/unparser.rs +++ b/datafusion/core/tests/sql/unparser.rs @@ -47,6 +47,7 @@ use datafusion_physical_plan::ExecutionPlanProperties; use datafusion_sql::unparser::Unparser; use datafusion_sql::unparser::dialect::DefaultDialect; use itertools::Itertools; +use recursive::{set_minimum_stack_size, set_stack_allocation_size}; /// Paths to benchmark query files (supports running from repo root or different working directories). const BENCHMARK_PATHS: &[&str] = &["../../benchmarks/", "./benchmarks/"]; @@ -458,5 +459,8 @@ async fn test_clickbench_unparser_roundtrip() { #[tokio::test] async fn test_tpch_unparser_roundtrip() { + // Grow stacker segments earlier to avoid deep unparser recursion overflow in q20. + set_minimum_stack_size(512 * 1024); + set_stack_allocation_size(8 * 1024 * 1024); run_roundtrip_tests("TPC-H", tpch_queries(), tpch_test_context).await; } diff --git a/datafusion/core/tests/tracing/traceable_object_store.rs b/datafusion/core/tests/tracing/traceable_object_store.rs index 00aa4ea3f36d..71a61dbf8772 100644 --- a/datafusion/core/tests/tracing/traceable_object_store.rs +++ b/datafusion/core/tests/tracing/traceable_object_store.rs @@ -18,10 +18,11 @@ //! Object store implementation used for testing use crate::tracing::asserting_tracer::assert_traceability; +use futures::StreamExt; use futures::stream::BoxStream; use object_store::{ - GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore, - PutMultipartOptions, PutOptions, PutPayload, PutResult, path::Path, + CopyOptions, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, + ObjectStore, PutMultipartOptions, PutOptions, PutPayload, PutResult, path::Path, }; use std::fmt::{Debug, Display, Formatter}; use std::sync::Arc; @@ -83,14 +84,17 @@ impl ObjectStore for TraceableObjectStore { self.inner.get_opts(location, options).await } - async fn head(&self, location: &Path) -> object_store::Result { - assert_traceability().await; - self.inner.head(location).await - } - - async fn delete(&self, location: &Path) -> object_store::Result<()> { - assert_traceability().await; - self.inner.delete(location).await + fn delete_stream( + &self, + locations: BoxStream<'static, object_store::Result>, + ) -> BoxStream<'static, object_store::Result> { + self.inner + .delete_stream(locations) + .then(|res| async { + futures::executor::block_on(assert_traceability()); + res + }) + .boxed() } fn list( @@ -109,17 +113,13 @@ impl ObjectStore for TraceableObjectStore { self.inner.list_with_delimiter(prefix).await } - async fn copy(&self, from: &Path, to: &Path) -> object_store::Result<()> { - assert_traceability().await; - self.inner.copy(from, to).await - } - - async fn copy_if_not_exists( + async fn copy_opts( &self, from: &Path, to: &Path, + options: CopyOptions, ) -> object_store::Result<()> { assert_traceability().await; - self.inner.copy_if_not_exists(from, to).await + self.inner.copy_opts(from, to, options).await } } diff --git a/datafusion/core/tests/user_defined/insert_operation.rs b/datafusion/core/tests/user_defined/insert_operation.rs index 7ad00dece1b2..4d2a31ca1f96 100644 --- a/datafusion/core/tests/user_defined/insert_operation.rs +++ b/datafusion/core/tests/user_defined/insert_operation.rs @@ -122,20 +122,22 @@ impl TableProvider for TestInsertTableProvider { #[derive(Debug)] struct TestInsertExec { op: InsertOp, - plan_properties: PlanProperties, + plan_properties: Arc, } impl TestInsertExec { fn new(op: InsertOp) -> Self { Self { op, - plan_properties: PlanProperties::new( - EquivalenceProperties::new(make_count_schema()), - Partitioning::UnknownPartitioning(1), - EmissionType::Incremental, - Boundedness::Bounded, - ) - .with_scheduling_type(SchedulingType::Cooperative), + plan_properties: Arc::new( + PlanProperties::new( + EquivalenceProperties::new(make_count_schema()), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ) + .with_scheduling_type(SchedulingType::Cooperative), + ), } } } @@ -159,7 +161,7 @@ impl ExecutionPlan for TestInsertExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.plan_properties } diff --git a/datafusion/core/tests/user_defined/relation_planner.rs b/datafusion/core/tests/user_defined/relation_planner.rs index bda9b37ebea6..54af53ad858d 100644 --- a/datafusion/core/tests/user_defined/relation_planner.rs +++ b/datafusion/core/tests/user_defined/relation_planner.rs @@ -68,9 +68,11 @@ fn plan_static_values_table( .project(vec![col("column1").alias(column_name)])? .build()?; - Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) } - other => Ok(RelationPlanning::Original(other)), + other => Ok(RelationPlanning::Original(Box::new(other))), } } @@ -176,9 +178,11 @@ impl RelationPlanner for SamplingJoinPlanner { .cross_join(right_sampled)? .build()?; - Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) } - other => Ok(RelationPlanning::Original(other)), + other => Ok(RelationPlanning::Original(Box::new(other))), } } } @@ -195,7 +199,7 @@ impl RelationPlanner for PassThroughPlanner { _context: &mut dyn RelationPlannerContext, ) -> Result { // Never handles anything - always delegates - Ok(RelationPlanning::Original(relation)) + Ok(RelationPlanning::Original(Box::new(relation))) } } @@ -217,7 +221,7 @@ impl RelationPlanner for PremiumFeaturePlanner { to unlock advanced array operations." .to_string(), )), - other => Ok(RelationPlanning::Original(other)), + other => Ok(RelationPlanning::Original(Box::new(other))), } } } diff --git a/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs index 168d81fc6b44..31af4445ace0 100644 --- a/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs @@ -21,16 +21,14 @@ use arrow::array::{Int32Array, RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use async_trait::async_trait; use datafusion::prelude::*; +use datafusion_common::test_util::format_batches; use datafusion_common::{Result, assert_batches_eq}; use datafusion_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; -// This test checks the case where batch_size doesn't evenly divide -// the number of rows. -#[tokio::test] -async fn test_async_udf_with_non_modular_batch_size() -> Result<()> { +fn register_table_and_udf() -> Result { let num_rows = 3; let batch_size = 2; @@ -59,6 +57,15 @@ async fn test_async_udf_with_non_modular_batch_size() -> Result<()> { .into_scalar_udf(), ); + Ok(ctx) +} + +// This test checks the case where batch_size doesn't evenly divide +// the number of rows. +#[tokio::test] +async fn test_async_udf_with_non_modular_batch_size() -> Result<()> { + let ctx = register_table_and_udf()?; + let df = ctx .sql("SELECT id, test_async_udf(prompt) as result FROM test_table") .await?; @@ -81,6 +88,31 @@ async fn test_async_udf_with_non_modular_batch_size() -> Result<()> { Ok(()) } +// This test checks if metrics are printed for `AsyncFuncExec` +#[tokio::test] +async fn test_async_udf_metrics() -> Result<()> { + let ctx = register_table_and_udf()?; + + let df = ctx + .sql( + "EXPLAIN ANALYZE SELECT id, test_async_udf(prompt) as result FROM test_table", + ) + .await?; + + let result = df.collect().await?; + + let explain_analyze_str = format_batches(&result)?.to_string(); + let async_func_exec_without_metrics = + explain_analyze_str.split("\n").any(|metric_line| { + metric_line.contains("AsyncFuncExec") + && !metric_line.contains("output_rows=3") + }); + + assert!(!async_func_exec_without_metrics); + + Ok(()) +} + #[derive(Debug, PartialEq, Eq, Hash, Clone)] struct TestAsyncUDFImpl { batch_size: usize, diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index d53e07673960..f97923ffc5be 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -84,7 +84,7 @@ use datafusion::{ physical_expr::EquivalenceProperties, physical_plan::{ DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, - PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, + PlanProperties, RecordBatchStream, SendableRecordBatchStream, }, physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner}, prelude::{SessionConfig, SessionContext}, @@ -653,13 +653,17 @@ struct TopKExec { input: Arc, /// The maximum number of values k: usize, - cache: PlanProperties, + cache: Arc, } impl TopKExec { fn new(input: Arc, k: usize) -> Self { let cache = Self::compute_properties(input.schema()); - Self { input, k, cache } + Self { + input, + k, + cache: Arc::new(cache), + } } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -704,7 +708,7 @@ impl ExecutionPlan for TopKExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -742,12 +746,6 @@ impl ExecutionPlan for TopKExec { state: BTreeMap::new(), })) } - - fn statistics(&self) -> Result { - // to improve the optimizability of this plan - // better statistics inference could be provided - Ok(Statistics::new_unknown(&self.schema())) - } } // A very specialized TopK implementation diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index b86cd94a8a9b..b4ce3a03dbcb 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -42,7 +42,7 @@ use datafusion_common::{ assert_batches_sorted_eq, assert_contains, exec_datafusion_err, exec_err, not_impl_err, plan_err, }; -use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr::{ Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, @@ -699,7 +699,7 @@ impl ScalarUDFImpl for CastToI64UDF { fn simplify( &self, mut args: Vec, - info: &dyn SimplifyInfo, + info: &SimplifyContext, ) -> Result { // DataFusion should have ensured the function is called with just a // single argument @@ -975,7 +975,7 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { fn simplify( &self, args: Vec, - _info: &dyn SimplifyInfo, + _info: &SimplifyContext, ) -> Result { let replacement = Self::replacement(&self.expr, &args, &self.defaults)?; @@ -1306,19 +1306,14 @@ async fn create_scalar_function_from_sql_statement_default_arguments() -> Result "Error during planning: Non-default arguments cannot follow default arguments."; assert!(expected.starts_with(&err.strip_backtrace())); - // FIXME: The `DEFAULT` syntax does not work with positional params - let bad_expression_sql = r#" + let expression_sql = r#" CREATE FUNCTION bad_expression_fun(DOUBLE, DOUBLE DEFAULT 2.0) RETURNS DOUBLE RETURN $1 + $2 "#; - let err = ctx - .sql(bad_expression_sql) - .await - .expect_err("sqlparser error"); - let expected = - "SQL error: ParserError(\"Expected: ), found: 2.0 at Line: 2, Column: 63\")"; - assert!(expected.starts_with(&err.strip_backtrace())); + let result = ctx.sql(expression_sql).await; + + assert!(result.is_ok()); Ok(()) } diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs b/datafusion/core/tests/user_defined/user_defined_table_functions.rs index 8be8609c6248..95694d00a6c3 100644 --- a/datafusion/core/tests/user_defined/user_defined_table_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs @@ -221,6 +221,31 @@ impl TableFunctionImpl for SimpleCsvTableFunc { } } +/// Test that expressions passed to UDTFs are properly type-coerced +/// This is a regression test for https://github.com/apache/datafusion/issues/19914 +#[tokio::test] +async fn test_udtf_type_coercion() -> Result<()> { + use datafusion::datasource::MemTable; + + #[derive(Debug)] + struct NoOpTableFunc; + + impl TableFunctionImpl for NoOpTableFunc { + fn call(&self, _: &[Expr]) -> Result> { + let schema = Arc::new(arrow::datatypes::Schema::empty()); + Ok(Arc::new(MemTable::try_new(schema, vec![vec![]])?)) + } + } + + let ctx = SessionContext::new(); + ctx.register_udtf("f", Arc::new(NoOpTableFunc)); + + // This should not panic - the array elements should be coerced to Float64 + let _ = ctx.sql("SELECT * FROM f(ARRAY[0.1, 1, 2])").await?; + + Ok(()) +} + fn read_csv_batches(csv_path: impl AsRef) -> Result<(SchemaRef, Vec)> { let mut file = File::open(csv_path)?; let (schema, _) = Format::default() diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index 57baf271c591..775325a33718 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -536,7 +536,7 @@ impl OddCounter { impl SimpleWindowUDF { fn new(test_state: Arc) -> Self { let signature = - Signature::exact(vec![DataType::Float64], Volatility::Immutable); + Signature::exact(vec![DataType::Int64], Volatility::Immutable); Self { signature, test_state: test_state.into(), diff --git a/datafusion/datasource-arrow/NOTICE.txt b/datafusion/datasource-arrow/NOTICE.txt index 7f3c80d606c0..0bd2d52368fe 100644 --- a/datafusion/datasource-arrow/NOTICE.txt +++ b/datafusion/datasource-arrow/NOTICE.txt @@ -1,5 +1,5 @@ Apache DataFusion -Copyright 2019-2025 The Apache Software Foundation +Copyright 2019-2026 The Apache Software Foundation This product includes software developed at The Apache Software Foundation (http://www.apache.org/). diff --git a/datafusion/datasource-arrow/src/file_format.rs b/datafusion/datasource-arrow/src/file_format.rs index 9997d23d4c61..f60bce324993 100644 --- a/datafusion/datasource-arrow/src/file_format.rs +++ b/datafusion/datasource-arrow/src/file_format.rs @@ -63,7 +63,8 @@ use datafusion_session::Session; use futures::StreamExt; use futures::stream::BoxStream; use object_store::{ - GetOptions, GetRange, GetResultPayload, ObjectMeta, ObjectStore, path::Path, + GetOptions, GetRange, GetResultPayload, ObjectMeta, ObjectStore, ObjectStoreExt, + path::Path, }; use tokio::io::AsyncWriteExt; diff --git a/datafusion/datasource-arrow/src/mod.rs b/datafusion/datasource-arrow/src/mod.rs index cbfd7887093e..4816a45942e5 100644 --- a/datafusion/datasource-arrow/src/mod.rs +++ b/datafusion/datasource-arrow/src/mod.rs @@ -19,7 +19,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] -#![deny(clippy::allow_attributes)] //! [`ArrowFormat`]: Apache Arrow file format abstractions diff --git a/datafusion/datasource-arrow/src/source.rs b/datafusion/datasource-arrow/src/source.rs index 4c8fd5b3407b..99446cb87623 100644 --- a/datafusion/datasource-arrow/src/source.rs +++ b/datafusion/datasource-arrow/src/source.rs @@ -52,7 +52,7 @@ use datafusion_datasource::file_stream::FileOpenFuture; use datafusion_datasource::file_stream::FileOpener; use futures::StreamExt; use itertools::Itertools; -use object_store::{GetOptions, GetRange, GetResultPayload, ObjectStore}; +use object_store::{GetOptions, GetRange, GetResultPayload, ObjectStore, ObjectStoreExt}; /// Enum indicating which Arrow IPC format to use #[derive(Clone, Copy, Debug)] diff --git a/datafusion/datasource-avro/src/avro_to_arrow/schema.rs b/datafusion/datasource-avro/src/avro_to_arrow/schema.rs index 0e8f2a4d5608..053be3c9aff9 100644 --- a/datafusion/datasource-avro/src/avro_to_arrow/schema.rs +++ b/datafusion/datasource-avro/src/avro_to_arrow/schema.rs @@ -117,8 +117,8 @@ fn schema_to_field_with_props( .iter() .map(|s| schema_to_field_with_props(s, None, has_nullable, None)) .collect::>>()?; - let type_ids = 0_i8..fields.len() as i8; - DataType::Union(UnionFields::new(type_ids, fields), UnionMode::Dense) + // Assign type_ids based on the order in which they appear + DataType::Union(UnionFields::from_fields(fields), UnionMode::Dense) } } AvroSchema::Record(RecordSchema { fields, .. }) => { diff --git a/datafusion/datasource-avro/src/file_format.rs b/datafusion/datasource-avro/src/file_format.rs index 2447c032e700..c4960dbcc99b 100644 --- a/datafusion/datasource-avro/src/file_format.rs +++ b/datafusion/datasource-avro/src/file_format.rs @@ -41,7 +41,7 @@ use datafusion_physical_plan::ExecutionPlan; use datafusion_session::Session; use async_trait::async_trait; -use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; +use object_store::{GetResultPayload, ObjectMeta, ObjectStore, ObjectStoreExt}; #[derive(Default)] /// Factory struct used to create [`AvroFormat`] diff --git a/datafusion/datasource-avro/src/mod.rs b/datafusion/datasource-avro/src/mod.rs index 22c40e203a01..5ad209591e38 100644 --- a/datafusion/datasource-avro/src/mod.rs +++ b/datafusion/datasource-avro/src/mod.rs @@ -24,7 +24,6 @@ // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -#![deny(clippy::allow_attributes)] //! An [Avro](https://avro.apache.org/) based [`FileSource`](datafusion_datasource::file::FileSource) implementation and related functionality. diff --git a/datafusion/datasource-avro/src/source.rs b/datafusion/datasource-avro/src/source.rs index 1c466be266f1..bd9ff2a7a842 100644 --- a/datafusion/datasource-avro/src/source.rs +++ b/datafusion/datasource-avro/src/source.rs @@ -147,7 +147,7 @@ mod private { use bytes::Buf; use datafusion_datasource::{PartitionedFile, file_stream::FileOpenFuture}; use futures::StreamExt; - use object_store::{GetResultPayload, ObjectStore}; + use object_store::{GetResultPayload, ObjectStore, ObjectStoreExt}; pub struct AvroOpener { pub config: Arc, diff --git a/datafusion/datasource-csv/src/file_format.rs b/datafusion/datasource-csv/src/file_format.rs index efb7829179e0..7a253d81db9f 100644 --- a/datafusion/datasource-csv/src/file_format.rs +++ b/datafusion/datasource-csv/src/file_format.rs @@ -60,7 +60,9 @@ use bytes::{Buf, Bytes}; use datafusion_datasource::source::DataSourceExec; use futures::stream::BoxStream; use futures::{Stream, StreamExt, TryStreamExt, pin_mut}; -use object_store::{ObjectMeta, ObjectStore, delimited::newline_delimited_stream}; +use object_store::{ + ObjectMeta, ObjectStore, ObjectStoreExt, delimited::newline_delimited_stream, +}; use regex::Regex; #[derive(Default)] diff --git a/datafusion/datasource-csv/src/mod.rs b/datafusion/datasource-csv/src/mod.rs index d58ce1188550..fdfee05d86a7 100644 --- a/datafusion/datasource-csv/src/mod.rs +++ b/datafusion/datasource-csv/src/mod.rs @@ -19,7 +19,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] -#![deny(clippy::allow_attributes)] pub mod file_format; pub mod source; diff --git a/datafusion/datasource-json/Cargo.toml b/datafusion/datasource-json/Cargo.toml index 37fa8d43a081..bd0cead8d2af 100644 --- a/datafusion/datasource-json/Cargo.toml +++ b/datafusion/datasource-json/Cargo.toml @@ -44,7 +44,9 @@ datafusion-physical-plan = { workspace = true } datafusion-session = { workspace = true } futures = { workspace = true } object_store = { workspace = true } +serde_json = { workspace = true } tokio = { workspace = true } +tokio-stream = { workspace = true, features = ["sync"] } # Note: add additional linter rules in lib.rs. # Rust does not support workspace + new linter rules in subcrates yet diff --git a/datafusion/datasource-json/src/file_format.rs b/datafusion/datasource-json/src/file_format.rs index a14458b5acd3..8fe445705a21 100644 --- a/datafusion/datasource-json/src/file_format.rs +++ b/datafusion/datasource-json/src/file_format.rs @@ -15,13 +15,13 @@ // specific language governing permissions and limitations // under the License. -//! [`JsonFormat`]: Line delimited JSON [`FileFormat`] abstractions +//! [`JsonFormat`]: Line delimited and array JSON [`FileFormat`] abstractions use std::any::Any; use std::collections::HashMap; use std::fmt; use std::fmt::Debug; -use std::io::BufReader; +use std::io::{BufReader, Read}; use std::sync::Arc; use crate::source::JsonSource; @@ -31,6 +31,7 @@ use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::ArrowError; use arrow::json; use arrow::json::reader::{ValueIter, infer_json_schema_from_iterator}; +use bytes::{Buf, Bytes}; use datafusion_common::config::{ConfigField, ConfigFileType, JsonOptions}; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::{ @@ -48,6 +49,7 @@ use datafusion_datasource::file_format::{ use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig}; use datafusion_datasource::sink::{DataSink, DataSinkExec}; +use datafusion_datasource::source::DataSourceExec; use datafusion_datasource::write::BatchSerializer; use datafusion_datasource::write::demux::DemuxedStreamReceiver; use datafusion_datasource::write::orchestration::spawn_writer_tasks_and_join; @@ -57,10 +59,9 @@ use datafusion_physical_expr_common::sort_expr::LexRequirement; use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; use datafusion_session::Session; +use crate::utils::JsonArrayToNdjsonReader; use async_trait::async_trait; -use bytes::{Buf, Bytes}; -use datafusion_datasource::source::DataSourceExec; -use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; +use object_store::{GetResultPayload, ObjectMeta, ObjectStore, ObjectStoreExt}; #[derive(Default)] /// Factory struct used to create [JsonFormat] @@ -132,7 +133,26 @@ impl Debug for JsonFormatFactory { } } -/// New line delimited JSON `FileFormat` implementation. +/// JSON `FileFormat` implementation supporting both line-delimited and array formats. +/// +/// # Supported Formats +/// +/// ## Line-Delimited JSON (default, `newline_delimited = true`) +/// ```text +/// {"key1": 1, "key2": "val"} +/// {"key1": 2, "key2": "vals"} +/// ``` +/// +/// ## JSON Array Format (`newline_delimited = false`) +/// ```text +/// [ +/// {"key1": 1, "key2": "val"}, +/// {"key1": 2, "key2": "vals"} +/// ] +/// ``` +/// +/// Note: JSON array format is processed using streaming conversion, +/// which is memory-efficient even for large files. #[derive(Debug, Default)] pub struct JsonFormat { options: JsonOptions, @@ -166,6 +186,57 @@ impl JsonFormat { self.options.compression = file_compression_type.into(); self } + + /// Set whether to read as newline-delimited JSON (NDJSON). + /// + /// When `true` (default), expects newline-delimited format: + /// ```text + /// {"a": 1} + /// {"a": 2} + /// ``` + /// + /// When `false`, expects JSON array format: + /// ```text + /// [{"a": 1}, {"a": 2}] + /// ``` + pub fn with_newline_delimited(mut self, newline_delimited: bool) -> Self { + self.options.newline_delimited = newline_delimited; + self + } + + /// Returns whether this format expects newline-delimited JSON. + pub fn is_newline_delimited(&self) -> bool { + self.options.newline_delimited + } +} + +/// Infer schema from JSON array format using streaming conversion. +/// +/// This function converts JSON array format to NDJSON on-the-fly and uses +/// arrow-json's schema inference. It properly tracks the number of records +/// processed for correct `records_to_read` management. +/// +/// # Returns +/// A tuple of (Schema, records_consumed) where records_consumed is the +/// number of records that were processed for schema inference. +fn infer_schema_from_json_array( + reader: R, + max_records: usize, +) -> Result<(Schema, usize)> { + let ndjson_reader = JsonArrayToNdjsonReader::new(reader); + + let iter = ValueIter::new(ndjson_reader, None); + let mut count = 0; + + let schema = infer_json_schema_from_iterator(iter.take_while(|_| { + let should_take = count < max_records; + if should_take { + count += 1; + } + should_take + }))?; + + Ok((schema, count)) } #[async_trait] @@ -202,37 +273,67 @@ impl FileFormat for JsonFormat { .schema_infer_max_rec .unwrap_or(DEFAULT_SCHEMA_INFER_MAX_RECORD); let file_compression_type = FileCompressionType::from(self.options.compression); + let newline_delimited = self.options.newline_delimited; + for object in objects { - let mut take_while = || { - let should_take = records_to_read > 0; - if should_take { - records_to_read -= 1; - } - should_take - }; + // Early exit if we've read enough records + if records_to_read == 0 { + break; + } let r = store.as_ref().get(&object.location).await?; - let schema = match r.payload { + + let (schema, records_consumed) = match r.payload { #[cfg(not(target_arch = "wasm32"))] GetResultPayload::File(file, _) => { let decoder = file_compression_type.convert_read(file)?; - let mut reader = BufReader::new(decoder); - let iter = ValueIter::new(&mut reader, None); - infer_json_schema_from_iterator(iter.take_while(|_| take_while()))? + let reader = BufReader::new(decoder); + + if newline_delimited { + // NDJSON: use ValueIter directly + let iter = ValueIter::new(reader, None); + let mut count = 0; + let schema = + infer_json_schema_from_iterator(iter.take_while(|_| { + let should_take = count < records_to_read; + if should_take { + count += 1; + } + should_take + }))?; + (schema, count) + } else { + // JSON array format: use streaming converter + infer_schema_from_json_array(reader, records_to_read)? + } } GetResultPayload::Stream(_) => { let data = r.bytes().await?; let decoder = file_compression_type.convert_read(data.reader())?; - let mut reader = BufReader::new(decoder); - let iter = ValueIter::new(&mut reader, None); - infer_json_schema_from_iterator(iter.take_while(|_| take_while()))? + let reader = BufReader::new(decoder); + + if newline_delimited { + let iter = ValueIter::new(reader, None); + let mut count = 0; + let schema = + infer_json_schema_from_iterator(iter.take_while(|_| { + let should_take = count < records_to_read; + if should_take { + count += 1; + } + should_take + }))?; + (schema, count) + } else { + // JSON array format: use streaming converter + infer_schema_from_json_array(reader, records_to_read)? + } } }; schemas.push(schema); - if records_to_read == 0 { - break; - } + // Correctly decrement records_to_read + records_to_read = records_to_read.saturating_sub(records_consumed); } let schema = Schema::try_merge(schemas)?; @@ -281,7 +382,10 @@ impl FileFormat for JsonFormat { } fn file_source(&self, table_schema: TableSchema) -> Arc { - Arc::new(JsonSource::new(table_schema)) + Arc::new( + JsonSource::new(table_schema) + .with_newline_delimited(self.options.newline_delimited), + ) } } diff --git a/datafusion/datasource-json/src/mod.rs b/datafusion/datasource-json/src/mod.rs index 3d27d4cc5ef5..7dc0a0c7ba0f 100644 --- a/datafusion/datasource-json/src/mod.rs +++ b/datafusion/datasource-json/src/mod.rs @@ -19,9 +19,9 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] -#![deny(clippy::allow_attributes)] pub mod file_format; pub mod source; +pub mod utils; pub use file_format::*; diff --git a/datafusion/datasource-json/src/source.rs b/datafusion/datasource-json/src/source.rs index 5797054f11b9..52a38f49945c 100644 --- a/datafusion/datasource-json/src/source.rs +++ b/datafusion/datasource-json/src/source.rs @@ -15,17 +15,19 @@ // specific language governing permissions and limitations // under the License. -//! Execution plan for reading line-delimited JSON files +//! Execution plan for reading JSON files (line-delimited and array formats) use std::any::Any; use std::io::{BufReader, Read, Seek, SeekFrom}; +use std::pin::Pin; use std::sync::Arc; -use std::task::Poll; +use std::task::{Context, Poll}; use crate::file_format::JsonDecoder; +use crate::utils::{ChannelReader, JsonArrayToNdjsonReader}; use datafusion_common::error::{DataFusionError, Result}; -use datafusion_common_runtime::JoinSet; +use datafusion_common_runtime::{JoinSet, SpawnedTask}; use datafusion_datasource::decoder::{DecoderDeserializer, deserialize_stream}; use datafusion_datasource::file_compression_type::FileCompressionType; use datafusion_datasource::file_stream::{FileOpenFuture, FileOpener}; @@ -36,6 +38,7 @@ use datafusion_datasource::{ use datafusion_physical_plan::projection::ProjectionExprs; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; +use arrow::array::RecordBatch; use arrow::json::ReaderBuilder; use arrow::{datatypes::SchemaRef, json}; use datafusion_datasource::file::FileSource; @@ -43,10 +46,55 @@ use datafusion_datasource::file_scan_config::FileScanConfig; use datafusion_execution::TaskContext; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; -use futures::{StreamExt, TryStreamExt}; +use futures::{Stream, StreamExt, TryStreamExt}; use object_store::buffered::BufWriter; use object_store::{GetOptions, GetResultPayload, ObjectStore}; use tokio::io::AsyncWriteExt; +use tokio_stream::wrappers::ReceiverStream; + +/// Channel buffer size for streaming JSON array processing. +/// With ~128KB average chunk size, 128 chunks ≈ 16MB buffer. +const CHANNEL_BUFFER_SIZE: usize = 128; + +/// Buffer size for JsonArrayToNdjsonReader (2MB each, 4MB total for input+output) +const JSON_CONVERTER_BUFFER_SIZE: usize = 2 * 1024 * 1024; + +// ============================================================================ +// JsonArrayStream - Custom stream wrapper to hold SpawnedTask handles +// ============================================================================ + +/// A stream wrapper that holds SpawnedTask handles to keep them alive +/// until the stream is fully consumed or dropped. +/// +/// This ensures cancel-safety: when the stream is dropped, the tasks +/// are properly aborted via SpawnedTask's Drop implementation. +struct JsonArrayStream { + inner: ReceiverStream>, + /// Task that reads from object store and sends bytes to channel. + /// Kept alive until stream is consumed or dropped. + _read_task: SpawnedTask<()>, + /// Task that parses JSON and sends RecordBatches. + /// Kept alive until stream is consumed or dropped. + _parse_task: SpawnedTask<()>, +} + +impl Stream for JsonArrayStream { + type Item = std::result::Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_next(cx) + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} +// ============================================================================ +// JsonOpener and JsonSource +// ============================================================================ /// A [`FileOpener`] that opens a JSON file and yields a [`FileOpenFuture`] pub struct JsonOpener { @@ -54,21 +102,26 @@ pub struct JsonOpener { projected_schema: SchemaRef, file_compression_type: FileCompressionType, object_store: Arc, + /// When `true` (default), expects newline-delimited JSON (NDJSON). + /// When `false`, expects JSON array format `[{...}, {...}]`. + newline_delimited: bool, } impl JsonOpener { - /// Returns a [`JsonOpener`] + /// Returns a [`JsonOpener`] pub fn new( batch_size: usize, projected_schema: SchemaRef, file_compression_type: FileCompressionType, object_store: Arc, + newline_delimited: bool, ) -> Self { Self { batch_size, projected_schema, file_compression_type, object_store, + newline_delimited, } } } @@ -80,6 +133,9 @@ pub struct JsonSource { batch_size: Option, metrics: ExecutionPlanMetricsSet, projection: SplitProjection, + /// When `true` (default), expects newline-delimited JSON (NDJSON). + /// When `false`, expects JSON array format `[{...}, {...}]`. + newline_delimited: bool, } impl JsonSource { @@ -91,8 +147,18 @@ impl JsonSource { table_schema, batch_size: None, metrics: ExecutionPlanMetricsSet::new(), + newline_delimited: true, } } + + /// Set whether to read as newline-delimited JSON. + /// + /// When `true` (default), expects newline-delimited format. + /// When `false`, expects JSON array format `[{...}, {...}]`. + pub fn with_newline_delimited(mut self, newline_delimited: bool) -> Self { + self.newline_delimited = newline_delimited; + self + } } impl From for Arc { @@ -120,6 +186,7 @@ impl FileSource for JsonSource { projected_schema, file_compression_type: base_config.file_compression_type, object_store, + newline_delimited: self.newline_delimited, }) as Arc; // Wrap with ProjectionOpener @@ -172,7 +239,7 @@ impl FileSource for JsonSource { } impl FileOpener for JsonOpener { - /// Open a partitioned NDJSON file. + /// Open a partitioned JSON file. /// /// If `file_meta.range` is `None`, the entire file is opened. /// Else `file_meta.range` is `Some(FileRange{start, end})`, which corresponds to the byte range [start, end) within the file. @@ -181,11 +248,23 @@ impl FileOpener for JsonOpener { /// are applied to determine which lines to read: /// 1. The first line of the partition is the line in which the index of the first character >= `start`. /// 2. The last line of the partition is the line in which the byte at position `end - 1` resides. + /// + /// Note: JSON array format does not support range-based scanning. fn open(&self, partitioned_file: PartitionedFile) -> Result { let store = Arc::clone(&self.object_store); let schema = Arc::clone(&self.projected_schema); let batch_size = self.batch_size; let file_compression_type = self.file_compression_type.to_owned(); + let newline_delimited = self.newline_delimited; + + // JSON array format requires reading the complete file + if !newline_delimited && partitioned_file.range.is_some() { + return Err(DataFusionError::NotImplemented( + "JSON array format does not support range-based file scanning. \ + Disable repartition_file_scans or use newline-delimited JSON format." + .to_string(), + )); + } Ok(Box::pin(async move { let calculated_range = @@ -218,31 +297,150 @@ impl FileOpener for JsonOpener { Some(_) => { file.seek(SeekFrom::Start(result.range.start as _))?; let limit = result.range.end - result.range.start; - file_compression_type.convert_read(file.take(limit as u64))? + file_compression_type.convert_read(file.take(limit))? } }; - let reader = ReaderBuilder::new(schema) - .with_batch_size(batch_size) - .build(BufReader::new(bytes))?; - - Ok(futures::stream::iter(reader) - .map(|r| r.map_err(Into::into)) - .boxed()) + if newline_delimited { + // NDJSON: use BufReader directly + let reader = BufReader::new(bytes); + let arrow_reader = ReaderBuilder::new(schema) + .with_batch_size(batch_size) + .build(reader)?; + + Ok(futures::stream::iter(arrow_reader) + .map(|r| r.map_err(Into::into)) + .boxed()) + } else { + // JSON array format: wrap with streaming converter + let ndjson_reader = JsonArrayToNdjsonReader::with_capacity( + bytes, + JSON_CONVERTER_BUFFER_SIZE, + ); + let arrow_reader = ReaderBuilder::new(schema) + .with_batch_size(batch_size) + .build(ndjson_reader)?; + + Ok(futures::stream::iter(arrow_reader) + .map(|r| r.map_err(Into::into)) + .boxed()) + } } GetResultPayload::Stream(s) => { - let s = s.map_err(DataFusionError::from); - - let decoder = ReaderBuilder::new(schema) - .with_batch_size(batch_size) - .build_decoder()?; - let input = file_compression_type.convert_stream(s.boxed())?.fuse(); - - let stream = deserialize_stream( - input, - DecoderDeserializer::new(JsonDecoder::new(decoder)), - ); - Ok(stream.map_err(Into::into).boxed()) + if newline_delimited { + // Newline-delimited JSON (NDJSON) streaming reader + let s = s.map_err(DataFusionError::from); + let decoder = ReaderBuilder::new(schema) + .with_batch_size(batch_size) + .build_decoder()?; + let input = + file_compression_type.convert_stream(s.boxed())?.fuse(); + let stream = deserialize_stream( + input, + DecoderDeserializer::new(JsonDecoder::new(decoder)), + ); + Ok(stream.map_err(Into::into).boxed()) + } else { + // JSON array format: streaming conversion with channel-based byte transfer + // + // Architecture: + // 1. Async task reads from object store stream, decompresses, sends to channel + // 2. Blocking task receives bytes, converts JSON array to NDJSON, parses to Arrow + // 3. RecordBatches are sent back via another channel + // + // Memory budget (~32MB): + // - sync_channel: CHANNEL_BUFFER_SIZE chunks (~16MB) + // - JsonArrayToNdjsonReader: 2 × JSON_CONVERTER_BUFFER_SIZE (~4MB) + // - Arrow JsonReader internal buffer (~8MB) + // - Miscellaneous (~4MB) + + let s = s.map_err(DataFusionError::from); + let decompressed_stream = + file_compression_type.convert_stream(s.boxed())?; + + // Channel for bytes: async producer -> blocking consumer + // Uses tokio::sync::mpsc so the async send never blocks a + // tokio worker thread; the consumer calls blocking_recv() + // inside spawn_blocking. + let (byte_tx, byte_rx) = tokio::sync::mpsc::channel::( + CHANNEL_BUFFER_SIZE, + ); + + // Channel for results: sync producer -> async consumer + let (result_tx, result_rx) = tokio::sync::mpsc::channel(2); + let error_tx = result_tx.clone(); + + // Async task: read from object store stream and send bytes to channel + // Store the SpawnedTask to keep it alive until stream is dropped + let read_task = SpawnedTask::spawn(async move { + tokio::pin!(decompressed_stream); + while let Some(chunk) = decompressed_stream.next().await { + match chunk { + Ok(bytes) => { + if byte_tx.send(bytes).await.is_err() { + break; // Consumer dropped + } + } + Err(e) => { + let _ = error_tx + .send(Err( + arrow::error::ArrowError::ExternalError( + Box::new(e), + ), + )) + .await; + break; + } + } + } + // byte_tx dropped here, signals EOF to ChannelReader + }); + + // Blocking task: receive bytes from channel and parse JSON + // Store the SpawnedTask to keep it alive until stream is dropped + let parse_task = SpawnedTask::spawn_blocking(move || { + let channel_reader = ChannelReader::new(byte_rx); + let mut ndjson_reader = + JsonArrayToNdjsonReader::with_capacity( + channel_reader, + JSON_CONVERTER_BUFFER_SIZE, + ); + + match ReaderBuilder::new(schema) + .with_batch_size(batch_size) + .build(&mut ndjson_reader) + { + Ok(arrow_reader) => { + for batch_result in arrow_reader { + if result_tx.blocking_send(batch_result).is_err() + { + break; // Receiver dropped + } + } + } + Err(e) => { + let _ = result_tx.blocking_send(Err(e)); + } + } + + // Validate the JSON array was properly formed + if let Err(e) = ndjson_reader.validate_complete() { + let _ = result_tx.blocking_send(Err( + arrow::error::ArrowError::JsonError(e.to_string()), + )); + } + // result_tx dropped here, closes the stream + }); + + // Wrap in JsonArrayStream to keep tasks alive until stream is consumed + let stream = JsonArrayStream { + inner: ReceiverStream::new(result_rx), + _read_task: read_task, + _parse_task: parse_task, + }; + + Ok(stream.map(|r| r.map_err(Into::into)).boxed()) + } } } })) @@ -303,3 +501,307 @@ pub async fn plan_to_json( Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::{DataType, Field, Schema}; + use bytes::Bytes; + use datafusion_datasource::FileRange; + use futures::TryStreamExt; + use object_store::memory::InMemory; + use object_store::path::Path; + use object_store::{ObjectStoreExt, PutPayload}; + + /// Helper to create a test schema + fn test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, true), + Field::new("name", DataType::Utf8, true), + ])) + } + + #[tokio::test] + async fn test_json_array_from_file() -> Result<()> { + // Test reading JSON array format from a file + let json_data = r#"[{"id": 1, "name": "alice"}, {"id": 2, "name": "bob"}]"#; + + let store = Arc::new(InMemory::new()); + let path = Path::from("test.json"); + store + .put(&path, PutPayload::from_static(json_data.as_bytes())) + .await?; + + let opener = JsonOpener::new( + 1024, + test_schema(), + FileCompressionType::UNCOMPRESSED, + store.clone(), + false, // JSON array format + ); + + let meta = store.head(&path).await?; + let file = PartitionedFile::new(path.to_string(), meta.size); + + let stream = opener.open(file)?.await?; + let batches: Vec<_> = stream.try_collect().await?; + + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].num_rows(), 2); + + Ok(()) + } + + #[tokio::test] + async fn test_json_array_from_stream() -> Result<()> { + // Test reading JSON array format from object store stream (simulates S3) + let json_data = r#"[{"id": 1, "name": "alice"}, {"id": 2, "name": "bob"}, {"id": 3, "name": "charlie"}]"#; + + // Use InMemory store which returns Stream payload + let store = Arc::new(InMemory::new()); + let path = Path::from("test_stream.json"); + store + .put(&path, PutPayload::from_static(json_data.as_bytes())) + .await?; + + let opener = JsonOpener::new( + 2, // small batch size to test multiple batches + test_schema(), + FileCompressionType::UNCOMPRESSED, + store.clone(), + false, // JSON array format + ); + + let meta = store.head(&path).await?; + let file = PartitionedFile::new(path.to_string(), meta.size); + + let stream = opener.open(file)?.await?; + let batches: Vec<_> = stream.try_collect().await?; + + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 3); + + Ok(()) + } + + #[tokio::test] + async fn test_json_array_nested_objects() -> Result<()> { + // Test JSON array with nested objects and arrays + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, true), + Field::new("data", DataType::Utf8, true), + ])); + + let json_data = r#"[ + {"id": 1, "data": "{\"nested\": true}"}, + {"id": 2, "data": "[1, 2, 3]"} + ]"#; + + let store = Arc::new(InMemory::new()); + let path = Path::from("nested.json"); + store + .put(&path, PutPayload::from_static(json_data.as_bytes())) + .await?; + + let opener = JsonOpener::new( + 1024, + schema, + FileCompressionType::UNCOMPRESSED, + store.clone(), + false, + ); + + let meta = store.head(&path).await?; + let file = PartitionedFile::new(path.to_string(), meta.size); + + let stream = opener.open(file)?.await?; + let batches: Vec<_> = stream.try_collect().await?; + + assert_eq!(batches[0].num_rows(), 2); + + Ok(()) + } + + #[tokio::test] + async fn test_json_array_empty() -> Result<()> { + // Test empty JSON array + let json_data = "[]"; + + let store = Arc::new(InMemory::new()); + let path = Path::from("empty.json"); + store + .put(&path, PutPayload::from_static(json_data.as_bytes())) + .await?; + + let opener = JsonOpener::new( + 1024, + test_schema(), + FileCompressionType::UNCOMPRESSED, + store.clone(), + false, + ); + + let meta = store.head(&path).await?; + let file = PartitionedFile::new(path.to_string(), meta.size); + + let stream = opener.open(file)?.await?; + let batches: Vec<_> = stream.try_collect().await?; + + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 0); + + Ok(()) + } + + #[tokio::test] + async fn test_json_array_range_not_supported() { + // Test that range-based scanning returns error for JSON array format + let store = Arc::new(InMemory::new()); + let path = Path::from("test.json"); + store + .put(&path, PutPayload::from_static(b"[]")) + .await + .unwrap(); + + let opener = JsonOpener::new( + 1024, + test_schema(), + FileCompressionType::UNCOMPRESSED, + store.clone(), + false, // JSON array format + ); + + let meta = store.head(&path).await.unwrap(); + let mut file = PartitionedFile::new(path.to_string(), meta.size); + file.range = Some(FileRange { start: 0, end: 10 }); + + let result = opener.open(file); + match result { + Ok(_) => panic!("Expected error for range-based JSON array scanning"), + Err(e) => { + assert!( + e.to_string().contains("does not support range-based"), + "Unexpected error message: {e}" + ); + } + } + } + + #[tokio::test] + async fn test_ndjson_still_works() -> Result<()> { + // Ensure NDJSON format still works correctly + let json_data = + "{\"id\": 1, \"name\": \"alice\"}\n{\"id\": 2, \"name\": \"bob\"}\n"; + + let store = Arc::new(InMemory::new()); + let path = Path::from("test.ndjson"); + store + .put(&path, PutPayload::from_static(json_data.as_bytes())) + .await?; + + let opener = JsonOpener::new( + 1024, + test_schema(), + FileCompressionType::UNCOMPRESSED, + store.clone(), + true, // NDJSON format + ); + + let meta = store.head(&path).await?; + let file = PartitionedFile::new(path.to_string(), meta.size); + + let stream = opener.open(file)?.await?; + let batches: Vec<_> = stream.try_collect().await?; + + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].num_rows(), 2); + + Ok(()) + } + + #[tokio::test] + async fn test_json_array_large_file() -> Result<()> { + // Test with a larger JSON array to verify streaming works + let mut json_data = String::from("["); + for i in 0..1000 { + if i > 0 { + json_data.push(','); + } + json_data.push_str(&format!(r#"{{"id": {i}, "name": "user{i}"}}"#)); + } + json_data.push(']'); + + let store = Arc::new(InMemory::new()); + let path = Path::from("large.json"); + store + .put(&path, PutPayload::from(Bytes::from(json_data))) + .await?; + + let opener = JsonOpener::new( + 100, // batch size of 100 + test_schema(), + FileCompressionType::UNCOMPRESSED, + store.clone(), + false, + ); + + let meta = store.head(&path).await?; + let file = PartitionedFile::new(path.to_string(), meta.size); + + let stream = opener.open(file)?.await?; + let batches: Vec<_> = stream.try_collect().await?; + + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 1000); + + // Should have multiple batches due to batch_size=100 + assert!(batches.len() >= 10); + + Ok(()) + } + + #[tokio::test] + async fn test_json_array_stream_cancellation() -> Result<()> { + // Test that cancellation works correctly (tasks are aborted when stream is dropped) + let mut json_data = String::from("["); + for i in 0..10000 { + if i > 0 { + json_data.push(','); + } + json_data.push_str(&format!(r#"{{"id": {i}, "name": "user{i}"}}"#)); + } + json_data.push(']'); + + let store = Arc::new(InMemory::new()); + let path = Path::from("cancel_test.json"); + store + .put(&path, PutPayload::from(Bytes::from(json_data))) + .await?; + + let opener = JsonOpener::new( + 10, // small batch size + test_schema(), + FileCompressionType::UNCOMPRESSED, + store.clone(), + false, + ); + + let meta = store.head(&path).await?; + let file = PartitionedFile::new(path.to_string(), meta.size); + + let mut stream = opener.open(file)?.await?; + + // Read only first batch, then drop the stream (simulating cancellation) + let first_batch = stream.next().await; + assert!(first_batch.is_some()); + + // Drop the stream - this should abort the spawned tasks via SpawnedTask's Drop + drop(stream); + + // Give tasks time to be aborted + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // If we reach here without hanging, cancellation worked + Ok(()) + } +} diff --git a/datafusion/datasource-json/src/utils.rs b/datafusion/datasource-json/src/utils.rs new file mode 100644 index 000000000000..bc75799edff7 --- /dev/null +++ b/datafusion/datasource-json/src/utils.rs @@ -0,0 +1,778 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utility types for JSON processing + +use std::io::{BufRead, Read}; + +use bytes::Bytes; + +// ============================================================================ +// JsonArrayToNdjsonReader - Streaming JSON Array to NDJSON Converter +// ============================================================================ +// +// Architecture: +// +// ```text +// ┌─────────────────────────────────────────────────────────────┐ +// │ JSON Array File (potentially very large, e.g. 33GB) │ +// │ [{"a":1}, {"a":2}, {"a":3}, ...... {"a":1000000}] │ +// └─────────────────────────────────────────────────────────────┘ +// │ +// ▼ read chunks via ChannelReader +// ┌───────────────────┐ +// │ JsonArrayToNdjson │ ← character substitution only: +// │ Reader │ '[' skip, ',' → '\n', ']' stop +// └───────────────────┘ +// │ +// ▼ outputs NDJSON format +// ┌───────────────────┐ +// │ Arrow Reader │ ← internal buffer, batch parsing +// │ batch_size=8192 │ +// └───────────────────┘ +// │ +// ▼ outputs RecordBatch +// ┌───────────────────┐ +// │ RecordBatch │ +// └───────────────────┘ +// ``` +// +// Memory Efficiency: +// +// | Approach | Memory for 33GB file | Parse count | +// |---------------------------------------|----------------------|-------------| +// | Load entire file + serde_json | ~100GB+ | 3x | +// | Streaming with JsonArrayToNdjsonReader| ~32MB (configurable) | 1x | +// +// Design Note: +// +// This implementation uses `inner: R` directly (not `BufReader`) and manages +// its own input buffer. This is critical for compatibility with `SyncIoBridge` +// and `ChannelReader` in `spawn_blocking` contexts. +// + +/// Default buffer size for JsonArrayToNdjsonReader (2MB for better throughput) +const DEFAULT_BUF_SIZE: usize = 2 * 1024 * 1024; + +/// Parser state for JSON array streaming +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum JsonArrayState { + /// Initial state, looking for opening '[' + Start, + /// Inside the JSON array, processing objects + InArray, + /// Reached the closing ']', finished + Done, +} + +/// A streaming reader that converts JSON array format to NDJSON format. +/// +/// This reader wraps an underlying reader containing JSON array data +/// `[{...}, {...}, ...]` and transforms it on-the-fly to newline-delimited +/// JSON format that Arrow's JSON reader can process. +/// +/// Implements both `Read` and `BufRead` traits for compatibility with Arrow's +/// `ReaderBuilder::build()` which requires `BufRead`. +/// +/// # Transformation Rules +/// +/// - Skip leading `[` and whitespace before it +/// - Convert top-level `,` (between objects) to `\n` +/// - Skip whitespace at top level (between objects) +/// - Stop at trailing `]` +/// - Preserve everything inside objects (including nested `[`, `]`, `,`) +/// - Properly handle strings (ignore special chars inside quotes) +/// +/// # Example +/// +/// ```text +/// Input: [{"a":1}, {"b":[1,2]}, {"c":"x,y"}] +/// Output: {"a":1} +/// {"b":[1,2]} +/// {"c":"x,y"} +/// ``` +pub struct JsonArrayToNdjsonReader { + /// Inner reader - we use R directly (not `BufReader`) for SyncIoBridge compatibility + inner: R, + state: JsonArrayState, + /// Tracks nesting depth of `{` and `[` to identify top-level commas + depth: i32, + /// Whether we're currently inside a JSON string + in_string: bool, + /// Whether the next character is escaped (after `\`) + escape_next: bool, + /// Input buffer - stores raw bytes read from inner reader + input_buffer: Vec, + /// Current read position in input buffer + input_pos: usize, + /// Number of valid bytes in input buffer + input_filled: usize, + /// Output buffer - stores transformed NDJSON bytes + output_buffer: Vec, + /// Current read position in output buffer + output_pos: usize, + /// Number of valid bytes in output buffer + output_filled: usize, + /// Whether trailing non-whitespace content was detected after ']' + has_trailing_content: bool, + /// Whether leading non-whitespace content was detected before '[' + has_leading_content: bool, +} + +impl JsonArrayToNdjsonReader { + /// Create a new streaming reader that converts JSON array to NDJSON. + pub fn new(reader: R) -> Self { + Self::with_capacity(reader, DEFAULT_BUF_SIZE) + } + + /// Create a new streaming reader with custom buffer size. + /// + /// Larger buffers improve throughput but use more memory. + /// Total memory usage is approximately 2 * capacity (input + output buffers). + pub fn with_capacity(reader: R, capacity: usize) -> Self { + Self { + inner: reader, + state: JsonArrayState::Start, + depth: 0, + in_string: false, + escape_next: false, + input_buffer: vec![0; capacity], + input_pos: 0, + input_filled: 0, + output_buffer: vec![0; capacity], + output_pos: 0, + output_filled: 0, + has_trailing_content: false, + has_leading_content: false, + } + } + + /// Check if the JSON array was properly terminated. + /// + /// This should be called after all data has been read. + /// + /// Returns an error if: + /// - Unbalanced braces/brackets (depth != 0) + /// - Unterminated string + /// - Missing closing `]` + /// - Unexpected trailing content after `]` + pub fn validate_complete(&self) -> std::io::Result<()> { + if self.has_leading_content { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Malformed JSON: unexpected leading content before '['", + )); + } + if self.depth != 0 { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Malformed JSON array: unbalanced braces or brackets", + )); + } + if self.in_string { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Malformed JSON array: unterminated string", + )); + } + if self.state != JsonArrayState::Done { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Incomplete JSON array: expected closing bracket ']'", + )); + } + if self.has_trailing_content { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Malformed JSON: unexpected trailing content after ']'", + )); + } + Ok(()) + } + + /// Process a single byte and return the transformed byte (if any) + #[inline] + fn process_byte(&mut self, byte: u8) -> Option { + match self.state { + JsonArrayState::Start => { + // Looking for the opening '[', skip whitespace + if byte == b'[' { + self.state = JsonArrayState::InArray; + } else if !byte.is_ascii_whitespace() { + self.has_leading_content = true; + } + None + } + JsonArrayState::InArray => { + // Handle escape sequences in strings + if self.escape_next { + self.escape_next = false; + return Some(byte); + } + + if self.in_string { + // Inside a string: handle escape and closing quote + match byte { + b'\\' => self.escape_next = true, + b'"' => self.in_string = false, + _ => {} + } + Some(byte) + } else { + // Outside strings: track depth and transform + match byte { + b'"' => { + self.in_string = true; + Some(byte) + } + b'{' | b'[' => { + self.depth += 1; + Some(byte) + } + b'}' => { + self.depth -= 1; + Some(byte) + } + b']' => { + if self.depth == 0 { + // Top-level ']' means end of array + self.state = JsonArrayState::Done; + None + } else { + // Nested ']' inside an object + self.depth -= 1; + Some(byte) + } + } + b',' if self.depth == 0 => { + // Top-level comma between objects → newline + Some(b'\n') + } + _ => { + // At depth 0, skip whitespace between objects + if self.depth == 0 && byte.is_ascii_whitespace() { + None + } else { + Some(byte) + } + } + } + } + } + JsonArrayState::Done => { + // After ']', check for non-whitespace trailing content + if !byte.is_ascii_whitespace() { + self.has_trailing_content = true; + } + None + } + } + } + + /// Refill input buffer from inner reader if needed. + /// Returns true if there's data available, false on EOF. + fn refill_input_if_needed(&mut self) -> std::io::Result { + if self.input_pos >= self.input_filled { + // Input buffer exhausted, read more from inner + let bytes_read = self.inner.read(&mut self.input_buffer)?; + if bytes_read == 0 { + return Ok(false); // EOF + } + self.input_pos = 0; + self.input_filled = bytes_read; + } + Ok(true) + } + + /// Fill the output buffer with transformed data. + /// + /// This method manages its own input buffer, reading from the inner reader + /// as needed. When the output buffer is full, we stop processing but preserve + /// the current position in the input buffer for the next call. + fn fill_output_buffer(&mut self) -> std::io::Result<()> { + let mut write_pos = 0; + + while write_pos < self.output_buffer.len() { + // Refill input buffer if exhausted + if !self.refill_input_if_needed()? { + break; // EOF + } + + // Process bytes from input buffer + while self.input_pos < self.input_filled + && write_pos < self.output_buffer.len() + { + let byte = self.input_buffer[self.input_pos]; + self.input_pos += 1; + + if let Some(transformed) = self.process_byte(byte) { + self.output_buffer[write_pos] = transformed; + write_pos += 1; + } + } + } + + self.output_pos = 0; + self.output_filled = write_pos; + Ok(()) + } +} + +impl Read for JsonArrayToNdjsonReader { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + // If output buffer is empty, fill it + if self.output_pos >= self.output_filled { + self.fill_output_buffer()?; + if self.output_filled == 0 { + return Ok(0); // EOF + } + } + + // Copy from output buffer to caller's buffer + let available = self.output_filled - self.output_pos; + let to_copy = std::cmp::min(available, buf.len()); + buf[..to_copy].copy_from_slice( + &self.output_buffer[self.output_pos..self.output_pos + to_copy], + ); + self.output_pos += to_copy; + Ok(to_copy) + } +} + +impl BufRead for JsonArrayToNdjsonReader { + fn fill_buf(&mut self) -> std::io::Result<&[u8]> { + if self.output_pos >= self.output_filled { + self.fill_output_buffer()?; + } + Ok(&self.output_buffer[self.output_pos..self.output_filled]) + } + + fn consume(&mut self, amt: usize) { + self.output_pos = std::cmp::min(self.output_pos + amt, self.output_filled); + } +} + +// ============================================================================ +// ChannelReader - Sync reader that receives bytes from async channel +// ============================================================================ +// +// Architecture: +// +// ```text +// ┌─────────────────────────────────────────────────────────────────────────┐ +// │ S3 / MinIO (async) │ +// │ (33GB JSON Array File) │ +// └─────────────────────────────────────────────────────────────────────────┘ +// │ +// ▼ async stream (Bytes chunks) +// ┌─────────────────────────────────────────────────────────────────────────┐ +// │ Async Task (tokio runtime) │ +// │ while let Some(chunk) = stream.next().await │ +// │ byte_tx.send(chunk) │ +// └─────────────────────────────────────────────────────────────────────────┘ +// │ +// ▼ tokio::sync::mpsc::channel +// │ (bounded, ~32MB buffer) +// ▼ +// ┌─────────────────────────────────────────────────────────────────────────┐ +// │ Blocking Task (spawn_blocking) │ +// │ ┌──────────────┐ ┌────────────────────────┐ ┌──────────────────┐ │ +// │ │ChannelReader │ → │JsonArrayToNdjsonReader │ → │ Arrow JsonReader │ │ +// │ │ (Read) │ │ [{},...] → {}\n{} │ │ (RecordBatch) │ │ +// │ └──────────────┘ └────────────────────────┘ └──────────────────┘ │ +// └─────────────────────────────────────────────────────────────────────────┘ +// │ +// ▼ tokio::sync::mpsc::channel +// ┌─────────────────────────────────────────────────────────────────────────┐ +// │ ReceiverStream (async) │ +// │ → DataFusion execution engine │ +// └─────────────────────────────────────────────────────────────────────────┘ +// ``` +// +// Memory Budget (~32MB total): +// - sync_channel buffer: 128 chunks × ~128KB = ~16MB +// - JsonArrayToNdjsonReader: 2 × 2MB = 4MB +// - Arrow JsonReader internal: ~8MB +// - Miscellaneous: ~4MB +// + +/// A synchronous `Read` implementation that receives bytes from an async channel. +/// +/// This enables true streaming between async and sync contexts without +/// loading the entire file into memory. Uses `tokio::sync::mpsc::Receiver` +/// with `blocking_recv()` so the async producer never blocks a tokio worker +/// thread, while the sync consumer (running in `spawn_blocking`) safely blocks. +pub struct ChannelReader { + rx: tokio::sync::mpsc::Receiver, + current: Option, + pos: usize, +} + +impl ChannelReader { + /// Create a new ChannelReader from a tokio mpsc receiver. + pub fn new(rx: tokio::sync::mpsc::Receiver) -> Self { + Self { + rx, + current: None, + pos: 0, + } + } +} + +impl Read for ChannelReader { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + loop { + // If we have current chunk with remaining data, read from it + if let Some(ref chunk) = self.current { + let remaining = chunk.len() - self.pos; + if remaining > 0 { + let to_copy = std::cmp::min(remaining, buf.len()); + buf[..to_copy].copy_from_slice(&chunk[self.pos..self.pos + to_copy]); + self.pos += to_copy; + return Ok(to_copy); + } + } + + // Current chunk exhausted, get next from channel + match self.rx.blocking_recv() { + Some(bytes) => { + self.current = Some(bytes); + self.pos = 0; + // Loop back to read from new chunk + } + None => return Ok(0), // Channel closed = EOF + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_json_array_to_ndjson_simple() { + let input = r#"[{"a":1}, {"a":2}, {"a":3}]"#; + let mut reader = JsonArrayToNdjsonReader::new(input.as_bytes()); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + assert_eq!(output, "{\"a\":1}\n{\"a\":2}\n{\"a\":3}"); + } + + #[test] + fn test_json_array_to_ndjson_nested() { + let input = r#"[{"a":{"b":1}}, {"c":[1,2,3]}]"#; + let mut reader = JsonArrayToNdjsonReader::new(input.as_bytes()); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + assert_eq!(output, "{\"a\":{\"b\":1}}\n{\"c\":[1,2,3]}"); + } + + #[test] + fn test_json_array_to_ndjson_strings_with_special_chars() { + let input = r#"[{"a":"[1,2]"}, {"b":"x,y"}]"#; + let mut reader = JsonArrayToNdjsonReader::new(input.as_bytes()); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + assert_eq!(output, "{\"a\":\"[1,2]\"}\n{\"b\":\"x,y\"}"); + } + + #[test] + fn test_json_array_to_ndjson_escaped_quotes() { + let input = r#"[{"a":"say \"hello\""}, {"b":1}]"#; + let mut reader = JsonArrayToNdjsonReader::new(input.as_bytes()); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + assert_eq!(output, "{\"a\":\"say \\\"hello\\\"\"}\n{\"b\":1}"); + } + + #[test] + fn test_json_array_to_ndjson_empty() { + let input = r#"[]"#; + let mut reader = JsonArrayToNdjsonReader::new(input.as_bytes()); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + assert_eq!(output, ""); + } + + #[test] + fn test_json_array_to_ndjson_single_element() { + let input = r#"[{"a":1}]"#; + let mut reader = JsonArrayToNdjsonReader::new(input.as_bytes()); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + assert_eq!(output, "{\"a\":1}"); + } + + #[test] + fn test_json_array_to_ndjson_bufread() { + let input = r#"[{"a":1}, {"a":2}]"#; + let mut reader = JsonArrayToNdjsonReader::new(input.as_bytes()); + + let buf = reader.fill_buf().unwrap(); + assert!(!buf.is_empty()); + + let first_len = buf.len(); + reader.consume(first_len); + + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + } + + #[test] + fn test_json_array_to_ndjson_whitespace() { + let input = r#" [ {"a":1} , {"a":2} ] "#; + let mut reader = JsonArrayToNdjsonReader::new(input.as_bytes()); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + // Top-level whitespace is skipped, internal whitespace preserved + assert_eq!(output, "{\"a\":1}\n{\"a\":2}"); + } + + #[test] + fn test_validate_complete_valid_json() { + let valid_json = r#"[{"a":1},{"a":2}]"#; + let mut reader = JsonArrayToNdjsonReader::new(valid_json.as_bytes()); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + reader.validate_complete().unwrap(); + } + + #[test] + fn test_json_array_with_trailing_junk() { + let input = r#" [ {"a":1} , {"a":2} ] some { junk [ here ] "#; + let mut reader = JsonArrayToNdjsonReader::new(input.as_bytes()); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + + // Should extract the valid array content + assert_eq!(output, "{\"a\":1}\n{\"a\":2}"); + + // But validation should catch the trailing junk + let result = reader.validate_complete(); + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("trailing content") + || err_msg.contains("Unexpected trailing"), + "Expected trailing content error, got: {err_msg}" + ); + } + + #[test] + fn test_validate_complete_incomplete_array() { + let invalid_json = r#"[{"a":1},{"a":2}"#; // Missing closing ] + let mut reader = JsonArrayToNdjsonReader::new(invalid_json.as_bytes()); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + + let result = reader.validate_complete(); + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("expected closing bracket") + || err_msg.contains("missing closing"), + "Expected missing bracket error, got: {err_msg}" + ); + } + + #[test] + fn test_validate_complete_unbalanced_braces() { + let invalid_json = r#"[{"a":1},{"a":2]"#; // Wrong closing bracket + let mut reader = JsonArrayToNdjsonReader::new(invalid_json.as_bytes()); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + + let result = reader.validate_complete(); + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("unbalanced") + || err_msg.contains("expected closing bracket"), + "Expected unbalanced or missing bracket error, got: {err_msg}" + ); + } + + #[test] + fn test_json_array_with_leading_junk() { + let input = r#"junk[{"a":1}, {"a":2}]"#; + let mut reader = JsonArrayToNdjsonReader::new(input.as_bytes()); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + + // Should still extract the valid array content + assert_eq!(output, "{\"a\":1}\n{\"a\":2}"); + + // But validation should catch the leading junk + let result = reader.validate_complete(); + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("leading content"), + "Expected leading content error, got: {err_msg}" + ); + } + + #[test] + fn test_json_array_with_leading_whitespace_ok() { + let input = r#" + [{"a":1}, {"a":2}]"#; + let mut reader = JsonArrayToNdjsonReader::new(input.as_bytes()); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + assert_eq!(output, "{\"a\":1}\n{\"a\":2}"); + + // Leading whitespace should be fine + reader.validate_complete().unwrap(); + } + + #[test] + fn test_validate_complete_valid_with_trailing_whitespace() { + let input = r#"[{"a":1},{"a":2}] + "#; // Trailing whitespace is OK + let mut reader = JsonArrayToNdjsonReader::new(input.as_bytes()); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + + // Whitespace after ] should be allowed + reader.validate_complete().unwrap(); + } + + /// Test that data is not lost at buffer boundaries. + /// + /// This test creates input larger than the internal buffer to verify + /// that newline characters are not dropped when they occur at buffer boundaries. + #[test] + fn test_buffer_boundary_no_data_loss() { + // Create objects ~9KB each, so 10 objects = ~90KB + let large_value = "x".repeat(9000); + + let mut objects = vec![]; + for i in 0..10 { + objects.push(format!(r#"{{"id":{i},"data":"{large_value}"}}"#)); + } + + let input = format!("[{}]", objects.join(",")); + + // Use small buffer to force multiple fill cycles + let mut reader = JsonArrayToNdjsonReader::with_capacity(input.as_bytes(), 8192); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + + // Verify correct number of newlines (9 newlines separate 10 objects) + let newline_count = output.matches('\n').count(); + assert_eq!( + newline_count, 9, + "Expected 9 newlines separating 10 objects, got {newline_count}" + ); + + // Verify each line is valid JSON + for (i, line) in output.lines().enumerate() { + let parsed: Result = serde_json::from_str(line); + assert!( + parsed.is_ok(), + "Line {} is not valid JSON: {}...", + i, + &line[..100.min(line.len())] + ); + + // Verify the id field matches expected value + let value = parsed.unwrap(); + assert_eq!( + value["id"].as_i64(), + Some(i as i64), + "Object {i} has wrong id" + ); + } + } + + /// Test with real-world-like data format (with leading whitespace and newlines) + #[test] + fn test_real_world_format_large() { + let large_value = "x".repeat(8000); + + // Format similar to real files: opening bracket on its own line, + // each object indented with 2 spaces + let mut objects = vec![]; + for i in 0..10 { + objects.push(format!(r#" {{"id":{i},"data":"{large_value}"}}"#)); + } + + let input = format!("[\n{}\n]", objects.join(",\n")); + + let mut reader = JsonArrayToNdjsonReader::with_capacity(input.as_bytes(), 8192); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + + let lines: Vec<&str> = output.lines().collect(); + assert_eq!(lines.len(), 10, "Expected 10 objects"); + + for (i, line) in lines.iter().enumerate() { + assert!( + line.starts_with("{\"id\""), + "Line {} should start with object, got: {}...", + i, + &line[..50.min(line.len())] + ); + } + } + + /// Test ChannelReader + #[test] + fn test_channel_reader() { + let (tx, rx) = tokio::sync::mpsc::channel(4); + + // Send some chunks (try_send is non-async) + tx.try_send(Bytes::from("Hello, ")).unwrap(); + tx.try_send(Bytes::from("World!")).unwrap(); + drop(tx); // Close channel + + let mut reader = ChannelReader::new(rx); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + + assert_eq!(output, "Hello, World!"); + } + + /// Test ChannelReader with small reads + #[test] + fn test_channel_reader_small_reads() { + let (tx, rx) = tokio::sync::mpsc::channel(4); + + tx.try_send(Bytes::from("ABCDEFGHIJ")).unwrap(); + drop(tx); + + let mut reader = ChannelReader::new(rx); + let mut buf = [0u8; 3]; + + // Read in small chunks + assert_eq!(reader.read(&mut buf).unwrap(), 3); + assert_eq!(&buf, b"ABC"); + + assert_eq!(reader.read(&mut buf).unwrap(), 3); + assert_eq!(&buf, b"DEF"); + + assert_eq!(reader.read(&mut buf).unwrap(), 3); + assert_eq!(&buf, b"GHI"); + + assert_eq!(reader.read(&mut buf).unwrap(), 1); + assert_eq!(&buf[..1], b"J"); + + // EOF + assert_eq!(reader.read(&mut buf).unwrap(), 0); + } +} diff --git a/datafusion/datasource-parquet/Cargo.toml b/datafusion/datasource-parquet/Cargo.toml index a5f6f56ac6f3..b865422366f4 100644 --- a/datafusion/datasource-parquet/Cargo.toml +++ b/datafusion/datasource-parquet/Cargo.toml @@ -56,6 +56,9 @@ tokio = { workspace = true } [dev-dependencies] chrono = { workspace = true } +criterion = { workspace = true } +datafusion-functions-nested = { workspace = true } +tempfile = { workspace = true } # Note: add additional linter rules in lib.rs. # Rust does not support workspace + new linter rules in subcrates yet @@ -73,3 +76,7 @@ parquet_encryption = [ "datafusion-common/parquet_encryption", "datafusion-execution/parquet_encryption", ] + +[[bench]] +name = "parquet_nested_filter_pushdown" +harness = false diff --git a/datafusion/datasource-parquet/benches/parquet_nested_filter_pushdown.rs b/datafusion/datasource-parquet/benches/parquet_nested_filter_pushdown.rs new file mode 100644 index 000000000000..02137b5a1d28 --- /dev/null +++ b/datafusion/datasource-parquet/benches/parquet_nested_filter_pushdown.rs @@ -0,0 +1,238 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::path::{Path, PathBuf}; +use std::sync::{Arc, LazyLock}; + +use arrow::array::{ + BinaryBuilder, BooleanArray, ListBuilder, RecordBatch, StringBuilder, +}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use criterion::{Criterion, Throughput, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_datasource_parquet::{ParquetFileMetrics, build_row_filter}; +use datafusion_expr::{Expr, col}; +use datafusion_functions_nested::expr_fn::array_has; +use datafusion_physical_expr::planner::logical2physical; +use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; +use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; +use parquet::arrow::{ArrowWriter, ProjectionMask}; +use parquet::file::properties::WriterProperties; +use tempfile::TempDir; + +const ROW_GROUP_ROW_COUNT: usize = 10_000; +const TOTAL_ROW_GROUPS: usize = 10; +const TOTAL_ROWS: usize = ROW_GROUP_ROW_COUNT * TOTAL_ROW_GROUPS; +const TARGET_VALUE: &str = "target_value"; +const COLUMN_NAME: &str = "list_col"; +const PAYLOAD_COLUMN_NAME: &str = "payload"; +// Large binary payload to emphasize decoding overhead when pushdown is disabled. +const PAYLOAD_BYTES: usize = 8 * 1024; + +struct BenchmarkDataset { + _tempdir: TempDir, + file_path: PathBuf, +} + +impl BenchmarkDataset { + fn path(&self) -> &Path { + &self.file_path + } +} + +static DATASET: LazyLock = LazyLock::new(|| { + create_dataset().expect("failed to prepare parquet benchmark dataset") +}); + +fn parquet_nested_filter_pushdown(c: &mut Criterion) { + let dataset_path = DATASET.path().to_owned(); + let mut group = c.benchmark_group("parquet_nested_filter_pushdown"); + group.throughput(Throughput::Elements(TOTAL_ROWS as u64)); + + group.bench_function("no_pushdown", |b| { + let file_schema = setup_reader(&dataset_path); + let predicate = logical2physical(&create_predicate(), &file_schema); + b.iter(|| { + let matched = scan_with_predicate(&dataset_path, &predicate, false) + .expect("baseline parquet scan with filter succeeded"); + assert_eq!(matched, ROW_GROUP_ROW_COUNT); + }); + }); + + group.bench_function("with_pushdown", |b| { + let file_schema = setup_reader(&dataset_path); + let predicate = logical2physical(&create_predicate(), &file_schema); + b.iter(|| { + let matched = scan_with_predicate(&dataset_path, &predicate, true) + .expect("pushdown parquet scan with filter succeeded"); + assert_eq!(matched, ROW_GROUP_ROW_COUNT); + }); + }); + + group.finish(); +} + +fn setup_reader(path: &Path) -> SchemaRef { + let file = std::fs::File::open(path).expect("failed to open file"); + let builder = + ParquetRecordBatchReaderBuilder::try_new(file).expect("failed to build reader"); + Arc::clone(builder.schema()) +} + +fn create_predicate() -> Expr { + array_has( + col(COLUMN_NAME), + Expr::Literal(ScalarValue::Utf8(Some(TARGET_VALUE.to_string())), None), + ) +} + +fn scan_with_predicate( + path: &Path, + predicate: &Arc, + pushdown: bool, +) -> datafusion_common::Result { + let file = std::fs::File::open(path)?; + let builder = ParquetRecordBatchReaderBuilder::try_new(file)?; + let metadata = builder.metadata().clone(); + let file_schema = builder.schema(); + let projection = ProjectionMask::all(); + + let metrics = ExecutionPlanMetricsSet::new(); + let file_metrics = ParquetFileMetrics::new(0, &path.display().to_string(), &metrics); + + let builder = if pushdown { + if let Some(row_filter) = + build_row_filter(predicate, file_schema, &metadata, false, &file_metrics)? + { + builder.with_row_filter(row_filter) + } else { + builder + } + } else { + builder + }; + + let reader = builder.with_projection(projection).build()?; + + let mut matched_rows = 0usize; + for batch in reader { + let batch = batch?; + matched_rows += count_matches(predicate, &batch)?; + } + + if pushdown { + let pruned_rows = file_metrics.pushdown_rows_pruned.value(); + assert_eq!( + pruned_rows, + TOTAL_ROWS - matched_rows, + "row-level pushdown should prune 90% of rows" + ); + } + + Ok(matched_rows) +} + +fn count_matches( + expr: &Arc, + batch: &RecordBatch, +) -> datafusion_common::Result { + let values = expr.evaluate(batch)?.into_array(batch.num_rows())?; + let bools = values + .as_any() + .downcast_ref::() + .expect("boolean filter result"); + + Ok(bools.iter().filter(|v| matches!(v, Some(true))).count()) +} + +fn create_dataset() -> datafusion_common::Result { + let tempdir = TempDir::new()?; + let file_path = tempdir.path().join("nested_lists.parquet"); + + let field = Arc::new(Field::new("item", DataType::Utf8, true)); + let schema = Arc::new(Schema::new(vec![ + Field::new(COLUMN_NAME, DataType::List(field), false), + Field::new(PAYLOAD_COLUMN_NAME, DataType::Binary, false), + ])); + + let writer_props = WriterProperties::builder() + .set_max_row_group_row_count(Some(ROW_GROUP_ROW_COUNT)) + .build(); + + let mut writer = ArrowWriter::try_new( + std::fs::File::create(&file_path)?, + Arc::clone(&schema), + Some(writer_props), + )?; + + // Create sorted row groups with distinct values so that min/max statistics + // allow skipping most groups when applying a selective predicate. + let sorted_values = [ + "alpha", + "bravo", + "charlie", + "delta", + "echo", + "foxtrot", + "golf", + "hotel", + "india", + TARGET_VALUE, + ]; + + for value in sorted_values { + let batch = build_list_batch(&schema, value, ROW_GROUP_ROW_COUNT)?; + writer.write(&batch)?; + } + + writer.close()?; + + // Ensure the writer respected the requested row group size + let reader = + ParquetRecordBatchReaderBuilder::try_new(std::fs::File::open(&file_path)?)?; + assert_eq!(reader.metadata().row_groups().len(), TOTAL_ROW_GROUPS); + + Ok(BenchmarkDataset { + _tempdir: tempdir, + file_path, + }) +} + +fn build_list_batch( + schema: &SchemaRef, + value: &str, + len: usize, +) -> datafusion_common::Result { + let mut builder = ListBuilder::new(StringBuilder::new()); + let mut payload_builder = BinaryBuilder::new(); + let payload = vec![1u8; PAYLOAD_BYTES]; + for _ in 0..len { + builder.values().append_value(value); + builder.append(true); + payload_builder.append_value(&payload); + } + + let array = builder.finish(); + let payload_array = payload_builder.finish(); + Ok(RecordBatch::try_new( + Arc::clone(schema), + vec![Arc::new(array), Arc::new(payload_array)], + )?) +} + +criterion_group!(benches, parquet_nested_filter_pushdown); +criterion_main!(benches); diff --git a/datafusion/datasource-parquet/src/access_plan.rs b/datafusion/datasource-parquet/src/access_plan.rs index 570792d40e5b..44911fcf2a9c 100644 --- a/datafusion/datasource-parquet/src/access_plan.rs +++ b/datafusion/datasource-parquet/src/access_plan.rs @@ -82,6 +82,10 @@ use parquet::file::metadata::RowGroupMetaData; /// └───────────────────┘ /// Row Group 3 /// ``` +/// +/// For more background, please also see the [Embedding User-Defined Indexes in Apache Parquet Files blog] +/// +/// [Embedding User-Defined Indexes in Apache Parquet Files blog]: https://datafusion.apache.org/blog/2025/07/14/user-defined-parquet-indexes #[derive(Debug, Clone, PartialEq)] pub struct ParquetAccessPlan { /// How to access the i-th row group diff --git a/datafusion/datasource-parquet/src/file_format.rs b/datafusion/datasource-parquet/src/file_format.rs index 5e482382be68..edbdd618edb0 100644 --- a/datafusion/datasource-parquet/src/file_format.rs +++ b/datafusion/datasource-parquet/src/file_format.rs @@ -54,11 +54,11 @@ use datafusion_datasource::sink::{DataSink, DataSinkExec}; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::dml::InsertOp; -use datafusion_physical_expr_common::sort_expr::LexRequirement; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; use datafusion_session::Session; -use crate::metadata::DFParquetMetadata; +use crate::metadata::{DFParquetMetadata, lex_ordering_to_sorting_columns}; use crate::reader::CachedParquetFileReaderFactory; use crate::source::{ParquetSource, parse_coerce_int96_string}; use async_trait::async_trait; @@ -70,7 +70,7 @@ use futures::future::BoxFuture; use futures::{FutureExt, StreamExt, TryStreamExt}; use object_store::buffered::BufWriter; use object_store::path::Path; -use object_store::{ObjectMeta, ObjectStore}; +use object_store::{ObjectMeta, ObjectStore, ObjectStoreExt}; use parquet::arrow::arrow_writer::{ ArrowColumnChunk, ArrowColumnWriter, ArrowLeafColumn, ArrowRowGroupWriterFactory, ArrowWriterOptions, compute_leaves, @@ -81,8 +81,10 @@ use parquet::basic::Type; #[cfg(feature = "parquet_encryption")] use parquet::encryption::encrypt::FileEncryptionProperties; use parquet::errors::ParquetError; -use parquet::file::metadata::ParquetMetaData; -use parquet::file::properties::{WriterProperties, WriterPropertiesBuilder}; +use parquet::file::metadata::{ParquetMetaData, SortingColumn}; +use parquet::file::properties::{ + DEFAULT_MAX_ROW_GROUP_ROW_COUNT, WriterProperties, WriterPropertiesBuilder, +}; use parquet::file::writer::SerializedFileWriter; use parquet::schema::types::SchemaDescriptor; use tokio::io::{AsyncWrite, AsyncWriteExt}; @@ -391,7 +393,7 @@ impl FileFormat for ParquetFormat { }) .boxed() // Workaround https://github.com/rust-lang/rust/issues/64552 // fetch schemas concurrently, if requested - .buffered(state.config_options().execution.meta_fetch_concurrency) + .buffer_unordered(state.config_options().execution.meta_fetch_concurrency) .try_collect() .await?; @@ -401,12 +403,10 @@ impl FileFormat for ParquetFormat { // is not deterministic. Thus, to ensure deterministic schema inference // sort the files first. // https://github.com/apache/datafusion/pull/6629 - schemas.sort_by(|(location1, _), (location2, _)| location1.cmp(location2)); + schemas + .sort_unstable_by(|(location1, _), (location2, _)| location1.cmp(location2)); - let schemas = schemas - .into_iter() - .map(|(_, schema)| schema) - .collect::>(); + let schemas = schemas.into_iter().map(|(_, schema)| schema); let schema = if self.skip_metadata() { Schema::try_merge(clear_metadata(schemas)) @@ -449,6 +449,57 @@ impl FileFormat for ParquetFormat { .await } + async fn infer_ordering( + &self, + state: &dyn Session, + store: &Arc, + table_schema: SchemaRef, + object: &ObjectMeta, + ) -> Result> { + let file_decryption_properties = + get_file_decryption_properties(state, &self.options, &object.location) + .await?; + let file_metadata_cache = + state.runtime_env().cache_manager.get_file_metadata_cache(); + let metadata = DFParquetMetadata::new(store, object) + .with_metadata_size_hint(self.metadata_size_hint()) + .with_decryption_properties(file_decryption_properties) + .with_file_metadata_cache(Some(file_metadata_cache)) + .fetch_metadata() + .await?; + crate::metadata::ordering_from_parquet_metadata(&metadata, &table_schema) + } + + async fn infer_stats_and_ordering( + &self, + state: &dyn Session, + store: &Arc, + table_schema: SchemaRef, + object: &ObjectMeta, + ) -> Result { + let file_decryption_properties = + get_file_decryption_properties(state, &self.options, &object.location) + .await?; + let file_metadata_cache = + state.runtime_env().cache_manager.get_file_metadata_cache(); + let metadata = DFParquetMetadata::new(store, object) + .with_metadata_size_hint(self.metadata_size_hint()) + .with_decryption_properties(file_decryption_properties) + .with_file_metadata_cache(Some(file_metadata_cache)) + .fetch_metadata() + .await?; + let statistics = DFParquetMetadata::statistics_from_parquet_metadata( + &metadata, + &table_schema, + )?; + let ordering = + crate::metadata::ordering_from_parquet_metadata(&metadata, &table_schema)?; + Ok( + datafusion_datasource::file_format::FileMeta::new(statistics) + .with_ordering(ordering), + ) + } + async fn create_physical_plan( &self, state: &dyn Session, @@ -500,7 +551,22 @@ impl FileFormat for ParquetFormat { return not_impl_err!("Overwrites are not implemented yet for Parquet"); } - let sink = Arc::new(ParquetSink::new(conf, self.options.clone())); + // Convert ordering requirements to Parquet SortingColumns for file metadata + let sorting_columns = if let Some(ref requirements) = order_requirements { + let ordering: LexOrdering = requirements.clone().into(); + // In cases like `COPY (... ORDER BY ...) TO ...` the ORDER BY clause + // may not be compatible with Parquet sorting columns (e.g. ordering on `random()`). + // So if we cannot create a Parquet sorting column from the ordering requirement, + // we skip setting sorting columns on the Parquet sink. + lex_ordering_to_sorting_columns(&ordering).ok() + } else { + None + }; + + let sink = Arc::new( + ParquetSink::new(conf, self.options.clone()) + .with_sorting_columns(sorting_columns), + ); Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _) } @@ -1088,6 +1154,8 @@ pub struct ParquetSink { /// File metadata from successfully produced parquet files. The Mutex is only used /// to allow inserting to HashMap from behind borrowed reference in DataSink::write_all. written: Arc>>, + /// Optional sorting columns to write to Parquet metadata + sorting_columns: Option>, } impl Debug for ParquetSink { @@ -1119,9 +1187,19 @@ impl ParquetSink { config, parquet_options, written: Default::default(), + sorting_columns: None, } } + /// Set sorting columns for the Parquet file metadata. + pub fn with_sorting_columns( + mut self, + sorting_columns: Option>, + ) -> Self { + self.sorting_columns = sorting_columns; + self + } + /// Retrieve the file metadata for the written files, keyed to the path /// which may be partitioned (in the case of hive style partitioning). pub fn written(&self) -> HashMap { @@ -1145,6 +1223,12 @@ impl ParquetSink { } let mut builder = WriterPropertiesBuilder::try_from(&parquet_opts)?; + + // Set sorting columns if configured + if let Some(ref sorting_columns) = self.sorting_columns { + builder = builder.set_sorting_columns(Some(sorting_columns.clone())); + } + builder = set_writer_encryption_properties( builder, runtime, @@ -1276,7 +1360,7 @@ impl FileSink for ParquetSink { parquet_props.clone(), ) .await?; - let mut reservation = MemoryConsumer::new(format!("ParquetSink[{path}]")) + let reservation = MemoryConsumer::new(format!("ParquetSink[{path}]")) .register(context.memory_pool()); file_write_tasks.spawn(async move { while let Some(batch) = rx.recv().await { @@ -1381,7 +1465,7 @@ impl DataSink for ParquetSink { async fn column_serializer_task( mut rx: Receiver, mut writer: ArrowColumnWriter, - mut reservation: MemoryReservation, + reservation: MemoryReservation, ) -> Result<(ArrowColumnWriter, MemoryReservation)> { while let Some(col) = rx.recv().await { writer.write(&col)?; @@ -1466,7 +1550,7 @@ fn spawn_rg_join_and_finalize_task( rg_rows: usize, pool: &Arc, ) -> SpawnedTask { - let mut rg_reservation = + let rg_reservation = MemoryConsumer::new("ParquetSink(SerializedRowGroupWriter)").register(pool); SpawnedTask::spawn(async move { @@ -1505,7 +1589,9 @@ fn spawn_parquet_parallel_serialization_task( ) -> SpawnedTask> { SpawnedTask::spawn(async move { let max_buffer_rb = parallel_options.max_buffered_record_batches_per_stream; - let max_row_group_rows = writer_props.max_row_group_size(); + let max_row_group_rows = writer_props + .max_row_group_row_count() + .unwrap_or(DEFAULT_MAX_ROW_GROUP_ROW_COUNT); let mut row_group_index = 0; let col_writers = row_group_writer_factory.create_column_writers(row_group_index)?; @@ -1598,12 +1684,12 @@ async fn concatenate_parallel_row_groups( mut object_store_writer: Box, pool: Arc, ) -> Result { - let mut file_reservation = + let file_reservation = MemoryConsumer::new("ParquetSink(SerializedFileWriter)").register(&pool); while let Some(task) = serialize_rx.recv().await { let result = task.join_unwind().await; - let (serialized_columns, mut rg_reservation, _cnt) = + let (serialized_columns, rg_reservation, _cnt) = result.map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??; let mut rg_out = parquet_writer.next_row_group()?; diff --git a/datafusion/datasource-parquet/src/metadata.rs b/datafusion/datasource-parquet/src/metadata.rs index 8b11ba64ae7f..5a4c0bcdd514 100644 --- a/datafusion/datasource-parquet/src/metadata.rs +++ b/datafusion/datasource-parquet/src/metadata.rs @@ -21,7 +21,7 @@ use crate::{ ObjectStoreFetch, apply_file_schema_type_coercions, coerce_int96_to_resolution, }; -use arrow::array::{ArrayRef, BooleanArray}; +use arrow::array::{Array, ArrayRef, BooleanArray}; use arrow::compute::and; use arrow::compute::kernels::cmp::eq; use arrow::compute::sum; @@ -31,8 +31,12 @@ use datafusion_common::stats::Precision; use datafusion_common::{ ColumnStatistics, DataFusionError, Result, ScalarValue, Statistics, }; -use datafusion_execution::cache::cache_manager::{FileMetadata, FileMetadataCache}; +use datafusion_execution::cache::cache_manager::{ + CachedFileMetadataEntry, FileMetadata, FileMetadataCache, +}; use datafusion_functions_aggregate_common::min_max::{MaxAccumulator, MinAccumulator}; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use datafusion_physical_plan::Accumulator; use log::debug; use object_store::path::Path; @@ -41,6 +45,7 @@ use parquet::arrow::arrow_reader::statistics::StatisticsConverter; use parquet::arrow::{parquet_column, parquet_to_arrow_schema}; use parquet::file::metadata::{ PageIndexPolicy, ParquetMetaData, ParquetMetaDataReader, RowGroupMetaData, + SortingColumn, }; use parquet::schema::types::SchemaDescriptor; use std::any::Any; @@ -125,19 +130,15 @@ impl<'a> DFParquetMetadata<'a> { !cfg!(feature = "parquet_encryption") || decryption_properties.is_none(); if cache_metadata - && let Some(parquet_metadata) = file_metadata_cache - .as_ref() - .and_then(|file_metadata_cache| file_metadata_cache.get(object_meta)) - .and_then(|file_metadata| { - file_metadata - .as_any() - .downcast_ref::() - .map(|cached_parquet_metadata| { - Arc::clone(cached_parquet_metadata.parquet_metadata()) - }) - }) + && let Some(file_metadata_cache) = file_metadata_cache.as_ref() + && let Some(cached) = file_metadata_cache.get(&object_meta.location) + && cached.is_valid_for(object_meta) + && let Some(cached_parquet) = cached + .file_metadata + .as_any() + .downcast_ref::() { - return Ok(parquet_metadata); + return Ok(Arc::clone(cached_parquet.parquet_metadata())); } let mut reader = @@ -163,8 +164,11 @@ impl<'a> DFParquetMetadata<'a> { if cache_metadata && let Some(file_metadata_cache) = file_metadata_cache { file_metadata_cache.put( - object_meta, - Arc::new(CachedParquetMetaData::new(Arc::clone(&metadata))), + &object_meta.location, + CachedFileMetadataEntry::new( + (*object_meta).clone(), + Arc::new(CachedParquetMetaData::new(Arc::clone(&metadata))), + ), ); } @@ -483,22 +487,40 @@ fn summarize_min_max_null_counts( if let Some(max_acc) = &mut accumulators.max_accs[logical_schema_index] { max_acc.update_batch(&[Arc::clone(&max_values)])?; - let mut cur_max_acc = max_acc.clone(); - accumulators.is_max_value_exact[logical_schema_index] = has_any_exact_match( - &cur_max_acc.evaluate()?, - &max_values, - &is_max_value_exact_stat, - ); + + // handle the common special case when all row groups have exact statistics + let exactness = &is_max_value_exact_stat; + if !exactness.is_empty() + && exactness.null_count() == 0 + && exactness.true_count() == exactness.len() + { + accumulators.is_max_value_exact[logical_schema_index] = Some(true); + } else if exactness.true_count() == 0 { + accumulators.is_max_value_exact[logical_schema_index] = Some(false); + } else { + let val = max_acc.evaluate()?; + accumulators.is_max_value_exact[logical_schema_index] = + has_any_exact_match(&val, &max_values, exactness); + } } if let Some(min_acc) = &mut accumulators.min_accs[logical_schema_index] { min_acc.update_batch(&[Arc::clone(&min_values)])?; - let mut cur_min_acc = min_acc.clone(); - accumulators.is_min_value_exact[logical_schema_index] = has_any_exact_match( - &cur_min_acc.evaluate()?, - &min_values, - &is_min_value_exact_stat, - ); + + // handle the common special case when all row groups have exact statistics + let exactness = &is_min_value_exact_stat; + if !exactness.is_empty() + && exactness.null_count() == 0 + && exactness.true_count() == exactness.len() + { + accumulators.is_min_value_exact[logical_schema_index] = Some(true); + } else if exactness.true_count() == 0 { + accumulators.is_min_value_exact[logical_schema_index] = Some(false); + } else { + let val = min_acc.evaluate()?; + accumulators.is_min_value_exact[logical_schema_index] = + has_any_exact_match(&val, &min_values, exactness); + } } accumulators.null_counts_array[logical_schema_index] = match sum(&null_counts) { @@ -578,6 +600,15 @@ fn has_any_exact_match( array: &ArrayRef, exactness: &BooleanArray, ) -> Option { + if value.is_null() { + return Some(false); + } + + // Shortcut for single row group + if array.len() == 1 { + return Some(exactness.is_valid(0) && exactness.value(0)); + } + let scalar_array = value.to_scalar().ok()?; let eq_mask = eq(&scalar_array, &array).ok()?; let combined_mask = and(&eq_mask, exactness).ok()?; @@ -613,6 +644,114 @@ impl FileMetadata for CachedParquetMetaData { } } +/// Convert a [`PhysicalSortExpr`] to a Parquet [`SortingColumn`]. +/// +/// Returns `Err` if the expression is not a simple column reference. +pub(crate) fn sort_expr_to_sorting_column( + sort_expr: &PhysicalSortExpr, +) -> Result { + let column = sort_expr + .expr + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Plan(format!( + "Parquet sorting_columns only supports simple column references, \ + but got expression: {}", + sort_expr.expr + )) + })?; + + let column_idx: i32 = column.index().try_into().map_err(|_| { + DataFusionError::Plan(format!( + "Column index {} is too large to be represented as i32", + column.index() + )) + })?; + + Ok(SortingColumn { + column_idx, + descending: sort_expr.options.descending, + nulls_first: sort_expr.options.nulls_first, + }) +} + +/// Convert a [`LexOrdering`] to `Vec` for Parquet. +/// +/// Returns `Err` if any expression is not a simple column reference. +pub(crate) fn lex_ordering_to_sorting_columns( + ordering: &LexOrdering, +) -> Result> { + ordering.iter().map(sort_expr_to_sorting_column).collect() +} + +/// Extracts ordering information from Parquet metadata. +/// +/// This function reads the sorting_columns from the first row group's metadata +/// and converts them into a [`LexOrdering`] that can be used by the query engine. +/// +/// # Arguments +/// * `metadata` - The Parquet metadata containing sorting_columns information +/// * `schema` - The Arrow schema to use for column lookup +/// +/// # Returns +/// * `Ok(Some(ordering))` if valid ordering information was found +/// * `Ok(None)` if no sorting columns were specified or they couldn't be resolved +pub fn ordering_from_parquet_metadata( + metadata: &ParquetMetaData, + schema: &SchemaRef, +) -> Result> { + // Get the sorting columns from the first row group metadata. + // If no row groups exist or no sorting columns are specified, return None. + let sorting_columns = metadata + .row_groups() + .first() + .and_then(|rg| rg.sorting_columns()) + .filter(|cols| !cols.is_empty()); + + let Some(sorting_columns) = sorting_columns else { + return Ok(None); + }; + + let parquet_schema = metadata.file_metadata().schema_descr(); + + let sort_exprs = + sorting_columns_to_physical_exprs(sorting_columns, parquet_schema, schema); + + if sort_exprs.is_empty() { + return Ok(None); + } + + Ok(LexOrdering::new(sort_exprs)) +} + +/// Converts Parquet sorting columns to physical sort expressions. +fn sorting_columns_to_physical_exprs( + sorting_columns: &[SortingColumn], + parquet_schema: &SchemaDescriptor, + arrow_schema: &SchemaRef, +) -> Vec { + use arrow::compute::SortOptions; + + sorting_columns + .iter() + .filter_map(|sc| { + let parquet_column = parquet_schema.column(sc.column_idx as usize); + let name = parquet_column.name(); + + // Find the column in the arrow schema + let (index, _) = arrow_schema.column_with_name(name)?; + + let expr = Arc::new(Column::new(name, index)); + let options = SortOptions { + descending: sc.descending, + nulls_first: sc.nulls_first, + }; + Some(PhysicalSortExpr::new(expr, options)) + }) + .collect() +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/datasource-parquet/src/metrics.rs b/datafusion/datasource-parquet/src/metrics.rs index 5eaa137e9a45..2d6fb69270bf 100644 --- a/datafusion/datasource-parquet/src/metrics.rs +++ b/datafusion/datasource-parquet/src/metrics.rs @@ -16,7 +16,7 @@ // under the License. use datafusion_physical_plan::metrics::{ - Count, ExecutionPlanMetricsSet, MetricBuilder, MetricType, PruningMetrics, + Count, ExecutionPlanMetricsSet, Gauge, MetricBuilder, MetricType, PruningMetrics, RatioMergeStrategy, RatioMetrics, Time, }; @@ -45,9 +45,11 @@ pub struct ParquetFileMetrics { pub files_ranges_pruned_statistics: PruningMetrics, /// Number of times the predicate could not be evaluated pub predicate_evaluation_errors: Count, - /// Number of row groups whose bloom filters were checked, tracked with matched/pruned counts + /// Number of row groups pruned by bloom filters pub row_groups_pruned_bloom_filter: PruningMetrics, - /// Number of row groups whose statistics were checked, tracked with matched/pruned counts + /// Number of row groups pruned due to limit pruning. + pub limit_pruned_row_groups: PruningMetrics, + /// Number of row groups pruned by statistics pub row_groups_pruned_statistics: PruningMetrics, /// Total number of bytes scanned pub bytes_scanned: Count, @@ -63,19 +65,32 @@ pub struct ParquetFileMetrics { pub bloom_filter_eval_time: Time, /// Total rows filtered or matched by parquet page index pub page_index_rows_pruned: PruningMetrics, + /// Total pages filtered or matched by parquet page index + pub page_index_pages_pruned: PruningMetrics, /// Total time spent evaluating parquet page index filters pub page_index_eval_time: Time, /// Total time spent reading and parsing metadata from the footer pub metadata_load_time: Time, /// Scan Efficiency Ratio, calculated as bytes_scanned / total_file_size pub scan_efficiency_ratio: RatioMetrics, - /// Predicate Cache: number of records read directly from the inner reader. - /// This is the number of rows decoded while evaluating predicates - pub predicate_cache_inner_records: Count, + /// Predicate Cache: Total number of rows physically read and decoded from the Parquet file. + /// + /// This metric tracks "cache misses" in the predicate pushdown optimization. + /// When the specialized predicate reader cannot find the requested data in its cache, + /// it must fall back to the "inner reader" to physically decode the data from the + /// Parquet. + /// + /// This is the expensive path (IO + Decompression + Decoding). + /// + /// We use a Gauge here as arrow-rs reports absolute numbers rather + /// than incremental readings, we want a `set` operation here rather + /// than `add`. Earlier it was `Count`, which led to this issue: + /// github.com/apache/datafusion/issues/19334 + pub predicate_cache_inner_records: Gauge, /// Predicate Cache: number of records read from the cache. This is the /// number of rows that were stored in the cache after evaluating predicates /// reused for the output. - pub predicate_cache_records: Count, + pub predicate_cache_records: Gauge, } impl ParquetFileMetrics { @@ -93,15 +108,20 @@ impl ParquetFileMetrics { .with_type(MetricType::SUMMARY) .pruning_metrics("row_groups_pruned_bloom_filter", partition); + let limit_pruned_row_groups = MetricBuilder::new(metrics) + .with_new_label("filename", filename.to_string()) + .with_type(MetricType::SUMMARY) + .pruning_metrics("limit_pruned_row_groups", partition); + let row_groups_pruned_statistics = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) .with_type(MetricType::SUMMARY) .pruning_metrics("row_groups_pruned_statistics", partition); - let page_index_rows_pruned = MetricBuilder::new(metrics) + let page_index_pages_pruned = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) .with_type(MetricType::SUMMARY) - .pruning_metrics("page_index_rows_pruned", partition); + .pruning_metrics("page_index_pages_pruned", partition); let bytes_scanned = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) @@ -154,24 +174,30 @@ impl ParquetFileMetrics { .with_new_label("filename", filename.to_string()) .subset_time("page_index_eval_time", partition); + let page_index_rows_pruned = MetricBuilder::new(metrics) + .with_new_label("filename", filename.to_string()) + .pruning_metrics("page_index_rows_pruned", partition); + let predicate_cache_inner_records = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) - .counter("predicate_cache_inner_records", partition); + .gauge("predicate_cache_inner_records", partition); let predicate_cache_records = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) - .counter("predicate_cache_records", partition); + .gauge("predicate_cache_records", partition); Self { files_ranges_pruned_statistics, predicate_evaluation_errors, row_groups_pruned_bloom_filter, row_groups_pruned_statistics, + limit_pruned_row_groups, bytes_scanned, pushdown_rows_pruned, pushdown_rows_matched, row_pushdown_eval_time, page_index_rows_pruned, + page_index_pages_pruned, statistics_eval_time, bloom_filter_eval_time, page_index_eval_time, diff --git a/datafusion/datasource-parquet/src/mod.rs b/datafusion/datasource-parquet/src/mod.rs index eb4cc9e9ad5a..0e137a706fad 100644 --- a/datafusion/datasource-parquet/src/mod.rs +++ b/datafusion/datasource-parquet/src/mod.rs @@ -19,7 +19,6 @@ // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -#![deny(clippy::allow_attributes)] pub mod access_plan; pub mod file_format; @@ -32,6 +31,7 @@ mod row_filter; mod row_group_filter; mod sort; pub mod source; +mod supported_predicates; mod writer; pub use access_plan::{ParquetAccessPlan, RowGroupAccess}; diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index 83bdf79c8fcc..f657b709fe09 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -47,7 +47,7 @@ use datafusion_physical_expr_common::physical_expr::{ PhysicalExpr, is_dynamic_physical_expr, }; use datafusion_physical_plan::metrics::{ - Count, ExecutionPlanMetricsSet, MetricBuilder, PruningMetrics, + Count, ExecutionPlanMetricsSet, Gauge, MetricBuilder, PruningMetrics, }; use datafusion_pruning::{FilePruner, PruningPredicate, build_pruning_predicate}; @@ -69,13 +69,15 @@ use parquet::file::metadata::{PageIndexPolicy, ParquetMetaDataReader, RowGroupMe /// Implements [`FileOpener`] for a parquet file pub(super) struct ParquetOpener { /// Execution partition index - pub partition_index: usize, + pub(crate) partition_index: usize, /// Projection to apply on top of the table schema (i.e. can reference partition columns). pub projection: ProjectionExprs, /// Target number of rows in each output RecordBatch pub batch_size: usize, /// Optional limit on the number of rows to read - pub limit: Option, + pub(crate) limit: Option, + /// If should keep the output rows in order + pub preserve_order: bool, /// Optional predicate to apply during the scan pub predicate: Option>, /// Table schema, including partition columns. @@ -180,6 +182,9 @@ impl PreparedAccessPlan { impl FileOpener for ParquetOpener { fn open(&self, partitioned_file: PartitionedFile) -> Result { + // ----------------------------------- + // Step: prepare configurations, etc. + // ----------------------------------- let file_range = partitioned_file.range.clone(); let extensions = partitioned_file.extensions.clone(); let file_location = partitioned_file.object_meta.location.clone(); @@ -274,12 +279,18 @@ impl FileOpener for ParquetOpener { let max_predicate_cache_size = self.max_predicate_cache_size; let reverse_row_groups = self.reverse_row_groups; + let preserve_order = self.preserve_order; + Ok(Box::pin(async move { #[cfg(feature = "parquet_encryption")] let file_decryption_properties = encryption_context .get_file_decryption_properties(&file_location) .await?; + // --------------------------------------------- + // Step: try to prune the current file partition + // --------------------------------------------- + // Prune this file using the file level statistics and partition values. // Since dynamic filters may have been updated since planning it is possible that we are able // to prune files now that we couldn't prune at planning time. @@ -328,12 +339,17 @@ impl FileOpener for ParquetOpener { file_metrics.files_ranges_pruned_statistics.add_matched(1); + // -------------------------------------------------------- + // Step: fetch Parquet metadata (and optionally page index) + // -------------------------------------------------------- + // Don't load the page index yet. Since it is not stored inline in // the footer, loading the page index if it is not needed will do // unnecessary I/O. We decide later if it is needed to evaluate the - // pruning predicates. Thus default to not requesting if from the + // pruning predicates. Thus default to not requesting it from the // underlying reader. - let mut options = ArrowReaderOptions::new().with_page_index(false); + let mut options = + ArrowReaderOptions::new().with_page_index_policy(PageIndexPolicy::Skip); #[cfg(feature = "parquet_encryption")] if let Some(fd_val) = file_decryption_properties { options = options.with_file_decryption_properties(Arc::clone(&fd_val)); @@ -394,17 +410,27 @@ impl FileOpener for ParquetOpener { // and we can avoid doing any more work on the file (bloom filters, loading the page index, etc.). // Additionally, if any casts were inserted we can move casts from the column to the literal side: // `CAST(col AS INT) = 5` can become `col = CAST(5 AS )`, which can be evaluated statically. - let rewriter = expr_adapter_factory.create( - Arc::clone(&logical_file_schema), - Arc::clone(&physical_file_schema), - ); - let simplifier = PhysicalExprSimplifier::new(&physical_file_schema); - predicate = predicate - .map(|p| simplifier.simplify(rewriter.rewrite(p)?)) - .transpose()?; - // Adapt projections to the physical file schema as well - projection = projection - .try_map_exprs(|p| simplifier.simplify(rewriter.rewrite(p)?))?; + // + // When the schemas are identical and there is no predicate, the + // rewriter is a no-op: column indices already match (partition + // columns are appended after file columns in the table schema), + // types are the same, and there are no missing columns. Skip the + // tree walk entirely in that case. + let needs_rewrite = + predicate.is_some() || logical_file_schema != physical_file_schema; + if needs_rewrite { + let rewriter = expr_adapter_factory.create( + Arc::clone(&logical_file_schema), + Arc::clone(&physical_file_schema), + )?; + let simplifier = PhysicalExprSimplifier::new(&physical_file_schema); + predicate = predicate + .map(|p| simplifier.simplify(rewriter.rewrite(p)?)) + .transpose()?; + // Adapt projections to the physical file schema as well + projection = projection + .try_map_exprs(|p| simplifier.simplify(rewriter.rewrite(p)?))?; + } // Build predicates for this specific file let (pruning_predicate, page_pruning_predicate) = build_pruning_predicates( @@ -421,21 +447,28 @@ impl FileOpener for ParquetOpener { reader_metadata, &mut async_file_reader, // Since we're manually loading the page index the option here should not matter but we pass it in for consistency - options.with_page_index(true), + options.with_page_index_policy(PageIndexPolicy::Optional), ) .await?; } metadata_timer.stop(); + // --------------------------------------------------------- + // Step: construct builder for the final RecordBatch stream + // --------------------------------------------------------- + let mut builder = ParquetRecordBatchStreamBuilder::new_with_metadata( async_file_reader, reader_metadata, ); - let indices = projection.column_indices(); - - let mask = ProjectionMask::roots(builder.parquet_schema(), indices); + // --------------------------------------------------------------------- + // Step: optionally add row filter to the builder + // + // Row filter is used for late materialization in parquet decoding, see + // `row_filter` for details. + // --------------------------------------------------------------------- // Filter pushdown: evaluate predicates during scan if let Some(predicate) = pushdown_filters.then_some(predicate).flatten() { @@ -464,6 +497,10 @@ impl FileOpener for ParquetOpener { builder.with_row_selection_policy(RowSelectionPolicy::Selectors); } + // ------------------------------------------------------------ + // Step: prune row groups by range, predicate and bloom filter + // ------------------------------------------------------------ + // Determine which row groups to actually read. The idea is to skip // as many row groups as possible based on the metadata and query let file_metadata = Arc::clone(builder.metadata()); @@ -523,11 +560,19 @@ impl FileOpener for ParquetOpener { .add_matched(n_remaining_row_groups); } - let mut access_plan = row_groups.build(); + // Prune by limit if limit is set and limit order is not sensitive + if let (Some(limit), false) = (limit, preserve_order) { + row_groups.prune_by_limit(limit, rg_metadata, &file_metrics); + } + // -------------------------------------------------------- + // Step: prune pages from the kept row groups + // + let mut access_plan = row_groups.build(); // page index pruning: if all data on individual pages can // be ruled using page metadata, rows from other columns // with that range can be skipped as well + // -------------------------------------------------------- if enable_page_index && !access_plan.is_empty() && let Some(p) = page_pruning_predicate @@ -545,7 +590,10 @@ impl FileOpener for ParquetOpener { let mut prepared_plan = PreparedAccessPlan::from_access_plan(access_plan, rg_metadata)?; - // If reverse scanning is enabled, reverse the prepared plan + // ---------------------------------------------------------- + // Step: potentially reverse the access plan for performance. + // See `ParquetSource::try_pushdown_sort` for the rationale. + // ---------------------------------------------------------- if reverse_row_groups { prepared_plan = prepared_plan.reverse(file_metadata.as_ref())?; } @@ -564,6 +612,9 @@ impl FileOpener for ParquetOpener { // metrics from the arrow reader itself let arrow_reader_metrics = ArrowReaderMetrics::enabled(); + let indices = projection.column_indices(); + let mask = ProjectionMask::roots(builder.parquet_schema(), indices); + let stream = builder .with_projection(mask) .with_batch_size(batch_size) @@ -621,6 +672,9 @@ impl FileOpener for ParquetOpener { }) }); + // ---------------------------------------------------------------------- + // Step: wrap the stream so a dynamic filter can stop the file scan early + // ---------------------------------------------------------------------- if let Some(file_pruner) = file_pruner { Ok(EarlyStoppingStream::new( stream, @@ -639,15 +693,15 @@ impl FileOpener for ParquetOpener { /// arrow-rs parquet reader) to the parquet file metrics for DataFusion fn copy_arrow_reader_metrics( arrow_reader_metrics: &ArrowReaderMetrics, - predicate_cache_inner_records: &Count, - predicate_cache_records: &Count, + predicate_cache_inner_records: &Gauge, + predicate_cache_records: &Gauge, ) { if let Some(v) = arrow_reader_metrics.records_read_from_inner() { - predicate_cache_inner_records.add(v); + predicate_cache_inner_records.set(v); } if let Some(v) = arrow_reader_metrics.records_read_from_cache() { - predicate_cache_records.add(v); + predicate_cache_records.set(v); } } @@ -696,6 +750,10 @@ fn constant_value_from_stats( && !min.is_null() && matches!(column_stats.null_count, Precision::Exact(0)) { + // Cast to the expected data type if needed (e.g., Utf8 -> Dictionary) + if min.data_type() != *data_type { + return min.cast_to(data_type).ok(); + } return Some(min.clone()); } @@ -990,7 +1048,7 @@ mod test { }; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use futures::{Stream, StreamExt}; - use object_store::{ObjectStore, memory::InMemory, path::Path}; + use object_store::{ObjectStore, ObjectStoreExt, memory::InMemory, path::Path}; use parquet::arrow::ArrowWriter; use parquet::file::properties::WriterProperties; @@ -1016,6 +1074,7 @@ mod test { coerce_int96: Option, max_predicate_cache_size: Option, reverse_row_groups: bool, + preserve_order: bool, } impl ParquetOpenerBuilder { @@ -1041,6 +1100,7 @@ mod test { coerce_int96: None, max_predicate_cache_size: None, reverse_row_groups: false, + preserve_order: false, } } @@ -1148,6 +1208,7 @@ mod test { encryption_factory: None, max_predicate_cache_size: self.max_predicate_cache_size, reverse_row_groups: self.reverse_row_groups, + preserve_order: self.preserve_order, } } } @@ -1684,7 +1745,7 @@ mod test { // Write parquet file with multiple row groups // Force small row groups by setting max_row_group_size let props = WriterProperties::builder() - .set_max_row_group_size(3) // Force each batch into its own row group + .set_max_row_group_row_count(Some(3)) // Force each batch into its own row group .build(); let data_len = write_parquet_batches( @@ -1784,7 +1845,7 @@ mod test { .unwrap(); // 4 rows let props = WriterProperties::builder() - .set_max_row_group_size(4) + .set_max_row_group_row_count(Some(4)) .build(); let data_len = write_parquet_batches( @@ -1871,7 +1932,7 @@ mod test { let batch3 = record_batch!(("a", Int32, vec![Some(7), Some(8)])).unwrap(); let props = WriterProperties::builder() - .set_max_row_group_size(2) + .set_max_row_group_row_count(Some(2)) .build(); let data_len = write_parquet_batches( diff --git a/datafusion/datasource-parquet/src/page_filter.rs b/datafusion/datasource-parquet/src/page_filter.rs index e25e33835f79..194e6e94fba3 100644 --- a/datafusion/datasource-parquet/src/page_filter.rs +++ b/datafusion/datasource-parquet/src/page_filter.rs @@ -189,6 +189,10 @@ impl PagePruningAccessPlanFilter { let mut total_skip = 0; // track the total number of rows that should not be skipped let mut total_select = 0; + // track the total number of pages that should be skipped + let mut total_pages_skip = 0; + // track the total number of pages that should not be skipped + let mut total_pages_select = 0; // for each row group specified in the access plan let row_group_indexes = access_plan.row_group_indexes(); @@ -226,10 +230,12 @@ impl PagePruningAccessPlanFilter { file_metrics, ); - let Some(selection) = selection else { + let Some((selection, total_pages, matched_pages)) = selection else { trace!("No pages pruned in prune_pages_in_one_row_group"); continue; }; + total_pages_select += matched_pages; + total_pages_skip += total_pages - matched_pages; debug!( "Use filter and page index to create RowSelection {:?} from predicate: {:?}", @@ -278,6 +284,12 @@ impl PagePruningAccessPlanFilter { file_metrics .page_index_rows_pruned .add_matched(total_select); + file_metrics + .page_index_pages_pruned + .add_pruned(total_pages_skip); + file_metrics + .page_index_pages_pruned + .add_matched(total_pages_select); access_plan } @@ -297,7 +309,8 @@ fn update_selection( } } -/// Returns a [`RowSelection`] for the rows in this row group to scan. +/// Returns a [`RowSelection`] for the rows in this row group to scan, in addition to the number of +/// total and matched pages. /// /// This Row Selection is formed from the page index and the predicate skips row /// ranges that can be ruled out based on the predicate. @@ -310,7 +323,7 @@ fn prune_pages_in_one_row_group( converter: StatisticsConverter<'_>, parquet_metadata: &ParquetMetaData, metrics: &ParquetFileMetrics, -) -> Option { +) -> Option<(RowSelection, usize, usize)> { let pruning_stats = PagesPruningStatistics::try_new(row_group_index, converter, parquet_metadata)?; @@ -362,7 +375,11 @@ fn prune_pages_in_one_row_group( RowSelector::skip(sum_row) }; vec.push(selector); - Some(RowSelection::from(vec)) + + let total_pages = values.len(); + let matched_pages = values.iter().filter(|v| **v).count(); + + Some((RowSelection::from(vec), total_pages, matched_pages)) } /// Implement [`PruningStatistics`] for one column's PageIndex (column_index + offset_index) diff --git a/datafusion/datasource-parquet/src/row_filter.rs b/datafusion/datasource-parquet/src/row_filter.rs index ba3b29be40d7..62ba53bb871e 100644 --- a/datafusion/datasource-parquet/src/row_filter.rs +++ b/datafusion/datasource-parquet/src/row_filter.rs @@ -58,6 +58,11 @@ //! 8. Build the `RowFilter` with the sorted predicates followed by //! the unsorted predicates. Within each partition, predicates are //! still be sorted by size. +//! +//! List-aware predicates (for example, `array_has`, `array_has_all`, and +//! `array_has_any`) can be evaluated directly during Parquet decoding. Struct +//! columns and other nested projections that are not explicitly supported will +//! continue to be evaluated after the batches are materialized. use std::cmp::Ordering; use std::collections::BTreeSet; @@ -70,6 +75,7 @@ use arrow::record_batch::RecordBatch; use parquet::arrow::ProjectionMask; use parquet::arrow::arrow_reader::{ArrowPredicate, RowFilter}; use parquet::file::metadata::ParquetMetaData; +use parquet::schema::types::SchemaDescriptor; use datafusion_common::Result; use datafusion_common::cast::as_boolean_array; @@ -81,6 +87,7 @@ use datafusion_physical_expr::{PhysicalExpr, split_conjunction}; use datafusion_physical_plan::metrics; use super::ParquetFileMetrics; +use super::supported_predicates::supports_list_predicates; /// A "compiled" predicate passed to `ParquetRecordBatchStream` to perform /// row-level filtering during parquet decoding. @@ -91,12 +98,14 @@ use super::ParquetFileMetrics; /// /// An expression can be evaluated as a `DatafusionArrowPredicate` if it: /// * Does not reference any projected columns -/// * Does not reference columns with non-primitive types (e.g. structs / lists) +/// * References either primitive columns or list columns used by +/// supported predicates (such as `array_has_all` or NULL checks). Struct +/// columns are still evaluated after decoding. #[derive(Debug)] pub(crate) struct DatafusionArrowPredicate { /// the filter expression physical_expr: Arc, - /// Path to the columns in the parquet schema required to evaluate the + /// Path to the leaf columns in the parquet schema required to evaluate the /// expression projection_mask: ProjectionMask, /// how many rows were filtered out by this predicate @@ -121,9 +130,12 @@ impl DatafusionArrowPredicate { Ok(Self { physical_expr, - projection_mask: ProjectionMask::roots( + // Use leaf indices: when nested columns are involved, we must specify + // leaf (primitive) column indices in the Parquet schema so the decoder + // can properly project and filter nested structures. + projection_mask: ProjectionMask::leaves( metadata.file_metadata().schema_descr(), - candidate.projection, + candidate.projection.leaf_indices.iter().copied(), ), rows_pruned, rows_matched, @@ -177,12 +189,23 @@ pub(crate) struct FilterCandidate { /// Can this filter use an index (e.g. a page index) to prune rows? can_use_index: bool, /// Column indices into the parquet file schema required to evaluate this filter. - projection: Vec, + projection: LeafProjection, /// The Arrow schema containing only the columns required by this filter, /// projected from the file's Arrow schema. filter_schema: SchemaRef, } +/// Projection specification for nested columns using Parquet leaf column indices. +/// +/// For nested types like List and Struct, Parquet stores data in leaf columns +/// (the primitive fields). This struct tracks which leaf columns are needed +/// to evaluate a filter expression. +#[derive(Debug, Clone)] +struct LeafProjection { + /// Leaf column indices in the Parquet schema descriptor. + leaf_indices: Vec, +} + /// Helper to build a `FilterCandidate`. /// /// This will do several things: @@ -212,23 +235,29 @@ impl FilterCandidateBuilder { /// * `Ok(None)` if the expression cannot be used as an ArrowFilter /// * `Err(e)` if an error occurs while building the candidate pub fn build(self, metadata: &ParquetMetaData) -> Result> { - let Some(required_column_indices) = - pushdown_columns(&self.expr, &self.file_schema)? + let Some(required_columns) = pushdown_columns(&self.expr, &self.file_schema)? else { return Ok(None); }; - let projected_schema = - Arc::new(self.file_schema.project(&required_column_indices)?); + let root_indices: Vec<_> = + required_columns.required_columns.into_iter().collect(); - let required_bytes = size_of_columns(&required_column_indices, metadata)?; - let can_use_index = columns_sorted(&required_column_indices, metadata)?; + let leaf_indices = leaf_indices_for_roots( + &root_indices, + metadata.file_metadata().schema_descr(), + ); + + let projected_schema = Arc::new(self.file_schema.project(&root_indices)?); + + let required_bytes = size_of_columns(&leaf_indices, metadata)?; + let can_use_index = columns_sorted(&leaf_indices, metadata)?; Ok(Some(FilterCandidate { expr: self.expr, required_bytes, can_use_index, - projection: required_column_indices, + projection: LeafProjection { leaf_indices }, filter_schema: projected_schema, })) } @@ -238,7 +267,8 @@ impl FilterCandidateBuilder { /// prevent the expression from being pushed down to the parquet decoder. /// /// An expression cannot be pushed down if it references: -/// - Non-primitive columns (like structs or lists) +/// - Unsupported nested columns (structs or list fields that are not covered by +/// the supported predicate set) /// - Columns that don't exist in the file schema struct PushdownChecker<'schema> { /// Does the expression require any non-primitive columns (like structs)? @@ -246,41 +276,92 @@ struct PushdownChecker<'schema> { /// Does the expression reference any columns not present in the file schema? projected_columns: bool, /// Indices into the file schema of columns required to evaluate the expression. - required_columns: BTreeSet, + required_columns: Vec, + /// Whether nested list columns are supported by the predicate semantics. + allow_list_columns: bool, /// The Arrow schema of the parquet file. file_schema: &'schema Schema, } impl<'schema> PushdownChecker<'schema> { - fn new(file_schema: &'schema Schema) -> Self { + fn new(file_schema: &'schema Schema, allow_list_columns: bool) -> Self { Self { non_primitive_columns: false, projected_columns: false, - required_columns: BTreeSet::default(), + required_columns: Vec::new(), + allow_list_columns, file_schema, } } fn check_single_column(&mut self, column_name: &str) -> Option { - if let Ok(idx) = self.file_schema.index_of(column_name) { - self.required_columns.insert(idx); - if DataType::is_nested(self.file_schema.field(idx).data_type()) { - self.non_primitive_columns = true; + let idx = match self.file_schema.index_of(column_name) { + Ok(idx) => idx, + Err(_) => { + // Column does not exist in the file schema, so we can't push this down. + self.projected_columns = true; return Some(TreeNodeRecursion::Jump); } + }; + + // Duplicates are handled by dedup() in into_sorted_columns() + self.required_columns.push(idx); + let data_type = self.file_schema.field(idx).data_type(); + + if DataType::is_nested(data_type) { + self.handle_nested_type(data_type) + } else { + None + } + } + + /// Determines whether a nested data type can be pushed down to Parquet decoding. + /// + /// Returns `Some(TreeNodeRecursion::Jump)` if the nested type prevents pushdown, + /// `None` if the type is supported and pushdown can continue. + fn handle_nested_type(&mut self, data_type: &DataType) -> Option { + if self.is_nested_type_supported(data_type) { + None } else { - // Column does not exist in the file schema, so we can't push this down. - self.projected_columns = true; - return Some(TreeNodeRecursion::Jump); + // Block pushdown for unsupported nested types: + // - Structs (regardless of predicate support) + // - Lists without supported predicates + self.non_primitive_columns = true; + Some(TreeNodeRecursion::Jump) } + } - None + /// Checks if a nested data type is supported for list column pushdown. + /// + /// List columns are only supported if: + /// 1. The data type is a list variant (List, LargeList, or FixedSizeList) + /// 2. The expression contains supported list predicates (e.g., array_has_all) + fn is_nested_type_supported(&self, data_type: &DataType) -> bool { + let is_list = matches!( + data_type, + DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) + ); + self.allow_list_columns && is_list } #[inline] fn prevents_pushdown(&self) -> bool { self.non_primitive_columns || self.projected_columns } + + /// Consumes the checker and returns sorted, deduplicated column indices + /// wrapped in a `PushdownColumns` struct. + /// + /// This method sorts the column indices and removes duplicates. The sort + /// is required because downstream code relies on column indices being in + /// ascending order for correct schema projection. + fn into_sorted_columns(mut self) -> PushdownColumns { + self.required_columns.sort_unstable(); + self.required_columns.dedup(); + PushdownColumns { + required_columns: self.required_columns, + } + } } impl TreeNodeVisitor<'_> for PushdownChecker<'_> { @@ -297,34 +378,121 @@ impl TreeNodeVisitor<'_> for PushdownChecker<'_> { } } +/// Describes the nested column behavior for filter pushdown. +/// +/// This enum makes explicit the different states a predicate can be in +/// with respect to nested column handling during Parquet decoding. +/// Result of checking which columns are required for filter pushdown. +#[derive(Debug)] +struct PushdownColumns { + /// Sorted, unique column indices into the file schema required to evaluate + /// the filter expression. Must be in ascending order for correct schema + /// projection matching. + required_columns: Vec, +} + /// Checks if a given expression can be pushed down to the parquet decoder. /// -/// Returns `Some(column_indices)` if the expression can be pushed down, -/// where `column_indices` are the indices into the file schema of all columns +/// Returns `Some(PushdownColumns)` if the expression can be pushed down, +/// where the struct contains the indices into the file schema of all columns /// required to evaluate the expression. /// /// Returns `None` if the expression cannot be pushed down (e.g., references -/// non-primitive types or columns not in the file). +/// unsupported nested types or columns not in the file). fn pushdown_columns( expr: &Arc, file_schema: &Schema, -) -> Result>> { - let mut checker = PushdownChecker::new(file_schema); +) -> Result> { + let allow_list_columns = supports_list_predicates(expr); + let mut checker = PushdownChecker::new(file_schema, allow_list_columns); expr.visit(&mut checker)?; - Ok((!checker.prevents_pushdown()) - .then_some(checker.required_columns.into_iter().collect())) + Ok((!checker.prevents_pushdown()).then(|| checker.into_sorted_columns())) +} + +fn leaf_indices_for_roots( + root_indices: &[usize], + schema_descr: &SchemaDescriptor, +) -> Vec { + // Always map root (Arrow) indices to Parquet leaf indices via the schema + // descriptor. Arrow root indices only equal Parquet leaf indices when the + // schema has no group columns (Struct, Map, etc.); when group columns + // exist, their children become separate leaves and shift all subsequent + // leaf indices. + // Struct columns are unsupported. + let root_set: BTreeSet<_> = root_indices.iter().copied().collect(); + + (0..schema_descr.num_columns()) + .filter(|leaf_idx| { + root_set.contains(&schema_descr.get_column_root_idx(*leaf_idx)) + }) + .collect() } /// Checks if a predicate expression can be pushed down to the parquet decoder. /// /// Returns `true` if all columns referenced by the expression: /// - Exist in the provided schema -/// - Are primitive types (not structs, lists, etc.) +/// - Are primitive types OR list columns with supported predicates +/// (e.g., `array_has`, `array_has_all`, `array_has_any`, IS NULL, IS NOT NULL) +/// - Struct columns are not supported and will prevent pushdown /// /// # Arguments /// * `expr` - The filter expression to check /// * `file_schema` - The Arrow schema of the parquet file (or table schema when /// the file schema is not yet available during planning) +/// +/// # Examples +/// +/// Primitive column filters can be pushed down: +/// ```ignore +/// use datafusion_expr::{col, Expr}; +/// use datafusion_common::ScalarValue; +/// use arrow::datatypes::{DataType, Field, Schema}; +/// use std::sync::Arc; +/// +/// let schema = Arc::new(Schema::new(vec![ +/// Field::new("age", DataType::Int32, false), +/// ])); +/// +/// // Primitive filter: can be pushed down +/// let expr = col("age").gt(Expr::Literal(ScalarValue::Int32(Some(30)), None)); +/// let expr = logical2physical(&expr, &schema); +/// assert!(can_expr_be_pushed_down_with_schemas(&expr, &schema)); +/// ``` +/// +/// Struct column filters cannot be pushed down: +/// ```ignore +/// use arrow::datatypes::Fields; +/// +/// let schema = Arc::new(Schema::new(vec![ +/// Field::new("person", DataType::Struct( +/// Fields::from(vec![Field::new("name", DataType::Utf8, true)]) +/// ), true), +/// ])); +/// +/// // Struct filter: cannot be pushed down +/// let expr = col("person").is_not_null(); +/// let expr = logical2physical(&expr, &schema); +/// assert!(!can_expr_be_pushed_down_with_schemas(&expr, &schema)); +/// ``` +/// +/// List column filters with supported predicates can be pushed down: +/// ```ignore +/// use datafusion_functions_nested::expr_fn::{array_has_all, make_array}; +/// +/// let schema = Arc::new(Schema::new(vec![ +/// Field::new("tags", DataType::List( +/// Arc::new(Field::new("item", DataType::Utf8, true)) +/// ), true), +/// ])); +/// +/// // Array filter with supported predicate: can be pushed down +/// let expr = array_has_all(col("tags"), make_array(vec![ +/// Expr::Literal(ScalarValue::Utf8(Some("rust".to_string())), None) +/// ])); +/// let expr = logical2physical(&expr, &schema); +/// assert!(can_expr_be_pushed_down_with_schemas(&expr, &schema)); +/// ``` pub fn can_expr_be_pushed_down_with_schemas( expr: &Arc, file_schema: &Schema, @@ -335,7 +503,7 @@ pub fn can_expr_be_pushed_down_with_schemas( } } -/// Calculate the total compressed size of all `Column`'s required for +/// Calculate the total compressed size of all leaf columns required for /// predicate `Expr`. /// /// This value represents the total amount of IO required to evaluate the @@ -464,21 +632,30 @@ mod test { use super::*; use datafusion_common::ScalarValue; + use arrow::array::{ListBuilder, StringBuilder}; use arrow::datatypes::{Field, TimeUnit::Nanosecond}; use datafusion_expr::{Expr, col}; + use datafusion_functions_nested::array_has::{ + array_has_all_udf, array_has_any_udf, array_has_udf, + }; + use datafusion_functions_nested::expr_fn::{ + array_has, array_has_all, array_has_any, make_array, + }; use datafusion_physical_expr::planner::logical2physical; use datafusion_physical_expr_adapter::{ DefaultPhysicalExprAdapterFactory, PhysicalExprAdapterFactory, }; - use datafusion_physical_plan::metrics::{Count, Time}; + use datafusion_physical_plan::metrics::{Count, ExecutionPlanMetricsSet, Time}; + use parquet::arrow::ArrowWriter; use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; use parquet::arrow::parquet_to_arrow_schema; use parquet::file::reader::{FileReader, SerializedFileReader}; + use tempfile::NamedTempFile; - // We should ignore predicate that read non-primitive columns + // List predicates used by the decoder should be accepted for pushdown #[test] - fn test_filter_candidate_builder_ignore_complex_types() { + fn test_filter_candidate_builder_supports_list_types() { let testdata = datafusion_common::test_util::parquet_test_data(); let file = std::fs::File::open(format!("{testdata}/list_columns.parquet")) .expect("opening file"); @@ -496,11 +673,16 @@ mod test { let table_schema = Arc::new(table_schema.clone()); + let list_index = table_schema + .index_of("int64_list") + .expect("list column should exist"); + let candidate = FilterCandidateBuilder::new(expr, table_schema) .build(metadata) - .expect("building candidate"); + .expect("building candidate") + .expect("list pushdown should be supported"); - assert!(candidate.is_none()); + assert_eq!(candidate.projection.leaf_indices, vec![list_index]); } #[test] @@ -530,6 +712,7 @@ mod test { let expr = logical2physical(&expr, &table_schema); let expr = DefaultPhysicalExprAdapterFactory {} .create(Arc::new(table_schema.clone()), Arc::clone(&file_schema)) + .expect("creating expr adapter") .rewrite(expr) .expect("rewriting expression"); let candidate = FilterCandidateBuilder::new(expr, file_schema.clone()) @@ -569,6 +752,7 @@ mod test { // Rewrite the expression to add CastExpr for type coercion let expr = DefaultPhysicalExprAdapterFactory {} .create(Arc::new(table_schema), Arc::clone(&file_schema)) + .expect("creating expr adapter") .rewrite(expr) .expect("rewriting expression"); let candidate = FilterCandidateBuilder::new(expr, file_schema) @@ -590,14 +774,233 @@ mod test { } #[test] - fn nested_data_structures_prevent_pushdown() { + fn struct_data_structures_prevent_pushdown() { + let table_schema = Arc::new(Schema::new(vec![Field::new( + "struct_col", + DataType::Struct( + vec![Arc::new(Field::new("a", DataType::Int32, true))].into(), + ), + true, + )])); + + let expr = col("struct_col").is_not_null(); + let expr = logical2physical(&expr, &table_schema); + + assert!(!can_expr_be_pushed_down_with_schemas(&expr, &table_schema)); + } + + #[test] + fn mixed_primitive_and_struct_prevents_pushdown() { + // Even when a predicate contains both primitive and unsupported nested columns, + // the entire predicate should not be pushed down because the struct column + // cannot be evaluated during Parquet decoding. + let table_schema = Arc::new(Schema::new(vec![ + Field::new( + "struct_col", + DataType::Struct( + vec![Arc::new(Field::new("a", DataType::Int32, true))].into(), + ), + true, + ), + Field::new("int_col", DataType::Int32, false), + ])); + + // Expression: (struct_col IS NOT NULL) AND (int_col = 5) + // Even though int_col is primitive, the presence of struct_col in the + // conjunction should prevent pushdown of the entire expression. + let expr = col("struct_col") + .is_not_null() + .and(col("int_col").eq(Expr::Literal(ScalarValue::Int32(Some(5)), None))); + let expr = logical2physical(&expr, &table_schema); + + // The entire expression should not be pushed down + assert!(!can_expr_be_pushed_down_with_schemas(&expr, &table_schema)); + + // However, just the int_col predicate alone should be pushable + let expr_int_only = + col("int_col").eq(Expr::Literal(ScalarValue::Int32(Some(5)), None)); + let expr_int_only = logical2physical(&expr_int_only, &table_schema); + assert!(can_expr_be_pushed_down_with_schemas( + &expr_int_only, + &table_schema + )); + } + + #[test] + fn nested_lists_allow_pushdown_checks() { let table_schema = Arc::new(get_lists_table_schema()); let expr = col("utf8_list").is_not_null(); let expr = logical2physical(&expr, &table_schema); check_expression_can_evaluate_against_schema(&expr, &table_schema); - assert!(!can_expr_be_pushed_down_with_schemas(&expr, &table_schema)); + assert!(can_expr_be_pushed_down_with_schemas(&expr, &table_schema)); + } + + #[test] + fn array_has_all_pushdown_filters_rows() { + // Test array_has_all: checks if array contains all of ["c"] + // Rows with "c": row 1 and row 2 + let expr = array_has_all( + col("letters"), + make_array(vec![Expr::Literal( + ScalarValue::Utf8(Some("c".to_string())), + None, + )]), + ); + test_array_predicate_pushdown("array_has_all", expr, 1, 2, true); + } + + /// Helper function to test array predicate pushdown functionality. + /// + /// Creates a Parquet file with a list column, applies the given predicate, + /// and verifies that rows are correctly filtered during decoding. + fn test_array_predicate_pushdown( + func_name: &str, + predicate_expr: Expr, + expected_pruned: usize, + expected_matched: usize, + expect_list_support: bool, + ) { + let item_field = Arc::new(Field::new("item", DataType::Utf8, true)); + let schema = Arc::new(Schema::new(vec![Field::new( + "letters", + DataType::List(item_field), + true, + )])); + + let mut builder = ListBuilder::new(StringBuilder::new()); + // Row 0: ["a", "b"] + builder.values().append_value("a"); + builder.values().append_value("b"); + builder.append(true); + + // Row 1: ["c"] + builder.values().append_value("c"); + builder.append(true); + + // Row 2: ["c", "d"] + builder.values().append_value("c"); + builder.values().append_value("d"); + builder.append(true); + + let batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(builder.finish())]) + .expect("record batch"); + + let file = NamedTempFile::new().expect("temp file"); + let mut writer = + ArrowWriter::try_new(file.reopen().unwrap(), schema, None).expect("writer"); + writer.write(&batch).expect("write batch"); + writer.close().expect("close writer"); + + let reader_file = file.reopen().expect("reopen file"); + let parquet_reader_builder = + ParquetRecordBatchReaderBuilder::try_new(reader_file) + .expect("reader builder"); + let metadata = parquet_reader_builder.metadata().clone(); + let file_schema = parquet_reader_builder.schema().clone(); + + let expr = logical2physical(&predicate_expr, &file_schema); + if expect_list_support { + assert!(supports_list_predicates(&expr)); + } + + let metrics = ExecutionPlanMetricsSet::new(); + let file_metrics = + ParquetFileMetrics::new(0, &format!("{func_name}.parquet"), &metrics); + + let row_filter = + build_row_filter(&expr, &file_schema, &metadata, false, &file_metrics) + .expect("building row filter") + .expect("row filter should exist"); + + let reader = parquet_reader_builder + .with_row_filter(row_filter) + .build() + .expect("build reader"); + + let mut total_rows = 0; + for batch in reader { + let batch = batch.expect("record batch"); + total_rows += batch.num_rows(); + } + + assert_eq!( + file_metrics.pushdown_rows_pruned.value(), + expected_pruned, + "{func_name}: expected {expected_pruned} pruned rows" + ); + assert_eq!( + file_metrics.pushdown_rows_matched.value(), + expected_matched, + "{func_name}: expected {expected_matched} matched rows" + ); + assert_eq!( + total_rows, expected_matched, + "{func_name}: expected {expected_matched} total rows" + ); + } + + #[test] + fn array_has_pushdown_filters_rows() { + // Test array_has: checks if "c" is in the array + // Rows with "c": row 1 and row 2 + let expr = array_has( + col("letters"), + Expr::Literal(ScalarValue::Utf8(Some("c".to_string())), None), + ); + test_array_predicate_pushdown("array_has", expr, 1, 2, true); + } + + #[test] + fn array_has_any_pushdown_filters_rows() { + // Test array_has_any: checks if array contains any of ["a", "d"] + // Row 0 has "a", row 2 has "d" - both should match + let expr = array_has_any( + col("letters"), + make_array(vec![ + Expr::Literal(ScalarValue::Utf8(Some("a".to_string())), None), + Expr::Literal(ScalarValue::Utf8(Some("d".to_string())), None), + ]), + ); + test_array_predicate_pushdown("array_has_any", expr, 1, 2, true); + } + + #[test] + fn array_has_udf_pushdown_filters_rows() { + let expr = array_has_udf().call(vec![ + col("letters"), + Expr::Literal(ScalarValue::Utf8(Some("c".to_string())), None), + ]); + + test_array_predicate_pushdown("array_has_udf", expr, 1, 2, true); + } + + #[test] + fn array_has_all_udf_pushdown_filters_rows() { + let expr = array_has_all_udf().call(vec![ + col("letters"), + make_array(vec![Expr::Literal( + ScalarValue::Utf8(Some("c".to_string())), + None, + )]), + ]); + + test_array_predicate_pushdown("array_has_all_udf", expr, 1, 2, true); + } + + #[test] + fn array_has_any_udf_pushdown_filters_rows() { + let expr = array_has_any_udf().call(vec![ + col("letters"), + make_array(vec![ + Expr::Literal(ScalarValue::Utf8(Some("a".to_string())), None), + Expr::Literal(ScalarValue::Utf8(Some("d".to_string())), None), + ]), + ]); + + test_array_predicate_pushdown("array_has_any_udf", expr, 1, 2, true); } #[test] @@ -658,6 +1061,91 @@ mod test { .expect("parsing schema") } + /// Regression test: when a schema has Struct columns, Arrow field indices diverge + /// from Parquet leaf indices (Struct children become separate leaves). The + /// `PrimitiveOnly` fast-path in `leaf_indices_for_roots` assumes they are equal, + /// so a filter on a primitive column *after* a Struct gets the wrong leaf index. + /// + /// Schema: + /// Arrow indices: col_a=0 struct_col=1 col_b=2 + /// Parquet leaves: col_a=0 struct_col.x=1 struct_col.y=2 col_b=3 + /// + /// A filter on col_b should project Parquet leaf 3, but the bug causes it to + /// project leaf 2 (struct_col.y). + #[test] + fn test_filter_pushdown_leaf_index_with_struct_in_schema() { + use arrow::array::{Int32Array, StringArray, StructArray}; + + let schema = Arc::new(Schema::new(vec![ + Field::new("col_a", DataType::Int32, false), + Field::new( + "struct_col", + DataType::Struct( + vec![ + Arc::new(Field::new("x", DataType::Int32, true)), + Arc::new(Field::new("y", DataType::Int32, true)), + ] + .into(), + ), + true, + ), + Field::new("col_b", DataType::Utf8, false), + ])); + + let col_a = Arc::new(Int32Array::from(vec![1, 2, 3])); + let struct_col = Arc::new(StructArray::from(vec![ + ( + Arc::new(Field::new("x", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![10, 20, 30])) as _, + ), + ( + Arc::new(Field::new("y", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![100, 200, 300])) as _, + ), + ])); + let col_b = Arc::new(StringArray::from(vec!["aaa", "target", "zzz"])); + + let batch = + RecordBatch::try_new(Arc::clone(&schema), vec![col_a, struct_col, col_b]) + .unwrap(); + + let file = NamedTempFile::new().expect("temp file"); + let mut writer = + ArrowWriter::try_new(file.reopen().unwrap(), Arc::clone(&schema), None) + .expect("writer"); + writer.write(&batch).expect("write batch"); + writer.close().expect("close writer"); + + let reader_file = file.reopen().expect("reopen file"); + let builder = ParquetRecordBatchReaderBuilder::try_new(reader_file) + .expect("reader builder"); + let metadata = builder.metadata().clone(); + let file_schema = builder.schema().clone(); + + // sanity check: 4 Parquet leaves, 3 Arrow fields + assert_eq!(metadata.file_metadata().schema_descr().num_columns(), 4); + assert_eq!(file_schema.fields().len(), 3); + + // build a filter candidate for `col_b = 'target'` through the public API + let expr = col("col_b").eq(Expr::Literal( + ScalarValue::Utf8(Some("target".to_string())), + None, + )); + let expr = logical2physical(&expr, &file_schema); + + let candidate = FilterCandidateBuilder::new(expr, file_schema) + .build(&metadata) + .expect("building candidate") + .expect("filter on primitive col_b should be pushable"); + + // col_b is Parquet leaf 3 (shifted by struct_col's two children). + assert_eq!( + candidate.projection.leaf_indices, + vec![3], + "leaf_indices should be [3] for col_b" + ); + } + /// Sanity check that the given expression could be evaluated against the given schema without any errors. /// This will fail if the expression references columns that are not in the schema or if the types of the columns are incompatible, etc. fn check_expression_can_evaluate_against_schema( diff --git a/datafusion/datasource-parquet/src/row_group_filter.rs b/datafusion/datasource-parquet/src/row_group_filter.rs index 1264197609f3..932988af051e 100644 --- a/datafusion/datasource-parquet/src/row_group_filter.rs +++ b/datafusion/datasource-parquet/src/row_group_filter.rs @@ -24,6 +24,8 @@ use arrow::datatypes::Schema; use datafusion_common::pruning::PruningStatistics; use datafusion_common::{Column, Result, ScalarValue}; use datafusion_datasource::FileRange; +use datafusion_physical_expr::PhysicalExprSimplifier; +use datafusion_physical_expr::expressions::NotExpr; use datafusion_pruning::PruningPredicate; use parquet::arrow::arrow_reader::statistics::StatisticsConverter; use parquet::arrow::parquet_column; @@ -46,13 +48,20 @@ use parquet::{ pub struct RowGroupAccessPlanFilter { /// which row groups should be accessed access_plan: ParquetAccessPlan, + /// Row groups where ALL rows are known to match the pruning predicate + /// (the predicate does not filter any rows) + is_fully_matched: Vec, } impl RowGroupAccessPlanFilter { /// Create a new `RowGroupPlanBuilder` for pruning out the groups to scan /// based on metadata and statistics pub fn new(access_plan: ParquetAccessPlan) -> Self { - Self { access_plan } + let num_row_groups = access_plan.len(); + Self { + access_plan, + is_fully_matched: vec![false; num_row_groups], + } } /// Return true if there are no row groups @@ -70,6 +79,139 @@ impl RowGroupAccessPlanFilter { self.access_plan } + /// Returns the is_fully_matched vector + pub fn is_fully_matched(&self) -> &Vec { + &self.is_fully_matched + } + + /// Prunes the access plan based on the limit and fully contained row groups. + /// + /// The pruning works by leveraging the concept of fully matched row groups. Consider a query like: + /// `WHERE species LIKE 'Alpine%' AND s >= 50 LIMIT N` + /// + /// After initial filtering, row groups can be classified into three states: + /// + /// 1. Not Matching / Pruned + /// 2. Partially Matching (Row Group/Page contains some matches) + /// 3. Fully Matching (Entire range is within predicate) + /// + /// +-----------------------------------------------------------------------+ + /// | NOT MATCHING | + /// | Row group 1 | + /// | +-----------------------------------+-----------------------------+ | + /// | | SPECIES | S | | + /// | +-----------------------------------+-----------------------------+ | + /// | | Snow Vole | 7 | | + /// | | Brown Bear | 133 ✅ | | + /// | | Gray Wolf | 82 ✅ | | + /// | +-----------------------------------+-----------------------------+ | + /// +-----------------------------------------------------------------------+ + /// + /// +---------------------------------------------------------------------------+ + /// | PARTIALLY MATCHING | + /// | | + /// | Row group 2 Row group 4 | + /// | +------------------+--------------+ +------------------+----------+ | + /// | | SPECIES | S | | SPECIES | S | | + /// | +------------------+--------------+ +------------------+----------+ | + /// | | Lynx | 71 ✅ | | Europ. Mole | 4 | | + /// | | Red Fox | 40 | | Polecat | 16 | | + /// | | Alpine Bat ✅ | 6 | | Alpine Ibex ✅ | 97 ✅ | | + /// | +------------------+--------------+ +------------------+----------+ | + /// +---------------------------------------------------------------------------+ + /// + /// +-----------------------------------------------------------------------+ + /// | FULLY MATCHING | + /// | Row group 3 | + /// | +-----------------------------------+-----------------------------+ | + /// | | SPECIES | S | | + /// | +-----------------------------------+-----------------------------+ | + /// | | Alpine Ibex ✅ | 101 ✅ | | + /// | | Alpine Goat ✅ | 76 ✅ | | + /// | | Alpine Sheep ✅ | 83 ✅ | | + /// | +-----------------------------------+-----------------------------+ | + /// +-----------------------------------------------------------------------+ + /// + /// ### Identification of Fully Matching Row Groups + /// + /// DataFusion identifies row groups where ALL rows satisfy the filter by inverting the + /// predicate and checking if statistics prove the inverted version is false for the group. + /// + /// For example, prefix matches like `species LIKE 'Alpine%'` are pruned using ranges: + /// 1. Candidate Range: `species >= 'Alpine' AND species < 'Alpinf'` + /// 2. Inverted Condition (to prove full match): `species < 'Alpine' OR species >= 'Alpinf'` + /// 3. Statistical Evaluation (check if any row *could* satisfy the inverted condition): + /// `min < 'Alpine' OR max >= 'Alpinf'` + /// + /// If this evaluation is **false**, it proves no row can fail the original filter, + /// so the row group is **FULLY MATCHING**. + /// + /// ### Impact of Statistics Truncation + /// + /// The precision of pruning depends on the metadata quality. Truncated statistics + /// may prevent the system from proving a full match. + /// + /// **Example**: `WHERE species LIKE 'Alpine%'` (Target range: `['Alpine', 'Alpinf')`) + /// + /// | Truncation Length | min / max | Inverted Evaluation | Status | + /// |-------------------|---------------------|---------------------------------------------------------------------|------------------------| + /// | **Length 6** | `Alpine` / `Alpine` | `"Alpine" < "Alpine" (F) OR "Alpine" >= "Alpinf" (F)` -> **false** | **FULLY MATCHING** | + /// | **Length 3** | `Alp` / `Alq` | `"Alp" < "Alpine" (T) OR "Alq" >= "Alpinf" (T)` -> **true** | **PARTIALLY MATCHING** | + /// + /// Even though Row Group 3 only contains matching rows, truncation to length 3 makes + /// the statistics `[Alp, Alq]` too broad to prove it (they could include "Alpha"). + /// The system must conservatively scan the group. + /// + /// Without limit pruning: Scan Partition 2 → Partition 3 → Partition 4 (until limit reached) + /// With limit pruning: If Partition 3 contains enough rows to satisfy the limit, + /// skip Partitions 2 and 4 entirely and go directly to Partition 3. + /// + /// This optimization is particularly effective when: + /// - The limit is small relative to the total dataset size + /// - There are row groups that are fully matched by the filter predicates + /// - The fully matched row groups contain sufficient rows to satisfy the limit + /// + /// For more information, see the [paper](https://arxiv.org/pdf/2504.11540)'s "Pruning for LIMIT Queries" part + pub fn prune_by_limit( + &mut self, + limit: usize, + rg_metadata: &[RowGroupMetaData], + metrics: &ParquetFileMetrics, + ) { + let mut fully_matched_row_group_indexes: Vec = Vec::new(); + let mut fully_matched_rows_count: usize = 0; + + // Iterate through the currently accessible row groups and try to + // find a set of matching row groups that can satisfy the limit + for &idx in self.access_plan.row_group_indexes().iter() { + if self.is_fully_matched[idx] { + let row_group_row_count = rg_metadata[idx].num_rows() as usize; + fully_matched_row_group_indexes.push(idx); + fully_matched_rows_count += row_group_row_count; + if fully_matched_rows_count >= limit { + break; + } + } + } + + // If we can satisfy the limit with fully matching row groups, + // rewrite the plan to do so + if fully_matched_rows_count >= limit { + let original_num_accessible_row_groups = + self.access_plan.row_group_indexes().len(); + let new_num_accessible_row_groups = fully_matched_row_group_indexes.len(); + let pruned_count = original_num_accessible_row_groups + .saturating_sub(new_num_accessible_row_groups); + metrics.limit_pruned_row_groups.add_pruned(pruned_count); + + let mut new_access_plan = ParquetAccessPlan::new_none(rg_metadata.len()); + for &idx in &fully_matched_row_group_indexes { + new_access_plan.scan(idx); + } + self.access_plan = new_access_plan; + } + } + /// Prune remaining row groups to only those within the specified range. /// /// Updates this set to mark row groups that should not be scanned @@ -135,15 +277,26 @@ impl RowGroupAccessPlanFilter { // try to prune the row groups in a single call match predicate.prune(&pruning_stats) { Ok(values) => { - // values[i] is false means the predicate could not be true for row group i + let mut fully_contained_candidates_original_idx: Vec = Vec::new(); for (idx, &value) in row_group_indexes.iter().zip(values.iter()) { if !value { self.access_plan.skip(*idx); metrics.row_groups_pruned_statistics.add_pruned(1); } else { metrics.row_groups_pruned_statistics.add_matched(1); + fully_contained_candidates_original_idx.push(*idx); } } + + // Check if any of the matched row groups are fully contained by the predicate + self.identify_fully_matched_row_groups( + &fully_contained_candidates_original_idx, + arrow_schema, + parquet_schema, + groups, + predicate, + metrics, + ); } // stats filter array could not be built, so we can't prune Err(e) => { @@ -153,6 +306,68 @@ impl RowGroupAccessPlanFilter { } } + /// Identifies row groups that are fully matched by the predicate. + /// + /// This optimization checks whether all rows in a row group satisfy the predicate + /// by inverting the predicate and checking if it prunes the row group. If the + /// inverted predicate prunes a row group, it means no rows match the inverted + /// predicate, which implies all rows match the original predicate. + /// + /// Note: This optimization is relatively inexpensive for a limited number of row groups. + fn identify_fully_matched_row_groups( + &mut self, + candidate_row_group_indices: &[usize], + arrow_schema: &Schema, + parquet_schema: &SchemaDescriptor, + groups: &[RowGroupMetaData], + predicate: &PruningPredicate, + metrics: &ParquetFileMetrics, + ) { + if candidate_row_group_indices.is_empty() { + return; + } + + // Use NotExpr to create the inverted predicate + let inverted_expr = Arc::new(NotExpr::new(Arc::clone(predicate.orig_expr()))); + + // Simplify the NOT expression (e.g., NOT(c1 = 0) -> c1 != 0) + // before building the pruning predicate + let simplifier = PhysicalExprSimplifier::new(arrow_schema); + let Ok(inverted_expr) = simplifier.simplify(inverted_expr) else { + return; + }; + + let Ok(inverted_predicate) = + PruningPredicate::try_new(inverted_expr, Arc::clone(predicate.schema())) + else { + return; + }; + + let inverted_pruning_stats = RowGroupPruningStatistics { + parquet_schema, + row_group_metadatas: candidate_row_group_indices + .iter() + .map(|&i| &groups[i]) + .collect::>(), + arrow_schema, + }; + + let Ok(inverted_values) = inverted_predicate.prune(&inverted_pruning_stats) + else { + return; + }; + + for (i, &original_row_group_idx) in candidate_row_group_indices.iter().enumerate() + { + // If the inverted predicate *also* prunes this row group (meaning inverted_values[i] is false), + // it implies that *all* rows in this group satisfy the original predicate. + if !inverted_values[i] { + self.is_fully_matched[original_row_group_idx] = true; + metrics.row_groups_pruned_statistics.add_fully_matched(1); + } + } + } + /// Prune remaining row groups using available bloom filters and the /// [`PruningPredicate`]. /// @@ -447,6 +662,7 @@ mod tests { use datafusion_expr::{Expr, cast, col, lit}; use datafusion_physical_expr::planner::logical2physical; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; + use object_store::ObjectStoreExt; use parquet::arrow::ArrowSchemaConverter; use parquet::arrow::async_reader::ParquetObjectReader; use parquet::basic::LogicalType; @@ -1537,7 +1753,7 @@ mod tests { pruning_predicate: &PruningPredicate, ) -> Result { use datafusion_datasource::PartitionedFile; - use object_store::{ObjectMeta, ObjectStore}; + use object_store::ObjectMeta; let object_meta = ObjectMeta { location: object_store::path::Path::parse(file_name).expect("creating path"), @@ -1559,14 +1775,7 @@ mod tests { ParquetObjectReader::new(Arc::new(in_memory), object_meta.location.clone()) .with_file_size(object_meta.size); - let partitioned_file = PartitionedFile { - object_meta, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }; + let partitioned_file = PartitionedFile::new_from_meta(object_meta); let reader = ParquetFileReader { inner, diff --git a/datafusion/datasource-parquet/src/source.rs b/datafusion/datasource-parquet/src/source.rs index 2e0919b1447d..75d87a4cd16f 100644 --- a/datafusion/datasource-parquet/src/source.rs +++ b/datafusion/datasource-parquet/src/source.rs @@ -548,6 +548,7 @@ impl FileSource for ParquetSource { .batch_size .expect("Batch size must set before creating ParquetOpener"), limit: base_config.limit, + preserve_order: base_config.preserve_order, predicate: self.predicate.clone(), table_schema: self.table_schema.clone(), metadata_size_hint: self.metadata_size_hint, @@ -756,7 +757,7 @@ impl FileSource for ParquetSource { /// # Returns /// - `Inexact`: Created an optimized source (e.g., reversed scan) that approximates the order /// - `Unsupported`: Cannot optimize for this ordering - fn try_reverse_output( + fn try_pushdown_sort( &self, order: &[PhysicalSortExpr], eq_properties: &EquivalenceProperties, diff --git a/datafusion/datasource-parquet/src/supported_predicates.rs b/datafusion/datasource-parquet/src/supported_predicates.rs new file mode 100644 index 000000000000..a205c12dd06a --- /dev/null +++ b/datafusion/datasource-parquet/src/supported_predicates.rs @@ -0,0 +1,144 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Registry of physical expressions that support nested list column pushdown +//! to the Parquet decoder. +//! +//! This module provides a trait-based approach for determining which predicates +//! can be safely evaluated on nested list columns during Parquet decoding. + +use std::sync::Arc; + +use datafusion_physical_expr::expressions::{IsNotNullExpr, IsNullExpr}; +use datafusion_physical_expr::{PhysicalExpr, ScalarFunctionExpr}; + +/// Trait for physical expressions that support list column pushdown during +/// Parquet decoding. +/// +/// This trait provides a type-safe mechanism for identifying expressions that +/// can be safely pushed down to the Parquet decoder for evaluation on nested +/// list columns. +/// +/// # Implementation Notes +/// +/// Expression types in external crates cannot directly implement this trait +/// due to Rust's orphan rules. Instead, we use a blanket implementation that +/// delegates to a registration mechanism. +/// +/// # Examples +/// +/// ```ignore +/// use datafusion_physical_expr::PhysicalExpr; +/// use datafusion_datasource_parquet::SupportsListPushdown; +/// +/// let expr: Arc = ...; +/// if expr.supports_list_pushdown() { +/// // Can safely push down to Parquet decoder +/// } +/// ``` +pub trait SupportsListPushdown { + /// Returns `true` if this expression supports list column pushdown. + fn supports_list_pushdown(&self) -> bool; +} + +/// Blanket implementation for all physical expressions. +/// +/// This delegates to specialized predicates that check whether the concrete +/// expression type is registered as supporting list pushdown. This design +/// allows the trait to work with expression types defined in external crates. +impl SupportsListPushdown for dyn PhysicalExpr { + fn supports_list_pushdown(&self) -> bool { + is_null_check(self) || is_supported_scalar_function(self) + } +} + +/// Checks if an expression is a NULL or NOT NULL check. +/// +/// These checks are universally supported for all column types. +fn is_null_check(expr: &dyn PhysicalExpr) -> bool { + expr.as_any().downcast_ref::().is_some() + || expr.as_any().downcast_ref::().is_some() +} + +/// Checks if an expression is a scalar function registered for list pushdown. +/// +/// Returns `true` if the expression is a `ScalarFunctionExpr` whose function +/// is in the registry of supported operations. +fn is_supported_scalar_function(expr: &dyn PhysicalExpr) -> bool { + scalar_function_name(expr).is_some_and(|name| { + // Registry of verified array functions + matches!(name, "array_has" | "array_has_all" | "array_has_any") + }) +} + +fn scalar_function_name(expr: &dyn PhysicalExpr) -> Option<&str> { + expr.as_any() + .downcast_ref::() + .map(ScalarFunctionExpr::name) +} + +/// Checks whether the given physical expression contains a supported nested +/// predicate (for example, `array_has_all`). +/// +/// This function recursively traverses the expression tree to determine if +/// any node contains predicates that support list column pushdown to the +/// Parquet decoder. +/// +/// # Supported predicates +/// +/// - `IS NULL` and `IS NOT NULL` checks on any column type +/// - Array functions: `array_has`, `array_has_all`, `array_has_any` +/// +/// # Returns +/// +/// `true` if the expression or any of its children contain supported predicates. +pub fn supports_list_predicates(expr: &Arc) -> bool { + expr.supports_list_pushdown() + || expr + .children() + .iter() + .any(|child| supports_list_predicates(child)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_null_check_detection() { + use datafusion_physical_expr::expressions::Column; + + let col_expr: Arc = Arc::new(Column::new("test", 0)); + assert!(!is_null_check(col_expr.as_ref())); + + // IsNullExpr and IsNotNullExpr detection requires actual instances + // which need schema setup - tested in integration tests + } + + #[test] + fn test_supported_scalar_functions() { + use datafusion_physical_expr::expressions::Column; + + let col_expr: Arc = Arc::new(Column::new("test", 0)); + + // Non-function expressions should return false + assert!(!is_supported_scalar_function(col_expr.as_ref())); + + // Testing with actual ScalarFunctionExpr requires function setup + // and is better suited for integration tests + } +} diff --git a/datafusion/datasource/Cargo.toml b/datafusion/datasource/Cargo.toml index 48bf30f7a448..1315f871a68f 100644 --- a/datafusion/datasource/Cargo.toml +++ b/datafusion/datasource/Cargo.toml @@ -36,7 +36,7 @@ default = ["compression"] [dependencies] arrow = { workspace = true } -async-compression = { version = "0.4.35", features = [ +async-compression = { version = "0.4.40", features = [ "bzip2", "gzip", "xz", diff --git a/datafusion/datasource/src/display.rs b/datafusion/datasource/src/display.rs index 15fe8679acda..0f59e33ff9ea 100644 --- a/datafusion/datasource/src/display.rs +++ b/datafusion/datasource/src/display.rs @@ -287,13 +287,6 @@ mod tests { version: None, }; - PartitionedFile { - object_meta, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - } + PartitionedFile::new_from_meta(object_meta) } } diff --git a/datafusion/datasource/src/file.rs b/datafusion/datasource/src/file.rs index f5380c27ecc2..a0f82ff7a9b5 100644 --- a/datafusion/datasource/src/file.rs +++ b/datafusion/datasource/src/file.rs @@ -39,12 +39,19 @@ use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use object_store::ObjectStore; -/// Helper function to convert any type implementing FileSource to Arc<dyn FileSource> +/// Helper function to convert any type implementing [`FileSource`] to `Arc` pub fn as_file_source(source: T) -> Arc { Arc::new(source) } -/// file format specific behaviors for elements in [`DataSource`] +/// File format specific behaviors for [`DataSource`] +/// +/// # Schema information +/// There are two important schemas for a [`FileSource`]: +/// 1. [`Self::table_schema`] -- the schema for the overall table +/// (file data plus partition columns) +/// 2. The logical output schema, comprised of [`Self::table_schema`] with +/// [`Self::projection`] applied /// /// See more details on specific implementations: /// * [`ArrowSource`](https://docs.rs/datafusion/latest/datafusion/datasource/physical_plan/struct.ArrowSource.html) @@ -64,24 +71,44 @@ pub trait FileSource: Send + Sync { ) -> Result>; /// Any fn as_any(&self) -> &dyn Any; - /// Returns the table schema for this file source. + + /// Returns the table schema for the overall table (including partition columns, if any) + /// + /// This method returns the unprojected schema: the full schema of the data + /// without [`Self::projection`] applied. /// - /// This always returns the unprojected schema (the full schema of the data). + /// The output schema of this `FileSource` is this TableSchema + /// with [`Self::projection`] applied. + /// + /// Use [`ProjectionExprs::project_schema`] to get the projected schema + /// after applying the projection. fn table_schema(&self) -> &crate::table_schema::TableSchema; + /// Initialize new type with batch size configuration fn with_batch_size(&self, batch_size: usize) -> Arc; - /// Returns the filter expression that will be applied during the file scan. + + /// Returns the filter expression that will be applied *during* the file scan. + /// + /// These expressions are in terms of the unprojected [`Self::table_schema`]. fn filter(&self) -> Option> { None } - /// Return the projection that will be applied to the output stream on top of the table schema. + + /// Return the projection that will be applied to the output stream on top + /// of [`Self::table_schema`]. + /// + /// Note you can use [`ProjectionExprs::project_schema`] on the table + /// schema to get the effective output schema of this source. fn projection(&self) -> Option<&ProjectionExprs> { None } + /// Return execution plan metrics fn metrics(&self) -> &ExecutionPlanMetricsSet; + /// String representation of file source such as "csv", "json", "parquet" fn file_type(&self) -> &str; + /// Format FileType specific information fn fmt_extra(&self, _t: DisplayFormatType, _f: &mut Formatter) -> fmt::Result { Ok(()) @@ -135,6 +162,19 @@ pub trait FileSource: Send + Sync { } /// Try to push down filters into this FileSource. + /// + /// `filters` must be in terms of the unprojected table schema (file schema + /// plus partition columns), before any projection is applied. + /// + /// Any filters that this FileSource chooses to evaluate itself should be + /// returned as `PushedDown::Yes` in the result, along with a FileSource + /// instance that incorporates those filters. Such filters are logically + /// applied "during" the file scan, meaning they may refer to columns not + /// included in the final output projection. + /// + /// Filters that cannot be pushed down should be marked as `PushedDown::No`, + /// and will be evaluated by an execution plan after the file source. + /// /// See [`ExecutionPlan::handle_child_pushdown_result`] for more details. /// /// [`ExecutionPlan::handle_child_pushdown_result`]: datafusion_physical_plan::ExecutionPlan::handle_child_pushdown_result @@ -189,7 +229,29 @@ pub trait FileSource: Send + Sync { /// * `Inexact` - Created a source optimized for ordering (e.g., reversed row groups) but not perfectly sorted /// * `Unsupported` - Cannot optimize for this ordering /// - /// Default implementation returns `Unsupported`. + /// # Deprecation / migration notes + /// - [`Self::try_reverse_output`] was renamed to this method and deprecated since `53.0.0`. + /// Per DataFusion's deprecation guidelines, it will be removed in `59.0.0` or later + /// (6 major versions or 6 months, whichever is longer). + /// - New implementations should override [`Self::try_pushdown_sort`] directly. + /// - For backwards compatibility, the default implementation of + /// [`Self::try_pushdown_sort`] delegates to the deprecated + /// [`Self::try_reverse_output`] until it is removed. After that point, the + /// default implementation will return [`SortOrderPushdownResult::Unsupported`]. + fn try_pushdown_sort( + &self, + order: &[PhysicalSortExpr], + eq_properties: &EquivalenceProperties, + ) -> Result>> { + #[expect(deprecated)] + self.try_reverse_output(order, eq_properties) + } + + /// Deprecated: Renamed to [`Self::try_pushdown_sort`]. + #[deprecated( + since = "53.0.0", + note = "Renamed to try_pushdown_sort. This method was never limited to reversing output. It will be removed in 59.0.0 or later." + )] fn try_reverse_output( &self, _order: &[PhysicalSortExpr], @@ -198,7 +260,7 @@ pub trait FileSource: Send + Sync { Ok(SortOrderPushdownResult::Unsupported) } - /// Try to push down a projection into a this FileSource. + /// Try to push down a projection into this FileSource. /// /// `FileSource` implementations that support projection pushdown should /// override this method and return a new `FileSource` instance with the @@ -232,7 +294,7 @@ pub trait FileSource: Send + Sync { /// `SchemaAdapterFactory` has been removed. Use `PhysicalExprAdapterFactory` instead. /// See `upgrading.md` for more details. #[deprecated( - since = "52.0.0", + since = "53.0.0", note = "SchemaAdapterFactory has been removed. Use PhysicalExprAdapterFactory instead. See upgrading.md for more details." )] #[expect(deprecated)] @@ -250,7 +312,7 @@ pub trait FileSource: Send + Sync { /// `SchemaAdapterFactory` has been removed. Use `PhysicalExprAdapterFactory` instead. /// See `upgrading.md` for more details. #[deprecated( - since = "52.0.0", + since = "53.0.0", note = "SchemaAdapterFactory has been removed. Use PhysicalExprAdapterFactory instead. See upgrading.md for more details." )] #[expect(deprecated)] diff --git a/datafusion/datasource/src/file_format.rs b/datafusion/datasource/src/file_format.rs index 54389ecd214e..9f8fa622d258 100644 --- a/datafusion/datasource/src/file_format.rs +++ b/datafusion/datasource/src/file_format.rs @@ -32,6 +32,7 @@ use arrow::datatypes::SchemaRef; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{GetExt, Result, Statistics, internal_err, not_impl_err}; use datafusion_physical_expr::LexRequirement; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_plan::ExecutionPlan; use datafusion_session::Session; @@ -41,6 +42,35 @@ use object_store::{ObjectMeta, ObjectStore}; /// Default max records to scan to infer the schema pub const DEFAULT_SCHEMA_INFER_MAX_RECORD: usize = 1000; +/// Metadata fetched from a file, including statistics and ordering. +/// +/// This struct is returned by [`FileFormat::infer_stats_and_ordering`] to +/// provide all metadata in a single read, avoiding duplicate I/O operations. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub struct FileMeta { + /// Statistics for the file (row counts, byte sizes, column statistics). + pub statistics: Statistics, + /// The ordering (sort order) of the file, if known. + pub ordering: Option, +} + +impl FileMeta { + /// Creates a new `FileMeta` with the given statistics and no ordering. + pub fn new(statistics: Statistics) -> Self { + Self { + statistics, + ordering: None, + } + } + + /// Sets the ordering for this file metadata. + pub fn with_ordering(mut self, ordering: Option) -> Self { + self.ordering = ordering; + self + } +} + /// This trait abstracts all the file format specific implementations /// from the [`TableProvider`]. This helps code re-utilization across /// providers that support the same file formats. @@ -90,6 +120,52 @@ pub trait FileFormat: Send + Sync + fmt::Debug { object: &ObjectMeta, ) -> Result; + /// Infer the ordering (sort order) for the provided object from file metadata. + /// + /// Returns `Ok(None)` if the file format does not support ordering inference + /// or if the file does not have ordering information. + /// + /// `table_schema` is the (combined) schema of the overall table + /// and may be a superset of the schema contained in this file. + /// + /// The default implementation returns `Ok(None)`. + async fn infer_ordering( + &self, + _state: &dyn Session, + _store: &Arc, + _table_schema: SchemaRef, + _object: &ObjectMeta, + ) -> Result> { + Ok(None) + } + + /// Infer both statistics and ordering from a single metadata read. + /// + /// This is more efficient than calling [`Self::infer_stats`] and + /// [`Self::infer_ordering`] separately when both are needed, as it avoids + /// reading file metadata twice. + /// + /// The default implementation calls both methods separately. File formats + /// that can extract both from a single read should override this method. + async fn infer_stats_and_ordering( + &self, + state: &dyn Session, + store: &Arc, + table_schema: SchemaRef, + object: &ObjectMeta, + ) -> Result { + let statistics = self + .infer_stats(state, store, Arc::clone(&table_schema), object) + .await?; + let ordering = self + .infer_ordering(state, store, table_schema, object) + .await?; + Ok(FileMeta { + statistics, + ordering, + }) + } + /// Take a list of files and convert it to the appropriate executor /// according to this file format. async fn create_physical_plan( diff --git a/datafusion/datasource/src/file_scan_config.rs b/datafusion/datasource/src/file_scan_config.rs index c8636343ccc5..c3e5cabce7bc 100644 --- a/datafusion/datasource/src/file_scan_config.rs +++ b/datafusion/datasource/src/file_scan_config.rs @@ -55,10 +55,21 @@ use datafusion_physical_plan::{ use log::{debug, warn}; use std::{any::Any, fmt::Debug, fmt::Formatter, fmt::Result as FmtResult, sync::Arc}; -/// The base configurations for a [`DataSourceExec`], the a physical plan for -/// any given file format. +/// [`FileScanConfig`] represents scanning data from a group of files /// -/// Use [`DataSourceExec::from_data_source`] to create a [`DataSourceExec`] from a ``FileScanConfig`. +/// `FileScanConfig` is used to create a [`DataSourceExec`], the physical plan +/// for scanning files with a particular file format. +/// +/// The [`FileSource`] (e.g. `ParquetSource`, `CsvSource`, etc.) is responsible +/// for creating the actual execution plan to read the files based on a +/// `FileScanConfig`. Fields in a `FileScanConfig` such as Statistics represent +/// information about the files **before** any projection or filtering is +/// applied in the file source. +/// +/// Use [`FileScanConfigBuilder`] to construct a `FileScanConfig`. +/// +/// Use [`DataSourceExec::from_data_source`] to create a [`DataSourceExec`] from +/// a `FileScanConfig`. /// /// # Example /// ``` @@ -152,7 +163,18 @@ pub struct FileScanConfig { /// The maximum number of records to read from this plan. If `None`, /// all records after filtering are returned. pub limit: Option, - /// All equivalent lexicographical orderings that describe the schema. + /// Whether the scan's limit is order sensitive + /// When `true`, files must be read in the exact order specified to produce + /// correct results (e.g., for `ORDER BY ... LIMIT` queries). When `false`, + /// DataFusion may reorder file processing for optimization without affecting correctness. + pub preserve_order: bool, + /// All equivalent lexicographical output orderings of this file scan, in terms of + /// [`FileSource::table_schema`]. See [`FileScanConfigBuilder::with_output_ordering`] for more + /// details. + /// + /// [`Self::eq_properties`] uses this information along with projection + /// and filtering information to compute the effective + /// [`EquivalenceProperties`] pub output_ordering: Vec, /// File compression type pub file_compression_type: FileCompressionType, @@ -164,8 +186,11 @@ pub struct FileScanConfig { /// Expression adapter used to adapt filters and projections that are pushed down into the scan /// from the logical schema to the physical schema of the file. pub expr_adapter_factory: Option>, - /// Unprojected statistics for the table (file schema + partition columns). - /// These are projected on-demand via `projected_stats()`. + /// Statistics for the entire table (file schema + partition columns). + /// See [`FileScanConfigBuilder::with_statistics`] for more details. + /// + /// The effective statistics are computed on-demand via + /// [`ProjectionExprs::project_statistics`]. /// /// Note that this field is pub(crate) because accessing it directly from outside /// would be incorrect if there are filters being applied, thus this should be accessed @@ -240,6 +265,7 @@ pub struct FileScanConfigBuilder { object_store_url: ObjectStoreUrl, file_source: Arc, limit: Option, + preserve_order: bool, constraints: Option, file_groups: Vec, statistics: Option, @@ -269,6 +295,7 @@ impl FileScanConfigBuilder { output_ordering: vec![], file_compression_type: None, limit: None, + preserve_order: false, constraints: None, batch_size: None, expr_adapter_factory: None, @@ -276,22 +303,35 @@ impl FileScanConfigBuilder { } } - /// Set the maximum number of records to read from this plan. If `None`, - /// all records after filtering are returned. + /// Set the maximum number of records to read from this plan. + /// + /// If `None`, all records after filtering are returned. pub fn with_limit(mut self, limit: Option) -> Self { self.limit = limit; self } + /// Set whether the limit should be order-sensitive. + /// + /// When `true`, files must be read in the exact order specified to produce + /// correct results (e.g., for `ORDER BY ... LIMIT` queries). When `false`, + /// DataFusion may reorder file processing for optimization without + /// affecting correctness. + pub fn with_preserve_order(mut self, order_sensitive: bool) -> Self { + self.preserve_order = order_sensitive; + self + } + /// Set the file source for scanning files. /// - /// This method allows you to change the file source implementation (e.g. ParquetSource, CsvSource, etc.) - /// after the builder has been created. + /// This method allows you to change the file source implementation (e.g. + /// ParquetSource, CsvSource, etc.) after the builder has been created. pub fn with_source(mut self, file_source: Arc) -> Self { self.file_source = file_source; self } + /// Return the table schema pub fn table_schema(&self) -> &SchemaRef { self.file_source.table_schema().table_schema() } @@ -316,7 +356,12 @@ impl FileScanConfigBuilder { /// Set the columns on which to project the data using column indices. /// - /// Indexes that are higher than the number of columns of `file_schema` refer to `table_partition_cols`. + /// This method attempts to push down the projection to the underlying file + /// source if supported. If the file source does not support projection + /// pushdown, an error is returned. + /// + /// Indexes that are higher than the number of columns of `file_schema` + /// refer to `table_partition_cols`. pub fn with_projection_indices( mut self, indices: Option>, @@ -355,8 +400,18 @@ impl FileScanConfigBuilder { self } - /// Set the estimated overall statistics of the files, taking `filters` into account. - /// Defaults to [`Statistics::new_unknown`]. + /// Set the statistics of the files, including partition + /// columns. Defaults to [`Statistics::new_unknown`]. + /// + /// These statistics are for the entire table (file schema + partition + /// columns) before any projection or filtering is applied. Projections are + /// applied when statistics are retrieved, and if a filter is present, + /// [`FileScanConfig::statistics`] will mark the statistics as inexact + /// (counts are not adjusted). + /// + /// Projections and filters may be applied by the file source, either by + /// [`Self::with_projection_indices`] or a preexisting + /// [`FileSource::projection`] or [`FileSource::filter`]. pub fn with_statistics(mut self, statistics: Statistics) -> Self { self.statistics = Some(statistics); self @@ -392,6 +447,13 @@ impl FileScanConfigBuilder { } /// Set the output ordering of the files + /// + /// The expressions are in terms of the entire table schema (file schema + + /// partition columns), before any projection or filtering from the file + /// scan is applied. + /// + /// This is used for optimization purposes, e.g. to determine if a file scan + /// can satisfy an `ORDER BY` without an additional sort. pub fn with_output_ordering(mut self, output_ordering: Vec) -> Self { self.output_ordering = output_ordering; self @@ -450,6 +512,7 @@ impl FileScanConfigBuilder { object_store_url, file_source, limit, + preserve_order, constraints, file_groups, statistics, @@ -467,10 +530,14 @@ impl FileScanConfigBuilder { let file_compression_type = file_compression_type.unwrap_or(FileCompressionType::UNCOMPRESSED); + // If there is an output ordering, we should preserve it. + let preserve_order = preserve_order || !output_ordering.is_empty(); + FileScanConfig { object_store_url, file_source, limit, + preserve_order, constraints, file_groups, output_ordering, @@ -493,6 +560,7 @@ impl From for FileScanConfigBuilder { output_ordering: config.output_ordering, file_compression_type: Some(config.file_compression_type), limit: config.limit, + preserve_order: config.preserve_order, constraints: Some(config.constraints), batch_size: config.batch_size, expr_adapter_factory: config.expr_adapter_factory, @@ -661,11 +729,14 @@ impl DataSource for FileScanConfig { Partitioning::UnknownPartitioning(self.file_groups.len()) } + /// Computes the effective equivalence properties of this file scan, taking + /// into account the file schema, any projections or filters applied by the + /// file source, and the output ordering. fn eq_properties(&self) -> EquivalenceProperties { let schema = self.file_source.table_schema().table_schema(); let mut eq_properties = EquivalenceProperties::new_with_orderings( Arc::clone(schema), - self.output_ordering.clone(), + self.validated_output_ordering(), ) .with_constraints(self.constraints.clone()); @@ -771,37 +842,27 @@ impl DataSource for FileScanConfig { config: &ConfigOptions, ) -> Result>> { // Remap filter Column indices to match the table schema (file + partition columns). - // This is necessary because filters may have been created against a different schema - // (e.g., after projection pushdown) and need to be remapped to the table schema - // before being passed to the file source and ultimately serialized. - // For example, the filter being pushed down is `c1_c2 > 5` and it was created - // against the output schema of the this `DataSource` which has projection `c1 + c2 as c1_c2`. - // Thus we need to rewrite the filter back to `c1 + c2 > 5` before passing it to the file source. + // This is necessary because filters refer to the output schema of this `DataSource` + // (e.g., after projection pushdown has been applied) and need to be remapped to the table schema + // before being passed to the file source + // + // For example, consider a filter `c1_c2 > 5` being pushed down. If the + // `DataSource` has a projection `c1 + c2 as c1_c2`, the filter must be rewritten + // to refer to the table schema `c1 + c2 > 5` let table_schema = self.file_source.table_schema().table_schema(); - // If there's a projection with aliases, first map the filters back through - // the projection expressions before remapping to the table schema. let filters_to_remap = if let Some(projection) = self.file_source.projection() { - use datafusion_physical_plan::projection::update_expr; filters .into_iter() - .map(|filter| { - update_expr(&filter, projection.as_ref(), true)?.ok_or_else(|| { - internal_datafusion_err!( - "Failed to map filter expression through projection: {}", - filter - ) - }) - }) + .map(|filter| projection.unproject_expr(&filter)) .collect::>>()? } else { filters }; // Now remap column indices to match the table schema. - let remapped_filters: Result> = filters_to_remap + let remapped_filters = filters_to_remap .into_iter() - .map(|filter| reassign_expr_columns(filter, table_schema.as_ref())) - .collect(); - let remapped_filters = remapped_filters?; + .map(|filter| reassign_expr_columns(filter, table_schema)) + .collect::>>()?; let result = self .file_source @@ -829,20 +890,20 @@ impl DataSource for FileScanConfig { &self, order: &[PhysicalSortExpr], ) -> Result>> { - // Delegate to FileSource to check if reverse scanning can satisfy the request. + // Delegate to FileSource to see if it can optimize for the requested ordering. let pushdown_result = self .file_source - .try_reverse_output(order, &self.eq_properties())?; + .try_pushdown_sort(order, &self.eq_properties())?; match pushdown_result { SortOrderPushdownResult::Exact { inner } => { Ok(SortOrderPushdownResult::Exact { - inner: self.rebuild_with_source(inner, true)?, + inner: self.rebuild_with_source(inner, true, order)?, }) } SortOrderPushdownResult::Inexact { inner } => { Ok(SortOrderPushdownResult::Inexact { - inner: self.rebuild_with_source(inner, false)?, + inner: self.rebuild_with_source(inner, false, order)?, }) } SortOrderPushdownResult::Unsupported => { @@ -850,9 +911,55 @@ impl DataSource for FileScanConfig { } } } + + fn with_preserve_order(&self, preserve_order: bool) -> Option> { + if self.preserve_order == preserve_order { + return Some(Arc::new(self.clone())); + } + + let new_config = FileScanConfig { + preserve_order, + ..self.clone() + }; + Some(Arc::new(new_config)) + } } impl FileScanConfig { + /// Returns only the output orderings that are validated against actual + /// file group statistics. + /// + /// For example, individual files may be ordered by `col1 ASC`, + /// but if we have files with these min/max statistics in a single partition / file group: + /// + /// - file1: min(col1) = 10, max(col1) = 20 + /// - file2: min(col1) = 5, max(col1) = 15 + /// + /// Because reading file1 followed by file2 would produce out-of-order output (there is overlap + /// in the ranges), we cannot retain `col1 ASC` as a valid output ordering. + /// + /// Similarly this would not be a valid order (non-overlapping ranges but not ordered): + /// + /// - file1: min(col1) = 20, max(col1) = 30 + /// - file2: min(col1) = 10, max(col1) = 15 + /// + /// On the other hand if we had: + /// + /// - file1: min(col1) = 5, max(col1) = 15 + /// - file2: min(col1) = 16, max(col1) = 25 + /// + /// Then we know that reading file1 followed by file2 will produce ordered output, + /// so `col1 ASC` would be retained. + /// + /// Note that we are checking for ordering *within* *each* file group / partition, + /// files in different partitions are read independently and do not affect each other's ordering. + /// Merging of the multiple partition streams into a single ordered stream is handled + /// upstream e.g. by `SortPreservingMergeExec`. + fn validated_output_ordering(&self) -> Vec { + let schema = self.file_source.table_schema().table_schema(); + validate_orderings(&self.output_ordering, schema, &self.file_groups, None) + } + /// Get the file schema (schema of the files without partition columns) pub fn file_schema(&self) -> &SchemaRef { self.file_source.table_schema().file_schema() @@ -1123,19 +1230,44 @@ impl FileScanConfig { &self, new_file_source: Arc, is_exact: bool, + order: &[PhysicalSortExpr], ) -> Result> { let mut new_config = self.clone(); - // Reverse file groups (FileScanConfig's responsibility) - new_config.file_groups = new_config - .file_groups - .into_iter() - .map(|group| { - let mut files = group.into_inner(); - files.reverse(); - files.into() - }) - .collect(); + // Reverse file order (within each group) if the caller is requesting a reversal of this + // scan's declared output ordering. + // + // Historically this function always reversed `file_groups` because it was only reached + // via `FileSource::try_reverse_output` (where a reversal was the only supported + // optimization). + // + // Now that `FileSource::try_pushdown_sort` is generic, we must not assume reversal: other + // optimizations may become possible (e.g. already-sorted data, statistics-based file + // reordering). Therefore we only reverse files when it is known to help satisfy the + // requested ordering. + let reverse_file_groups = if self.output_ordering.is_empty() { + false + } else if let Some(requested) = LexOrdering::new(order.iter().cloned()) { + let projected_schema = self.projected_schema()?; + let orderings = project_orderings(&self.output_ordering, &projected_schema); + orderings + .iter() + .any(|ordering| ordering.is_reverse(&requested)) + } else { + false + }; + + if reverse_file_groups { + new_config.file_groups = new_config + .file_groups + .into_iter() + .map(|group| { + let mut files = group.into_inner(); + files.reverse(); + files.into() + }) + .collect(); + } new_config.file_source = new_file_source; @@ -1202,6 +1334,51 @@ fn ordered_column_indices_from_projection( .collect::>>() } +/// Check whether a given ordering is valid for all file groups by verifying +/// that files within each group are sorted according to their min/max statistics. +/// +/// For single-file (or empty) groups, the ordering is trivially valid. +/// For multi-file groups, we check that the min/max statistics for the sort +/// columns are in order and non-overlapping (or touching at boundaries). +/// +/// `projection` maps projected column indices back to table-schema indices +/// when validating after projection; pass `None` when validating at +/// table-schema level. +fn is_ordering_valid_for_file_groups( + file_groups: &[FileGroup], + ordering: &LexOrdering, + schema: &SchemaRef, + projection: Option<&[usize]>, +) -> bool { + file_groups.iter().all(|group| { + if group.len() <= 1 { + return true; // single-file groups are trivially sorted + } + match MinMaxStatistics::new_from_files(ordering, schema, projection, group.iter()) + { + Ok(stats) => stats.is_sorted(), + Err(_) => false, // can't prove sorted → reject + } + }) +} + +/// Filters orderings to retain only those valid for all file groups, +/// verified via min/max statistics. +fn validate_orderings( + orderings: &[LexOrdering], + schema: &SchemaRef, + file_groups: &[FileGroup], + projection: Option<&[usize]>, +) -> Vec { + orderings + .iter() + .filter(|ordering| { + is_ordering_valid_for_file_groups(file_groups, ordering, schema, projection) + }) + .cloned() + .collect() +} + /// The various listing tables does not attempt to read all files /// concurrently, instead they will read files in sequence within a /// partition. This is an important property as it allows plans to @@ -1268,52 +1445,47 @@ fn get_projected_output_ordering( let projected_orderings = project_orderings(&base_config.output_ordering, projected_schema); - let mut all_orderings = vec![]; - for new_ordering in projected_orderings { - // Check if any file groups are not sorted - if base_config.file_groups.iter().any(|group| { - if group.len() <= 1 { - // File groups with <= 1 files are always sorted - return false; - } - - let Some(indices) = base_config - .file_source - .projection() - .as_ref() - .map(|p| ordered_column_indices_from_projection(p)) - else { - // Can't determine if ordered without a simple projection - return true; - }; - - let statistics = match MinMaxStatistics::new_from_files( - &new_ordering, + let indices = base_config + .file_source + .projection() + .as_ref() + .map(|p| ordered_column_indices_from_projection(p)); + + match indices { + Some(Some(indices)) => { + // Simple column projection — validate with statistics + validate_orderings( + &projected_orderings, projected_schema, - indices.as_deref(), - group.iter(), - ) { - Ok(statistics) => statistics, - Err(e) => { - log::trace!("Error fetching statistics for file group: {e}"); - // we can't prove that it's ordered, so we have to reject it - return true; - } - }; - - !statistics.is_sorted() - }) { - debug!( - "Skipping specified output ordering {:?}. \ - Some file groups couldn't be determined to be sorted: {:?}", - base_config.output_ordering[0], base_config.file_groups - ); - continue; + &base_config.file_groups, + Some(indices.as_slice()), + ) + } + None => { + // No projection — validate with statistics (no remapping needed) + validate_orderings( + &projected_orderings, + projected_schema, + &base_config.file_groups, + None, + ) + } + Some(None) => { + // Complex projection (expressions, not simple columns) — can't + // determine column indices for statistics. Still valid if all + // file groups have at most one file. + if base_config.file_groups.iter().all(|g| g.len() <= 1) { + projected_orderings + } else { + debug!( + "Skipping specified output orderings. \ + Some file groups couldn't be determined to be sorted: {:?}", + base_config.file_groups + ); + vec![] + } } - - all_orderings.push(new_ordering); } - all_orderings } /// Convert type to a type suitable for use as a `ListingTable` @@ -1358,6 +1530,62 @@ mod tests { use datafusion_physical_expr::projection::ProjectionExpr; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; + #[derive(Clone)] + struct InexactSortPushdownSource { + metrics: ExecutionPlanMetricsSet, + table_schema: TableSchema, + } + + impl InexactSortPushdownSource { + fn new(table_schema: TableSchema) -> Self { + Self { + metrics: ExecutionPlanMetricsSet::new(), + table_schema, + } + } + } + + impl FileSource for InexactSortPushdownSource { + fn create_file_opener( + &self, + _object_store: Arc, + _base_config: &FileScanConfig, + _partition: usize, + ) -> Result> { + unimplemented!() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn table_schema(&self) -> &TableSchema { + &self.table_schema + } + + fn with_batch_size(&self, _batch_size: usize) -> Arc { + Arc::new(self.clone()) + } + + fn metrics(&self) -> &ExecutionPlanMetricsSet { + &self.metrics + } + + fn file_type(&self) -> &str { + "mock" + } + + fn try_pushdown_sort( + &self, + _order: &[PhysicalSortExpr], + _eq_properties: &EquivalenceProperties, + ) -> Result>> { + Ok(SortOrderPushdownResult::Inexact { + inner: Arc::new(self.clone()) as Arc, + }) + } + } + #[test] fn physical_plan_config_no_projection_tab_cols_as_field() { let file_schema = aggr_test_schema(); @@ -1661,43 +1889,40 @@ mod tests { impl From for PartitionedFile { fn from(file: File) -> Self { - PartitionedFile { - object_meta: ObjectMeta { - location: Path::from(format!( - "data/date={}/{}.parquet", - file.date, file.name - )), - last_modified: chrono::Utc.timestamp_nanos(0), - size: 0, - e_tag: None, - version: None, - }, - partition_values: vec![ScalarValue::from(file.date)], - range: None, - statistics: Some(Arc::new(Statistics { - num_rows: Precision::Absent, - total_byte_size: Precision::Absent, - column_statistics: file - .statistics - .into_iter() - .map(|stats| { - stats - .map(|(min, max)| ColumnStatistics { - min_value: Precision::Exact( - ScalarValue::Float64(min), - ), - max_value: Precision::Exact( - ScalarValue::Float64(max), - ), - ..Default::default() - }) - .unwrap_or_default() - }) - .collect::>(), - })), - extensions: None, - metadata_size_hint: None, - } + let object_meta = ObjectMeta { + location: Path::from(format!( + "data/date={}/{}.parquet", + file.date, file.name + )), + last_modified: chrono::Utc.timestamp_nanos(0), + size: 0, + e_tag: None, + version: None, + }; + let statistics = Arc::new(Statistics { + num_rows: Precision::Absent, + total_byte_size: Precision::Absent, + column_statistics: file + .statistics + .into_iter() + .map(|stats| { + stats + .map(|(min, max)| ColumnStatistics { + min_value: Precision::Exact(ScalarValue::Float64( + min, + )), + max_value: Precision::Exact(ScalarValue::Float64( + max, + )), + ..Default::default() + }) + .unwrap_or_default() + }) + .collect::>(), + }); + PartitionedFile::new_from_meta(object_meta) + .with_partition_values(vec![ScalarValue::from(file.date)]) + .with_statistics(statistics) } } } @@ -2306,4 +2531,56 @@ mod tests { _ => panic!("Expected Hash partitioning"), } } + + #[test] + fn try_pushdown_sort_reverses_file_groups_only_when_requested_is_reverse() + -> Result<()> { + let file_schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + + let table_schema = TableSchema::new(Arc::clone(&file_schema), vec![]); + let file_source = Arc::new(InexactSortPushdownSource::new(table_schema)); + + let file_groups = vec![FileGroup::new(vec![ + PartitionedFile::new("file1", 1), + PartitionedFile::new("file2", 1), + ])]; + + let sort_expr_asc = PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))); + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), file_source) + .with_file_groups(file_groups) + .with_output_ordering(vec![ + LexOrdering::new(vec![sort_expr_asc.clone()]).unwrap(), + ]) + .build(); + + let requested_asc = vec![sort_expr_asc.clone()]; + let result = config.try_pushdown_sort(&requested_asc)?; + let SortOrderPushdownResult::Inexact { inner } = result else { + panic!("Expected Inexact result"); + }; + let pushed_config = inner + .as_any() + .downcast_ref::() + .expect("Expected FileScanConfig"); + let pushed_files = pushed_config.file_groups[0].files(); + assert_eq!(pushed_files[0].object_meta.location.as_ref(), "file1"); + assert_eq!(pushed_files[1].object_meta.location.as_ref(), "file2"); + + let requested_desc = vec![sort_expr_asc.reverse()]; + let result = config.try_pushdown_sort(&requested_desc)?; + let SortOrderPushdownResult::Inexact { inner } = result else { + panic!("Expected Inexact result"); + }; + let pushed_config = inner + .as_any() + .downcast_ref::() + .expect("Expected FileScanConfig"); + let pushed_files = pushed_config.file_groups[0].files(); + assert_eq!(pushed_files[0].object_meta.location.as_ref(), "file2"); + assert_eq!(pushed_files[1].object_meta.location.as_ref(), "file1"); + + Ok(()) + } } diff --git a/datafusion/datasource/src/file_sink_config.rs b/datafusion/datasource/src/file_sink_config.rs index 643831a1199f..1abce86a3565 100644 --- a/datafusion/datasource/src/file_sink_config.rs +++ b/datafusion/datasource/src/file_sink_config.rs @@ -32,6 +32,52 @@ use datafusion_expr::dml::InsertOp; use async_trait::async_trait; use object_store::ObjectStore; +/// Determines how `FileSink` output paths are interpreted. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum FileOutputMode { + /// Infer output mode from the output URL (for example, by extension / trailing `/`). + #[default] + Automatic, + /// Write to a single output file at the exact output path. + SingleFile, + /// Write to a directory under the output path with generated filenames. + Directory, +} + +impl FileOutputMode { + /// Resolve this mode into a `single_file_output` boolean for the demuxer. + pub fn single_file_output(self, base_output_path: &ListingTableUrl) -> bool { + match self { + Self::Automatic => { + !base_output_path.is_collection() + && base_output_path.file_extension().is_some() + } + Self::SingleFile => true, + Self::Directory => false, + } + } +} + +impl From> for FileOutputMode { + fn from(value: Option) -> Self { + match value { + None => Self::Automatic, + Some(true) => Self::SingleFile, + Some(false) => Self::Directory, + } + } +} + +impl From for Option { + fn from(value: FileOutputMode) -> Self { + match value { + FileOutputMode::Automatic => None, + FileOutputMode::SingleFile => Some(true), + FileOutputMode::Directory => Some(false), + } + } +} + /// General behaviors for files that do `DataSink` operations #[async_trait] pub trait FileSink: DataSink { @@ -112,6 +158,8 @@ pub struct FileSinkConfig { pub keep_partition_by_columns: bool, /// File extension without a dot(.) pub file_extension: String, + /// Determines how the output path is interpreted. + pub file_output_mode: FileOutputMode, } impl FileSinkConfig { diff --git a/datafusion/datasource/src/mod.rs b/datafusion/datasource/src/mod.rs index 744ec667d50e..d19d20ec1ff3 100644 --- a/datafusion/datasource/src/mod.rs +++ b/datafusion/datasource/src/mod.rs @@ -24,7 +24,6 @@ // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -#![deny(clippy::allow_attributes)] //! A table that uses the `ObjectStore` listing capability //! to get the list of files to process. @@ -58,6 +57,7 @@ use chrono::TimeZone; use datafusion_common::stats::Precision; use datafusion_common::{ColumnStatistics, Result, exec_datafusion_err}; use datafusion_common::{ScalarValue, Statistics}; +use datafusion_physical_expr::LexOrdering; use futures::{Stream, StreamExt}; use object_store::{GetOptions, GetRange, ObjectStore}; use object_store::{ObjectMeta, path::Path}; @@ -133,6 +133,16 @@ pub struct PartitionedFile { /// When set via [`Self::with_statistics`], partition column statistics are automatically /// computed from [`Self::partition_values`] with exact min/max/null_count/distinct_count. pub statistics: Option>, + /// The known lexicographical ordering of the rows in this file, if any. + /// + /// This describes how the data within the file is sorted with respect to one or more + /// columns, and is used by the optimizer for planning operations that depend on input + /// ordering (e.g. merges, sorts, and certain aggregations). + /// + /// When available, this is typically inferred from file-level metadata exposed by the + /// underlying format (for example, Parquet `sorting_columns`), but it may also be set + /// explicitly via [`Self::with_ordering`]. + pub ordering: Option, /// An optional field for user defined per object metadata pub extensions: Option>, /// The estimated size of the parquet metadata, in bytes @@ -153,6 +163,20 @@ impl PartitionedFile { partition_values: vec![], range: None, statistics: None, + ordering: None, + extensions: None, + metadata_size_hint: None, + } + } + + /// Create a file from a known ObjectMeta without partition + pub fn new_from_meta(object_meta: ObjectMeta) -> Self { + Self { + object_meta, + partition_values: vec![], + range: None, + statistics: None, + ordering: None, extensions: None, metadata_size_hint: None, } @@ -171,12 +195,20 @@ impl PartitionedFile { partition_values: vec![], range: Some(FileRange { start, end }), statistics: None, + ordering: None, extensions: None, metadata_size_hint: None, } .with_range(start, end) } + /// Attach partition values to this file. + /// This replaces any existing partition values. + pub fn with_partition_values(mut self, partition_values: Vec) -> Self { + self.partition_values = partition_values; + self + } + /// Size of the file to be scanned (taking into account the range, if present). pub fn effective_size(&self) -> u64 { if let Some(range) = &self.range { @@ -282,6 +314,15 @@ impl PartitionedFile { false } } + + /// Set the known ordering of data in this file. + /// + /// The ordering represents the lexicographical sort order of the data, + /// typically inferred from file metadata (e.g., Parquet sorting_columns). + pub fn with_ordering(mut self, ordering: Option) -> Self { + self.ordering = ordering; + self + } } impl From for PartitionedFile { @@ -291,6 +332,7 @@ impl From for PartitionedFile { partition_values: vec![], range: None, statistics: None, + ordering: None, extensions: None, metadata_size_hint: None, } @@ -487,6 +529,7 @@ pub fn generate_test_files(num_files: usize, overlap_factor: f64) -> Vec, - cache: PlanProperties, + cache: Arc, } impl Debug for DataSinkExec { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "DataSinkExec schema: {:?}", self.count_schema) + write!(f, "DataSinkExec schema: {}", self.count_schema) } } @@ -117,7 +117,7 @@ impl DataSinkExec { sink, count_schema: make_count_schema(), sort_order, - cache, + cache: Arc::new(cache), } } @@ -174,7 +174,7 @@ impl ExecutionPlan for DataSinkExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } diff --git a/datafusion/datasource/src/source.rs b/datafusion/datasource/src/source.rs index a3892dfac977..05028ed0f468 100644 --- a/datafusion/datasource/src/source.rs +++ b/datafusion/datasource/src/source.rs @@ -74,8 +74,8 @@ use datafusion_physical_plan::filter_pushdown::{ /// ```text /// ┌─────────────────────┐ -----► execute path /// │ │ ┄┄┄┄┄► init path -/// │ DataSourceExec │ -/// │ │ +/// │ DataSourceExec │ +/// │ │ /// └───────▲─────────────┘ /// ┊ │ /// ┊ │ @@ -158,16 +158,6 @@ pub trait DataSource: Send + Sync + Debug { /// across all partitions if `partition` is `None`. fn partition_statistics(&self, partition: Option) -> Result; - /// Returns aggregate statistics across all partitions. - /// - /// # Deprecated - /// Use [`Self::partition_statistics`] instead, which provides more fine-grained - /// control over statistics retrieval (per-partition or aggregate). - #[deprecated(since = "51.0.0", note = "Use partition_statistics instead")] - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - /// Return a copy of this DataSource with a new fetch limit fn with_fetch(&self, _limit: Option) -> Option>; fn fetch(&self) -> Option; @@ -178,7 +168,13 @@ pub trait DataSource: Send + Sync + Debug { &self, _projection: &ProjectionExprs, ) -> Result>>; + /// Try to push down filters into this DataSource. + /// + /// These filters are in terms of the output schema of this DataSource (e.g. + /// [`Self::eq_properties`] and output of any projections pushed into the + /// source), not the original table schema. + /// /// See [`ExecutionPlan::handle_child_pushdown_result`] for more details. /// /// [`ExecutionPlan::handle_child_pushdown_result`]: datafusion_physical_plan::ExecutionPlan::handle_child_pushdown_result @@ -210,6 +206,11 @@ pub trait DataSource: Send + Sync + Debug { ) -> Result>> { Ok(SortOrderPushdownResult::Unsupported) } + + /// Returns a variant of this `DataSource` that is aware of order-sensitivity. + fn with_preserve_order(&self, _preserve_order: bool) -> Option> { + None + } } /// [`ExecutionPlan`] that reads one or more files @@ -229,7 +230,7 @@ pub struct DataSourceExec { /// The source of the data -- for example, `FileScanConfig` or `MemorySourceConfig` data_source: Arc, /// Cached plan properties such as sort order - cache: PlanProperties, + cache: Arc, } impl DisplayAs for DataSourceExec { @@ -253,7 +254,7 @@ impl ExecutionPlan for DataSourceExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -323,7 +324,7 @@ impl ExecutionPlan for DataSourceExec { fn with_fetch(&self, limit: Option) -> Option> { let data_source = self.data_source.with_fetch(limit)?; - let cache = self.cache.clone(); + let cache = Arc::clone(&self.cache); Some(Arc::new(Self { data_source, cache })) } @@ -367,7 +368,8 @@ impl ExecutionPlan for DataSourceExec { let mut new_node = self.clone(); new_node.data_source = data_source; // Re-compute properties since we have new filters which will impact equivalence info - new_node.cache = Self::compute_properties(&new_node.data_source); + new_node.cache = + Arc::new(Self::compute_properties(&new_node.data_source)); Ok(FilterPushdownPropagation { filters: res.filters, @@ -393,6 +395,18 @@ impl ExecutionPlan for DataSourceExec { Ok(Arc::new(new_exec) as Arc) }) } + + fn with_preserve_order( + &self, + preserve_order: bool, + ) -> Option> { + self.data_source + .with_preserve_order(preserve_order) + .map(|new_data_source| { + Arc::new(self.clone().with_data_source(new_data_source)) + as Arc + }) + } } impl DataSourceExec { @@ -403,7 +417,10 @@ impl DataSourceExec { // Default constructor for `DataSourceExec`, setting the `cooperative` flag to `true`. pub fn new(data_source: Arc) -> Self { let cache = Self::compute_properties(&data_source); - Self { data_source, cache } + Self { + data_source, + cache: Arc::new(cache), + } } /// Return the source object @@ -412,20 +429,20 @@ impl DataSourceExec { } pub fn with_data_source(mut self, data_source: Arc) -> Self { - self.cache = Self::compute_properties(&data_source); + self.cache = Arc::new(Self::compute_properties(&data_source)); self.data_source = data_source; self } /// Assign constraints pub fn with_constraints(mut self, constraints: Constraints) -> Self { - self.cache = self.cache.with_constraints(constraints); + Arc::make_mut(&mut self.cache).set_constraints(constraints); self } /// Assign output partitioning pub fn with_partitioning(mut self, partitioning: Partitioning) -> Self { - self.cache = self.cache.with_partitioning(partitioning); + Arc::make_mut(&mut self.cache).partitioning = partitioning; self } diff --git a/datafusion/datasource/src/statistics.rs b/datafusion/datasource/src/statistics.rs index 2f34ca032e13..b1a56e096c22 100644 --- a/datafusion/datasource/src/statistics.rs +++ b/datafusion/datasource/src/statistics.rs @@ -266,11 +266,12 @@ impl MinMaxStatistics { } /// Check if the min/max statistics are in order and non-overlapping + /// (or touching at boundaries) pub fn is_sorted(&self) -> bool { self.max_by_sort_order .iter() .zip(self.min_by_sort_order.iter().skip(1)) - .all(|(max, next_min)| max < next_min) + .all(|(max, next_min)| max <= next_min) } } diff --git a/datafusion/datasource/src/table_schema.rs b/datafusion/datasource/src/table_schema.rs index a45cdbaaea07..5b7fc4727df0 100644 --- a/datafusion/datasource/src/table_schema.rs +++ b/datafusion/datasource/src/table_schema.rs @@ -20,13 +20,13 @@ use arrow::datatypes::{FieldRef, SchemaBuilder, SchemaRef}; use std::sync::Arc; -/// Helper to hold table schema information for partitioned data sources. +/// The overall schema for potentially partitioned data sources. /// -/// When reading partitioned data (such as Hive-style partitioning), a table's schema +/// When reading partitioned data (such as Hive-style partitioning), a [`TableSchema`] /// consists of two parts: /// 1. **File schema**: The schema of the actual data files on disk -/// 2. **Partition columns**: Columns that are encoded in the directory structure, -/// not stored in the files themselves +/// 2. **Partition columns**: Columns whose values are encoded in the directory structure, +/// but not stored in the files themselves /// /// # Example: Partitioned Table /// diff --git a/datafusion/datasource/src/url.rs b/datafusion/datasource/src/url.rs index 155d6efe462c..39d1047984ff 100644 --- a/datafusion/datasource/src/url.rs +++ b/datafusion/datasource/src/url.rs @@ -17,7 +17,9 @@ use std::sync::Arc; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{DataFusionError, Result, TableReference}; +use datafusion_execution::cache::TableScopedPath; +use datafusion_execution::cache::cache_manager::CachedFileList; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_session::Session; @@ -28,7 +30,7 @@ use itertools::Itertools; use log::debug; use object_store::path::DELIMITER; use object_store::path::Path; -use object_store::{ObjectMeta, ObjectStore}; +use object_store::{ObjectMeta, ObjectStore, ObjectStoreExt}; use url::Url; /// A parsed URL identifying files for a listing table, see [`ListingTableUrl::parse`] @@ -41,6 +43,8 @@ pub struct ListingTableUrl { prefix: Path, /// An optional glob expression used to filter files glob: Option, + /// Optional table reference for the table this url belongs to + table_ref: Option, } impl ListingTableUrl { @@ -145,7 +149,12 @@ impl ListingTableUrl { /// to create a [`ListingTableUrl`]. pub fn try_new(url: Url, glob: Option) -> Result { let prefix = Path::from_url_path(url.path())?; - Ok(Self { url, prefix, glob }) + Ok(Self { + url, + prefix, + glob, + table_ref: None, + }) } /// Returns the URL scheme @@ -255,7 +264,14 @@ impl ListingTableUrl { }; let list: BoxStream<'a, Result> = if self.is_collection() { - list_with_cache(ctx, store, &self.prefix, prefix.as_ref()).await? + list_with_cache( + ctx, + store, + self.table_ref.as_ref(), + &self.prefix, + prefix.as_ref(), + ) + .await? } else { match store.head(&full_prefix).await { Ok(meta) => futures::stream::once(async { Ok(meta) }) @@ -264,7 +280,14 @@ impl ListingTableUrl { // If the head command fails, it is likely that object doesn't exist. // Retry as though it were a prefix (aka a collection) Err(object_store::Error::NotFound { .. }) => { - list_with_cache(ctx, store, &self.prefix, prefix.as_ref()).await? + list_with_cache( + ctx, + store, + self.table_ref.as_ref(), + &self.prefix, + prefix.as_ref(), + ) + .await? } Err(e) => return Err(e.into()), } @@ -318,10 +341,21 @@ impl ListingTableUrl { } /// Returns a copy of current [`ListingTableUrl`] with a specified `glob` - pub fn with_glob(self, glob: &str) -> Result { - let glob = - Pattern::new(glob).map_err(|e| DataFusionError::External(Box::new(e)))?; - Self::try_new(self.url, Some(glob)) + pub fn with_glob(mut self, glob: &str) -> Result { + self.glob = + Some(Pattern::new(glob).map_err(|e| DataFusionError::External(Box::new(e)))?); + Ok(self) + } + + /// Set the table reference for this [`ListingTableUrl`] + pub fn with_table_ref(mut self, table_ref: TableReference) -> Self { + self.table_ref = Some(table_ref); + self + } + + /// Return the table reference for this [`ListingTableUrl`] + pub fn get_table_ref(&self) -> &Option { + &self.table_ref } } @@ -345,6 +379,7 @@ impl ListingTableUrl { async fn list_with_cache<'b>( ctx: &'b dyn Session, store: &'b dyn ObjectStore, + table_ref: Option<&TableReference>, table_base_path: &Path, prefix: Option<&Path>, ) -> Result>> { @@ -364,37 +399,35 @@ async fn list_with_cache<'b>( .map(|res| res.map_err(|e| DataFusionError::ObjectStore(Box::new(e)))) .boxed()), Some(cache) => { - // Convert prefix to Option for cache lookup - let prefix_filter = prefix.cloned(); + // Build the filter prefix (only Some if prefix was requested) + let filter_prefix = prefix.is_some().then(|| full_prefix.clone()); + + let table_scoped_base_path = TableScopedPath { + table: table_ref.cloned(), + path: table_base_path.clone(), + }; - // Try cache lookup with optional prefix filter - let vec = if let Some(res) = - cache.get_with_extra(table_base_path, &prefix_filter) - { + // Try cache lookup - get returns CachedFileList + let vec = if let Some(cached) = cache.get(&table_scoped_base_path) { debug!("Hit list files cache"); - res.as_ref().clone() + cached.files_matching_prefix(&filter_prefix) } else { // Cache miss - always list and cache the full table // This ensures we have complete data for future prefix queries - let vec = store + let mut vec = store .list(Some(table_base_path)) .try_collect::>() .await?; - cache.put(table_base_path, Arc::new(vec.clone())); - - // If a prefix filter was requested, apply it to the results - if prefix.is_some() { - let full_prefix_str = full_prefix.as_ref(); - vec.into_iter() - .filter(|meta| { - meta.location.as_ref().starts_with(full_prefix_str) - }) - .collect() - } else { - vec - } + vec.shrink_to_fit(); // Right-size before caching + let cached: CachedFileList = vec.into(); + let result = cached.files_matching_prefix(&filter_prefix); + cache.put(&table_scoped_base_path, cached); + result }; - Ok(futures::stream::iter(vec.into_iter().map(Ok)).boxed()) + Ok( + futures::stream::iter(Arc::unwrap_or_clone(vec).into_iter().map(Ok)) + .boxed(), + ) } } } @@ -488,12 +521,13 @@ mod tests { use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_plan::ExecutionPlan; use object_store::{ - GetOptions, GetResult, ListResult, MultipartUpload, PutMultipartOptions, - PutPayload, + CopyOptions, GetOptions, GetResult, ListResult, MultipartUpload, + PutMultipartOptions, PutPayload, }; use std::any::Any; use std::collections::HashMap; use std::ops::Range; + use std::sync::Arc; use tempfile::tempdir; #[test] @@ -1074,7 +1108,14 @@ mod tests { location: &Path, options: GetOptions, ) -> object_store::Result { - self.in_mem.get_opts(location, options).await + if options.head && self.forbidden_paths.contains(location) { + Err(object_store::Error::PermissionDenied { + path: location.to_string(), + source: "forbidden".into(), + }) + } else { + self.in_mem.get_opts(location, options).await + } } async fn get_ranges( @@ -1085,19 +1126,11 @@ mod tests { self.in_mem.get_ranges(location, ranges).await } - async fn head(&self, location: &Path) -> object_store::Result { - if self.forbidden_paths.contains(location) { - Err(object_store::Error::PermissionDenied { - path: location.to_string(), - source: "forbidden".into(), - }) - } else { - self.in_mem.head(location).await - } - } - - async fn delete(&self, location: &Path) -> object_store::Result<()> { - self.in_mem.delete(location).await + fn delete_stream( + &self, + locations: BoxStream<'static, object_store::Result>, + ) -> BoxStream<'static, object_store::Result> { + self.in_mem.delete_stream(locations) } fn list( @@ -1114,16 +1147,13 @@ mod tests { self.in_mem.list_with_delimiter(prefix).await } - async fn copy(&self, from: &Path, to: &Path) -> object_store::Result<()> { - self.in_mem.copy(from, to).await - } - - async fn copy_if_not_exists( + async fn copy_opts( &self, from: &Path, to: &Path, + options: CopyOptions, ) -> object_store::Result<()> { - self.in_mem.copy_if_not_exists(from, to).await + self.in_mem.copy_opts(from, to, options).await } } diff --git a/datafusion/datasource/src/write/demux.rs b/datafusion/datasource/src/write/demux.rs index bec5b8b0bff0..1648624747af 100644 --- a/datafusion/datasource/src/write/demux.rs +++ b/datafusion/datasource/src/write/demux.rs @@ -35,8 +35,8 @@ use arrow::datatypes::{DataType, Schema}; use datafusion_common::cast::{ as_boolean_array, as_date32_array, as_date64_array, as_float16_array, as_float32_array, as_float64_array, as_int8_array, as_int16_array, as_int32_array, - as_int64_array, as_string_array, as_string_view_array, as_uint8_array, - as_uint16_array, as_uint32_array, as_uint64_array, + as_int64_array, as_large_string_array, as_string_array, as_string_view_array, + as_uint8_array, as_uint16_array, as_uint32_array, as_uint64_array, }; use datafusion_common::{exec_datafusion_err, internal_datafusion_err, not_impl_err}; use datafusion_common_runtime::SpawnedTask; @@ -106,8 +106,9 @@ pub(crate) fn start_demuxer_task( let file_extension = config.file_extension.clone(); let base_output_path = config.table_paths[0].clone(); let task = if config.table_partition_cols.is_empty() { - let single_file_output = !base_output_path.is_collection() - && base_output_path.file_extension().is_some(); + let single_file_output = config + .file_output_mode + .single_file_output(&base_output_path); SpawnedTask::spawn(async move { row_count_demuxer( tx, @@ -397,6 +398,12 @@ fn compute_partition_keys_by_row<'a>( partition_values.push(Cow::from(array.value(i))); } } + DataType::LargeUtf8 => { + let array = as_large_string_array(col_array)?; + for i in 0..rb.num_rows() { + partition_values.push(Cow::from(array.value(i))); + } + } DataType::Utf8View => { let array = as_string_view_array(col_array)?; for i in 0..rb.num_rows() { diff --git a/datafusion/doc/src/lib.rs b/datafusion/doc/src/lib.rs index 836cb9345b51..591a5a62f3b2 100644 --- a/datafusion/doc/src/lib.rs +++ b/datafusion/doc/src/lib.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -#![deny(clippy::allow_attributes)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] #![doc( html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", diff --git a/datafusion/execution/Cargo.toml b/datafusion/execution/Cargo.toml index ca1fba07cae2..c8371d2eddb8 100644 --- a/datafusion/execution/Cargo.toml +++ b/datafusion/execution/Cargo.toml @@ -46,15 +46,20 @@ default = ["sql"] parquet_encryption = [ "parquet/encryption", ] +arrow_buffer_pool = [ + "arrow-buffer/pool", +] sql = [] [dependencies] arrow = { workspace = true } +arrow-buffer = { workspace = true } async-trait = { workspace = true } chrono = { workspace = true } dashmap = { workspace = true } datafusion-common = { workspace = true, default-features = false } datafusion-expr = { workspace = true, default-features = false } +datafusion-physical-expr-common = { workspace = true, default-features = false } futures = { workspace = true } log = { workspace = true } object_store = { workspace = true, features = ["fs"] } diff --git a/datafusion/execution/src/cache/cache_manager.rs b/datafusion/execution/src/cache/cache_manager.rs index c76a68c651eb..bd34c441bdbd 100644 --- a/datafusion/execution/src/cache/cache_manager.rs +++ b/datafusion/execution/src/cache/cache_manager.rs @@ -15,15 +15,21 @@ // specific language governing permissions and limitations // under the License. +use crate::cache::CacheAccessor; +use crate::cache::DefaultListFilesCache; use crate::cache::cache_unit::DefaultFilesMetadataCache; -use crate::cache::{CacheAccessor, DefaultListFilesCache}; +use crate::cache::list_files_cache::ListFilesEntry; +use crate::cache::list_files_cache::TableScopedPath; +use datafusion_common::TableReference; use datafusion_common::stats::Precision; use datafusion_common::{Result, Statistics}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use object_store::ObjectMeta; use object_store::path::Path; use std::any::Any; use std::collections::HashMap; use std::fmt::{Debug, Formatter}; +use std::ops::Deref; use std::sync::Arc; use std::time::Duration; @@ -31,16 +37,61 @@ pub use super::list_files_cache::{ DEFAULT_LIST_FILES_CACHE_MEMORY_LIMIT, DEFAULT_LIST_FILES_CACHE_TTL, }; -/// A cache for [`Statistics`]. +/// Cached metadata for a file, including statistics and ordering. +/// +/// This struct embeds the [`ObjectMeta`] used for cache validation, +/// along with the cached statistics and ordering information. +#[derive(Debug, Clone)] +pub struct CachedFileMetadata { + /// File metadata used for cache validation (size, last_modified). + pub meta: ObjectMeta, + /// Cached statistics for the file, if available. + pub statistics: Arc, + /// Cached ordering for the file. + pub ordering: Option, +} + +impl CachedFileMetadata { + /// Create a new cached file metadata entry. + pub fn new( + meta: ObjectMeta, + statistics: Arc, + ordering: Option, + ) -> Self { + Self { + meta, + statistics, + ordering, + } + } + + /// Check if this cached entry is still valid for the given metadata. + /// + /// Returns true if the file size and last modified time match. + pub fn is_valid_for(&self, current_meta: &ObjectMeta) -> bool { + self.meta.size == current_meta.size + && self.meta.last_modified == current_meta.last_modified + } +} + +/// A cache for file statistics and orderings. +/// +/// This cache stores [`CachedFileMetadata`] which includes: +/// - File metadata for validation (size, last_modified) +/// - Statistics for the file +/// - Ordering information for the file /// /// If enabled via [`CacheManagerConfig::with_files_statistics_cache`] this /// cache avoids inferring the same file statistics repeatedly during the /// session lifetime. /// +/// The typical usage pattern is: +/// 1. Call `get(path)` to check for cached value +/// 2. If `Some(cached)`, validate with `cached.is_valid_for(¤t_meta)` +/// 3. If invalid or missing, compute new value and call `put(path, new_value)` +/// /// See [`crate::runtime_env::RuntimeEnv`] for more details -pub trait FileStatisticsCache: - CacheAccessor, Extra = ObjectMeta> -{ +pub trait FileStatisticsCache: CacheAccessor { /// Retrieves the information about the entries currently cached. fn list_entries(&self) -> HashMap; } @@ -58,6 +109,63 @@ pub struct FileStatisticsCacheEntry { pub table_size_bytes: Precision, /// Size of the statistics entry, in bytes. pub statistics_size_bytes: usize, + /// Whether ordering information is cached for this file. + pub has_ordering: bool, +} + +/// Cached file listing. +/// +/// TTL expiration is handled internally by the cache implementation. +#[derive(Debug, Clone, PartialEq)] +pub struct CachedFileList { + /// The cached file list. + pub files: Arc>, +} + +impl CachedFileList { + /// Create a new cached file list. + pub fn new(files: Vec) -> Self { + Self { + files: Arc::new(files), + } + } + + /// Filter the files by prefix. + fn filter_by_prefix(&self, prefix: &Option) -> Vec { + match prefix { + Some(prefix) => self + .files + .iter() + .filter(|meta| meta.location.as_ref().starts_with(prefix.as_ref())) + .cloned() + .collect(), + None => self.files.as_ref().clone(), + } + } + + /// Returns files matching the given prefix. + /// + /// When prefix is `None`, returns a clone of the `Arc` (no data copy). + /// When filtering is needed, returns a new `Arc` with filtered results (clones each matching [`ObjectMeta`]). + pub fn files_matching_prefix(&self, prefix: &Option) -> Arc> { + match prefix { + None => Arc::clone(&self.files), + Some(p) => Arc::new(self.filter_by_prefix(&Some(p.clone()))), + } + } +} + +impl Deref for CachedFileList { + type Target = Arc>; + fn deref(&self) -> &Self::Target { + &self.files + } +} + +impl From> for CachedFileList { + fn from(files: Vec) -> Self { + Self::new(files) + } } /// Cache for storing the [`ObjectMeta`]s that result from listing a path @@ -67,21 +175,12 @@ pub struct FileStatisticsCacheEntry { /// especially when done over remote object stores. /// /// The cache key is always the table's base path, ensuring a stable cache key. -/// The `Extra` type is `Option`, representing an optional prefix filter -/// (relative to the table base path) for partition-aware lookups. -/// -/// When `get_with_extra(key, Some(prefix))` is called: -/// - The cache entry for `key` (table base path) is fetched -/// - Results are filtered to only include files matching `key/prefix` -/// - Filtered results are returned without making a storage call +/// The cached value is a [`CachedFileList`] containing the files and a timestamp. /// -/// This enables efficient partition pruning: a single cached listing of the -/// full table can serve queries for any partition subset. +/// Partition filtering is done after retrieval using [`CachedFileList::files_matching_prefix`]. /// /// See [`crate::runtime_env::RuntimeEnv`] for more details. -pub trait ListFilesCache: - CacheAccessor>, Extra = Option> -{ +pub trait ListFilesCache: CacheAccessor { /// Returns the cache's memory limit in bytes. fn cache_limit(&self) -> usize; @@ -93,6 +192,12 @@ pub trait ListFilesCache: /// Updates the cache with a new TTL (time-to-live). fn update_cache_ttl(&self, ttl: Option); + + /// Retrieves the information about the entries currently cached. + fn list_entries(&self) -> HashMap; + + /// Drop all entries for the given table reference. + fn drop_table_entries(&self, table_ref: &Option) -> Result<()>; } /// Generic file-embedded metadata used with [`FileMetadataCache`]. @@ -113,9 +218,44 @@ pub trait FileMetadata: Any + Send + Sync { fn extra_info(&self) -> HashMap; } +/// Cached file metadata entry with validation information. +#[derive(Clone)] +pub struct CachedFileMetadataEntry { + /// File metadata used for cache validation (size, last_modified). + pub meta: ObjectMeta, + /// The cached file metadata. + pub file_metadata: Arc, +} + +impl CachedFileMetadataEntry { + /// Create a new cached file metadata entry. + pub fn new(meta: ObjectMeta, file_metadata: Arc) -> Self { + Self { + meta, + file_metadata, + } + } + + /// Check if this cached entry is still valid for the given metadata. + pub fn is_valid_for(&self, current_meta: &ObjectMeta) -> bool { + self.meta.size == current_meta.size + && self.meta.last_modified == current_meta.last_modified + } +} + +impl Debug for CachedFileMetadataEntry { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CachedFileMetadataEntry") + .field("meta", &self.meta) + .field("memory_size", &self.file_metadata.memory_size()) + .finish() + } +} + /// Cache for file-embedded metadata. /// -/// This cache stores per-file metadata in the form of [`FileMetadata`], +/// This cache stores per-file metadata in the form of [`CachedFileMetadataEntry`], +/// which includes the [`ObjectMeta`] for validation. /// /// For example, the built in [`ListingTable`] uses this cache to avoid parsing /// Parquet footers multiple times for the same file. @@ -124,12 +264,15 @@ pub trait FileMetadata: Any + Send + Sync { /// and users can also provide their own implementations to implement custom /// caching strategies. /// +/// The typical usage pattern is: +/// 1. Call `get(path)` to check for cached value +/// 2. If `Some(cached)`, validate with `cached.is_valid_for(¤t_meta)` +/// 3. If invalid or missing, compute new value and call `put(path, new_value)` +/// /// See [`crate::runtime_env::RuntimeEnv`] for more details. /// /// [`ListingTable`]: https://docs.rs/datafusion/latest/datafusion/datasource/listing/struct.ListingTable.html -pub trait FileMetadataCache: - CacheAccessor, Extra = ObjectMeta> -{ +pub trait FileMetadataCache: CacheAccessor { /// Returns the cache's memory limit in bytes. fn cache_limit(&self) -> usize; diff --git a/datafusion/execution/src/cache/cache_unit.rs b/datafusion/execution/src/cache/cache_unit.rs index 5351df449a7c..d98d23821ec7 100644 --- a/datafusion/execution/src/cache/cache_unit.rs +++ b/datafusion/execution/src/cache/cache_unit.rs @@ -16,46 +16,80 @@ // under the License. use std::collections::HashMap; -use std::sync::Arc; use crate::cache::CacheAccessor; -use crate::cache::cache_manager::{FileStatisticsCache, FileStatisticsCacheEntry}; - -use datafusion_common::Statistics; +use crate::cache::cache_manager::{ + CachedFileMetadata, FileStatisticsCache, FileStatisticsCacheEntry, +}; use dashmap::DashMap; -use object_store::ObjectMeta; use object_store::path::Path; pub use crate::cache::DefaultFilesMetadataCache; /// Default implementation of [`FileStatisticsCache`] /// -/// Stores collected statistics for files +/// Stores cached file metadata (statistics and orderings) for files. +/// +/// The typical usage pattern is: +/// 1. Call `get(path)` to check for cached value +/// 2. If `Some(cached)`, validate with `cached.is_valid_for(¤t_meta)` +/// 3. If invalid or missing, compute new value and call `put(path, new_value)` /// -/// Cache is invalided when file size or last modification has changed +/// Uses DashMap for lock-free concurrent access. /// /// [`FileStatisticsCache`]: crate::cache::cache_manager::FileStatisticsCache #[derive(Default)] pub struct DefaultFileStatisticsCache { - statistics: DashMap)>, + cache: DashMap, +} + +impl CacheAccessor for DefaultFileStatisticsCache { + fn get(&self, key: &Path) -> Option { + self.cache.get(key).map(|entry| entry.value().clone()) + } + + fn put(&self, key: &Path, value: CachedFileMetadata) -> Option { + self.cache.insert(key.clone(), value) + } + + fn remove(&self, k: &Path) -> Option { + self.cache.remove(k).map(|(_, entry)| entry) + } + + fn contains_key(&self, k: &Path) -> bool { + self.cache.contains_key(k) + } + + fn len(&self) -> usize { + self.cache.len() + } + + fn clear(&self) { + self.cache.clear(); + } + + fn name(&self) -> String { + "DefaultFileStatisticsCache".to_string() + } } impl FileStatisticsCache for DefaultFileStatisticsCache { fn list_entries(&self) -> HashMap { let mut entries = HashMap::::new(); - for entry in &self.statistics { + for entry in self.cache.iter() { let path = entry.key(); - let (object_meta, stats) = entry.value(); + let cached = entry.value(); entries.insert( path.clone(), FileStatisticsCacheEntry { - object_meta: object_meta.clone(), - num_rows: stats.num_rows, - num_columns: stats.column_statistics.len(), - table_size_bytes: stats.total_byte_size, + object_meta: cached.meta.clone(), + num_rows: cached.statistics.num_rows, + num_columns: cached.statistics.column_statistics.len(), + table_size_bytes: cached.statistics.total_byte_size, statistics_size_bytes: 0, // TODO: set to the real size in the future + has_ordering: cached.ordering.is_some(), }, ); } @@ -64,141 +98,319 @@ impl FileStatisticsCache for DefaultFileStatisticsCache { } } -impl CacheAccessor> for DefaultFileStatisticsCache { - type Extra = ObjectMeta; +#[cfg(test)] +mod tests { + use super::*; + use crate::cache::CacheAccessor; + use crate::cache::cache_manager::{ + CachedFileMetadata, FileStatisticsCache, FileStatisticsCacheEntry, + }; + use arrow::array::RecordBatch; + use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; + use chrono::DateTime; + use datafusion_common::Statistics; + use datafusion_common::stats::Precision; + use datafusion_expr::ColumnarValue; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; + use object_store::ObjectMeta; + use object_store::path::Path; + use std::sync::Arc; - /// Get `Statistics` for file location. - fn get(&self, k: &Path) -> Option> { - self.statistics - .get(k) - .map(|s| Some(Arc::clone(&s.value().1))) - .unwrap_or(None) + fn create_test_meta(path: &str, size: u64) -> ObjectMeta { + ObjectMeta { + location: Path::from(path), + last_modified: DateTime::parse_from_rfc3339("2022-09-27T22:36:00+02:00") + .unwrap() + .into(), + size, + e_tag: None, + version: None, + } } - /// Get `Statistics` for file location. Returns None if file has changed or not found. - fn get_with_extra(&self, k: &Path, e: &Self::Extra) -> Option> { - self.statistics - .get(k) - .map(|s| { - let (saved_meta, statistics) = s.value(); - if saved_meta.size != e.size - || saved_meta.last_modified != e.last_modified - { - // file has changed - None - } else { - Some(Arc::clone(statistics)) - } - }) - .unwrap_or(None) - } + #[test] + fn test_statistics_cache() { + let meta = create_test_meta("test", 1024); + let cache = DefaultFileStatisticsCache::default(); - /// Save collected file statistics - fn put(&self, _key: &Path, _value: Arc) -> Option> { - panic!("Put cache in DefaultFileStatisticsCache without Extra not supported.") - } + let schema = Schema::new(vec![Field::new( + "test_column", + DataType::Timestamp(TimeUnit::Second, None), + false, + )]); - fn put_with_extra( - &self, - key: &Path, - value: Arc, - e: &Self::Extra, - ) -> Option> { - self.statistics - .insert(key.clone(), (e.clone(), value)) - .map(|x| x.1) - } + // Cache miss + assert!(cache.get(&meta.location).is_none()); + + // Put a value + let cached_value = CachedFileMetadata::new( + meta.clone(), + Arc::new(Statistics::new_unknown(&schema)), + None, + ); + cache.put(&meta.location, cached_value); + + // Cache hit + let result = cache.get(&meta.location); + assert!(result.is_some()); + let cached = result.unwrap(); + assert!(cached.is_valid_for(&meta)); + + // File size changed - validation should fail + let meta2 = create_test_meta("test", 2048); + let cached = cache.get(&meta2.location).unwrap(); + assert!(!cached.is_valid_for(&meta2)); + + // Update with new value + let cached_value2 = CachedFileMetadata::new( + meta2.clone(), + Arc::new(Statistics::new_unknown(&schema)), + None, + ); + cache.put(&meta2.location, cached_value2); - fn remove(&self, k: &Path) -> Option> { - self.statistics.remove(k).map(|x| x.1.1) + // Test list_entries + let entries = cache.list_entries(); + assert_eq!(entries.len(), 1); + let entry = entries.get(&Path::from("test")).unwrap(); + assert_eq!(entry.object_meta.size, 2048); // Should be updated value } - fn contains_key(&self, k: &Path) -> bool { - self.statistics.contains_key(k) + #[derive(Clone, Debug, PartialEq, Eq, Hash)] + struct MockExpr {} + + impl std::fmt::Display for MockExpr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "MockExpr") + } } - fn len(&self) -> usize { - self.statistics.len() + impl PhysicalExpr for MockExpr { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn data_type( + &self, + _input_schema: &Schema, + ) -> datafusion_common::Result { + Ok(DataType::Int32) + } + + fn nullable(&self, _input_schema: &Schema) -> datafusion_common::Result { + Ok(false) + } + + fn evaluate( + &self, + _batch: &RecordBatch, + ) -> datafusion_common::Result { + unimplemented!() + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + assert!(children.is_empty()); + Ok(self) + } + + fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "MockExpr") + } } - fn clear(&self) { - self.statistics.clear() + fn ordering() -> LexOrdering { + let expr = Arc::new(MockExpr {}) as Arc; + LexOrdering::new(vec![PhysicalSortExpr::new_default(expr)]).unwrap() } - fn name(&self) -> String { - "DefaultFileStatisticsCache".to_string() + + #[test] + fn test_ordering_cache() { + let meta = create_test_meta("test.parquet", 100); + let cache = DefaultFileStatisticsCache::default(); + + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + // Cache statistics with no ordering + let cached_value = CachedFileMetadata::new( + meta.clone(), + Arc::new(Statistics::new_unknown(&schema)), + None, // No ordering yet + ); + cache.put(&meta.location, cached_value); + + let result = cache.get(&meta.location).unwrap(); + assert!(result.ordering.is_none()); + + // Update to add ordering + let mut cached = cache.get(&meta.location).unwrap(); + if cached.is_valid_for(&meta) && cached.ordering.is_none() { + cached.ordering = Some(ordering()); + } + cache.put(&meta.location, cached); + + let result2 = cache.get(&meta.location).unwrap(); + assert!(result2.ordering.is_some()); + + // Verify list_entries shows has_ordering = true + let entries = cache.list_entries(); + assert_eq!(entries.len(), 1); + assert!(entries.get(&meta.location).unwrap().has_ordering); } -} -#[cfg(test)] -mod tests { - use super::*; - use crate::cache::CacheAccessor; - use crate::cache::cache_manager::{FileStatisticsCache, FileStatisticsCacheEntry}; - use crate::cache::cache_unit::DefaultFileStatisticsCache; - use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; - use chrono::DateTime; - use datafusion_common::Statistics; - use datafusion_common::stats::Precision; - use object_store::ObjectMeta; - use object_store::path::Path; + #[test] + fn test_cache_invalidation_on_file_modification() { + let cache = DefaultFileStatisticsCache::default(); + let path = Path::from("test.parquet"); + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let meta_v1 = create_test_meta("test.parquet", 100); + + // Cache initial value + let cached_value = CachedFileMetadata::new( + meta_v1.clone(), + Arc::new(Statistics::new_unknown(&schema)), + None, + ); + cache.put(&path, cached_value); + + // File modified (size changed) + let meta_v2 = create_test_meta("test.parquet", 200); + + let cached = cache.get(&path).unwrap(); + // Should not be valid for new meta + assert!(!cached.is_valid_for(&meta_v2)); + + // Compute new value and update + let new_cached = CachedFileMetadata::new( + meta_v2.clone(), + Arc::new(Statistics::new_unknown(&schema)), + None, + ); + cache.put(&path, new_cached); + + // Should have new metadata + let result = cache.get(&path).unwrap(); + assert_eq!(result.meta.size, 200); + } #[test] - fn test_statistics_cache() { - let meta = ObjectMeta { - location: Path::from("test"), + fn test_ordering_cache_invalidation_on_file_modification() { + let cache = DefaultFileStatisticsCache::default(); + let path = Path::from("test.parquet"); + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + // Cache with original metadata and ordering + let meta_v1 = ObjectMeta { + location: path.clone(), last_modified: DateTime::parse_from_rfc3339("2022-09-27T22:36:00+02:00") .unwrap() .into(), - size: 1024, + size: 100, + e_tag: None, + version: None, + }; + let ordering_v1 = ordering(); + let cached_v1 = CachedFileMetadata::new( + meta_v1.clone(), + Arc::new(Statistics::new_unknown(&schema)), + Some(ordering_v1), + ); + cache.put(&path, cached_v1); + + // Verify cached ordering is valid + let cached = cache.get(&path).unwrap(); + assert!(cached.is_valid_for(&meta_v1)); + assert!(cached.ordering.is_some()); + + // File modified (size changed) + let meta_v2 = ObjectMeta { + location: path.clone(), + last_modified: DateTime::parse_from_rfc3339("2022-09-28T10:00:00+02:00") + .unwrap() + .into(), + size: 200, // Changed e_tag: None, version: None, }; + + // Cache entry exists but should be invalid for new metadata + let cached = cache.get(&path).unwrap(); + assert!(!cached.is_valid_for(&meta_v2)); + + // Cache new version with different ordering + let ordering_v2 = ordering(); // New ordering instance + let cached_v2 = CachedFileMetadata::new( + meta_v2.clone(), + Arc::new(Statistics::new_unknown(&schema)), + Some(ordering_v2), + ); + cache.put(&path, cached_v2); + + // Old metadata should be invalid + let cached = cache.get(&path).unwrap(); + assert!(!cached.is_valid_for(&meta_v1)); + + // New metadata should be valid + assert!(cached.is_valid_for(&meta_v2)); + assert!(cached.ordering.is_some()); + } + + #[test] + fn test_list_entries() { let cache = DefaultFileStatisticsCache::default(); - assert!(cache.get_with_extra(&meta.location, &meta).is_none()); - - cache.put_with_extra( - &meta.location, - Statistics::new_unknown(&Schema::new(vec![Field::new( - "test_column", - DataType::Timestamp(TimeUnit::Second, None), - false, - )])) - .into(), - &meta, + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let meta1 = create_test_meta("test1.parquet", 100); + + let cached_value = CachedFileMetadata::new( + meta1.clone(), + Arc::new(Statistics::new_unknown(&schema)), + None, + ); + cache.put(&meta1.location, cached_value); + let meta2 = create_test_meta("test2.parquet", 200); + let cached_value = CachedFileMetadata::new( + meta2.clone(), + Arc::new(Statistics::new_unknown(&schema)), + Some(ordering()), ); - assert!(cache.get_with_extra(&meta.location, &meta).is_some()); - - // file size changed - let mut meta2 = meta.clone(); - meta2.size = 2048; - assert!(cache.get_with_extra(&meta2.location, &meta2).is_none()); - - // file last_modified changed - let mut meta2 = meta.clone(); - meta2.last_modified = DateTime::parse_from_rfc3339("2022-09-27T22:40:00+02:00") - .unwrap() - .into(); - assert!(cache.get_with_extra(&meta2.location, &meta2).is_none()); - - // different file - let mut meta2 = meta.clone(); - meta2.location = Path::from("test2"); - assert!(cache.get_with_extra(&meta2.location, &meta2).is_none()); - - // test the list_entries method + cache.put(&meta2.location, cached_value); + let entries = cache.list_entries(); assert_eq!( entries, - HashMap::from([( - Path::from("test"), - FileStatisticsCacheEntry { - object_meta: meta.clone(), - num_rows: Precision::Absent, - num_columns: 1, - table_size_bytes: Precision::Absent, - statistics_size_bytes: 0, - } - )]) + HashMap::from([ + ( + Path::from("test1.parquet"), + FileStatisticsCacheEntry { + object_meta: meta1, + num_rows: Precision::Absent, + num_columns: 1, + table_size_bytes: Precision::Absent, + statistics_size_bytes: 0, + has_ordering: false, + } + ), + ( + Path::from("test2.parquet"), + FileStatisticsCacheEntry { + object_meta: meta2, + num_rows: Precision::Absent, + num_columns: 1, + table_size_bytes: Precision::Absent, + statistics_size_bytes: 0, + has_ordering: true, + } + ), + ]) ); } } diff --git a/datafusion/execution/src/cache/file_metadata_cache.rs b/datafusion/execution/src/cache/file_metadata_cache.rs index c7a24dd878e4..5e899d7dd9f8 100644 --- a/datafusion/execution/src/cache/file_metadata_cache.rs +++ b/datafusion/execution/src/cache/file_metadata_cache.rs @@ -15,22 +15,19 @@ // specific language governing permissions and limitations // under the License. -use std::{ - collections::HashMap, - sync::{Arc, Mutex}, -}; +use std::{collections::HashMap, sync::Mutex}; -use object_store::{ObjectMeta, path::Path}; +use object_store::path::Path; use crate::cache::{ CacheAccessor, - cache_manager::{FileMetadata, FileMetadataCache, FileMetadataCacheEntry}, + cache_manager::{CachedFileMetadataEntry, FileMetadataCache, FileMetadataCacheEntry}, lru_queue::LruQueue, }; /// Handles the inner state of the [`DefaultFilesMetadataCache`] struct. struct DefaultFilesMetadataCacheState { - lru_queue: LruQueue)>, + lru_queue: LruQueue, memory_limit: usize, memory_used: usize, cache_hits: HashMap, @@ -46,35 +43,18 @@ impl DefaultFilesMetadataCacheState { } } - /// Returns the respective entry from the cache, if it exists and the `size` and `last_modified` - /// properties from [`ObjectMeta`] match. + /// Returns the respective entry from the cache, if it exists. /// If the entry exists, it becomes the most recently used. - fn get(&mut self, k: &ObjectMeta) -> Option> { - self.lru_queue - .get(&k.location) - .map(|(object_meta, metadata)| { - if object_meta.size != k.size - || object_meta.last_modified != k.last_modified - { - None - } else { - *self.cache_hits.entry(k.location.clone()).or_insert(0) += 1; - Some(Arc::clone(metadata)) - } - }) - .unwrap_or(None) + fn get(&mut self, k: &Path) -> Option { + self.lru_queue.get(k).cloned().inspect(|_| { + *self.cache_hits.entry(k.clone()).or_insert(0) += 1; + }) } - /// Checks if the metadata is currently cached (entry exists and the `size` and `last_modified` - /// properties of [`ObjectMeta`] match). + /// Checks if the metadata is currently cached. /// The LRU queue is not updated. - fn contains_key(&self, k: &ObjectMeta) -> bool { - self.lru_queue - .peek(&k.location) - .map(|(object_meta, _)| { - object_meta.size == k.size && object_meta.last_modified == k.last_modified - }) - .unwrap_or(false) + fn contains_key(&self, k: &Path) -> bool { + self.lru_queue.peek(k).is_some() } /// Adds a new key-value pair to cache, meaning LRU entries might be evicted if required. @@ -82,35 +62,34 @@ impl DefaultFilesMetadataCacheState { /// If the size of the metadata is greater than the `memory_limit`, the value is not inserted. fn put( &mut self, - key: ObjectMeta, - value: Arc, - ) -> Option> { - let value_size = value.memory_size(); + key: Path, + value: CachedFileMetadataEntry, + ) -> Option { + let value_size = value.file_metadata.memory_size(); // no point in trying to add this value to the cache if it cannot fit entirely if value_size > self.memory_limit { return None; } - self.cache_hits.insert(key.location.clone(), 0); + self.cache_hits.insert(key.clone(), 0); // if the key is already in the cache, the old value is removed - let old_value = self.lru_queue.put(key.location.clone(), (key, value)); + let old_value = self.lru_queue.put(key, value); self.memory_used += value_size; - if let Some((_, ref old_metadata)) = old_value { - self.memory_used -= old_metadata.memory_size(); + if let Some(ref old_entry) = old_value { + self.memory_used -= old_entry.file_metadata.memory_size(); } self.evict_entries(); - old_value.map(|v| v.1) + old_value } /// Evicts entries from the LRU cache until `memory_used` is lower than `memory_limit`. fn evict_entries(&mut self) { while self.memory_used > self.memory_limit { if let Some(removed) = self.lru_queue.pop() { - let metadata: Arc = removed.1.1; - self.memory_used -= metadata.memory_size(); + self.memory_used -= removed.1.file_metadata.memory_size(); } else { // cache is empty while memory_used > memory_limit, cannot happen debug_assert!( @@ -123,11 +102,11 @@ impl DefaultFilesMetadataCacheState { } /// Removes an entry from the cache and returns it, if it exists. - fn remove(&mut self, k: &ObjectMeta) -> Option> { - if let Some((_, old_metadata)) = self.lru_queue.remove(&k.location) { - self.memory_used -= old_metadata.memory_size(); - self.cache_hits.remove(&k.location); - Some(old_metadata) + fn remove(&mut self, k: &Path) -> Option { + if let Some(old_entry) = self.lru_queue.remove(k) { + self.memory_used -= old_entry.file_metadata.memory_size(); + self.cache_hits.remove(k); + Some(old_entry) } else { None } @@ -150,8 +129,8 @@ impl DefaultFilesMetadataCacheState { /// /// Collected file embedded metadata cache. /// -/// The metadata for each file is invalidated when the file size or last -/// modification time have been changed. +/// The metadata for each file is validated by comparing the cached [`ObjectMeta`] +/// (size and last_modified) against the current file state using `cached.is_valid_for(¤t_meta)`. /// /// # Internal details /// @@ -160,11 +139,7 @@ impl DefaultFilesMetadataCacheState { /// size of the cached entries exceeds `memory_limit`, the least recently used entries /// are evicted until the total size is lower than `memory_limit`. /// -/// # `Extra` Handling -/// -/// Users should use the [`Self::get`] and [`Self::put`] methods. The -/// [`Self::get_with_extra`] and [`Self::put_with_extra`] methods simply call -/// `get` and `put`, respectively. +/// [`ObjectMeta`]: object_store::ObjectMeta pub struct DefaultFilesMetadataCache { // the state is wrapped in a Mutex to ensure the operations are atomic state: Mutex, @@ -189,78 +164,27 @@ impl DefaultFilesMetadataCache { } } -impl FileMetadataCache for DefaultFilesMetadataCache { - fn cache_limit(&self) -> usize { - let state = self.state.lock().unwrap(); - state.memory_limit - } - - fn update_cache_limit(&self, limit: usize) { - let mut state = self.state.lock().unwrap(); - state.memory_limit = limit; - state.evict_entries(); - } - - fn list_entries(&self) -> HashMap { - let state = self.state.lock().unwrap(); - let mut entries = HashMap::::new(); - - for (path, (object_meta, metadata)) in state.lru_queue.list_entries() { - entries.insert( - path.clone(), - FileMetadataCacheEntry { - object_meta: object_meta.clone(), - size_bytes: metadata.memory_size(), - hits: *state.cache_hits.get(path).expect("entry must exist"), - extra: metadata.extra_info(), - }, - ); - } - - entries - } -} - -impl CacheAccessor> for DefaultFilesMetadataCache { - type Extra = ObjectMeta; - - fn get(&self, k: &ObjectMeta) -> Option> { +impl CacheAccessor for DefaultFilesMetadataCache { + fn get(&self, key: &Path) -> Option { let mut state = self.state.lock().unwrap(); - state.get(k) - } - - fn get_with_extra( - &self, - k: &ObjectMeta, - _e: &Self::Extra, - ) -> Option> { - self.get(k) + state.get(key) } fn put( &self, - key: &ObjectMeta, - value: Arc, - ) -> Option> { + key: &Path, + value: CachedFileMetadataEntry, + ) -> Option { let mut state = self.state.lock().unwrap(); state.put(key.clone(), value) } - fn put_with_extra( - &self, - key: &ObjectMeta, - value: Arc, - _e: &Self::Extra, - ) -> Option> { - self.put(key, value) - } - - fn remove(&self, k: &ObjectMeta) -> Option> { + fn remove(&self, k: &Path) -> Option { let mut state = self.state.lock().unwrap(); state.remove(k) } - fn contains_key(&self, k: &ObjectMeta) -> bool { + fn contains_key(&self, k: &Path) -> bool { let state = self.state.lock().unwrap(); state.contains_key(k) } @@ -280,6 +204,38 @@ impl CacheAccessor> for DefaultFilesMetadataCa } } +impl FileMetadataCache for DefaultFilesMetadataCache { + fn cache_limit(&self) -> usize { + let state = self.state.lock().unwrap(); + state.memory_limit + } + + fn update_cache_limit(&self, limit: usize) { + let mut state = self.state.lock().unwrap(); + state.memory_limit = limit; + state.evict_entries(); + } + + fn list_entries(&self) -> HashMap { + let state = self.state.lock().unwrap(); + let mut entries = HashMap::::new(); + + for (path, entry) in state.lru_queue.list_entries() { + entries.insert( + path.clone(), + FileMetadataCacheEntry { + object_meta: entry.meta.clone(), + size_bytes: entry.file_metadata.memory_size(), + hits: *state.cache_hits.get(path).expect("entry must exist"), + extra: entry.file_metadata.extra_info(), + }, + ); + } + + entries + } +} + #[cfg(test)] mod tests { use std::collections::HashMap; @@ -287,7 +243,7 @@ mod tests { use crate::cache::CacheAccessor; use crate::cache::cache_manager::{ - FileMetadata, FileMetadataCache, FileMetadataCacheEntry, + CachedFileMetadataEntry, FileMetadata, FileMetadataCache, FileMetadataCacheEntry, }; use crate::cache::file_metadata_cache::DefaultFilesMetadataCache; use object_store::ObjectMeta; @@ -311,67 +267,77 @@ mod tests { } } - #[test] - fn test_default_file_metadata_cache() { - let object_meta = ObjectMeta { - location: Path::from("test"), + fn create_test_object_meta(path: &str, size: usize) -> ObjectMeta { + ObjectMeta { + location: Path::from(path), last_modified: chrono::DateTime::parse_from_rfc3339( "2025-07-29T12:12:12+00:00", ) .unwrap() .into(), - size: 1024, + size: size as u64, e_tag: None, version: None, - }; + } + } + + #[test] + fn test_default_file_metadata_cache() { + let object_meta = create_test_object_meta("test", 1024); let metadata: Arc = Arc::new(TestFileMetadata { metadata: "retrieved_metadata".to_owned(), }); let cache = DefaultFilesMetadataCache::new(1024 * 1024); - assert!(cache.get(&object_meta).is_none()); - // put - cache.put(&object_meta, Arc::clone(&metadata)); + // Cache miss + assert!(cache.get(&object_meta.location).is_none()); + + // Put a value + let cached_entry = + CachedFileMetadataEntry::new(object_meta.clone(), Arc::clone(&metadata)); + cache.put(&object_meta.location, cached_entry); - // get and contains of a valid entry - assert!(cache.contains_key(&object_meta)); - let value = cache.get(&object_meta); - assert!(value.is_some()); - let test_file_metadata = Arc::downcast::(value.unwrap()); + // Verify the cached value + assert!(cache.contains_key(&object_meta.location)); + let result = cache.get(&object_meta.location).unwrap(); + let test_file_metadata = Arc::downcast::(result.file_metadata); assert!(test_file_metadata.is_ok()); assert_eq!(test_file_metadata.unwrap().metadata, "retrieved_metadata"); - // file size changed - let mut object_meta2 = object_meta.clone(); - object_meta2.size = 2048; - assert!(cache.get(&object_meta2).is_none()); - assert!(!cache.contains_key(&object_meta2)); - - // file last_modified changed - let mut object_meta2 = object_meta.clone(); - object_meta2.last_modified = - chrono::DateTime::parse_from_rfc3339("2025-07-29T13:13:13+00:00") - .unwrap() - .into(); - assert!(cache.get(&object_meta2).is_none()); - assert!(!cache.contains_key(&object_meta2)); - - // different file - let mut object_meta2 = object_meta.clone(); - object_meta2.location = Path::from("test2"); - assert!(cache.get(&object_meta2).is_none()); - assert!(!cache.contains_key(&object_meta2)); + // Cache hit - check validation + let result2 = cache.get(&object_meta.location).unwrap(); + assert!(result2.is_valid_for(&object_meta)); + + // File size changed - closure should detect invalidity + let object_meta2 = create_test_object_meta("test", 2048); + let result3 = cache.get(&object_meta2.location).unwrap(); + // Cached entry should NOT be valid for new meta + assert!(!result3.is_valid_for(&object_meta2)); + + // Return new entry + let new_entry = + CachedFileMetadataEntry::new(object_meta2.clone(), Arc::clone(&metadata)); + cache.put(&object_meta2.location, new_entry); + + let result4 = cache.get(&object_meta2.location).unwrap(); + assert_eq!(result4.meta.size, 2048); // remove - cache.remove(&object_meta); - assert!(cache.get(&object_meta).is_none()); - assert!(!cache.contains_key(&object_meta)); + cache.remove(&object_meta.location); + assert!(!cache.contains_key(&object_meta.location)); // len and clear - cache.put(&object_meta, Arc::clone(&metadata)); - cache.put(&object_meta2, metadata); + let object_meta3 = create_test_object_meta("test3", 100); + cache.put( + &object_meta.location, + CachedFileMetadataEntry::new(object_meta.clone(), Arc::clone(&metadata)), + ); + cache.put( + &object_meta3.location, + CachedFileMetadataEntry::new(object_meta3.clone(), Arc::clone(&metadata)), + ); assert_eq!(cache.len(), 2); cache.clear(); assert_eq!(cache.len(), 0); @@ -402,92 +368,129 @@ mod tests { let (object_meta2, metadata2) = generate_test_metadata_with_size("2", 500); let (object_meta3, metadata3) = generate_test_metadata_with_size("3", 300); - cache.put(&object_meta1, metadata1); - cache.put(&object_meta2, metadata2); - cache.put(&object_meta3, metadata3); + cache.put( + &object_meta1.location, + CachedFileMetadataEntry::new(object_meta1.clone(), metadata1), + ); + cache.put( + &object_meta2.location, + CachedFileMetadataEntry::new(object_meta2.clone(), metadata2), + ); + cache.put( + &object_meta3.location, + CachedFileMetadataEntry::new(object_meta3.clone(), metadata3), + ); // all entries will fit assert_eq!(cache.len(), 3); assert_eq!(cache.memory_used(), 900); - assert!(cache.contains_key(&object_meta1)); - assert!(cache.contains_key(&object_meta2)); - assert!(cache.contains_key(&object_meta3)); + assert!(cache.contains_key(&object_meta1.location)); + assert!(cache.contains_key(&object_meta2.location)); + assert!(cache.contains_key(&object_meta3.location)); // add a new entry which will remove the least recently used ("1") let (object_meta4, metadata4) = generate_test_metadata_with_size("4", 200); - cache.put(&object_meta4, metadata4); + cache.put( + &object_meta4.location, + CachedFileMetadataEntry::new(object_meta4.clone(), metadata4), + ); assert_eq!(cache.len(), 3); assert_eq!(cache.memory_used(), 1000); - assert!(!cache.contains_key(&object_meta1)); - assert!(cache.contains_key(&object_meta4)); + assert!(!cache.contains_key(&object_meta1.location)); + assert!(cache.contains_key(&object_meta4.location)); // get entry "2", which will move it to the top of the queue, and add a new one which will // remove the new least recently used ("3") - cache.get(&object_meta2); + let _ = cache.get(&object_meta2.location); let (object_meta5, metadata5) = generate_test_metadata_with_size("5", 100); - cache.put(&object_meta5, metadata5); + cache.put( + &object_meta5.location, + CachedFileMetadataEntry::new(object_meta5.clone(), metadata5), + ); assert_eq!(cache.len(), 3); assert_eq!(cache.memory_used(), 800); - assert!(!cache.contains_key(&object_meta3)); - assert!(cache.contains_key(&object_meta5)); + assert!(!cache.contains_key(&object_meta3.location)); + assert!(cache.contains_key(&object_meta5.location)); // new entry which will not be able to fit in the 1000 bytes allocated let (object_meta6, metadata6) = generate_test_metadata_with_size("6", 1200); - cache.put(&object_meta6, metadata6); + cache.put( + &object_meta6.location, + CachedFileMetadataEntry::new(object_meta6.clone(), metadata6), + ); assert_eq!(cache.len(), 3); assert_eq!(cache.memory_used(), 800); - assert!(!cache.contains_key(&object_meta6)); + assert!(!cache.contains_key(&object_meta6.location)); // new entry which is able to fit without removing any entry let (object_meta7, metadata7) = generate_test_metadata_with_size("7", 200); - cache.put(&object_meta7, metadata7); + cache.put( + &object_meta7.location, + CachedFileMetadataEntry::new(object_meta7.clone(), metadata7), + ); assert_eq!(cache.len(), 4); assert_eq!(cache.memory_used(), 1000); - assert!(cache.contains_key(&object_meta7)); + assert!(cache.contains_key(&object_meta7.location)); // new entry which will remove all other entries let (object_meta8, metadata8) = generate_test_metadata_with_size("8", 999); - cache.put(&object_meta8, metadata8); + cache.put( + &object_meta8.location, + CachedFileMetadataEntry::new(object_meta8.clone(), metadata8), + ); assert_eq!(cache.len(), 1); assert_eq!(cache.memory_used(), 999); - assert!(cache.contains_key(&object_meta8)); + assert!(cache.contains_key(&object_meta8.location)); // when updating an entry, the previous ones are not unnecessarily removed let (object_meta9, metadata9) = generate_test_metadata_with_size("9", 300); let (object_meta10, metadata10) = generate_test_metadata_with_size("10", 200); let (object_meta11_v1, metadata11_v1) = generate_test_metadata_with_size("11", 400); - cache.put(&object_meta9, metadata9); - cache.put(&object_meta10, metadata10); - cache.put(&object_meta11_v1, metadata11_v1); + cache.put( + &object_meta9.location, + CachedFileMetadataEntry::new(object_meta9.clone(), metadata9), + ); + cache.put( + &object_meta10.location, + CachedFileMetadataEntry::new(object_meta10.clone(), metadata10), + ); + cache.put( + &object_meta11_v1.location, + CachedFileMetadataEntry::new(object_meta11_v1.clone(), metadata11_v1), + ); assert_eq!(cache.memory_used(), 900); assert_eq!(cache.len(), 3); let (object_meta11_v2, metadata11_v2) = generate_test_metadata_with_size("11", 500); - cache.put(&object_meta11_v2, metadata11_v2); + cache.put( + &object_meta11_v2.location, + CachedFileMetadataEntry::new(object_meta11_v2.clone(), metadata11_v2), + ); assert_eq!(cache.memory_used(), 1000); assert_eq!(cache.len(), 3); - assert!(cache.contains_key(&object_meta9)); - assert!(cache.contains_key(&object_meta10)); - assert!(cache.contains_key(&object_meta11_v2)); - assert!(!cache.contains_key(&object_meta11_v1)); + assert!(cache.contains_key(&object_meta9.location)); + assert!(cache.contains_key(&object_meta10.location)); + assert!(cache.contains_key(&object_meta11_v2.location)); // when updating an entry that now exceeds the limit, the LRU ("9") needs to be removed let (object_meta11_v3, metadata11_v3) = generate_test_metadata_with_size("11", 501); - cache.put(&object_meta11_v3, metadata11_v3); + cache.put( + &object_meta11_v3.location, + CachedFileMetadataEntry::new(object_meta11_v3.clone(), metadata11_v3), + ); assert_eq!(cache.memory_used(), 701); assert_eq!(cache.len(), 2); - assert!(cache.contains_key(&object_meta10)); - assert!(cache.contains_key(&object_meta11_v3)); - assert!(!cache.contains_key(&object_meta11_v2)); + assert!(cache.contains_key(&object_meta10.location)); + assert!(cache.contains_key(&object_meta11_v3.location)); // manually removing an entry that is not the LRU - cache.remove(&object_meta11_v3); + cache.remove(&object_meta11_v3.location); assert_eq!(cache.len(), 1); assert_eq!(cache.memory_used(), 200); - assert!(cache.contains_key(&object_meta10)); - assert!(!cache.contains_key(&object_meta11_v3)); + assert!(cache.contains_key(&object_meta10.location)); + assert!(!cache.contains_key(&object_meta11_v3.location)); // clear cache.clear(); @@ -498,17 +501,26 @@ mod tests { let (object_meta12, metadata12) = generate_test_metadata_with_size("12", 300); let (object_meta13, metadata13) = generate_test_metadata_with_size("13", 200); let (object_meta14, metadata14) = generate_test_metadata_with_size("14", 500); - cache.put(&object_meta12, metadata12); - cache.put(&object_meta13, metadata13); - cache.put(&object_meta14, metadata14); + cache.put( + &object_meta12.location, + CachedFileMetadataEntry::new(object_meta12.clone(), metadata12), + ); + cache.put( + &object_meta13.location, + CachedFileMetadataEntry::new(object_meta13.clone(), metadata13), + ); + cache.put( + &object_meta14.location, + CachedFileMetadataEntry::new(object_meta14.clone(), metadata14), + ); assert_eq!(cache.len(), 3); assert_eq!(cache.memory_used(), 1000); cache.update_cache_limit(600); assert_eq!(cache.len(), 1); assert_eq!(cache.memory_used(), 500); - assert!(!cache.contains_key(&object_meta12)); - assert!(!cache.contains_key(&object_meta13)); - assert!(cache.contains_key(&object_meta14)); + assert!(!cache.contains_key(&object_meta12.location)); + assert!(!cache.contains_key(&object_meta13.location)); + assert!(cache.contains_key(&object_meta14.location)); } #[test] @@ -519,9 +531,18 @@ mod tests { let (object_meta3, metadata3) = generate_test_metadata_with_size("3", 300); // initial entries, all will have hits = 0 - cache.put(&object_meta1, metadata1); - cache.put(&object_meta2, metadata2); - cache.put(&object_meta3, metadata3); + cache.put( + &object_meta1.location, + CachedFileMetadataEntry::new(object_meta1.clone(), metadata1), + ); + cache.put( + &object_meta2.location, + CachedFileMetadataEntry::new(object_meta2.clone(), metadata2), + ); + cache.put( + &object_meta3.location, + CachedFileMetadataEntry::new(object_meta3.clone(), metadata3), + ); assert_eq!( cache.list_entries(), HashMap::from([ @@ -565,7 +586,7 @@ mod tests { ); // new hit on "1" - cache.get(&object_meta1); + let _ = cache.get(&object_meta1.location); assert_eq!( cache.list_entries(), HashMap::from([ @@ -610,7 +631,10 @@ mod tests { // new entry, will evict "2" let (object_meta4, metadata4) = generate_test_metadata_with_size("4", 600); - cache.put(&object_meta4, metadata4); + cache.put( + &object_meta4.location, + CachedFileMetadataEntry::new(object_meta4.clone(), metadata4), + ); assert_eq!( cache.list_entries(), HashMap::from([ @@ -655,7 +679,10 @@ mod tests { // replace entry "1" let (object_meta1_new, metadata1_new) = generate_test_metadata_with_size("1", 50); - cache.put(&object_meta1_new, metadata1_new); + cache.put( + &object_meta1_new.location, + CachedFileMetadataEntry::new(object_meta1_new.clone(), metadata1_new), + ); assert_eq!( cache.list_entries(), HashMap::from([ @@ -699,7 +726,7 @@ mod tests { ); // remove entry "4" - cache.remove(&object_meta4); + cache.remove(&object_meta4.location); assert_eq!( cache.list_entries(), HashMap::from([ diff --git a/datafusion/execution/src/cache/list_files_cache.rs b/datafusion/execution/src/cache/list_files_cache.rs index 661bc47b5468..b1b8e6b50016 100644 --- a/datafusion/execution/src/cache/list_files_cache.rs +++ b/datafusion/execution/src/cache/list_files_cache.rs @@ -17,14 +17,20 @@ use std::mem::size_of; use std::{ + collections::HashMap, sync::{Arc, Mutex}, time::Duration, }; +use datafusion_common::TableReference; use datafusion_common::instant::Instant; use object_store::{ObjectMeta, path::Path}; -use crate::cache::{CacheAccessor, cache_manager::ListFilesCache, lru_queue::LruQueue}; +use crate::cache::{ + CacheAccessor, + cache_manager::{CachedFileList, ListFilesCache}, + lru_queue::LruQueue, +}; pub trait TimeProvider: Send + Sync + 'static { fn now(&self) -> Instant; @@ -50,11 +56,10 @@ impl TimeProvider for SystemTimeProvider { /// the cache exceeds `memory_limit`, the least recently used entries are evicted until the total /// size is lower than the `memory_limit`. /// -/// # `Extra` Handling +/// # Cache API /// -/// Users should use the [`Self::get`] and [`Self::put`] methods. The -/// [`Self::get_with_extra`] and [`Self::put_with_extra`] methods simply call -/// `get` and `put`, respectively. +/// Uses `get` and `put` methods for cache operations. TTL validation is handled internally - +/// expired entries return `None` from `get`. pub struct DefaultListFilesCache { state: Mutex, time_provider: Arc, @@ -84,42 +89,30 @@ impl DefaultListFilesCache { self.time_provider = provider; self } - - /// Returns the cache's memory limit in bytes. - pub fn cache_limit(&self) -> usize { - self.state.lock().unwrap().memory_limit - } - - /// Updates the cache with a new memory limit in bytes. - pub fn update_cache_limit(&self, limit: usize) { - let mut state = self.state.lock().unwrap(); - state.memory_limit = limit; - state.evict_entries(); - } - - /// Returns the TTL (time-to-live) applied to cache entries. - pub fn cache_ttl(&self) -> Option { - self.state.lock().unwrap().ttl - } } -struct ListFilesEntry { - metas: Arc>, - size_bytes: usize, - expires: Option, +#[derive(Clone, PartialEq, Debug)] +pub struct ListFilesEntry { + pub metas: CachedFileList, + pub size_bytes: usize, + pub expires: Option, } impl ListFilesEntry { fn try_new( - metas: Arc>, + cached_file_list: CachedFileList, ttl: Option, now: Instant, ) -> Option { - let size_bytes = (metas.capacity() * size_of::()) - + metas.iter().map(meta_heap_bytes).reduce(|acc, b| acc + b)?; + let size_bytes = (cached_file_list.files.capacity() * size_of::()) + + cached_file_list + .files + .iter() + .map(meta_heap_bytes) + .reduce(|acc, b| acc + b)?; Some(Self { - metas, + metas: cached_file_list, size_bytes, expires: ttl.map(|t| now + t), }) @@ -146,9 +139,20 @@ pub const DEFAULT_LIST_FILES_CACHE_MEMORY_LIMIT: usize = 1024 * 1024; // 1MiB /// The default cache TTL for the [`DefaultListFilesCache`] pub const DEFAULT_LIST_FILES_CACHE_TTL: Option = None; // Infinite +/// Key for [`DefaultListFilesCache`] +/// +/// Each entry is scoped to its use within a specific table so that the cache +/// can differentiate between identical paths in different tables, and +/// table-level cache invalidation. +#[derive(PartialEq, Eq, Hash, Clone, Debug)] +pub struct TableScopedPath { + pub table: Option, + pub path: Path, +} + /// Handles the inner state of the [`DefaultListFilesCache`] struct. pub struct DefaultListFilesCacheState { - lru_queue: LruQueue, + lru_queue: LruQueue, memory_limit: usize, memory_used: usize, ttl: Option, @@ -175,65 +179,22 @@ impl DefaultListFilesCacheState { } } - /// Performs a prefix-aware cache lookup. - /// - /// # Arguments - /// * `table_base` - The table's base path (the cache key) - /// * `prefix` - Optional prefix filter relative to the table base path - /// * `now` - Current time for expiration checking - /// - /// # Behavior - /// - Fetches the cache entry for `table_base` - /// - If `prefix` is `Some`, filters results to only files matching `table_base/prefix` - /// - Returns the (potentially filtered) results + /// Gets an entry from the cache, checking for expiration. /// - /// # Example - /// ```text - /// get_with_prefix("my_table", Some("a=1"), now) - /// → Fetch cache entry for "my_table" - /// → Filter to files matching "my_table/a=1/*" - /// → Return filtered results - /// ``` - fn get_with_prefix( - &mut self, - table_base: &Path, - prefix: Option<&Path>, - now: Instant, - ) -> Option>> { - let entry = self.lru_queue.get(table_base)?; + /// Returns the cached file list if it exists and hasn't expired. + /// If the entry has expired, it is removed from the cache. + fn get(&mut self, key: &TableScopedPath, now: Instant) -> Option { + let entry = self.lru_queue.get(key)?; // Check expiration if let Some(exp) = entry.expires && now > exp { - self.remove(table_base); + self.remove(key); return None; } - // Early return if no prefix filter - return all files - let Some(prefix) = prefix else { - return Some(Arc::clone(&entry.metas)); - }; - - // Build the full prefix path: table_base/prefix - let mut parts: Vec<_> = table_base.parts().collect(); - parts.extend(prefix.parts()); - let full_prefix = Path::from_iter(parts); - let full_prefix_str = full_prefix.as_ref(); - - // Filter files to only those matching the prefix - let filtered: Vec = entry - .metas - .iter() - .filter(|meta| meta.location.as_ref().starts_with(full_prefix_str)) - .cloned() - .collect(); - - if filtered.is_empty() { - None - } else { - Some(Arc::new(filtered)) - } + Some(entry.metas.clone()) } /// Checks if the respective entry is currently cached. @@ -241,7 +202,7 @@ impl DefaultListFilesCacheState { /// If the entry has expired by `now` it is removed from the cache. /// /// The LRU queue is not updated. - fn contains_key(&mut self, k: &Path, now: Instant) -> bool { + fn contains_key(&mut self, k: &TableScopedPath, now: Instant) -> bool { let Some(entry) = self.lru_queue.peek(k) else { return false; }; @@ -262,10 +223,10 @@ impl DefaultListFilesCacheState { /// If the size of the entry is greater than the `memory_limit`, the value is not inserted. fn put( &mut self, - key: &Path, - value: Arc>, + key: &TableScopedPath, + value: CachedFileList, now: Instant, - ) -> Option>> { + ) -> Option { let entry = ListFilesEntry::try_new(value, self.ttl, now)?; let entry_size = entry.size_bytes; @@ -304,7 +265,7 @@ impl DefaultListFilesCacheState { } /// Removes an entry from the cache and returns it, if it exists. - fn remove(&mut self, k: &Path) -> Option>> { + fn remove(&mut self, k: &TableScopedPath) -> Option { if let Some(entry) = self.lru_queue.remove(k) { self.memory_used -= entry.size_bytes; Some(entry.metas) @@ -325,88 +286,29 @@ impl DefaultListFilesCacheState { } } -impl ListFilesCache for DefaultListFilesCache { - fn cache_limit(&self) -> usize { - let state = self.state.lock().unwrap(); - state.memory_limit - } - - fn cache_ttl(&self) -> Option { - let state = self.state.lock().unwrap(); - state.ttl - } - - fn update_cache_limit(&self, limit: usize) { - let mut state = self.state.lock().unwrap(); - state.memory_limit = limit; - state.evict_entries(); - } - - fn update_cache_ttl(&self, ttl: Option) { - let mut state = self.state.lock().unwrap(); - state.ttl = ttl; - state.evict_entries(); - } -} - -impl CacheAccessor>> for DefaultListFilesCache { - type Extra = Option; - - /// Gets all files for the given table base path. - /// - /// This is equivalent to calling `get_with_extra(k, &None)`. - fn get(&self, k: &Path) -> Option>> { - self.get_with_extra(k, &None) - } - - /// Performs a prefix-aware cache lookup. - /// - /// # Arguments - /// * `table_base` - The table's base path (the cache key) - /// * `prefix` - Optional prefix filter (relative to table base) for partition filtering - /// - /// # Behavior - /// - Fetches the cache entry for `table_base` - /// - If `prefix` is `Some`, filters results to only files matching `table_base/prefix` - /// - Returns the (potentially filtered) results - /// - /// This enables efficient partition pruning - a single cached listing of the full table - /// can serve queries for any partition subset without additional storage calls. - fn get_with_extra( - &self, - table_base: &Path, - prefix: &Self::Extra, - ) -> Option>> { +impl CacheAccessor for DefaultListFilesCache { + fn get(&self, key: &TableScopedPath) -> Option { let mut state = self.state.lock().unwrap(); let now = self.time_provider.now(); - state.get_with_prefix(table_base, prefix.as_ref(), now) + state.get(key, now) } fn put( &self, - key: &Path, - value: Arc>, - ) -> Option>> { + key: &TableScopedPath, + value: CachedFileList, + ) -> Option { let mut state = self.state.lock().unwrap(); let now = self.time_provider.now(); state.put(key, value, now) } - fn put_with_extra( - &self, - key: &Path, - value: Arc>, - _e: &Self::Extra, - ) -> Option>> { - self.put(key, value) - } - - fn remove(&self, k: &Path) -> Option>> { + fn remove(&self, k: &TableScopedPath) -> Option { let mut state = self.state.lock().unwrap(); state.remove(k) } - fn contains_key(&self, k: &Path) -> bool { + fn contains_key(&self, k: &TableScopedPath) -> bool { let mut state = self.state.lock().unwrap(); let now = self.time_provider.now(); state.contains_key(k, now) @@ -427,6 +329,56 @@ impl CacheAccessor>> for DefaultListFilesCache { } } +impl ListFilesCache for DefaultListFilesCache { + fn cache_limit(&self) -> usize { + let state = self.state.lock().unwrap(); + state.memory_limit + } + + fn cache_ttl(&self) -> Option { + let state = self.state.lock().unwrap(); + state.ttl + } + + fn update_cache_limit(&self, limit: usize) { + let mut state = self.state.lock().unwrap(); + state.memory_limit = limit; + state.evict_entries(); + } + + fn update_cache_ttl(&self, ttl: Option) { + let mut state = self.state.lock().unwrap(); + state.ttl = ttl; + state.evict_entries(); + } + + fn list_entries(&self) -> HashMap { + let state = self.state.lock().unwrap(); + let mut entries = HashMap::::new(); + for (path, entry) in state.lru_queue.list_entries() { + entries.insert(path.clone(), entry.clone()); + } + entries + } + + fn drop_table_entries( + &self, + table_ref: &Option, + ) -> datafusion_common::Result<()> { + let mut state = self.state.lock().unwrap(); + let mut table_paths = vec![]; + for (path, _) in state.lru_queue.list_entries() { + if path.table == *table_ref { + table_paths.push(path.clone()); + } + } + for path in table_paths { + state.remove(&path); + } + Ok(()) + } +} + #[cfg(test)] mod tests { use super::*; @@ -478,64 +430,99 @@ mod tests { } } - /// Helper function to create a vector of ObjectMeta with at least meta_size bytes + /// Helper function to create a CachedFileList with at least meta_size bytes fn create_test_list_files_entry( path: &str, count: usize, meta_size: usize, - ) -> (Path, Arc>, usize) { + ) -> (Path, CachedFileList, usize) { let metas: Vec = (0..count) .map(|i| create_test_object_meta(&format!("file{i}"), meta_size)) .collect(); - let metas = Arc::new(metas); // Calculate actual size using the same logic as ListFilesEntry::try_new let size = (metas.capacity() * size_of::()) + metas.iter().map(meta_heap_bytes).sum::(); - (Path::from(path), metas, size) + (Path::from(path), CachedFileList::new(metas), size) } #[test] fn test_basic_operations() { let cache = DefaultListFilesCache::default(); + let table_ref = Some(TableReference::from("table")); let path = Path::from("test_path"); + let key = TableScopedPath { + table: table_ref.clone(), + path, + }; // Initially cache is empty - assert!(cache.get(&path).is_none()); - assert!(!cache.contains_key(&path)); + assert!(!cache.contains_key(&key)); assert_eq!(cache.len(), 0); - // Put an entry + // Cache miss - get returns None + assert!(cache.get(&key).is_none()); + + // Put a value let meta = create_test_object_meta("file1", 50); - let value = Arc::new(vec![meta.clone()]); - cache.put(&path, Arc::clone(&value)); + cache.put(&key, CachedFileList::new(vec![meta])); - // Entry should be retrievable - assert!(cache.contains_key(&path)); + // Entry should be cached + assert!(cache.contains_key(&key)); assert_eq!(cache.len(), 1); - let retrieved = cache.get(&path).unwrap(); - assert_eq!(retrieved.len(), 1); - assert_eq!(retrieved[0].location, meta.location); + let result = cache.get(&key).unwrap(); + assert_eq!(result.files.len(), 1); // Remove the entry - let removed = cache.remove(&path).unwrap(); - assert_eq!(removed.len(), 1); - assert!(!cache.contains_key(&path)); + let removed = cache.remove(&key).unwrap(); + assert_eq!(removed.files.len(), 1); + assert!(!cache.contains_key(&key)); assert_eq!(cache.len(), 0); // Put multiple entries - let (path1, value1, _) = create_test_list_files_entry("path1", 2, 50); - let (path2, value2, _) = create_test_list_files_entry("path2", 3, 50); - cache.put(&path1, value1); - cache.put(&path2, value2); + let (path1, value1, size1) = create_test_list_files_entry("path1", 2, 50); + let (path2, value2, size2) = create_test_list_files_entry("path2", 3, 50); + let key1 = TableScopedPath { + table: table_ref.clone(), + path: path1, + }; + let key2 = TableScopedPath { + table: table_ref, + path: path2, + }; + cache.put(&key1, value1.clone()); + cache.put(&key2, value2.clone()); assert_eq!(cache.len(), 2); + // List cache entries + assert_eq!( + cache.list_entries(), + HashMap::from([ + ( + key1.clone(), + ListFilesEntry { + metas: value1, + size_bytes: size1, + expires: None, + } + ), + ( + key2.clone(), + ListFilesEntry { + metas: value2, + size_bytes: size2, + expires: None, + } + ) + ]) + ); + // Clear all entries cache.clear(); assert_eq!(cache.len(), 0); - assert!(!cache.contains_key(&path1)); - assert!(!cache.contains_key(&path2)); + assert!(!cache.contains_key(&key1)); + assert!(!cache.contains_key(&key2)); } #[test] @@ -547,24 +534,42 @@ mod tests { // Set cache limit to exactly fit all three entries let cache = DefaultListFilesCache::new(size * 3, None); + let table_ref = Some(TableReference::from("table")); + let key1 = TableScopedPath { + table: table_ref.clone(), + path: path1, + }; + let key2 = TableScopedPath { + table: table_ref.clone(), + path: path2, + }; + let key3 = TableScopedPath { + table: table_ref.clone(), + path: path3, + }; + // All three entries should fit - cache.put(&path1, value1); - cache.put(&path2, value2); - cache.put(&path3, value3); + cache.put(&key1, value1); + cache.put(&key2, value2); + cache.put(&key3, value3); assert_eq!(cache.len(), 3); - assert!(cache.contains_key(&path1)); - assert!(cache.contains_key(&path2)); - assert!(cache.contains_key(&path3)); + assert!(cache.contains_key(&key1)); + assert!(cache.contains_key(&key2)); + assert!(cache.contains_key(&key3)); // Adding a new entry should evict path1 (LRU) let (path4, value4, _) = create_test_list_files_entry("path4", 1, 100); - cache.put(&path4, value4); + let key4 = TableScopedPath { + table: table_ref, + path: path4, + }; + cache.put(&key4, value4); assert_eq!(cache.len(), 3); - assert!(!cache.contains_key(&path1)); // Evicted - assert!(cache.contains_key(&path2)); - assert!(cache.contains_key(&path3)); - assert!(cache.contains_key(&path4)); + assert!(!cache.contains_key(&key1)); // Evicted + assert!(cache.contains_key(&key2)); + assert!(cache.contains_key(&key3)); + assert!(cache.contains_key(&key4)); } #[test] @@ -576,24 +581,42 @@ mod tests { // Set cache limit to fit exactly three entries let cache = DefaultListFilesCache::new(size * 3, None); - cache.put(&path1, value1); - cache.put(&path2, value2); - cache.put(&path3, value3); + let table_ref = Some(TableReference::from("table")); + let key1 = TableScopedPath { + table: table_ref.clone(), + path: path1, + }; + let key2 = TableScopedPath { + table: table_ref.clone(), + path: path2, + }; + let key3 = TableScopedPath { + table: table_ref.clone(), + path: path3, + }; + + cache.put(&key1, value1); + cache.put(&key2, value2); + cache.put(&key3, value3); assert_eq!(cache.len(), 3); // Access path1 to move it to front (MRU) // Order is now: path2 (LRU), path3, path1 (MRU) - cache.get(&path1); + let _ = cache.get(&key1); // Adding a new entry should evict path2 (the LRU) let (path4, value4, _) = create_test_list_files_entry("path4", 1, 100); - cache.put(&path4, value4); + let key4 = TableScopedPath { + table: table_ref, + path: path4, + }; + cache.put(&key4, value4); assert_eq!(cache.len(), 3); - assert!(cache.contains_key(&path1)); // Still present (recently accessed) - assert!(!cache.contains_key(&path2)); // Evicted (was LRU) - assert!(cache.contains_key(&path3)); - assert!(cache.contains_key(&path4)); + assert!(cache.contains_key(&key1)); // Still present (recently accessed) + assert!(!cache.contains_key(&key2)); // Evicted (was LRU) + assert!(cache.contains_key(&key3)); + assert!(cache.contains_key(&key4)); } #[test] @@ -604,19 +627,33 @@ mod tests { // Set cache limit to fit both entries let cache = DefaultListFilesCache::new(size * 2, None); - cache.put(&path1, value1); - cache.put(&path2, value2); + let table_ref = Some(TableReference::from("table")); + let key1 = TableScopedPath { + table: table_ref.clone(), + path: path1, + }; + let key2 = TableScopedPath { + table: table_ref.clone(), + path: path2, + }; + cache.put(&key1, value1); + cache.put(&key2, value2); assert_eq!(cache.len(), 2); // Try to add an entry that's too large to fit in the cache + // The entry is not stored (too large) let (path_large, value_large, _) = create_test_list_files_entry("large", 1, 1000); - cache.put(&path_large, value_large); + let key_large = TableScopedPath { + table: table_ref, + path: path_large, + }; + cache.put(&key_large, value_large); // Large entry should not be added - assert!(!cache.contains_key(&path_large)); + assert!(!cache.contains_key(&key_large)); assert_eq!(cache.len(), 2); - assert!(cache.contains_key(&path1)); - assert!(cache.contains_key(&path2)); + assert!(cache.contains_key(&key1)); + assert!(cache.contains_key(&key2)); } #[test] @@ -628,21 +665,38 @@ mod tests { // Set cache limit for exactly 3 entries let cache = DefaultListFilesCache::new(size * 3, None); - cache.put(&path1, value1); - cache.put(&path2, value2); - cache.put(&path3, value3); + let table_ref = Some(TableReference::from("table")); + let key1 = TableScopedPath { + table: table_ref.clone(), + path: path1, + }; + let key2 = TableScopedPath { + table: table_ref.clone(), + path: path2, + }; + let key3 = TableScopedPath { + table: table_ref.clone(), + path: path3, + }; + cache.put(&key1, value1); + cache.put(&key2, value2); + cache.put(&key3, value3); assert_eq!(cache.len(), 3); // Add a large entry that requires evicting 2 entries let (path_large, value_large, _) = create_test_list_files_entry("large", 1, 200); - cache.put(&path_large, value_large); + let key_large = TableScopedPath { + table: table_ref, + path: path_large, + }; + cache.put(&key_large, value_large); // path1 and path2 should be evicted (both LRU), path3 and path_large remain assert_eq!(cache.len(), 2); - assert!(!cache.contains_key(&path1)); // Evicted - assert!(!cache.contains_key(&path2)); // Evicted - assert!(cache.contains_key(&path3)); - assert!(cache.contains_key(&path_large)); + assert!(!cache.contains_key(&key1)); // Evicted + assert!(!cache.contains_key(&key2)); // Evicted + assert!(cache.contains_key(&key3)); + assert!(cache.contains_key(&key_large)); } #[test] @@ -653,10 +707,23 @@ mod tests { let cache = DefaultListFilesCache::new(size * 3, None); + let table_ref = Some(TableReference::from("table")); + let key1 = TableScopedPath { + table: table_ref.clone(), + path: path1, + }; + let key2 = TableScopedPath { + table: table_ref.clone(), + path: path2, + }; + let key3 = TableScopedPath { + table: table_ref, + path: path3, + }; // Add three entries - cache.put(&path1, value1); - cache.put(&path2, value2); - cache.put(&path3, value3); + cache.put(&key1, value1); + cache.put(&key2, value2); + cache.put(&key3, value3); assert_eq!(cache.len(), 3); // Resize cache to only fit one entry @@ -664,71 +731,137 @@ mod tests { // Should keep only the most recent entry (path3, the MRU) assert_eq!(cache.len(), 1); - assert!(cache.contains_key(&path3)); + assert!(cache.contains_key(&key3)); // Earlier entries (LRU) should be evicted - assert!(!cache.contains_key(&path1)); - assert!(!cache.contains_key(&path2)); + assert!(!cache.contains_key(&key1)); + assert!(!cache.contains_key(&key2)); } #[test] fn test_entry_update_with_size_change() { let (path1, value1, size) = create_test_list_files_entry("path1", 1, 100); - let (path2, value2, _) = create_test_list_files_entry("path2", 1, 100); + let (path2, value2, size2) = create_test_list_files_entry("path2", 1, 100); let (path3, value3_v1, _) = create_test_list_files_entry("path3", 1, 100); let cache = DefaultListFilesCache::new(size * 3, None); + let table_ref = Some(TableReference::from("table")); + let key1 = TableScopedPath { + table: table_ref.clone(), + path: path1, + }; + let key2 = TableScopedPath { + table: table_ref.clone(), + path: path2, + }; + let key3 = TableScopedPath { + table: table_ref, + path: path3, + }; // Add three entries - cache.put(&path1, value1); - cache.put(&path2, value2); - cache.put(&path3, value3_v1); + cache.put(&key1, value1); + cache.put(&key2, value2.clone()); + cache.put(&key3, value3_v1); assert_eq!(cache.len(), 3); // Update path3 with same size - should not cause eviction let (_, value3_v2, _) = create_test_list_files_entry("path3", 1, 100); - cache.put(&path3, value3_v2); + cache.put(&key3, value3_v2); assert_eq!(cache.len(), 3); - assert!(cache.contains_key(&path1)); - assert!(cache.contains_key(&path2)); - assert!(cache.contains_key(&path3)); + assert!(cache.contains_key(&key1)); + assert!(cache.contains_key(&key2)); + assert!(cache.contains_key(&key3)); // Update path3 with larger size that requires evicting path1 (LRU) - let (_, value3_v3, _) = create_test_list_files_entry("path3", 1, 200); - cache.put(&path3, value3_v3); + let (_, value3_v3, size3_v3) = create_test_list_files_entry("path3", 1, 200); + cache.put(&key3, value3_v3.clone()); assert_eq!(cache.len(), 2); - assert!(!cache.contains_key(&path1)); // Evicted (was LRU) - assert!(cache.contains_key(&path2)); - assert!(cache.contains_key(&path3)); + assert!(!cache.contains_key(&key1)); // Evicted (was LRU) + assert!(cache.contains_key(&key2)); + assert!(cache.contains_key(&key3)); + + // List cache entries + assert_eq!( + cache.list_entries(), + HashMap::from([ + ( + key2, + ListFilesEntry { + metas: value2, + size_bytes: size2, + expires: None, + } + ), + ( + key3, + ListFilesEntry { + metas: value3_v3, + size_bytes: size3_v3, + expires: None, + } + ) + ]) + ); } #[test] fn test_cache_with_ttl() { let ttl = Duration::from_millis(100); - let cache = DefaultListFilesCache::new(10000, Some(ttl)); - let (path1, value1, _) = create_test_list_files_entry("path1", 2, 50); - let (path2, value2, _) = create_test_list_files_entry("path2", 2, 50); + let mock_time = Arc::new(MockTimeProvider::new()); + let cache = DefaultListFilesCache::new(10000, Some(ttl)) + .with_time_provider(Arc::clone(&mock_time) as Arc); - cache.put(&path1, value1); - cache.put(&path2, value2); + let (path1, value1, size1) = create_test_list_files_entry("path1", 2, 50); + let (path2, value2, size2) = create_test_list_files_entry("path2", 2, 50); - // Entries should be accessible immediately - assert!(cache.get(&path1).is_some()); - assert!(cache.get(&path2).is_some()); - assert!(cache.contains_key(&path1)); - assert!(cache.contains_key(&path2)); - assert_eq!(cache.len(), 2); + let table_ref = Some(TableReference::from("table")); + let key1 = TableScopedPath { + table: table_ref.clone(), + path: path1, + }; + let key2 = TableScopedPath { + table: table_ref, + path: path2, + }; + cache.put(&key1, value1.clone()); + cache.put(&key2, value2.clone()); + // Entries should be accessible immediately + assert!(cache.get(&key1).is_some()); + assert!(cache.get(&key2).is_some()); + // List cache entries + assert_eq!( + cache.list_entries(), + HashMap::from([ + ( + key1.clone(), + ListFilesEntry { + metas: value1, + size_bytes: size1, + expires: mock_time.now().checked_add(ttl), + } + ), + ( + key2.clone(), + ListFilesEntry { + metas: value2, + size_bytes: size2, + expires: mock_time.now().checked_add(ttl), + } + ) + ]) + ); // Wait for TTL to expire - thread::sleep(Duration::from_millis(150)); + mock_time.inc(Duration::from_millis(150)); - // Entries should now return None and be removed when observed through get or contains_key - assert!(cache.get(&path1).is_none()); - assert_eq!(cache.len(), 1); // path1 was removed by get() - assert!(!cache.contains_key(&path2)); - assert_eq!(cache.len(), 0); // path2 was removed by contains_key() + // Entries should now return None when observed through contains_key + assert!(!cache.contains_key(&key1)); + assert_eq!(cache.len(), 1); // key1 was removed by contains_key() + assert!(!cache.contains_key(&key2)); + assert_eq!(cache.len(), 0); // key2 was removed by contains_key() } #[test] @@ -743,21 +876,62 @@ mod tests { let (path2, value2, _) = create_test_list_files_entry("path2", 1, 400); let (path3, value3, _) = create_test_list_files_entry("path3", 1, 400); - cache.put(&path1, value1); + let table_ref = Some(TableReference::from("table")); + let key1 = TableScopedPath { + table: table_ref.clone(), + path: path1, + }; + let key2 = TableScopedPath { + table: table_ref.clone(), + path: path2, + }; + let key3 = TableScopedPath { + table: table_ref, + path: path3, + }; + cache.put(&key1, value1); mock_time.inc(Duration::from_millis(50)); - cache.put(&path2, value2); + cache.put(&key2, value2); mock_time.inc(Duration::from_millis(50)); // path3 should evict path1 due to size limit - cache.put(&path3, value3); - assert!(!cache.contains_key(&path1)); // Evicted by LRU - assert!(cache.contains_key(&path2)); - assert!(cache.contains_key(&path3)); + cache.put(&key3, value3); + assert!(!cache.contains_key(&key1)); // Evicted by LRU + assert!(cache.contains_key(&key2)); + assert!(cache.contains_key(&key3)); mock_time.inc(Duration::from_millis(151)); - assert!(!cache.contains_key(&path2)); // Expired - assert!(cache.contains_key(&path3)); // Still valid + assert!(!cache.contains_key(&key2)); // Expired + assert!(cache.contains_key(&key3)); // Still valid + } + + #[test] + fn test_ttl_expiration_in_get() { + let ttl = Duration::from_millis(100); + let cache = DefaultListFilesCache::new(10000, Some(ttl)); + + let (path, value, _) = create_test_list_files_entry("path", 2, 50); + let table_ref = Some(TableReference::from("table")); + let key = TableScopedPath { + table: table_ref, + path, + }; + + // Cache the entry + cache.put(&key, value.clone()); + + // Entry should be accessible immediately + let result = cache.get(&key); + assert!(result.is_some()); + assert_eq!(result.unwrap().files.len(), 2); + + // Wait for TTL to expire + thread::sleep(Duration::from_millis(150)); + + // Get should return None because entry expired + let result2 = cache.get(&key); + assert!(result2.is_none()); } #[test] @@ -806,28 +980,28 @@ mod tests { #[test] fn test_entry_creation() { // Test with empty vector - let empty_vec: Arc> = Arc::new(vec![]); + let empty_list = CachedFileList::new(vec![]); let now = Instant::now(); - let entry = ListFilesEntry::try_new(empty_vec, None, now); + let entry = ListFilesEntry::try_new(empty_list, None, now); assert!(entry.is_none()); // Validate entry size let metas: Vec = (0..5) .map(|i| create_test_object_meta(&format!("file{i}"), 30)) .collect(); - let metas = Arc::new(metas); - let entry = ListFilesEntry::try_new(metas, None, now).unwrap(); - assert_eq!(entry.metas.len(), 5); + let cached_list = CachedFileList::new(metas); + let entry = ListFilesEntry::try_new(cached_list, None, now).unwrap(); + assert_eq!(entry.metas.files.len(), 5); // Size should be: capacity * sizeof(ObjectMeta) + (5 * 30) for heap bytes - let expected_size = - (entry.metas.capacity() * size_of::()) + (entry.metas.len() * 30); + let expected_size = (entry.metas.files.capacity() * size_of::()) + + (entry.metas.files.len() * 30); assert_eq!(entry.size_bytes, expected_size); // Test with TTL let meta = create_test_object_meta("file", 50); let ttl = Duration::from_secs(10); - let entry = - ListFilesEntry::try_new(Arc::new(vec![meta]), Some(ttl), now).unwrap(); + let cached_list = CachedFileList::new(vec![meta]); + let entry = ListFilesEntry::try_new(cached_list, Some(ttl), now).unwrap(); assert!(entry.expires.unwrap() > now); } @@ -843,7 +1017,12 @@ mod tests { // Add entry and verify memory tracking let (path1, value1, size1) = create_test_list_files_entry("path1", 1, 100); - cache.put(&path1, value1); + let table_ref = Some(TableReference::from("table")); + let key1 = TableScopedPath { + table: table_ref.clone(), + path: path1, + }; + cache.put(&key1, value1); { let state = cache.state.lock().unwrap(); assert_eq!(state.memory_used, size1); @@ -851,14 +1030,18 @@ mod tests { // Add another entry let (path2, value2, size2) = create_test_list_files_entry("path2", 1, 200); - cache.put(&path2, value2); + let key2 = TableScopedPath { + table: table_ref.clone(), + path: path2, + }; + cache.put(&key2, value2); { let state = cache.state.lock().unwrap(); assert_eq!(state.memory_used, size1 + size2); } // Remove first entry and verify memory decreases - cache.remove(&path1); + cache.remove(&key1); { let state = cache.state.lock().unwrap(); assert_eq!(state.memory_used, size2); @@ -872,7 +1055,7 @@ mod tests { } } - // Prefix-aware cache tests + // Prefix filtering tests using CachedFileList::filter_by_prefix /// Helper function to create ObjectMeta with a specific location path fn create_object_meta_with_path(location: &str) -> ObjectMeta { @@ -888,30 +1071,31 @@ mod tests { } #[test] - fn test_prefix_aware_cache_hit() { - // Scenario: Cache has full table listing, query for partition returns filtered results + fn test_prefix_filtering() { let cache = DefaultListFilesCache::new(100000, None); // Create files for a partitioned table let table_base = Path::from("my_table"); - let files = Arc::new(vec![ + let files = vec![ create_object_meta_with_path("my_table/a=1/file1.parquet"), create_object_meta_with_path("my_table/a=1/file2.parquet"), create_object_meta_with_path("my_table/a=2/file3.parquet"), create_object_meta_with_path("my_table/a=2/file4.parquet"), - ]); + ]; // Cache the full table listing - cache.put(&table_base, files); + let table_ref = Some(TableReference::from("table")); + let key = TableScopedPath { + table: table_ref, + path: table_base, + }; + cache.put(&key, CachedFileList::new(files)); - // Query for partition a=1 using get_with_extra - // New API: get_with_extra(table_base, Some(relative_prefix)) - let prefix_a1 = Some(Path::from("a=1")); - let result = cache.get_with_extra(&table_base, &prefix_a1); + let result = cache.get(&key).unwrap(); - // Should return filtered results (only files from a=1) - assert!(result.is_some()); - let filtered = result.unwrap(); + // Filter for partition a=1 + let prefix_a1 = Some(Path::from("my_table/a=1")); + let filtered = result.files_matching_prefix(&prefix_a1); assert_eq!(filtered.len(), 2); assert!( filtered @@ -919,92 +1103,51 @@ mod tests { .all(|m| m.location.as_ref().starts_with("my_table/a=1")) ); - // Query for partition a=2 - let prefix_a2 = Some(Path::from("a=2")); - let result_2 = cache.get_with_extra(&table_base, &prefix_a2); - - assert!(result_2.is_some()); - let filtered_2 = result_2.unwrap(); + // Filter for partition a=2 + let prefix_a2 = Some(Path::from("my_table/a=2")); + let filtered_2 = result.files_matching_prefix(&prefix_a2); assert_eq!(filtered_2.len(), 2); assert!( filtered_2 .iter() .all(|m| m.location.as_ref().starts_with("my_table/a=2")) ); - } - - #[test] - fn test_prefix_aware_cache_no_filter_returns_all() { - // Scenario: Query with no prefix filter should return all files - let cache = DefaultListFilesCache::new(100000, None); - - let table_base = Path::from("my_table"); - - // Cache full table listing with 4 files - let full_files = Arc::new(vec![ - create_object_meta_with_path("my_table/a=1/file1.parquet"), - create_object_meta_with_path("my_table/a=1/file2.parquet"), - create_object_meta_with_path("my_table/a=2/file3.parquet"), - create_object_meta_with_path("my_table/a=2/file4.parquet"), - ]); - cache.put(&table_base, full_files); - - // Query with no prefix filter (None) should return all 4 files - let result = cache.get_with_extra(&table_base, &None); - assert!(result.is_some()); - let files = result.unwrap(); - assert_eq!(files.len(), 4); - - // Also test using get() which delegates to get_with_extra(&None) - let result_get = cache.get(&table_base); - assert!(result_get.is_some()); - assert_eq!(result_get.unwrap().len(), 4); - } - - #[test] - fn test_prefix_aware_cache_miss_no_entry() { - // Scenario: Table not cached, query should miss - let cache = DefaultListFilesCache::new(100000, None); - - let table_base = Path::from("my_table"); - // Query for full table should miss (nothing cached) - let result = cache.get_with_extra(&table_base, &None); - assert!(result.is_none()); - - // Query with prefix should also miss - let prefix = Some(Path::from("a=1")); - let result_2 = cache.get_with_extra(&table_base, &prefix); - assert!(result_2.is_none()); + // No filter returns all + let all = result.files_matching_prefix(&None); + assert_eq!(all.len(), 4); } #[test] - fn test_prefix_aware_cache_no_matching_files() { - // Scenario: Cache has table listing but no files match the requested partition + fn test_prefix_no_matching_files() { let cache = DefaultListFilesCache::new(100000, None); let table_base = Path::from("my_table"); - let files = Arc::new(vec![ + let files = vec![ create_object_meta_with_path("my_table/a=1/file1.parquet"), create_object_meta_with_path("my_table/a=2/file2.parquet"), - ]); - cache.put(&table_base, files); + ]; - // Query for partition a=3 which doesn't exist - let prefix_a3 = Some(Path::from("a=3")); - let result = cache.get_with_extra(&table_base, &prefix_a3); + let table_ref = Some(TableReference::from("table")); + let key = TableScopedPath { + table: table_ref, + path: table_base, + }; + cache.put(&key, CachedFileList::new(files)); + let result = cache.get(&key).unwrap(); - // Should return None since no files match - assert!(result.is_none()); + // Query for partition a=3 which doesn't exist + let prefix_a3 = Some(Path::from("my_table/a=3")); + let filtered = result.files_matching_prefix(&prefix_a3); + assert!(filtered.is_empty()); } #[test] - fn test_prefix_aware_nested_partitions() { - // Scenario: Table with multiple partition levels (e.g., year/month/day) + fn test_nested_partitions() { let cache = DefaultListFilesCache::new(100000, None); let table_base = Path::from("events"); - let files = Arc::new(vec![ + let files = vec![ create_object_meta_with_path( "events/year=2024/month=01/day=01/file1.parquet", ), @@ -1017,56 +1160,59 @@ mod tests { create_object_meta_with_path( "events/year=2025/month=01/day=01/file4.parquet", ), - ]); - cache.put(&table_base, files); + ]; - // Query for year=2024/month=01 (should get 2 files) - let prefix_month = Some(Path::from("year=2024/month=01")); - let result = cache.get_with_extra(&table_base, &prefix_month); - assert!(result.is_some()); - assert_eq!(result.unwrap().len(), 2); - - // Query for year=2024 (should get 3 files) - let prefix_year = Some(Path::from("year=2024")); - let result_year = cache.get_with_extra(&table_base, &prefix_year); - assert!(result_year.is_some()); - assert_eq!(result_year.unwrap().len(), 3); - - // Query for specific day (should get 1 file) - let prefix_day = Some(Path::from("year=2024/month=01/day=01")); - let result_day = cache.get_with_extra(&table_base, &prefix_day); - assert!(result_day.is_some()); - assert_eq!(result_day.unwrap().len(), 1); + let table_ref = Some(TableReference::from("table")); + let key = TableScopedPath { + table: table_ref, + path: table_base, + }; + cache.put(&key, CachedFileList::new(files)); + let result = cache.get(&key).unwrap(); + + // Filter for year=2024/month=01 + let prefix_month = Some(Path::from("events/year=2024/month=01")); + let filtered = result.files_matching_prefix(&prefix_month); + assert_eq!(filtered.len(), 2); + + // Filter for year=2024 + let prefix_year = Some(Path::from("events/year=2024")); + let filtered_year = result.files_matching_prefix(&prefix_year); + assert_eq!(filtered_year.len(), 3); } #[test] - fn test_prefix_aware_different_tables() { - // Scenario: Multiple tables cached, queries should not cross-contaminate - let cache = DefaultListFilesCache::new(100000, None); + fn test_drop_table_entries() { + let cache = DefaultListFilesCache::default(); + + let (path1, value1, _) = create_test_list_files_entry("path1", 1, 100); + let (path2, value2, _) = create_test_list_files_entry("path2", 1, 100); + let (path3, value3, _) = create_test_list_files_entry("path3", 1, 100); + + let table_ref1 = Some(TableReference::from("table1")); + let key1 = TableScopedPath { + table: table_ref1.clone(), + path: path1, + }; + let key2 = TableScopedPath { + table: table_ref1.clone(), + path: path2, + }; + + let table_ref2 = Some(TableReference::from("table2")); + let key3 = TableScopedPath { + table: table_ref2.clone(), + path: path3, + }; + + cache.put(&key1, value1); + cache.put(&key2, value2); + cache.put(&key3, value3); + + cache.drop_table_entries(&table_ref1).unwrap(); - let table_a = Path::from("table_a"); - let table_b = Path::from("table_b"); - - let files_a = Arc::new(vec![create_object_meta_with_path( - "table_a/part=1/file1.parquet", - )]); - let files_b = Arc::new(vec![ - create_object_meta_with_path("table_b/part=1/file1.parquet"), - create_object_meta_with_path("table_b/part=2/file2.parquet"), - ]); - - cache.put(&table_a, files_a); - cache.put(&table_b, files_b); - - // Query table_a should only return table_a files - let result_a = cache.get(&table_a); - assert!(result_a.is_some()); - assert_eq!(result_a.unwrap().len(), 1); - - // Query table_b with prefix should only return matching table_b files - let prefix = Some(Path::from("part=1")); - let result_b = cache.get_with_extra(&table_b, &prefix); - assert!(result_b.is_some()); - assert_eq!(result_b.unwrap().len(), 1); + assert!(!cache.contains_key(&key1)); + assert!(!cache.contains_key(&key2)); + assert!(cache.contains_key(&key3)); } } diff --git a/datafusion/execution/src/cache/mod.rs b/datafusion/execution/src/cache/mod.rs index 8172069fdbab..0380e50c0935 100644 --- a/datafusion/execution/src/cache/mod.rs +++ b/datafusion/execution/src/cache/mod.rs @@ -24,36 +24,57 @@ mod list_files_cache; pub use file_metadata_cache::DefaultFilesMetadataCache; pub use list_files_cache::DefaultListFilesCache; +pub use list_files_cache::ListFilesEntry; +pub use list_files_cache::TableScopedPath; -/// A trait that can be implemented to provide custom cache behavior for the caches managed by -/// [`cache_manager::CacheManager`]. +/// Base trait for cache implementations with common operations. +/// +/// This trait provides the fundamental cache operations (`get`, `put`, `remove`, etc.) +/// that all cache types share. Specific cache traits like [`cache_manager::FileStatisticsCache`], +/// [`cache_manager::ListFilesCache`], and [`cache_manager::FileMetadataCache`] extend this +/// trait with their specialized methods. +/// +/// ## Thread Safety /// /// Implementations must handle their own locking via internal mutability, as methods do not /// take mutable references and may be accessed by multiple concurrent queries. +/// +/// ## Validation Pattern +/// +/// Validation metadata (e.g., file size, last modified time) should be embedded in the +/// value type `V`. The typical usage pattern is: +/// 1. Call `get(key)` to check for cached value +/// 2. If `Some(cached)`, validate with `cached.is_valid_for(¤t_meta)` +/// 3. If invalid or missing, compute new value and call `put(key, new_value)` pub trait CacheAccessor: Send + Sync { - // Extra info but not part of the cache key or cache value. - type Extra: Clone; - - /// Get value from cache. - fn get(&self, k: &K) -> Option; - /// Get value from cache. - fn get_with_extra(&self, k: &K, e: &Self::Extra) -> Option; - /// Put value into cache. Returns the old value associated with the key if there was one. + /// Get a cached entry if it exists. + /// + /// Returns the cached value without any validation. The caller should + /// validate the returned value if freshness matters. + fn get(&self, key: &K) -> Option; + + /// Store a value in the cache. + /// + /// Returns the previous value if one existed. fn put(&self, key: &K, value: V) -> Option; - /// Put value into cache. Returns the old value associated with the key if there was one. - fn put_with_extra(&self, key: &K, value: V, e: &Self::Extra) -> Option; - /// Remove an entry from the cache, returning value if they existed in the map. + + /// Remove an entry from the cache, returning the value if it existed. fn remove(&self, k: &K) -> Option; + /// Check if the cache contains a specific key. fn contains_key(&self, k: &K) -> bool; + /// Fetch the total number of cache entries. fn len(&self) -> usize; - /// Check if the Cache collection is empty or not. + + /// Check if the cache collection is empty. fn is_empty(&self) -> bool { self.len() == 0 } + /// Remove all entries from the cache. fn clear(&self); + /// Return the cache name. fn name(&self) -> String; } diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index 30ba7de76a47..854d23923676 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -480,6 +480,12 @@ impl SessionConfig { self.options.execution.enforce_batch_size_in_joins } + /// Toggle SQL ANSI mode for expressions, casting, and error handling + pub fn with_enable_ansi_mode(mut self, enable_ansi_mode: bool) -> Self { + self.options_mut().execution.enable_ansi_mode = enable_ansi_mode; + self + } + /// Convert configuration options to name-value pairs with values /// converted to strings. /// diff --git a/datafusion/execution/src/disk_manager.rs b/datafusion/execution/src/disk_manager.rs index cb87053d8d03..d878fdcf66a4 100644 --- a/datafusion/execution/src/disk_manager.rs +++ b/datafusion/execution/src/disk_manager.rs @@ -25,7 +25,7 @@ use parking_lot::Mutex; use rand::{Rng, rng}; use std::path::{Path, PathBuf}; use std::sync::Arc; -use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; use tempfile::{Builder, NamedTempFile, TempDir}; use datafusion_common::human_readable_size; @@ -77,6 +77,7 @@ impl DiskManagerBuilder { local_dirs: Mutex::new(Some(vec![])), max_temp_directory_size: self.max_temp_directory_size, used_disk_space: Arc::new(AtomicU64::new(0)), + active_files_count: Arc::new(AtomicUsize::new(0)), }), DiskManagerMode::Directories(conf_dirs) => { let local_dirs = create_local_dirs(&conf_dirs)?; @@ -87,12 +88,14 @@ impl DiskManagerBuilder { local_dirs: Mutex::new(Some(local_dirs)), max_temp_directory_size: self.max_temp_directory_size, used_disk_space: Arc::new(AtomicU64::new(0)), + active_files_count: Arc::new(AtomicUsize::new(0)), }) } DiskManagerMode::Disabled => Ok(DiskManager { local_dirs: Mutex::new(None), max_temp_directory_size: self.max_temp_directory_size, used_disk_space: Arc::new(AtomicU64::new(0)), + active_files_count: Arc::new(AtomicUsize::new(0)), }), } } @@ -169,6 +172,17 @@ pub struct DiskManager { /// Used disk space in the temporary directories. Now only spilled data for /// external executors are counted. used_disk_space: Arc, + /// Number of active temporary files created by this disk manager + active_files_count: Arc, +} + +/// Information about the current disk usage for spilling +#[derive(Debug, Clone, Copy)] +pub struct SpillingProgress { + /// Total bytes currently used on disk for spilling + pub current_bytes: u64, + /// Total number of active spill files + pub active_files_count: usize, } impl DiskManager { @@ -187,6 +201,7 @@ impl DiskManager { local_dirs: Mutex::new(Some(vec![])), max_temp_directory_size: DEFAULT_MAX_TEMP_DIRECTORY_SIZE, used_disk_space: Arc::new(AtomicU64::new(0)), + active_files_count: Arc::new(AtomicUsize::new(0)), })), DiskManagerConfig::NewSpecified(conf_dirs) => { let local_dirs = create_local_dirs(&conf_dirs)?; @@ -197,12 +212,14 @@ impl DiskManager { local_dirs: Mutex::new(Some(local_dirs)), max_temp_directory_size: DEFAULT_MAX_TEMP_DIRECTORY_SIZE, used_disk_space: Arc::new(AtomicU64::new(0)), + active_files_count: Arc::new(AtomicUsize::new(0)), })) } DiskManagerConfig::Disabled => Ok(Arc::new(Self { local_dirs: Mutex::new(None), max_temp_directory_size: DEFAULT_MAX_TEMP_DIRECTORY_SIZE, used_disk_space: Arc::new(AtomicU64::new(0)), + active_files_count: Arc::new(AtomicUsize::new(0)), })), } } @@ -252,6 +269,14 @@ impl DiskManager { self.max_temp_directory_size } + /// Returns the current spilling progress + pub fn spilling_progress(&self) -> SpillingProgress { + SpillingProgress { + current_bytes: self.used_disk_space.load(Ordering::Relaxed), + active_files_count: self.active_files_count.load(Ordering::Relaxed), + } + } + /// Returns the temporary directory paths pub fn temp_dir_paths(&self) -> Vec { self.local_dirs @@ -301,6 +326,7 @@ impl DiskManager { } let dir_index = rng().random_range(0..local_dirs.len()); + self.active_files_count.fetch_add(1, Ordering::Relaxed); Ok(RefCountedTempFile { parent_temp_dir: Arc::clone(&local_dirs[dir_index]), tempfile: Arc::new( @@ -422,6 +448,9 @@ impl Drop for RefCountedTempFile { self.disk_manager .used_disk_space .fetch_sub(current_usage, Ordering::Relaxed); + self.disk_manager + .active_files_count + .fetch_sub(1, Ordering::Relaxed); } } } diff --git a/datafusion/execution/src/lib.rs b/datafusion/execution/src/lib.rs index aced2f46d722..1a8da9459ae1 100644 --- a/datafusion/execution/src/lib.rs +++ b/datafusion/execution/src/lib.rs @@ -24,7 +24,6 @@ // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -#![deny(clippy::allow_attributes)] //! DataFusion execution configuration and runtime structures diff --git a/datafusion/execution/src/memory_pool/arrow.rs b/datafusion/execution/src/memory_pool/arrow.rs new file mode 100644 index 000000000000..4e8d986f1f5e --- /dev/null +++ b/datafusion/execution/src/memory_pool/arrow.rs @@ -0,0 +1,142 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Adapter for integrating DataFusion's [`MemoryPool`] with Arrow's memory tracking APIs. + +use crate::memory_pool::{MemoryConsumer, MemoryLimit, MemoryPool, MemoryReservation}; +use std::fmt::Debug; +use std::sync::Arc; + +/// An adapter that implements Arrow's [`arrow_buffer::MemoryPool`] trait +/// by wrapping a DataFusion [`MemoryPool`]. +/// +/// This allows DataFusion's memory management system to be used with Arrow's +/// memory allocation APIs. Each reservation made through this pool will be +/// tracked using the provided [`MemoryConsumer`], enabling DataFusion to +/// monitor and limit memory usage across Arrow operations. +/// +/// This is useful when you want Arrow operations (such as array builders +/// or compute kernels) to participate in DataFusion's memory management +/// and respect the same memory limits as DataFusion operators. +#[derive(Debug)] +pub struct ArrowMemoryPool { + inner: Arc, + consumer: MemoryConsumer, +} + +impl ArrowMemoryPool { + /// Creates a new [`ArrowMemoryPool`] that wraps the given DataFusion [`MemoryPool`] + /// and tracks allocations under the specified [`MemoryConsumer`]. + pub fn new(inner: Arc, consumer: MemoryConsumer) -> Self { + Self { inner, consumer } + } +} + +impl arrow_buffer::MemoryReservation for MemoryReservation { + fn size(&self) -> usize { + MemoryReservation::size(self) + } + + fn resize(&mut self, new_size: usize) { + MemoryReservation::resize(self, new_size) + } +} + +impl arrow_buffer::MemoryPool for ArrowMemoryPool { + fn reserve(&self, size: usize) -> Box { + let consumer = self.consumer.clone_with_new_id(); + let mut reservation = consumer.register(&self.inner); + reservation.grow(size); + + Box::new(reservation) + } + + fn available(&self) -> isize { + // The pool may be overfilled, so this method might return a negative value. + (self.capacity() as i128 - self.used() as i128) + .try_into() + .unwrap_or(isize::MIN) + } + + fn used(&self) -> usize { + self.inner.reserved() + } + + fn capacity(&self) -> usize { + match self.inner.memory_limit() { + MemoryLimit::Infinite | MemoryLimit::Unknown => usize::MAX, + MemoryLimit::Finite(capacity) => capacity, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::memory_pool::{GreedyMemoryPool, UnboundedMemoryPool}; + use arrow::array::{Array, Int32Array}; + use arrow_buffer::MemoryPool; + + // Until https://github.com/apache/arrow-rs/pull/8918 lands, we need to iterate all + // buffers in the array. Change once the PR is released. + fn claim_array(array: &dyn Array, pool: &dyn MemoryPool) { + for buffer in array.to_data().buffers() { + buffer.claim(pool); + } + } + + #[test] + pub fn can_claim_array() { + let pool = Arc::new(UnboundedMemoryPool::default()); + + let consumer = MemoryConsumer::new("arrow"); + let arrow_pool = ArrowMemoryPool::new(pool, consumer); + + let array = Int32Array::from(vec![1, 2, 3, 4, 5]); + claim_array(&array, &arrow_pool); + + assert_eq!(arrow_pool.used(), array.get_buffer_memory_size()); + + let slice = array.slice(0, 2); + + // This should be a no-op + claim_array(&slice, &arrow_pool); + + assert_eq!(arrow_pool.used(), array.get_buffer_memory_size()); + } + + #[test] + pub fn can_claim_array_with_finite_limit() { + let pool_capacity = 1024; + let pool = Arc::new(GreedyMemoryPool::new(pool_capacity)); + + let consumer = MemoryConsumer::new("arrow"); + let arrow_pool = ArrowMemoryPool::new(pool, consumer); + + assert_eq!(arrow_pool.capacity(), pool_capacity); + assert_eq!(arrow_pool.available(), pool_capacity as isize); + + let array = Int32Array::from(vec![1, 2, 3, 4, 5]); + claim_array(&array, &arrow_pool); + + assert_eq!(arrow_pool.used(), array.get_buffer_memory_size()); + assert_eq!( + arrow_pool.available(), + (pool_capacity - array.get_buffer_memory_size()) as isize + ); + } +} diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index fbf9ce41da8f..a544cdfdb02e 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -18,11 +18,15 @@ //! [`MemoryPool`] for memory management during query execution, [`proxy`] for //! help with allocation accounting. -use datafusion_common::{Result, internal_err}; +use datafusion_common::{Result, internal_datafusion_err}; use std::hash::{Hash, Hasher}; use std::{cmp::Ordering, sync::Arc, sync::atomic}; mod pool; + +#[cfg(feature = "arrow_buffer_pool")] +pub mod arrow; + pub mod proxy { pub use datafusion_common::utils::proxy::{HashTableAllocExt, VecAllocExt}; } @@ -322,7 +326,7 @@ impl MemoryConsumer { pool: Arc::clone(pool), consumer: self, }), - size: 0, + size: atomic::AtomicUsize::new(0), } } } @@ -351,13 +355,13 @@ impl Drop for SharedRegistration { #[derive(Debug)] pub struct MemoryReservation { registration: Arc, - size: usize, + size: atomic::AtomicUsize, } impl MemoryReservation { /// Returns the size of this reservation in bytes pub fn size(&self) -> usize { - self.size + self.size.load(atomic::Ordering::Relaxed) } /// Returns [MemoryConsumer] for this [MemoryReservation] @@ -367,10 +371,10 @@ impl MemoryReservation { /// Frees all bytes from this reservation back to the underlying /// pool, returning the number of bytes freed. - pub fn free(&mut self) -> usize { - let size = self.size; + pub fn free(&self) -> usize { + let size = self.size.swap(0, atomic::Ordering::Relaxed); if size != 0 { - self.shrink(size) + self.registration.pool.shrink(self, size); } size } @@ -380,60 +384,76 @@ impl MemoryReservation { /// # Panics /// /// Panics if `capacity` exceeds [`Self::size`] - pub fn shrink(&mut self, capacity: usize) { - let new_size = self.size.checked_sub(capacity).unwrap(); + pub fn shrink(&self, capacity: usize) { + self.size + .fetch_update( + atomic::Ordering::Relaxed, + atomic::Ordering::Relaxed, + |prev| prev.checked_sub(capacity), + ) + .unwrap_or_else(|prev| { + panic!("Cannot free the capacity {capacity} out of allocated size {prev}") + }); self.registration.pool.shrink(self, capacity); - self.size = new_size } /// Tries to free `capacity` bytes from this reservation - /// if `capacity` does not exceed [`Self::size`] - /// Returns new reservation size - /// or error if shrinking capacity is more than allocated size - pub fn try_shrink(&mut self, capacity: usize) -> Result { - if let Some(new_size) = self.size.checked_sub(capacity) { - self.registration.pool.shrink(self, capacity); - self.size = new_size; - Ok(new_size) - } else { - internal_err!( - "Cannot free the capacity {capacity} out of allocated size {}", - self.size + /// if `capacity` does not exceed [`Self::size`]. + /// Returns new reservation size, + /// or error if shrinking capacity is more than allocated size. + pub fn try_shrink(&self, capacity: usize) -> Result { + let prev = self + .size + .fetch_update( + atomic::Ordering::Relaxed, + atomic::Ordering::Relaxed, + |prev| prev.checked_sub(capacity), ) - } + .map_err(|prev| { + internal_datafusion_err!( + "Cannot free the capacity {capacity} out of allocated size {prev}" + ) + })?; + + self.registration.pool.shrink(self, capacity); + Ok(prev - capacity) } /// Sets the size of this reservation to `capacity` - pub fn resize(&mut self, capacity: usize) { - match capacity.cmp(&self.size) { - Ordering::Greater => self.grow(capacity - self.size), - Ordering::Less => self.shrink(self.size - capacity), + pub fn resize(&self, capacity: usize) { + let size = self.size.load(atomic::Ordering::Relaxed); + match capacity.cmp(&size) { + Ordering::Greater => self.grow(capacity - size), + Ordering::Less => self.shrink(size - capacity), _ => {} } } /// Try to set the size of this reservation to `capacity` - pub fn try_resize(&mut self, capacity: usize) -> Result<()> { - match capacity.cmp(&self.size) { - Ordering::Greater => self.try_grow(capacity - self.size)?, - Ordering::Less => self.shrink(self.size - capacity), + pub fn try_resize(&self, capacity: usize) -> Result<()> { + let size = self.size.load(atomic::Ordering::Relaxed); + match capacity.cmp(&size) { + Ordering::Greater => self.try_grow(capacity - size)?, + Ordering::Less => { + self.try_shrink(size - capacity)?; + } _ => {} }; Ok(()) } /// Increase the size of this reservation by `capacity` bytes - pub fn grow(&mut self, capacity: usize) { + pub fn grow(&self, capacity: usize) { self.registration.pool.grow(self, capacity); - self.size += capacity; + self.size.fetch_add(capacity, atomic::Ordering::Relaxed); } /// Try to increase the size of this reservation by `capacity` /// bytes, returning error if there is insufficient capacity left /// in the pool. - pub fn try_grow(&mut self, capacity: usize) -> Result<()> { + pub fn try_grow(&self, capacity: usize) -> Result<()> { self.registration.pool.try_grow(self, capacity)?; - self.size += capacity; + self.size.fetch_add(capacity, atomic::Ordering::Relaxed); Ok(()) } @@ -447,10 +467,16 @@ impl MemoryReservation { /// # Panics /// /// Panics if `capacity` exceeds [`Self::size`] - pub fn split(&mut self, capacity: usize) -> MemoryReservation { - self.size = self.size.checked_sub(capacity).unwrap(); + pub fn split(&self, capacity: usize) -> MemoryReservation { + self.size + .fetch_update( + atomic::Ordering::Relaxed, + atomic::Ordering::Relaxed, + |prev| prev.checked_sub(capacity), + ) + .unwrap(); Self { - size: capacity, + size: atomic::AtomicUsize::new(capacity), registration: Arc::clone(&self.registration), } } @@ -458,7 +484,7 @@ impl MemoryReservation { /// Returns a new empty [`MemoryReservation`] with the same [`MemoryConsumer`] pub fn new_empty(&self) -> Self { Self { - size: 0, + size: atomic::AtomicUsize::new(0), registration: Arc::clone(&self.registration), } } @@ -466,7 +492,7 @@ impl MemoryReservation { /// Splits off all the bytes from this [`MemoryReservation`] into /// a new [`MemoryReservation`] with the same [`MemoryConsumer`] pub fn take(&mut self) -> MemoryReservation { - self.split(self.size) + self.split(self.size.load(atomic::Ordering::Relaxed)) } } @@ -492,7 +518,7 @@ mod tests { #[test] fn test_memory_pool_underflow() { let pool = Arc::new(GreedyMemoryPool::new(50)) as _; - let mut a1 = MemoryConsumer::new("a1").register(&pool); + let a1 = MemoryConsumer::new("a1").register(&pool); assert_eq!(pool.reserved(), 0); a1.grow(100); @@ -507,7 +533,7 @@ mod tests { a1.try_grow(30).unwrap(); assert_eq!(pool.reserved(), 30); - let mut a2 = MemoryConsumer::new("a2").register(&pool); + let a2 = MemoryConsumer::new("a2").register(&pool); a2.try_grow(25).unwrap_err(); assert_eq!(pool.reserved(), 30); @@ -521,7 +547,7 @@ mod tests { #[test] fn test_split() { let pool = Arc::new(GreedyMemoryPool::new(50)) as _; - let mut r1 = MemoryConsumer::new("r1").register(&pool); + let r1 = MemoryConsumer::new("r1").register(&pool); r1.try_grow(20).unwrap(); assert_eq!(r1.size(), 20); @@ -542,10 +568,10 @@ mod tests { #[test] fn test_new_empty() { let pool = Arc::new(GreedyMemoryPool::new(50)) as _; - let mut r1 = MemoryConsumer::new("r1").register(&pool); + let r1 = MemoryConsumer::new("r1").register(&pool); r1.try_grow(20).unwrap(); - let mut r2 = r1.new_empty(); + let r2 = r1.new_empty(); r2.try_grow(5).unwrap(); assert_eq!(r1.size(), 20); @@ -559,7 +585,7 @@ mod tests { let mut r1 = MemoryConsumer::new("r1").register(&pool); r1.try_grow(20).unwrap(); - let mut r2 = r1.take(); + let r2 = r1.take(); r2.try_grow(5).unwrap(); assert_eq!(r1.size(), 0); @@ -572,4 +598,37 @@ mod tests { assert_eq!(r2.size(), 25); assert_eq!(pool.reserved(), 28); } + + #[test] + fn test_try_shrink() { + let pool = Arc::new(GreedyMemoryPool::new(100)) as _; + let r1 = MemoryConsumer::new("r1").register(&pool); + + r1.try_grow(50).unwrap(); + assert_eq!(r1.size(), 50); + assert_eq!(pool.reserved(), 50); + + // Successful shrink returns new size and frees pool memory + let new_size = r1.try_shrink(30).unwrap(); + assert_eq!(new_size, 20); + assert_eq!(r1.size(), 20); + assert_eq!(pool.reserved(), 20); + + // Freed pool memory is now available to other consumers + let r2 = MemoryConsumer::new("r2").register(&pool); + r2.try_grow(80).unwrap(); + assert_eq!(pool.reserved(), 100); + + // Shrinking more than allocated fails without changing state + let err = r1.try_shrink(25); + assert!(err.is_err()); + assert_eq!(r1.size(), 20); + assert_eq!(pool.reserved(), 100); + + // Shrink to exactly zero + let new_size = r1.try_shrink(20).unwrap(); + assert_eq!(new_size, 0); + assert_eq!(r1.size(), 0); + assert_eq!(pool.reserved(), 80); + } } diff --git a/datafusion/execution/src/memory_pool/pool.rs b/datafusion/execution/src/memory_pool/pool.rs index bf74b5f6f4c6..b10270851cc0 100644 --- a/datafusion/execution/src/memory_pool/pool.rs +++ b/datafusion/execution/src/memory_pool/pool.rs @@ -212,7 +212,7 @@ impl MemoryPool for FairSpillPool { .checked_div(state.num_spill) .unwrap_or(spill_available); - if reservation.size + additional > available { + if reservation.size() + additional > available { return Err(insufficient_capacity_err( reservation, additional, @@ -264,7 +264,7 @@ fn insufficient_capacity_err( "Failed to allocate additional {} for {} with {} already allocated for this reservation - {} remain available for the total pool", human_readable_size(additional), reservation.registration.consumer.name, - human_readable_size(reservation.size), + human_readable_size(reservation.size()), human_readable_size(available) ) } @@ -526,12 +526,12 @@ mod tests { fn test_fair() { let pool = Arc::new(FairSpillPool::new(100)) as _; - let mut r1 = MemoryConsumer::new("unspillable").register(&pool); + let r1 = MemoryConsumer::new("unspillable").register(&pool); // Can grow beyond capacity of pool r1.grow(2000); assert_eq!(pool.reserved(), 2000); - let mut r2 = MemoryConsumer::new("r2") + let r2 = MemoryConsumer::new("r2") .with_can_spill(true) .register(&pool); // Can grow beyond capacity of pool @@ -563,7 +563,7 @@ mod tests { assert_eq!(r2.size(), 10); assert_eq!(pool.reserved(), 30); - let mut r3 = MemoryConsumer::new("r3") + let r3 = MemoryConsumer::new("r3") .with_can_spill(true) .register(&pool); @@ -584,7 +584,7 @@ mod tests { r1.free(); assert_eq!(pool.reserved(), 80); - let mut r4 = MemoryConsumer::new("s4").register(&pool); + let r4 = MemoryConsumer::new("s4").register(&pool); let err = r4.try_grow(30).unwrap_err().strip_backtrace(); assert_snapshot!(err, @"Resources exhausted: Failed to allocate additional 30.0 B for s4 with 0.0 B already allocated for this reservation - 20.0 B remain available for the total pool"); } @@ -601,18 +601,18 @@ mod tests { // Test: use all the different interfaces to change reservation size // set r1=50, using grow and shrink - let mut r1 = MemoryConsumer::new("r1").register(&pool); + let r1 = MemoryConsumer::new("r1").register(&pool); r1.grow(50); r1.grow(20); r1.shrink(20); // set r2=15 using try_grow - let mut r2 = MemoryConsumer::new("r2").register(&pool); + let r2 = MemoryConsumer::new("r2").register(&pool); r2.try_grow(15) .expect("should succeed in memory allotment for r2"); // set r3=20 using try_resize - let mut r3 = MemoryConsumer::new("r3").register(&pool); + let r3 = MemoryConsumer::new("r3").register(&pool); r3.try_resize(25) .expect("should succeed in memory allotment for r3"); r3.try_resize(20) @@ -620,12 +620,12 @@ mod tests { // set r4=10 // this should not be reported in top 3 - let mut r4 = MemoryConsumer::new("r4").register(&pool); + let r4 = MemoryConsumer::new("r4").register(&pool); r4.grow(10); // Test: reports if new reservation causes error // using the previously set sizes for other consumers - let mut r5 = MemoryConsumer::new("r5").register(&pool); + let r5 = MemoryConsumer::new("r5").register(&pool); let res = r5.try_grow(150); assert!(res.is_err()); let error = res.unwrap_err().strip_backtrace(); @@ -650,7 +650,7 @@ mod tests { let same_name = "foo"; // Test: see error message when no consumers recorded yet - let mut r0 = MemoryConsumer::new(same_name).register(&pool); + let r0 = MemoryConsumer::new(same_name).register(&pool); let res = r0.try_grow(150); assert!(res.is_err()); let error = res.unwrap_err().strip_backtrace(); @@ -665,7 +665,7 @@ mod tests { r0.grow(10); // make r0=10, pool available=90 let new_consumer_same_name = MemoryConsumer::new(same_name); - let mut r1 = new_consumer_same_name.register(&pool); + let r1 = new_consumer_same_name.register(&pool); // TODO: the insufficient_capacity_err() message is per reservation, not per consumer. // a followup PR will clarify this message "0 bytes already allocated for this reservation" let res = r1.try_grow(150); @@ -695,7 +695,7 @@ mod tests { // will be recognized as different in the TrackConsumersPool let consumer_with_same_name_but_different_hash = MemoryConsumer::new(same_name).with_can_spill(true); - let mut r2 = consumer_with_same_name_but_different_hash.register(&pool); + let r2 = consumer_with_same_name_but_different_hash.register(&pool); let res = r2.try_grow(150); assert!(res.is_err()); let error = res.unwrap_err().strip_backtrace(); @@ -714,10 +714,10 @@ mod tests { // Baseline: see the 2 memory consumers let setting = make_settings(); let _bound = setting.bind_to_scope(); - let mut r0 = MemoryConsumer::new("r0").register(&pool); + let r0 = MemoryConsumer::new("r0").register(&pool); r0.grow(10); let r1_consumer = MemoryConsumer::new("r1"); - let mut r1 = r1_consumer.register(&pool); + let r1 = r1_consumer.register(&pool); r1.grow(20); let res = r0.try_grow(150); @@ -791,13 +791,13 @@ mod tests { .downcast::>() .unwrap(); // set r1=20 - let mut r1 = MemoryConsumer::new("r1").register(&pool); + let r1 = MemoryConsumer::new("r1").register(&pool); r1.grow(20); // set r2=15 - let mut r2 = MemoryConsumer::new("r2").register(&pool); + let r2 = MemoryConsumer::new("r2").register(&pool); r2.grow(15); // set r3=45 - let mut r3 = MemoryConsumer::new("r3").register(&pool); + let r3 = MemoryConsumer::new("r3").register(&pool); r3.grow(45); let downcasted = upcasted diff --git a/datafusion/execution/src/runtime_env.rs b/datafusion/execution/src/runtime_env.rs index 67398d59f137..67604c424c76 100644 --- a/datafusion/execution/src/runtime_env.rs +++ b/datafusion/execution/src/runtime_env.rs @@ -19,7 +19,7 @@ //! store, memory manager, disk manager. #[expect(deprecated)] -use crate::disk_manager::DiskManagerConfig; +use crate::disk_manager::{DiskManagerConfig, SpillingProgress}; use crate::{ disk_manager::{DiskManager, DiskManagerBuilder, DiskManagerMode}, memory_pool::{ @@ -199,6 +199,11 @@ impl RuntimeEnv { self.object_store_registry.get_store(url.as_ref()) } + /// Returns the current spilling progress + pub fn spilling_progress(&self) -> SpillingProgress { + self.disk_manager.spilling_progress() + } + /// Register an [`EncryptionFactory`] with an associated identifier that can be later /// used to configure encryption when reading or writing Parquet. /// If an encryption factory with the same identifier was already registered, it is replaced and returned. diff --git a/datafusion/expr-common/src/accumulator.rs b/datafusion/expr-common/src/accumulator.rs index fc4e90114bee..3acf110a0bfc 100644 --- a/datafusion/expr-common/src/accumulator.rs +++ b/datafusion/expr-common/src/accumulator.rs @@ -58,17 +58,30 @@ pub trait Accumulator: Send + Sync + Debug { /// running sum. fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>; - /// Returns the final aggregate value, consuming the internal state. + /// Returns the final aggregate value. /// /// For example, the `SUM` accumulator maintains a running sum, /// and `evaluate` will produce that running sum as its output. /// - /// This function should not be called twice, otherwise it will - /// result in potentially non-deterministic behavior. - /// /// This function gets `&mut self` to allow for the accumulator to build /// arrow-compatible internal state that can be returned without copying - /// when possible (for example distinct strings) + /// when possible (for example distinct strings). + /// + /// ## Correctness + /// + /// This function must not consume the internal state, as it is also used in window + /// aggregate functions where it can be executed multiple times depending on the + /// current window frame. Consuming the internal state can cause the next invocation + /// to have incorrect results. + /// + /// - Even if this accumulator doesn't implement [`retract_batch`] it may still be used + /// in window aggregate functions where the window frame is + /// `ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW` + /// + /// It is fine to modify the state (e.g. re-order elements within internal state vec) so long + /// as this doesn't cause an incorrect computation on the next call of evaluate. + /// + /// [`retract_batch`]: Self::retract_batch fn evaluate(&mut self) -> Result; /// Returns the allocated size required for this accumulator, in diff --git a/datafusion/expr-common/src/columnar_value.rs b/datafusion/expr-common/src/columnar_value.rs index 99c21d4abdb6..1aa42470a148 100644 --- a/datafusion/expr-common/src/columnar_value.rs +++ b/datafusion/expr-common/src/columnar_value.rs @@ -20,7 +20,7 @@ use arrow::{ array::{Array, ArrayRef, Date32Array, Date64Array, NullArray}, compute::{CastOptions, kernels, max, min}, - datatypes::DataType, + datatypes::{DataType, Field}, util::pretty::pretty_format_columns, }; use datafusion_common::internal_datafusion_err; @@ -274,7 +274,17 @@ impl ColumnarValue { Ok(args) } - /// Cast's this [ColumnarValue] to the specified `DataType` + /// Cast this [ColumnarValue] to the specified `DataType` + /// + /// # Struct Casting Behavior + /// + /// When casting struct types, fields are matched **by name** rather than position: + /// - Source fields are matched to target fields using case-sensitive name comparison + /// - Fields are reordered to match the target schema + /// - Missing target fields are filled with null arrays + /// - Extra source fields are ignored + /// + /// For non-struct types, uses Arrow's standard positional casting. pub fn cast_to( &self, cast_type: &DataType, @@ -283,12 +293,8 @@ impl ColumnarValue { let cast_options = cast_options.cloned().unwrap_or(DEFAULT_CAST_OPTIONS); match self { ColumnarValue::Array(array) => { - ensure_date_array_timestamp_bounds(array, cast_type)?; - Ok(ColumnarValue::Array(kernels::cast::cast_with_options( - array, - cast_type, - &cast_options, - )?)) + let casted = cast_array_by_name(array, cast_type, &cast_options)?; + Ok(ColumnarValue::Array(casted)) } ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar( scalar.cast_to_with_options(cast_type, &cast_options)?, @@ -297,6 +303,37 @@ impl ColumnarValue { } } +fn cast_array_by_name( + array: &ArrayRef, + cast_type: &DataType, + cast_options: &CastOptions<'static>, +) -> Result { + // If types are already equal, no cast needed + if array.data_type() == cast_type { + return Ok(Arc::clone(array)); + } + + match cast_type { + DataType::Struct(_) => { + // Field name is unused; only the struct's inner field names matter + let target_field = Field::new("_", cast_type.clone(), true); + datafusion_common::nested_struct::cast_column( + array, + &target_field, + cast_options, + ) + } + _ => { + ensure_date_array_timestamp_bounds(array, cast_type)?; + Ok(kernels::cast::cast_with_options( + array, + cast_type, + cast_options, + )?) + } + } +} + fn ensure_date_array_timestamp_bounds( array: &ArrayRef, cast_type: &DataType, @@ -378,8 +415,8 @@ impl fmt::Display for ColumnarValue { mod tests { use super::*; use arrow::{ - array::{Date64Array, Int32Array}, - datatypes::TimeUnit, + array::{Date64Array, Int32Array, StructArray}, + datatypes::{Field, Fields, TimeUnit}, }; #[test] @@ -553,6 +590,102 @@ mod tests { ); } + #[test] + fn cast_struct_by_field_name() { + let source_fields = Fields::from(vec![ + Field::new("b", DataType::Int32, true), + Field::new("a", DataType::Int32, true), + ]); + + let target_fields = Fields::from(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ]); + + let struct_array = StructArray::new( + source_fields, + vec![ + Arc::new(Int32Array::from(vec![Some(3)])), + Arc::new(Int32Array::from(vec![Some(4)])), + ], + None, + ); + + let value = ColumnarValue::Array(Arc::new(struct_array)); + let casted = value + .cast_to(&DataType::Struct(target_fields.clone()), None) + .expect("struct cast should succeed"); + + let ColumnarValue::Array(arr) = casted else { + panic!("expected array after cast"); + }; + + let struct_array = arr + .as_any() + .downcast_ref::() + .expect("expected StructArray"); + + let field_a = struct_array + .column_by_name("a") + .expect("expected field a in cast result"); + let field_b = struct_array + .column_by_name("b") + .expect("expected field b in cast result"); + + assert_eq!( + field_a + .as_any() + .downcast_ref::() + .expect("expected Int32 array") + .value(0), + 4 + ); + assert_eq!( + field_b + .as_any() + .downcast_ref::() + .expect("expected Int32 array") + .value(0), + 3 + ); + } + + #[test] + fn cast_struct_missing_field_inserts_nulls() { + let source_fields = Fields::from(vec![Field::new("a", DataType::Int32, true)]); + + let target_fields = Fields::from(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ]); + + let struct_array = StructArray::new( + source_fields, + vec![Arc::new(Int32Array::from(vec![Some(5)]))], + None, + ); + + let value = ColumnarValue::Array(Arc::new(struct_array)); + let casted = value + .cast_to(&DataType::Struct(target_fields.clone()), None) + .expect("struct cast should succeed"); + + let ColumnarValue::Array(arr) = casted else { + panic!("expected array after cast"); + }; + + let struct_array = arr + .as_any() + .downcast_ref::() + .expect("expected StructArray"); + + let field_b = struct_array + .column_by_name("b") + .expect("expected missing field to be added"); + + assert!(field_b.is_null(0)); + } + #[test] fn cast_date64_array_to_timestamp_overflow() { let overflow_value = i64::MAX / 1_000_000 + 1; diff --git a/datafusion/expr-common/src/groups_accumulator.rs b/datafusion/expr-common/src/groups_accumulator.rs index 860e69245a7f..08c9f01f13c4 100644 --- a/datafusion/expr-common/src/groups_accumulator.rs +++ b/datafusion/expr-common/src/groups_accumulator.rs @@ -89,6 +89,9 @@ impl EmitTo { /// optional and is harder to implement than `Accumulator`, but can be much /// faster for queries with many group values. See the [Aggregating Millions of /// Groups Fast blog] for more background. +/// For more background, please also see the [Aggregating Millions of Groups Fast in Apache Arrow DataFusion 28.0.0 blog] +/// +/// [Aggregating Millions of Groups Fast in Apache Arrow DataFusion 28.0.0 blog]: https://datafusion.apache.org/blog/2023/08/05/datafusion_fast_grouping /// /// [`NullState`] can help keep the state for groups that have not seen any /// values and produce the correct output for those groups. diff --git a/datafusion/expr-common/src/lib.rs b/datafusion/expr-common/src/lib.rs index 2be066beaad2..c9a95fd29450 100644 --- a/datafusion/expr-common/src/lib.rs +++ b/datafusion/expr-common/src/lib.rs @@ -32,7 +32,6 @@ // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -#![deny(clippy::allow_attributes)] pub mod accumulator; pub mod casts; @@ -41,7 +40,10 @@ pub mod dyn_eq; pub mod groups_accumulator; pub mod interval_arithmetic; pub mod operator; +pub mod placement; pub mod signature; pub mod sort_properties; pub mod statistics; pub mod type_coercion; + +pub use placement::ExpressionPlacement; diff --git a/datafusion/expr-common/src/operator.rs b/datafusion/expr-common/src/operator.rs index 33512b0c354d..427069b326f9 100644 --- a/datafusion/expr-common/src/operator.rs +++ b/datafusion/expr-common/src/operator.rs @@ -140,6 +140,10 @@ pub enum Operator { /// /// Not implemented in DataFusion yet. QuestionPipe, + /// Colon operator, like `:` + /// + /// Not implemented in DataFusion yet. + Colon, } impl Operator { @@ -188,7 +192,8 @@ impl Operator { | Operator::AtQuestion | Operator::Question | Operator::QuestionAnd - | Operator::QuestionPipe => None, + | Operator::QuestionPipe + | Operator::Colon => None, } } @@ -283,7 +288,8 @@ impl Operator { | Operator::AtQuestion | Operator::Question | Operator::QuestionAnd - | Operator::QuestionPipe => None, + | Operator::QuestionPipe + | Operator::Colon => None, } } @@ -323,7 +329,8 @@ impl Operator { | Operator::AtQuestion | Operator::Question | Operator::QuestionAnd - | Operator::QuestionPipe => 30, + | Operator::QuestionPipe + | Operator::Colon => 30, Operator::Plus | Operator::Minus => 40, Operator::Multiply | Operator::Divide | Operator::Modulo => 45, } @@ -369,7 +376,8 @@ impl Operator { | Operator::AtQuestion | Operator::Question | Operator::QuestionAnd - | Operator::QuestionPipe => true, + | Operator::QuestionPipe + | Operator::Colon => true, // E.g. `TRUE OR NULL` is `TRUE` Operator::Or @@ -429,6 +437,7 @@ impl fmt::Display for Operator { Operator::Question => "?", Operator::QuestionAnd => "?&", Operator::QuestionPipe => "?|", + Operator::Colon => ":", }; write!(f, "{display}") } diff --git a/datafusion/expr-common/src/placement.rs b/datafusion/expr-common/src/placement.rs new file mode 100644 index 000000000000..8212ba618e32 --- /dev/null +++ b/datafusion/expr-common/src/placement.rs @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Expression placement information for optimization decisions. + +/// Describes where an expression should be placed in the query plan for +/// optimal execution. This is used by optimizers to make decisions about +/// expression placement, such as whether to push expressions down through +/// projections. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ExpressionPlacement { + /// A constant literal value. + Literal, + /// A simple column reference. + Column, + /// A cheap expression that can be pushed to leaf nodes in the plan. + /// Examples include `get_field` for struct field access. + /// Pushing these expressions down in the plan can reduce data early + /// at low compute cost. + /// See [`ExpressionPlacement::should_push_to_leaves`] for details. + MoveTowardsLeafNodes, + /// An expensive expression that should stay where it is in the plan. + /// Examples include complex scalar functions or UDFs. + KeepInPlace, +} + +impl ExpressionPlacement { + /// Returns true if the expression can be pushed down to leaf nodes + /// in the query plan. + /// + /// This returns true for: + /// - [`ExpressionPlacement::Column`]: Simple column references can be pushed down. They do no compute and do not increase or + /// decrease the amount of data being processed. + /// A projection that reduces the number of columns can eliminate unnecessary data early, + /// but this method only considers one expression at a time, not a projection as a whole. + /// - [`ExpressionPlacement::MoveTowardsLeafNodes`]: Cheap expressions can be pushed down to leaves to take advantage of + /// early computation and potential optimizations at the data source level. + /// For example `struct_col['field']` is cheap to compute (just an Arc clone of the nested array for `'field'`) + /// and thus can reduce data early in the plan at very low compute cost. + /// It may even be possible to eliminate the expression entirely if the data source can project only the needed field + /// (as e.g. Parquet can). + pub fn should_push_to_leaves(&self) -> bool { + matches!( + self, + ExpressionPlacement::Column | ExpressionPlacement::MoveTowardsLeafNodes + ) + } +} diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 54bb84f03d3d..857e9dc5d42d 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -1416,7 +1416,7 @@ impl Signature { Arity::Variable => { // For UserDefined signatures, allow parameter names // The function implementer is responsible for validating the names match the actual arguments - if !matches!(self.type_signature, TypeSignature::UserDefined) { + if self.type_signature != TypeSignature::UserDefined { return plan_err!( "Cannot specify parameter names for variable arity signature: {:?}", self.type_signature @@ -1585,6 +1585,7 @@ mod tests { vec![DataType::UInt16, DataType::UInt16], vec![DataType::UInt32, DataType::UInt32], vec![DataType::UInt64, DataType::UInt64], + vec![DataType::Float16, DataType::Float16], vec![DataType::Float32, DataType::Float32], vec![DataType::Float64, DataType::Float64] ] diff --git a/datafusion/expr-common/src/type_coercion/aggregates.rs b/datafusion/expr-common/src/type_coercion/aggregates.rs index 01d093950d47..ab4d086e4ca5 100644 --- a/datafusion/expr-common/src/type_coercion/aggregates.rs +++ b/datafusion/expr-common/src/type_coercion/aggregates.rs @@ -42,6 +42,7 @@ pub static NUMERICS: &[DataType] = &[ DataType::UInt16, DataType::UInt32, DataType::UInt64, + DataType::Float16, DataType::Float32, DataType::Float64, ]; diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index de16e9e01073..e696545ea6ca 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -17,6 +17,7 @@ //! Coercion rules for matching argument types for binary operators +use std::collections::HashMap; use std::collections::HashSet; use std::sync::Arc; @@ -260,8 +261,16 @@ impl<'a> BinaryTypeCoercer<'a> { ) }) } + Minus if is_date_minus_date(lhs, rhs) => { + return Ok(Signature { + lhs: lhs.clone(), + rhs: rhs.clone(), + ret: Int64, + }); + } Plus | Minus | Multiply | Divide | Modulo => { if let Ok(ret) = self.get_result(lhs, rhs) { + // Temporal arithmetic, e.g. Date32 + Interval Ok(Signature{ lhs: lhs.clone(), @@ -281,6 +290,7 @@ impl<'a> BinaryTypeCoercer<'a> { ret, }) } else if let Some(coerced) = temporal_coercion_strict_timezone(lhs, rhs) { + // Temporal arithmetic by first coercing to a common time representation // e.g. Date32 - Timestamp let ret = self.get_result(&coerced, &coerced).map_err(|e| { @@ -314,6 +324,9 @@ impl<'a> BinaryTypeCoercer<'a> { ) } }, + Colon => { + Ok(Signature { lhs: lhs.clone(), rhs: rhs.clone(), ret: lhs.clone() }) + }, IntegerDivide | Arrow | LongArrow | HashArrow | HashLongArrow | HashMinus | AtQuestion | Question | QuestionAnd | QuestionPipe => { not_impl_err!("Operator {} is not yet supported", self.op) @@ -341,13 +354,12 @@ impl<'a> BinaryTypeCoercer<'a> { // TODO Move the rest inside of BinaryTypeCoercer -fn is_decimal(data_type: &DataType) -> bool { +/// Returns true if both operands are Date types (Date32 or Date64) +/// Used to detect Date - Date operations which should return Int64 (days difference) +fn is_date_minus_date(lhs: &DataType, rhs: &DataType) -> bool { matches!( - data_type, - DataType::Decimal32(..) - | DataType::Decimal64(..) - | DataType::Decimal128(..) - | DataType::Decimal256(..) + (lhs, rhs), + (DataType::Date32, DataType::Date32) | (DataType::Date64, DataType::Date64) ) } @@ -383,8 +395,8 @@ fn math_decimal_coercion( } // Cross-variant decimal coercion - choose larger variant with appropriate precision/scale (lhs, rhs) - if is_decimal(lhs) - && is_decimal(rhs) + if lhs.is_decimal() + && rhs.is_decimal() && std::mem::discriminant(lhs) != std::mem::discriminant(rhs) => { let coerced_type = get_wider_decimal_type_cross_variant(lhs_type, rhs_type)?; @@ -461,7 +473,9 @@ fn bitwise_coercion(left_type: &DataType, right_type: &DataType) -> Option for TypeCategory { return TypeCategory::Numeric; } - if matches!(data_type, DataType::Boolean) { + if *data_type == DataType::Boolean { return TypeCategory::Boolean; } @@ -999,8 +1013,8 @@ pub fn decimal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { @@ -1008,8 +1022,8 @@ pub fn decimal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { @@ -1218,30 +1232,123 @@ fn coerce_numeric_type_to_decimal256(numeric_type: &DataType) -> Option Option { use arrow::datatypes::DataType::*; + match (lhs_type, rhs_type) { (Struct(lhs_fields), Struct(rhs_fields)) => { + // Field count must match for coercion if lhs_fields.len() != rhs_fields.len() { return None; } - let coerced_types = std::iter::zip(lhs_fields.iter(), rhs_fields.iter()) - .map(|(lhs, rhs)| comparison_coercion(lhs.data_type(), rhs.data_type())) - .collect::>>()?; - - // preserve the field name and nullability - let orig_fields = std::iter::zip(lhs_fields.iter(), rhs_fields.iter()); + // If the two structs have exactly the same set of field names (possibly in + // different order), prefer name-based coercion. Otherwise fall back to + // positional coercion which preserves backward compatibility. + // + // Name-based coercion is used in: + // 1. Array construction: [s1, s2] where s1 and s2 have reordered fields + // 2. UNION operations: different field orders unified by name + // 3. VALUES clauses: heterogeneous struct rows unified by field name + // 4. JOIN conditions: structs with matching field names + // 5. Window functions: partitions/orders by struct fields + // 6. Aggregate functions: collecting structs with reordered fields + // + // See docs/source/user-guide/sql/struct_coercion.md for detailed examples. + if fields_have_same_names(lhs_fields, rhs_fields) { + return coerce_struct_by_name(lhs_fields, rhs_fields); + } - let fields: Vec = coerced_types - .into_iter() - .zip(orig_fields) - .map(|(datatype, (lhs, rhs))| coerce_fields(datatype, lhs, rhs)) - .collect(); - Some(Struct(fields.into())) + coerce_struct_by_position(lhs_fields, rhs_fields) } _ => None, } } +/// Return true if every left-field name exists in the right fields (and lengths are equal). +/// +/// # Assumptions +/// **This function assumes field names within each struct are unique.** This assumption is safe +/// because field name uniqueness is enforced at multiple levels: +/// - **Arrow level:** `StructType` construction enforces unique field names at the schema level +/// - **DataFusion level:** SQL parser rejects duplicate field names in `CREATE TABLE` and struct type definitions +/// - **Runtime level:** `StructArray::try_new()` validates field uniqueness +/// +/// Therefore, we don't need to handle degenerate cases like: +/// - `struct -> struct` (target has duplicate field names) +/// - `struct -> struct` (source has duplicate field names) +fn fields_have_same_names(lhs_fields: &Fields, rhs_fields: &Fields) -> bool { + // Debug assertions: field names should be unique within each struct + #[cfg(debug_assertions)] + { + let lhs_names: HashSet<_> = lhs_fields.iter().map(|f| f.name()).collect(); + assert_eq!( + lhs_names.len(), + lhs_fields.len(), + "Struct has duplicate field names (should be caught by Arrow schema validation)" + ); + + let rhs_names_check: HashSet<_> = rhs_fields.iter().map(|f| f.name()).collect(); + assert_eq!( + rhs_names_check.len(), + rhs_fields.len(), + "Struct has duplicate field names (should be caught by Arrow schema validation)" + ); + } + + let rhs_names: HashSet<&str> = rhs_fields.iter().map(|f| f.name().as_str()).collect(); + lhs_fields + .iter() + .all(|lf| rhs_names.contains(lf.name().as_str())) +} + +/// Coerce two structs by matching fields by name. Assumes the name-sets match. +fn coerce_struct_by_name(lhs_fields: &Fields, rhs_fields: &Fields) -> Option { + use arrow::datatypes::DataType::*; + + let rhs_by_name: HashMap<&str, &FieldRef> = + rhs_fields.iter().map(|f| (f.name().as_str(), f)).collect(); + + let mut coerced: Vec = Vec::with_capacity(lhs_fields.len()); + + for lhs in lhs_fields.iter() { + let rhs = rhs_by_name.get(lhs.name().as_str()).unwrap(); // safe: caller ensured names match + let coerced_type = comparison_coercion(lhs.data_type(), rhs.data_type())?; + let is_nullable = lhs.is_nullable() || rhs.is_nullable(); + coerced.push(Arc::new(Field::new( + lhs.name().clone(), + coerced_type, + is_nullable, + ))); + } + + Some(Struct(coerced.into())) +} + +/// Coerce two structs positionally (left-to-right). This preserves field names from +/// the left struct and uses the combined nullability. +fn coerce_struct_by_position( + lhs_fields: &Fields, + rhs_fields: &Fields, +) -> Option { + use arrow::datatypes::DataType::*; + + // First coerce individual types; fail early if any pair cannot be coerced. + let coerced_types: Vec = lhs_fields + .iter() + .zip(rhs_fields.iter()) + .map(|(l, r)| comparison_coercion(l.data_type(), r.data_type())) + .collect::>>()?; + + // Build final fields preserving left-side names and combined nullability. + let orig_pairs = lhs_fields.iter().zip(rhs_fields.iter()); + let fields: Vec = coerced_types + .into_iter() + .zip(orig_pairs) + .map(|(datatype, (lhs, rhs))| coerce_fields(datatype, lhs, rhs)) + .collect(); + + Some(Struct(fields.into())) +} + /// returns the result of coercing two fields to a common type fn coerce_fields(common_type: DataType, lhs: &FieldRef, rhs: &FieldRef) -> FieldRef { let is_nullable = lhs.is_nullable() || rhs.is_nullable(); @@ -1691,9 +1798,10 @@ fn binary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option /// Coercion rules for like operations. /// This is a union of string coercion rules, dictionary coercion rules, and REE coercion rules +/// Note: list_coercion is intentionally NOT included here because LIKE is a string pattern +/// matching operation and is not supported for nested types (List, Struct, etc.) pub fn like_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { string_coercion(lhs_type, rhs_type) - .or_else(|| list_coercion(lhs_type, rhs_type)) .or_else(|| binary_to_string_coercion(lhs_type, rhs_type)) .or_else(|| dictionary_comparison_coercion(lhs_type, rhs_type, false)) .or_else(|| ree_comparison_coercion(lhs_type, rhs_type, false)) diff --git a/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs b/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs index bb9d44953b9f..eb5622fedb8a 100644 --- a/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs +++ b/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs @@ -228,6 +228,53 @@ fn test_type_coercion_arithmetic() -> Result<()> { Ok(()) } +#[test] +fn test_bitwise_coercion_non_integer_types() -> Result<()> { + let err = BinaryTypeCoercer::new( + &DataType::Float32, + &Operator::BitwiseAnd, + &DataType::Float32, + ) + .get_input_types() + .unwrap_err() + .to_string(); + assert_contains!( + &err, + "Cannot infer common type for bitwise operation Float32 & Float32" + ); + + let err = BinaryTypeCoercer::new( + &DataType::Float32, + &Operator::BitwiseAnd, + &DataType::Float64, + ) + .get_input_types() + .unwrap_err() + .to_string(); + assert_contains!( + &err, + "Cannot infer common type for bitwise operation Float32 & Float64" + ); + + let err = BinaryTypeCoercer::new( + &DataType::Decimal128(10, 2), + &Operator::BitwiseAnd, + &DataType::Decimal128(10, 2), + ) + .get_input_types() + .unwrap_err() + .to_string(); + assert_contains!( + &err, + "Cannot infer common type for bitwise operation Decimal128(10, 2) & Decimal128(10, 2)" + ); + + let dict_int8 = DataType::Dictionary(DataType::Int8.into(), DataType::Int8.into()); + test_coercion_binary_rule!(dict_int8, dict_int8, Operator::BitwiseAnd, dict_int8); + + Ok(()) +} + fn test_math_decimal_coercion_rule( lhs_type: DataType, rhs_type: DataType, diff --git a/datafusion/expr/src/execution_props.rs b/datafusion/expr/src/execution_props.rs index 20d8f82bf48a..3bf6978eb60e 100644 --- a/datafusion/expr/src/execution_props.rs +++ b/datafusion/expr/src/execution_props.rs @@ -16,7 +16,7 @@ // under the License. use crate::var_provider::{VarProvider, VarType}; -use chrono::{DateTime, TimeZone, Utc}; +use chrono::{DateTime, Utc}; use datafusion_common::HashMap; use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; @@ -33,7 +33,9 @@ use std::sync::Arc; /// done so during predicate pruning and expression simplification #[derive(Clone, Debug)] pub struct ExecutionProps { - pub query_execution_start_time: DateTime, + /// The time at which the query execution started. If `None`, + /// functions like `now()` will not be simplified during optimization. + pub query_execution_start_time: Option>, /// Alias generator used by subquery optimizer rules pub alias_generator: Arc, /// Snapshot of config options when the query started @@ -52,9 +54,7 @@ impl ExecutionProps { /// Creates a new execution props pub fn new() -> Self { ExecutionProps { - // Set this to a fixed sentinel to make it obvious if this is - // not being updated / propagated correctly - query_execution_start_time: Utc.timestamp_nanos(0), + query_execution_start_time: None, alias_generator: Arc::new(AliasGenerator::new()), config_options: None, var_providers: None, @@ -66,7 +66,7 @@ impl ExecutionProps { mut self, query_execution_start_time: DateTime, ) -> Self { - self.query_execution_start_time = query_execution_start_time; + self.query_execution_start_time = Some(query_execution_start_time); self } @@ -79,7 +79,7 @@ impl ExecutionProps { /// Marks the execution of query started timestamp. /// This also instantiates a new alias generator. pub fn mark_start_execution(&mut self, config_options: Arc) -> &Self { - self.query_execution_start_time = Utc::now(); + self.query_execution_start_time = Some(Utc::now()); self.alias_generator = Arc::new(AliasGenerator::new()); self.config_options = Some(config_options); &*self @@ -126,7 +126,7 @@ mod test { fn debug() { let props = ExecutionProps::new(); assert_eq!( - "ExecutionProps { query_execution_start_time: 1970-01-01T00:00:00Z, alias_generator: AliasGenerator { next_id: 1 }, config_options: None, var_providers: None }", + "ExecutionProps { query_execution_start_time: None, alias_generator: AliasGenerator { next_id: 1 }, config_options: None, var_providers: None }", format!("{props:?}") ); } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index c7d825ce1d52..87e8e029a6ee 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -38,11 +38,12 @@ use datafusion_common::tree_node::{ use datafusion_common::{ Column, DFSchema, HashMap, Result, ScalarValue, Spans, TableReference, }; +use datafusion_expr_common::placement::ExpressionPlacement; use datafusion_functions_window_common::field::WindowUDFFieldArgs; #[cfg(feature = "sql")] use sqlparser::ast::{ ExceptSelectItem, ExcludeSelectItem, IlikeSelectItem, RenameSelectItem, - ReplaceSelectElement, display_comma_separated, + ReplaceSelectElement, }; // Moved in 51.0.0 to datafusion_common @@ -309,6 +310,7 @@ impl From for NullTreatment { /// assert!(rewritten.transformed); /// // to 42 = 5 AND b = 6 /// assert_eq!(rewritten.data, lit(42).eq(lit(5)).and(col("b").eq(lit(6)))); +/// ``` #[derive(Clone, PartialEq, PartialOrd, Eq, Debug, Hash)] pub enum Expr { /// An expression with a specific name. @@ -372,6 +374,8 @@ pub enum Expr { Exists(Exists), /// IN subquery InSubquery(InSubquery), + /// Set comparison subquery (e.g. `= ANY`, `> ALL`) + SetComparison(SetComparison), /// Scalar subquery ScalarSubquery(Subquery), /// Represents a reference to all available fields in a specific schema, @@ -953,7 +957,7 @@ impl AggregateFunction { pub enum WindowFunctionDefinition { /// A user defined aggregate function AggregateUDF(Arc), - /// A user defined aggregate function + /// A user defined window function WindowUDF(Arc), } @@ -1101,6 +1105,54 @@ impl Exists { } } +/// Whether the set comparison uses `ANY`/`SOME` or `ALL` +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub enum SetQuantifier { + /// `ANY` (or `SOME`) + Any, + /// `ALL` + All, +} + +impl Display for SetQuantifier { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + SetQuantifier::Any => write!(f, "ANY"), + SetQuantifier::All => write!(f, "ALL"), + } + } +} + +/// Set comparison subquery (e.g. `= ANY`, `> ALL`) +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub struct SetComparison { + /// The expression to compare + pub expr: Box, + /// Subquery that will produce a single column of data to compare against + pub subquery: Subquery, + /// Comparison operator (e.g. `=`, `>`, `<`) + pub op: Operator, + /// Quantifier (`ANY`/`ALL`) + pub quantifier: SetQuantifier, +} + +impl SetComparison { + /// Create a new set comparison expression + pub fn new( + expr: Box, + subquery: Subquery, + op: Operator, + quantifier: SetQuantifier, + ) -> Self { + Self { + expr, + subquery, + op, + quantifier, + } + } +} + /// InList expression #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct InList { @@ -1268,7 +1320,6 @@ impl Display for ExceptSelectItem { } } -#[cfg(not(feature = "sql"))] pub fn display_comma_separated(slice: &[T]) -> String where T: Display, @@ -1487,6 +1538,24 @@ impl Expr { } } + /// Returns placement information for this expression. + /// + /// This is used by optimizers to make decisions about expression placement, + /// such as whether to push expressions down through projections. + pub fn placement(&self) -> ExpressionPlacement { + match self { + Expr::Column(_) => ExpressionPlacement::Column, + Expr::Literal(_, _) => ExpressionPlacement::Literal, + Expr::Alias(inner) => inner.expr.placement(), + Expr::ScalarFunction(func) => { + let arg_placements: Vec<_> = + func.args.iter().map(|arg| arg.placement()).collect(); + func.func.placement(&arg_placements) + } + _ => ExpressionPlacement::KeepInPlace, + } + } + /// Return String representation of the variant represented by `self` /// Useful for non-rust based bindings pub fn variant_name(&self) -> &str { @@ -1503,6 +1572,7 @@ impl Expr { Expr::GroupingSet(..) => "GroupingSet", Expr::InList { .. } => "InList", Expr::InSubquery(..) => "InSubquery", + Expr::SetComparison(..) => "SetComparison", Expr::IsNotNull(..) => "IsNotNull", Expr::IsNull(..) => "IsNull", Expr::Like { .. } => "Like", @@ -2058,6 +2128,7 @@ impl Expr { | Expr::GroupingSet(..) | Expr::InList(..) | Expr::InSubquery(..) + | Expr::SetComparison(..) | Expr::IsFalse(..) | Expr::IsNotFalse(..) | Expr::IsNotNull(..) @@ -2651,6 +2722,16 @@ impl HashNode for Expr { subquery.hash(state); negated.hash(state); } + Expr::SetComparison(SetComparison { + expr: _, + subquery, + op, + quantifier, + }) => { + subquery.hash(state); + op.hash(state); + quantifier.hash(state); + } Expr::ScalarSubquery(subquery) => { subquery.hash(state); } @@ -2841,6 +2922,12 @@ impl Display for SchemaDisplay<'_> { write!(f, "NOT IN") } Expr::InSubquery(InSubquery { negated: false, .. }) => write!(f, "IN"), + Expr::SetComparison(SetComparison { + expr, + op, + quantifier, + .. + }) => write!(f, "{} {op} {quantifier}", SchemaDisplay(expr.as_ref())), Expr::IsTrue(expr) => write!(f, "{} IS TRUE", SchemaDisplay(expr)), Expr::IsFalse(expr) => write!(f, "{} IS FALSE", SchemaDisplay(expr)), Expr::IsNotTrue(expr) => { @@ -3316,6 +3403,12 @@ impl Display for Expr { subquery, negated: false, }) => write!(f, "{expr} IN ({subquery:?})"), + Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + }) => write!(f, "{expr} {op} {quantifier} ({subquery:?})"), Expr::ScalarSubquery(subquery) => write!(f, "({subquery:?})"), Expr::BinaryExpr(expr) => write!(f, "{expr}"), Expr::ScalarFunction(fun) => { @@ -3799,6 +3892,7 @@ mod test { } use super::*; + use crate::logical_plan::{EmptyRelation, LogicalPlan}; #[test] fn test_display_wildcard() { @@ -3889,6 +3983,28 @@ mod test { ) } + #[test] + fn test_display_set_comparison() { + let subquery = Subquery { + subquery: Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::new(DFSchema::empty()), + })), + outer_ref_columns: vec![], + spans: Spans::new(), + }; + + let expr = Expr::SetComparison(SetComparison::new( + Box::new(Expr::Column(Column::from_name("a"))), + subquery, + Operator::Gt, + SetQuantifier::Any, + )); + + assert_eq!(format!("{expr}"), "a > ANY ()"); + assert_eq!(format!("{}", expr.human_display()), "a > ANY ()"); + } + #[test] fn test_schema_display_alias_with_relation() { assert_eq!( diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index a0faca76e91e..32a88ab8cf31 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -261,9 +261,16 @@ fn coerce_exprs_for_schema( #[expect(deprecated)] Expr::Wildcard { .. } => Ok(expr), _ => { - // maintain the original name when casting - let name = dst_schema.field(idx).name(); - Ok(expr.cast_to(new_type, src_schema)?.alias(name)) + match expr { + // maintain the original name when casting a column, to avoid the + // tablename being added to it when not explicitly set by the query + // (see: https://github.com/apache/datafusion/issues/18818) + Expr::Column(ref column) => { + let name = column.name().to_owned(); + Ok(expr.cast_to(new_type, src_schema)?.alias(name)) + } + _ => Ok(expr.cast_to(new_type, src_schema)?), + } } } } else { diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index dbba0f2914a6..f4e4f014f533 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -21,11 +21,11 @@ use crate::expr::{ InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, WindowFunctionParams, }; -use crate::type_coercion::functions::fields_with_udf; +use crate::type_coercion::functions::{UDFCoercionExt, fields_with_udf}; use crate::udf::ReturnFieldArgs; use crate::{LogicalPlan, Projection, Subquery, WindowFunctionDefinition, utils}; use arrow::compute::can_cast_types; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::datatype::FieldExt; use datafusion_common::metadata::FieldMetadata; use datafusion_common::{ @@ -152,48 +152,16 @@ impl ExprSchemable for Expr { } } } - Expr::ScalarFunction(_func) => { - let return_type = self.to_field(schema)?.1.data_type().clone(); - Ok(return_type) - } - Expr::WindowFunction(window_function) => self - .data_type_and_nullable_with_window_function(schema, window_function) - .map(|(return_type, _)| return_type), - Expr::AggregateFunction(AggregateFunction { - func, - params: AggregateFunctionParams { args, .. }, - }) => { - let fields = args - .iter() - .map(|e| e.to_field(schema).map(|(_, f)| f)) - .collect::>>()?; - let new_fields = fields_with_udf(&fields, func.as_ref()) - .map_err(|err| { - let data_types = fields - .iter() - .map(|f| f.data_type().clone()) - .collect::>(); - plan_datafusion_err!( - "{} {}", - match err { - DataFusionError::Plan(msg) => msg, - err => err.to_string(), - }, - utils::generate_signature_error_msg( - func.name(), - func.signature().clone(), - &data_types - ) - ) - })? - .into_iter() - .collect::>(); - Ok(func.return_field(&new_fields)?.data_type().clone()) + Expr::ScalarFunction(_) + | Expr::WindowFunction(_) + | Expr::AggregateFunction(_) => { + Ok(self.to_field(schema)?.1.data_type().clone()) } Expr::Not(_) | Expr::IsNull(_) | Expr::Exists { .. } | Expr::InSubquery(_) + | Expr::SetComparison(_) | Expr::Between { .. } | Expr::InList { .. } | Expr::IsNotNull(_) @@ -348,21 +316,9 @@ impl ExprSchemable for Expr { } } Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema), - Expr::ScalarFunction(_func) => { - let field = self.to_field(input_schema)?.1; - - let nullable = field.is_nullable(); - Ok(nullable) - } - Expr::AggregateFunction(AggregateFunction { func, .. }) => { - Ok(func.is_nullable()) - } - Expr::WindowFunction(window_function) => self - .data_type_and_nullable_with_window_function( - input_schema, - window_function, - ) - .map(|(_, nullable)| nullable), + Expr::ScalarFunction(_) + | Expr::AggregateFunction(_) + | Expr::WindowFunction(_) => Ok(self.to_field(input_schema)?.1.is_nullable()), Expr::ScalarVariable(field, _) => Ok(field.is_nullable()), Expr::TryCast { .. } | Expr::Unnest(_) | Expr::Placeholder(_) => Ok(true), Expr::IsNull(_) @@ -374,6 +330,7 @@ impl ExprSchemable for Expr { | Expr::IsNotFalse(_) | Expr::IsNotUnknown(_) | Expr::Exists { .. } => Ok(false), + Expr::SetComparison(_) => Ok(true), Expr::InSubquery(InSubquery { expr, .. }) => expr.nullable(input_schema), Expr::ScalarSubquery(subquery) => { Ok(subquery.subquery.schema().field(0).is_nullable()) @@ -458,7 +415,7 @@ impl ExprSchemable for Expr { /// with the default implementation returning empty field metadata /// - **Aggregate functions**: Generate metadata via function's [`return_field`] method, /// with the default implementation returning empty field metadata - /// - **Window functions**: field metadata is empty + /// - **Window functions**: field metadata follows the function's return field /// /// ## Table Reference Scoping /// - Establishes proper qualified field references when columns belong to specific tables @@ -534,73 +491,49 @@ impl ExprSchemable for Expr { ))) } Expr::WindowFunction(window_function) => { - let (dt, nullable) = self.data_type_and_nullable_with_window_function( - schema, - window_function, - )?; - Ok(Arc::new(Field::new(&schema_name, dt, nullable))) - } - Expr::AggregateFunction(aggregate_function) => { - let AggregateFunction { - func, - params: AggregateFunctionParams { args, .. }, + let WindowFunction { + fun, + params: WindowFunctionParams { args, .. }, .. - } = aggregate_function; + } = window_function.as_ref(); let fields = args .iter() .map(|e| e.to_field(schema).map(|(_, f)| f)) .collect::>>()?; - // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` - let new_fields = fields_with_udf(&fields, func.as_ref()) - .map_err(|err| { - let arg_types = fields - .iter() - .map(|f| f.data_type()) - .cloned() - .collect::>(); - plan_datafusion_err!( - "{} {}", - match err { - DataFusionError::Plan(msg) => msg, - err => err.to_string(), - }, - utils::generate_signature_error_msg( - func.name(), - func.signature().clone(), - &arg_types, - ) - ) - })? - .into_iter() - .collect::>(); - + match fun { + WindowFunctionDefinition::AggregateUDF(udaf) => { + let new_fields = + verify_function_arguments(udaf.as_ref(), &fields)?; + let return_field = udaf.return_field(&new_fields)?; + Ok(return_field) + } + WindowFunctionDefinition::WindowUDF(udwf) => { + let new_fields = + verify_function_arguments(udwf.as_ref(), &fields)?; + let return_field = udwf + .field(WindowUDFFieldArgs::new(&new_fields, &schema_name))?; + Ok(return_field) + } + } + } + Expr::AggregateFunction(AggregateFunction { + func, + params: AggregateFunctionParams { args, .. }, + }) => { + let fields = args + .iter() + .map(|e| e.to_field(schema).map(|(_, f)| f)) + .collect::>>()?; + let new_fields = verify_function_arguments(func.as_ref(), &fields)?; func.return_field(&new_fields) } Expr::ScalarFunction(ScalarFunction { func, args }) => { - let (arg_types, fields): (Vec, Vec>) = args + let fields = args .iter() .map(|e| e.to_field(schema).map(|(_, f)| f)) - .collect::>>()? - .into_iter() - .map(|f| (f.data_type().clone(), f)) - .unzip(); - // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` - let new_fields = - fields_with_udf(&fields, func.as_ref()).map_err(|err| { - plan_datafusion_err!( - "{} {}", - match err { - DataFusionError::Plan(msg) => msg, - err => err.to_string(), - }, - utils::generate_signature_error_msg( - func.name(), - func.signature().clone(), - &arg_types, - ) - ) - })?; + .collect::>>()?; + let new_fields = verify_function_arguments(func.as_ref(), &fields)?; let arguments = args .iter() @@ -632,6 +565,7 @@ impl ExprSchemable for Expr { | Expr::TryCast(_) | Expr::InList(_) | Expr::InSubquery(_) + | Expr::SetComparison(_) | Expr::Wildcard { .. } | Expr::GroupingSet(_) | Expr::Placeholder(_) @@ -665,7 +599,16 @@ impl ExprSchemable for Expr { // like all of the binary expressions below. Perhaps Expr should track the // type of the expression? - if can_cast_types(&this_type, cast_to_type) { + // Special handling for struct-to-struct casts with name-based field matching + let can_cast = match (&this_type, cast_to_type) { + (DataType::Struct(_), DataType::Struct(_)) => { + // Always allow struct-to-struct casts; field matching happens at runtime + true + } + _ => can_cast_types(&this_type, cast_to_type), + }; + + if can_cast { match self { Expr::ScalarSubquery(subquery) => { Ok(Expr::ScalarSubquery(cast_subquery(subquery, cast_to_type)?)) @@ -678,6 +621,33 @@ impl ExprSchemable for Expr { } } +/// Verify that function is invoked with correct number and type of arguments as +/// defined in `TypeSignature`. +fn verify_function_arguments( + function: &F, + input_fields: &[FieldRef], +) -> Result> { + fields_with_udf(input_fields, function).map_err(|err| { + let data_types = input_fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); + plan_datafusion_err!( + "{} {}", + match err { + DataFusionError::Plan(msg) => msg, + err => err.to_string(), + }, + utils::generate_signature_error_message( + function.name(), + function.signature(), + &data_types + ) + ) + }) +} + /// Returns the innermost [Expr] that is provably null if `expr` is null. fn unwrap_certainly_null_expr(expr: &Expr) -> &Expr { match expr { @@ -688,93 +658,6 @@ fn unwrap_certainly_null_expr(expr: &Expr) -> &Expr { } } -impl Expr { - /// Common method for window functions that applies type coercion - /// to all arguments of the window function to check if it matches - /// its signature. - /// - /// If successful, this method returns the data type and - /// nullability of the window function's result. - /// - /// Otherwise, returns an error if there's a type mismatch between - /// the window function's signature and the provided arguments. - fn data_type_and_nullable_with_window_function( - &self, - schema: &dyn ExprSchema, - window_function: &WindowFunction, - ) -> Result<(DataType, bool)> { - let WindowFunction { - fun, - params: WindowFunctionParams { args, .. }, - .. - } = window_function; - - let fields = args - .iter() - .map(|e| e.to_field(schema).map(|(_, f)| f)) - .collect::>>()?; - match fun { - WindowFunctionDefinition::AggregateUDF(udaf) => { - let data_types = fields - .iter() - .map(|f| f.data_type()) - .cloned() - .collect::>(); - let new_fields = fields_with_udf(&fields, udaf.as_ref()) - .map_err(|err| { - plan_datafusion_err!( - "{} {}", - match err { - DataFusionError::Plan(msg) => msg, - err => err.to_string(), - }, - utils::generate_signature_error_msg( - fun.name(), - fun.signature(), - &data_types - ) - ) - })? - .into_iter() - .collect::>(); - - let return_field = udaf.return_field(&new_fields)?; - - Ok((return_field.data_type().clone(), return_field.is_nullable())) - } - WindowFunctionDefinition::WindowUDF(udwf) => { - let data_types = fields - .iter() - .map(|f| f.data_type()) - .cloned() - .collect::>(); - let new_fields = fields_with_udf(&fields, udwf.as_ref()) - .map_err(|err| { - plan_datafusion_err!( - "{} {}", - match err { - DataFusionError::Plan(msg) => msg, - err => err.to_string(), - }, - utils::generate_signature_error_msg( - fun.name(), - fun.signature(), - &data_types - ) - ) - })? - .into_iter() - .collect::>(); - let (_, function_name) = self.qualified_name(); - let field_args = WindowUDFFieldArgs::new(&new_fields, &function_name); - - udwf.field(field_args) - .map(|field| (field.data_type().clone(), field.is_nullable())) - } - } - } -} - /// Cast subquery in InSubquery/ScalarSubquery to a given type. /// /// 1. **Projection plan**: If the subquery is a projection (i.e. a SELECT statement with specific diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index e0235d32292f..68d2c9073241 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -67,25 +67,25 @@ pub type StateTypeFunction = /// [crate::udaf::AggregateUDFImpl::simplify] simplifier closure /// A closure with two arguments: /// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked -/// * 'info': [crate::simplify::SimplifyInfo] +/// * 'info': [crate::simplify::SimplifyContext] /// /// Closure returns simplified [Expr] or an error. pub type AggregateFunctionSimplification = Box< dyn Fn( crate::expr::AggregateFunction, - &dyn crate::simplify::SimplifyInfo, + &crate::simplify::SimplifyContext, ) -> Result, >; /// [crate::udwf::WindowUDFImpl::simplify] simplifier closure /// A closure with two arguments: /// * 'window_function': [crate::expr::WindowFunction] for which simplified has been invoked -/// * 'info': [crate::simplify::SimplifyInfo] +/// * 'info': [crate::simplify::SimplifyContext] /// /// Closure returns simplified [Expr] or an error. pub type WindowFunctionSimplification = Box< dyn Fn( crate::expr::WindowFunction, - &dyn crate::simplify::SimplifyInfo, + &crate::simplify::SimplifyContext, ) -> Result, >; diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 4fb78933d7a5..cb136229bf88 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -24,7 +24,6 @@ // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -#![deny(clippy::allow_attributes)] //! [DataFusion](https://github.com/apache/datafusion) //! is an extensible query execution framework that uses @@ -77,6 +76,7 @@ pub mod statistics { pub use datafusion_expr_common::statistics::*; } mod predicate_bounds; +pub mod preimage; pub mod ptr_eq; pub mod test; pub mod tree_node; @@ -95,6 +95,7 @@ pub use datafusion_expr_common::accumulator::Accumulator; pub use datafusion_expr_common::columnar_value::ColumnarValue; pub use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; pub use datafusion_expr_common::operator::Operator; +pub use datafusion_expr_common::placement::ExpressionPlacement; pub use datafusion_expr_common::signature::{ ArrayFunctionArgument, ArrayFunctionSignature, Coercion, Signature, TIMEZONE_WILDCARD, TypeSignature, TypeSignatureClass, Volatility, diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 6f654428e41a..2e23fef1da76 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1011,6 +1011,25 @@ impl LogicalPlanBuilder { join_keys: (Vec>, Vec>), filter: Option, null_equality: NullEquality, + ) -> Result { + self.join_detailed_with_options( + right, + join_type, + join_keys, + filter, + null_equality, + false, + ) + } + + pub fn join_detailed_with_options( + self, + right: LogicalPlan, + join_type: JoinType, + join_keys: (Vec>, Vec>), + filter: Option, + null_equality: NullEquality, + null_aware: bool, ) -> Result { if join_keys.0.len() != join_keys.1.len() { return plan_err!("left_keys and right_keys were not the same length"); @@ -1128,6 +1147,7 @@ impl LogicalPlanBuilder { join_constraint: JoinConstraint::On, schema: DFSchemaRef::new(join_schema), null_equality, + null_aware, }))) } @@ -1201,6 +1221,7 @@ impl LogicalPlanBuilder { join_type, JoinConstraint::Using, NullEquality::NullEqualsNothing, + false, // null_aware )?; Ok(Self::new(LogicalPlan::Join(join))) @@ -1217,6 +1238,7 @@ impl LogicalPlanBuilder { JoinType::Inner, JoinConstraint::On, NullEquality::NullEqualsNothing, + false, // null_aware )?; Ok(Self::new(LogicalPlan::Join(join))) @@ -1471,6 +1493,7 @@ impl LogicalPlanBuilder { join_type, JoinConstraint::On, NullEquality::NullEqualsNothing, + false, // null_aware )?; Ok(Self::new(LogicalPlan::Join(join))) @@ -2756,12 +2779,12 @@ mod tests { assert_snapshot!(plan, @r" Union - Cross Join: + Cross Join: SubqueryAlias: left Values: (Int32(1)) SubqueryAlias: right Values: (Int32(1)) - Cross Join: + Cross Join: SubqueryAlias: left Values: (Int32(1)) SubqueryAlias: right diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index 480974b055d1..58c7feb61617 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -117,13 +117,7 @@ pub fn display_schema(schema: &Schema) -> impl fmt::Display + '_ { write!(f, ", ")?; } let nullable_str = if field.is_nullable() { ";N" } else { "" }; - write!( - f, - "{}:{:?}{}", - field.name(), - field.data_type(), - nullable_str - )?; + write!(f, "{}:{}{}", field.name(), field.data_type(), nullable_str)?; } write!(f, "]") } diff --git a/datafusion/expr/src/logical_plan/dml.rs b/datafusion/expr/src/logical_plan/dml.rs index 6ac3b309aa0c..b668cbfe2cc3 100644 --- a/datafusion/expr/src/logical_plan/dml.rs +++ b/datafusion/expr/src/logical_plan/dml.rs @@ -237,6 +237,8 @@ pub enum WriteOp { Update, /// `CREATE TABLE AS SELECT` operation Ctas, + /// `TRUNCATE` operation + Truncate, } impl WriteOp { @@ -247,6 +249,7 @@ impl WriteOp { WriteOp::Delete => "Delete", WriteOp::Update => "Update", WriteOp::Ctas => "Ctas", + WriteOp::Truncate => "Truncate", } } } diff --git a/datafusion/expr/src/logical_plan/invariants.rs b/datafusion/expr/src/logical_plan/invariants.rs index 762491a255cb..0889afd08fee 100644 --- a/datafusion/expr/src/logical_plan/invariants.rs +++ b/datafusion/expr/src/logical_plan/invariants.rs @@ -22,7 +22,7 @@ use datafusion_common::{ use crate::{ Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window, - expr::{Exists, InSubquery}, + expr::{Exists, InSubquery, SetComparison}, expr_rewriter::strip_outer_reference, utils::{collect_subquery_cols, split_conjunction}, }; @@ -81,6 +81,7 @@ fn assert_valid_extension_nodes(plan: &LogicalPlan, check: InvariantLevel) -> Re match expr { Expr::Exists(Exists { subquery, .. }) | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::SetComparison(SetComparison { subquery, .. }) | Expr::ScalarSubquery(subquery) => { assert_valid_extension_nodes(&subquery.subquery, check)?; } @@ -133,6 +134,7 @@ fn assert_subqueries_are_valid(plan: &LogicalPlan) -> Result<()> { match expr { Expr::Exists(Exists { subquery, .. }) | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::SetComparison(SetComparison { subquery, .. }) | Expr::ScalarSubquery(subquery) => { check_subquery_expr(plan, &subquery.subquery, expr)?; } @@ -206,14 +208,16 @@ pub fn check_subquery_expr( if group_expr.contains(expr) && !aggr_expr.contains(expr) { // TODO revisit this validation logic plan_err!( - "Correlated scalar subquery in the GROUP BY clause must also be in the aggregate expressions" + "Correlated scalar subquery in the GROUP BY clause must \ + also be in the aggregate expressions" ) } else { Ok(()) } } _ => plan_err!( - "Correlated scalar subquery can only be used in Projection, Filter, Aggregate plan nodes" + "Correlated scalar subquery can only be used in Projection, \ + Filter, Aggregate plan nodes" ), }?; } @@ -229,6 +233,20 @@ pub fn check_subquery_expr( ); } } + if let Expr::SetComparison(set_comparison) = expr + && set_comparison.subquery.subquery.schema().fields().len() > 1 + { + return plan_err!( + "Set comparison subquery should only return one column, but found {}: {}", + set_comparison.subquery.subquery.schema().fields().len(), + set_comparison + .subquery + .subquery + .schema() + .field_names() + .join(", ") + ); + } match outer_plan { LogicalPlan::Projection(_) | LogicalPlan::Filter(_) @@ -237,7 +255,7 @@ pub fn check_subquery_expr( | LogicalPlan::Aggregate(_) | LogicalPlan::Join(_) => Ok(()), _ => plan_err!( - "In/Exist subquery can only be used in \ + "In/Exist/SetComparison subquery can only be used in \ Projection, Filter, TableScan, Window functions, Aggregate and Join plan nodes, \ but was used in [{}]", outer_plan.display() diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 4219c24bfc9c..99688a52a75c 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -661,6 +661,7 @@ impl LogicalPlan { on, schema: _, null_equality, + null_aware, }) => { let schema = build_join_schema(left.schema(), right.schema(), &join_type)?; @@ -682,6 +683,7 @@ impl LogicalPlan { filter, schema: DFSchemaRef::new(schema), null_equality, + null_aware, })) } LogicalPlan::Subquery(_) => Ok(self), @@ -901,6 +903,7 @@ impl LogicalPlan { join_constraint, on, null_equality, + null_aware, .. }) => { let (left, right) = self.only_two_inputs(inputs)?; @@ -942,6 +945,7 @@ impl LogicalPlan { filter: filter_expr, schema: DFSchemaRef::new(schema), null_equality: *null_equality, + null_aware: *null_aware, })) } LogicalPlan::Subquery(Subquery { @@ -1388,6 +1392,82 @@ impl LogicalPlan { } } + /// Returns the skip (offset) of this plan node, if it has one. + /// + /// Only [`LogicalPlan::Limit`] carries a skip value; all other variants + /// return `Ok(None)`. Returns `Ok(None)` for a zero skip. + pub fn skip(&self) -> Result> { + match self { + LogicalPlan::Limit(limit) => match limit.get_skip_type()? { + SkipType::Literal(0) => Ok(None), + SkipType::Literal(n) => Ok(Some(n)), + SkipType::UnsupportedExpr => Ok(None), + }, + LogicalPlan::Sort(_) => Ok(None), + LogicalPlan::TableScan(_) => Ok(None), + LogicalPlan::Projection(_) => Ok(None), + LogicalPlan::Filter(_) => Ok(None), + LogicalPlan::Window(_) => Ok(None), + LogicalPlan::Aggregate(_) => Ok(None), + LogicalPlan::Join(_) => Ok(None), + LogicalPlan::Repartition(_) => Ok(None), + LogicalPlan::Union(_) => Ok(None), + LogicalPlan::EmptyRelation(_) => Ok(None), + LogicalPlan::Subquery(_) => Ok(None), + LogicalPlan::SubqueryAlias(_) => Ok(None), + LogicalPlan::Statement(_) => Ok(None), + LogicalPlan::Values(_) => Ok(None), + LogicalPlan::Explain(_) => Ok(None), + LogicalPlan::Analyze(_) => Ok(None), + LogicalPlan::Extension(_) => Ok(None), + LogicalPlan::Distinct(_) => Ok(None), + LogicalPlan::Dml(_) => Ok(None), + LogicalPlan::Ddl(_) => Ok(None), + LogicalPlan::Copy(_) => Ok(None), + LogicalPlan::DescribeTable(_) => Ok(None), + LogicalPlan::Unnest(_) => Ok(None), + LogicalPlan::RecursiveQuery(_) => Ok(None), + } + } + + /// Returns the fetch (limit) of this plan node, if it has one. + /// + /// [`LogicalPlan::Sort`], [`LogicalPlan::TableScan`], and + /// [`LogicalPlan::Limit`] may carry a fetch value; all other variants + /// return `Ok(None)`. + pub fn fetch(&self) -> Result> { + match self { + LogicalPlan::Sort(Sort { fetch, .. }) => Ok(*fetch), + LogicalPlan::TableScan(TableScan { fetch, .. }) => Ok(*fetch), + LogicalPlan::Limit(limit) => match limit.get_fetch_type()? { + FetchType::Literal(s) => Ok(s), + FetchType::UnsupportedExpr => Ok(None), + }, + LogicalPlan::Projection(_) => Ok(None), + LogicalPlan::Filter(_) => Ok(None), + LogicalPlan::Window(_) => Ok(None), + LogicalPlan::Aggregate(_) => Ok(None), + LogicalPlan::Join(_) => Ok(None), + LogicalPlan::Repartition(_) => Ok(None), + LogicalPlan::Union(_) => Ok(None), + LogicalPlan::EmptyRelation(_) => Ok(None), + LogicalPlan::Subquery(_) => Ok(None), + LogicalPlan::SubqueryAlias(_) => Ok(None), + LogicalPlan::Statement(_) => Ok(None), + LogicalPlan::Values(_) => Ok(None), + LogicalPlan::Explain(_) => Ok(None), + LogicalPlan::Analyze(_) => Ok(None), + LogicalPlan::Extension(_) => Ok(None), + LogicalPlan::Distinct(_) => Ok(None), + LogicalPlan::Dml(_) => Ok(None), + LogicalPlan::Ddl(_) => Ok(None), + LogicalPlan::Copy(_) => Ok(None), + LogicalPlan::DescribeTable(_) => Ok(None), + LogicalPlan::Unnest(_) => Ok(None), + LogicalPlan::RecursiveQuery(_) => Ok(None), + } + } + /// If this node's expressions contains any references to an outer subquery pub fn contains_outer_reference(&self) -> bool { let mut contains = false; @@ -1961,7 +2041,7 @@ impl LogicalPlan { .unwrap_or_else(|| "".to_string()); let join_type = if filter.is_none() && keys.is_empty() - && matches!(join_type, JoinType::Inner) + && *join_type == JoinType::Inner { "Cross".to_string() } else { @@ -1969,13 +2049,16 @@ impl LogicalPlan { }; match join_constraint { JoinConstraint::On => { - write!( - f, - "{} Join: {}{}", - join_type, - join_expr.join(", "), - filter_expr - ) + write!(f, "{join_type} Join:",)?; + if !join_expr.is_empty() || !filter_expr.is_empty() { + write!( + f, + " {}{}", + join_expr.join(", "), + filter_expr + )?; + } + Ok(()) } JoinConstraint::Using => { write!( @@ -3781,6 +3864,14 @@ pub struct Join { pub schema: DFSchemaRef, /// Defines the null equality for the join. pub null_equality: NullEquality, + /// Whether this is a null-aware anti join (for NOT IN semantics). + /// + /// Only applies to LeftAnti joins. When true, implements SQL NOT IN semantics where: + /// - If the right side (subquery) contains any NULL in join keys, no rows are output + /// - Left side rows with NULL in join keys are not output + /// + /// This is required for correct NOT IN subquery behavior with three-valued logic. + pub null_aware: bool, } impl Join { @@ -3798,10 +3889,12 @@ impl Join { /// * `join_type` - Type of join (Inner, Left, Right, etc.) /// * `join_constraint` - Join constraint (On, Using) /// * `null_equality` - How to handle nulls in join comparisons + /// * `null_aware` - Whether this is a null-aware anti join (for NOT IN semantics) /// /// # Returns /// /// A new Join operator with the computed schema + #[expect(clippy::too_many_arguments)] pub fn try_new( left: Arc, right: Arc, @@ -3810,6 +3903,7 @@ impl Join { join_type: JoinType, join_constraint: JoinConstraint, null_equality: NullEquality, + null_aware: bool, ) -> Result { let join_schema = build_join_schema(left.schema(), right.schema(), &join_type)?; @@ -3822,6 +3916,7 @@ impl Join { join_constraint, schema: Arc::new(join_schema), null_equality, + null_aware, }) } @@ -3877,6 +3972,7 @@ impl Join { join_constraint: original_join.join_constraint, schema: Arc::new(join_schema), null_equality: original_join.null_equality, + null_aware: original_join.null_aware, }, requalified, )) @@ -5329,6 +5425,7 @@ mod tests { join_constraint: JoinConstraint::On, schema: Arc::new(left_schema.join(&right_schema)?), null_equality: NullEquality::NullEqualsNothing, + null_aware: false, })) } @@ -5440,6 +5537,7 @@ mod tests { join_type, JoinConstraint::On, NullEquality::NullEqualsNothing, + false, )?; match join_type { @@ -5585,6 +5683,7 @@ mod tests { JoinType::Inner, JoinConstraint::Using, NullEquality::NullEqualsNothing, + false, )?; let fields = join.schema.fields(); @@ -5636,6 +5735,7 @@ mod tests { JoinType::Inner, JoinConstraint::On, NullEquality::NullEqualsNothing, + false, )?; let fields = join.schema.fields(); @@ -5685,6 +5785,7 @@ mod tests { JoinType::Inner, JoinConstraint::On, NullEquality::NullEqualsNull, + false, )?; assert_eq!(join.null_equality, NullEquality::NullEqualsNull); @@ -5727,6 +5828,7 @@ mod tests { join_type, JoinConstraint::On, NullEquality::NullEqualsNothing, + false, )?; let fields = join.schema.fields(); @@ -5766,6 +5868,7 @@ mod tests { JoinType::Inner, JoinConstraint::Using, NullEquality::NullEqualsNothing, + false, )?; assert_eq!( diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 62a27b0a025a..a1285510da56 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -46,7 +46,7 @@ use crate::{ }; use datafusion_common::tree_node::TreeNodeRefContainer; -use crate::expr::{Exists, InSubquery}; +use crate::expr::{Exists, InSubquery, SetComparison}; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeContainer, TreeNodeIterator, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, @@ -133,6 +133,7 @@ impl TreeNode for LogicalPlan { join_constraint, schema, null_equality, + null_aware, }) => (left, right).map_elements(f)?.update_data(|(left, right)| { LogicalPlan::Join(Join { left, @@ -143,6 +144,7 @@ impl TreeNode for LogicalPlan { join_constraint, schema, null_equality, + null_aware, }) }), LogicalPlan::Limit(Limit { skip, fetch, input }) => input @@ -564,6 +566,7 @@ impl LogicalPlan { join_constraint, schema, null_equality, + null_aware, }) => (on, filter).map_elements(f)?.update_data(|(on, filter)| { LogicalPlan::Join(Join { left, @@ -574,6 +577,7 @@ impl LogicalPlan { join_constraint, schema, null_equality, + null_aware, }) }), LogicalPlan::Sort(Sort { expr, input, fetch }) => expr @@ -815,6 +819,7 @@ impl LogicalPlan { expr.apply(|expr| match expr { Expr::Exists(Exists { subquery, .. }) | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::SetComparison(SetComparison { subquery, .. }) | Expr::ScalarSubquery(subquery) => { // use a synthetic plan so the collector sees a // LogicalPlan::Subquery (even though it is @@ -856,6 +861,22 @@ impl LogicalPlan { })), _ => internal_err!("Transformation should return Subquery"), }), + Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + }) => f(LogicalPlan::Subquery(subquery))?.map_data(|s| match s { + LogicalPlan::Subquery(subquery) => { + Ok(Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + })) + } + _ => internal_err!("Transformation should return Subquery"), + }), Expr::ScalarSubquery(subquery) => f(LogicalPlan::Subquery(subquery))? .map_data(|s| match s { LogicalPlan::Subquery(subquery) => { diff --git a/datafusion/expr/src/partition_evaluator.rs b/datafusion/expr/src/partition_evaluator.rs index 3e7ba5d4f575..0671f31f6d15 100644 --- a/datafusion/expr/src/partition_evaluator.rs +++ b/datafusion/expr/src/partition_evaluator.rs @@ -86,6 +86,10 @@ use crate::window_state::WindowAggState; /// [`uses_window_frame`]: Self::uses_window_frame /// [`include_rank`]: Self::include_rank /// [`supports_bounded_execution`]: Self::supports_bounded_execution +/// +/// For more background, please also see the [User defined Window Functions in DataFusion blog] +/// +/// [User defined Window Functions in DataFusion blog]: https://datafusion.apache.org/blog/2025/04/19/user-defined-window-functions pub trait PartitionEvaluator: Debug + Send { /// When the window frame has a fixed beginning (e.g UNBOUNDED /// PRECEDING), some functions such as FIRST_VALUE, LAST_VALUE and diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 954f511651ce..837a9eefe289 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -139,6 +139,10 @@ pub trait ContextProvider { } /// Customize planning of SQL AST expressions to [`Expr`]s +/// +/// For more background, please also see the [Extending SQL in DataFusion: from ->> to TABLESAMPLE blog] +/// +/// [Extending SQL in DataFusion: from ->> to TABLESAMPLE blog]: https://datafusion.apache.org/blog/2026/01/12/extending-sql pub trait ExprPlanner: Debug + Send + Sync { /// Plan the binary operation between two expressions, returns original /// BinaryExpr if not possible @@ -249,13 +253,6 @@ pub trait ExprPlanner: Debug + Send + Sync { ) } - /// Plans `ANY` expression, such as `expr = ANY(array_expr)` - /// - /// Returns origin binary expression if not possible - fn plan_any(&self, expr: RawBinaryExpr) -> Result> { - Ok(PlannerResult::Original(expr)) - } - /// Plans aggregate functions, such as `COUNT()` /// /// Returns original expression arguments if not possible @@ -369,13 +366,16 @@ impl PlannedRelation { #[derive(Debug)] pub enum RelationPlanning { /// The relation was successfully planned by an extension planner - Planned(PlannedRelation), + Planned(Box), /// No extension planner handled the relation, return it for default processing - Original(TableFactor), + Original(Box), } /// Customize planning SQL table factors to [`LogicalPlan`]s. #[cfg(feature = "sql")] +/// For more background, please also see the [Extending SQL in DataFusion: from ->> to TABLESAMPLE blog] +/// +/// [Extending SQL in DataFusion: from ->> to TABLESAMPLE blog]: https://datafusion.apache.org/blog/2026/01/12/extending-sql pub trait RelationPlanner: Debug + Send + Sync { /// Plan a table factor into a [`LogicalPlan`]. /// @@ -427,6 +427,9 @@ pub trait RelationPlannerContext { /// Customize planning SQL types to DataFusion (Arrow) types. #[cfg(feature = "sql")] +/// For more background, please also see the [Extending SQL in DataFusion: from ->> to TABLESAMPLE blog] +/// +/// [Extending SQL in DataFusion: from ->> to TABLESAMPLE blog]: https://datafusion.apache.org/blog/2026/01/12/extending-sql pub trait TypePlanner: Debug + Send + Sync { /// Plan SQL [`sqlparser::ast::DataType`] to DataFusion [`DataType`] /// diff --git a/datafusion/expr/src/preimage.rs b/datafusion/expr/src/preimage.rs new file mode 100644 index 000000000000..67ca7a91bbf3 --- /dev/null +++ b/datafusion/expr/src/preimage.rs @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr_common::interval_arithmetic::Interval; + +use crate::Expr; + +/// Return from [`crate::ScalarUDFImpl::preimage`] +pub enum PreimageResult { + /// No preimage exists for the specified value + None, + /// The expression always evaluates to the specified constant + /// given that `expr` is within the interval + Range { expr: Expr, interval: Box }, +} diff --git a/datafusion/expr/src/simplify.rs b/datafusion/expr/src/simplify.rs index bbe65904fb77..8c68067a55a3 100644 --- a/datafusion/expr/src/simplify.rs +++ b/datafusion/expr/src/simplify.rs @@ -15,92 +15,98 @@ // specific language governing permissions and limitations // under the License. -//! Structs and traits to provide the information needed for expression simplification. +//! Structs to provide the information needed for expression simplification. + +use std::sync::Arc; use arrow::datatypes::DataType; -use datafusion_common::{DFSchemaRef, Result, internal_datafusion_err}; +use chrono::{DateTime, Utc}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{DFSchema, DFSchemaRef, Result}; -use crate::{Expr, ExprSchemable, execution_props::ExecutionProps}; +use crate::{Expr, ExprSchemable}; -/// Provides the information necessary to apply algebraic simplification to an -/// [Expr]. See [SimplifyContext] for one concrete implementation. -/// -/// This trait exists so that other systems can plug schema -/// information in without having to create `DFSchema` objects. If you -/// have a [`DFSchemaRef`] you can use [`SimplifyContext`] -pub trait SimplifyInfo { - /// Returns true if this Expr has boolean type - fn is_boolean_type(&self, expr: &Expr) -> Result; - - /// Returns true of this expr is nullable (could possibly be NULL) - fn nullable(&self, expr: &Expr) -> Result; - - /// Returns details needed for partial expression evaluation - fn execution_props(&self) -> &ExecutionProps; - - /// Returns data type of this expr needed for determining optimized int type of a value - fn get_data_type(&self, expr: &Expr) -> Result; -} - -/// Provides simplification information based on DFSchema and -/// [`ExecutionProps`]. This is the default implementation used by DataFusion +/// Provides simplification information based on schema, query execution time, +/// and configuration options. /// /// # Example /// See the `simplify_demo` in the [`expr_api` example] /// /// [`expr_api` example]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/query_planning/expr_api.rs #[derive(Debug, Clone)] -pub struct SimplifyContext<'a> { - schema: Option, - props: &'a ExecutionProps, +pub struct SimplifyContext { + schema: DFSchemaRef, + query_execution_start_time: Option>, + config_options: Arc, } -impl<'a> SimplifyContext<'a> { - /// Create a new SimplifyContext - pub fn new(props: &'a ExecutionProps) -> Self { +impl Default for SimplifyContext { + fn default() -> Self { Self { - schema: None, - props, + schema: Arc::new(DFSchema::empty()), + query_execution_start_time: None, + config_options: Arc::new(ConfigOptions::default()), } } +} + +impl SimplifyContext { + /// Set the [`ConfigOptions`] for this context + pub fn with_config_options(mut self, config_options: Arc) -> Self { + self.config_options = config_options; + self + } - /// Register a [`DFSchemaRef`] with this context + /// Set the schema for this context pub fn with_schema(mut self, schema: DFSchemaRef) -> Self { - self.schema = Some(schema); + self.schema = schema; self } -} -impl SimplifyInfo for SimplifyContext<'_> { - /// Returns true if this Expr has boolean type - fn is_boolean_type(&self, expr: &Expr) -> Result { - if let Some(schema) = &self.schema - && let Ok(DataType::Boolean) = expr.get_type(schema) - { - return Ok(true); - } + /// Set the query execution start time + pub fn with_query_execution_start_time( + mut self, + query_execution_start_time: Option>, + ) -> Self { + self.query_execution_start_time = query_execution_start_time; + self + } - Ok(false) + /// Set the query execution start to the current time + pub fn with_current_time(mut self) -> Self { + self.query_execution_start_time = Some(Utc::now()); + self + } + + /// Returns the schema + pub fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + /// Returns true if this Expr has boolean type + pub fn is_boolean_type(&self, expr: &Expr) -> Result { + Ok(expr.get_type(&self.schema)? == DataType::Boolean) } /// Returns true if expr is nullable - fn nullable(&self, expr: &Expr) -> Result { - let schema = self.schema.as_ref().ok_or_else(|| { - internal_datafusion_err!("attempt to get nullability without schema") - })?; - expr.nullable(schema.as_ref()) + pub fn nullable(&self, expr: &Expr) -> Result { + expr.nullable(self.schema.as_ref()) } /// Returns data type of this expr needed for determining optimized int type of a value - fn get_data_type(&self, expr: &Expr) -> Result { - let schema = self.schema.as_ref().ok_or_else(|| { - internal_datafusion_err!("attempt to get data type without schema") - })?; - expr.get_type(schema) + pub fn get_data_type(&self, expr: &Expr) -> Result { + expr.get_type(&self.schema) + } + + /// Returns the time at which the query execution started. + /// If `None`, time-dependent functions like `now()` will not be simplified. + pub fn query_execution_start_time(&self) -> Option> { + self.query_execution_start_time } - fn execution_props(&self) -> &ExecutionProps { - self.props + /// Returns the configuration options for the session. + pub fn config_options(&self) -> &Arc { + &self.config_options } } diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 742bae5b2320..226c512a974d 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -20,8 +20,8 @@ use crate::Expr; use crate::expr::{ AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Cast, - GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest, - WindowFunction, WindowFunctionParams, + GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, SetComparison, + TryCast, Unnest, WindowFunction, WindowFunctionParams, }; use datafusion_common::Result; @@ -58,7 +58,8 @@ impl TreeNode for Expr { | Expr::Negative(expr) | Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) - | Expr::InSubquery(InSubquery { expr, .. }) => expr.apply_elements(f), + | Expr::InSubquery(InSubquery { expr, .. }) + | Expr::SetComparison(SetComparison { expr, .. }) => expr.apply_elements(f), Expr::GroupingSet(GroupingSet::Rollup(exprs)) | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.apply_elements(f), Expr::ScalarFunction(ScalarFunction { args, .. }) => { @@ -128,6 +129,19 @@ impl TreeNode for Expr { | Expr::ScalarSubquery(_) | Expr::ScalarVariable(_, _) | Expr::Literal(_, _) => Transformed::no(self), + Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + }) => expr.map_elements(f)?.update_data(|expr| { + Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + }) + }), Expr::Unnest(Unnest { expr, .. }) => expr .map_elements(f)? .update_data(|expr| Expr::Unnest(Unnest { expr })), diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index e1f2a1967282..fe259fb8c972 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -94,58 +94,6 @@ impl UDFCoercionExt for WindowUDF { } } -/// Performs type coercion for scalar function arguments. -/// -/// Returns the data types to which each argument must be coerced to -/// match `signature`. -/// -/// For more details on coercion in general, please see the -/// [`type_coercion`](crate::type_coercion) module. -#[deprecated(since = "52.0.0", note = "use fields_with_udf")] -pub fn data_types_with_scalar_udf( - current_types: &[DataType], - func: &ScalarUDF, -) -> Result> { - let current_fields = current_types - .iter() - .map(|dt| Arc::new(Field::new("f", dt.clone(), true))) - .collect::>(); - Ok(fields_with_udf(¤t_fields, func)? - .iter() - .map(|f| f.data_type().clone()) - .collect()) -} - -/// Performs type coercion for aggregate function arguments. -/// -/// Returns the fields to which each argument must be coerced to -/// match `signature`. -/// -/// For more details on coercion in general, please see the -/// [`type_coercion`](crate::type_coercion) module. -#[deprecated(since = "52.0.0", note = "use fields_with_udf")] -pub fn fields_with_aggregate_udf( - current_fields: &[FieldRef], - func: &AggregateUDF, -) -> Result> { - fields_with_udf(current_fields, func) -} - -/// Performs type coercion for window function arguments. -/// -/// Returns the data types to which each argument must be coerced to -/// match `signature`. -/// -/// For more details on coercion in general, please see the -/// [`type_coercion`](crate::type_coercion) module. -#[deprecated(since = "52.0.0", note = "use fields_with_udf")] -pub fn fields_with_window_udf( - current_fields: &[FieldRef], - func: &WindowUDF, -) -> Result> { - fields_with_udf(current_fields, func) -} - /// Performs type coercion for UDF arguments. /// /// Returns the data types to which each argument must be coerced to @@ -200,6 +148,58 @@ pub fn fields_with_udf( .collect()) } +/// Performs type coercion for scalar function arguments. +/// +/// Returns the data types to which each argument must be coerced to +/// match `signature`. +/// +/// For more details on coercion in general, please see the +/// [`type_coercion`](crate::type_coercion) module. +#[deprecated(since = "52.0.0", note = "use fields_with_udf")] +pub fn data_types_with_scalar_udf( + current_types: &[DataType], + func: &ScalarUDF, +) -> Result> { + let current_fields = current_types + .iter() + .map(|dt| Arc::new(Field::new("f", dt.clone(), true))) + .collect::>(); + Ok(fields_with_udf(¤t_fields, func)? + .iter() + .map(|f| f.data_type().clone()) + .collect()) +} + +/// Performs type coercion for aggregate function arguments. +/// +/// Returns the fields to which each argument must be coerced to +/// match `signature`. +/// +/// For more details on coercion in general, please see the +/// [`type_coercion`](crate::type_coercion) module. +#[deprecated(since = "52.0.0", note = "use fields_with_udf")] +pub fn fields_with_aggregate_udf( + current_fields: &[FieldRef], + func: &AggregateUDF, +) -> Result> { + fields_with_udf(current_fields, func) +} + +/// Performs type coercion for window function arguments. +/// +/// Returns the data types to which each argument must be coerced to +/// match `signature`. +/// +/// For more details on coercion in general, please see the +/// [`type_coercion`](crate::type_coercion) module. +#[deprecated(since = "52.0.0", note = "use fields_with_udf")] +pub fn fields_with_window_udf( + current_fields: &[FieldRef], + func: &WindowUDF, +) -> Result> { + fields_with_udf(current_fields, func) +} + /// Performs type coercion for function arguments. /// /// Returns the data types to which each argument must be coerced to @@ -487,7 +487,7 @@ fn get_valid_types( let valid_types = match signature { TypeSignature::Variadic(valid_types) => valid_types .iter() - .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect()) + .map(|valid_type| vec![valid_type.clone(); current_types.len()]) .collect(), TypeSignature::String(number) => { function_length_check(function_name, current_types.len(), *number)?; @@ -635,8 +635,13 @@ fn get_valid_types( default_casted_type.default_cast_for(current_type)?; new_types.push(casted_type); } else { - return internal_err!( - "Expect {} but received NativeType::{}, DataType: {}", + let hint = if matches!(current_native_type, NativeType::Binary) { + "\n\nHint: Binary types are not automatically coerced to String. Use CAST(column AS VARCHAR) to convert Binary data to String." + } else { + "" + }; + return plan_err!( + "Function '{function_name}' requires {}, but received {} (DataType: {}).{hint}", param.desired_type(), current_native_type, current_type @@ -655,7 +660,7 @@ fn get_valid_types( valid_types .iter() - .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) + .map(|valid_type| vec![valid_type.clone(); *number]) .collect() } TypeSignature::UserDefined => { @@ -722,7 +727,7 @@ fn get_valid_types( current_types.len() ); } - vec![(0..*number).map(|i| current_types[i].clone()).collect()] + vec![current_types.to_vec()] } TypeSignature::OneOf(types) => types .iter() @@ -800,6 +805,7 @@ fn maybe_data_types_without_coercion( /// (losslessly converted) into a value of `type_to` /// /// See the module level documentation for more detail on coercion. +#[deprecated(since = "53.0.0", note = "Unused internal function")] pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool { if type_into == type_from { return true; @@ -846,10 +852,13 @@ fn coerced_from<'a>( (UInt16, Null | UInt8 | UInt16) => Some(type_into.clone()), (UInt32, Null | UInt8 | UInt16 | UInt32) => Some(type_into.clone()), (UInt64, Null | UInt8 | UInt16 | UInt32 | UInt64) => Some(type_into.clone()), + (Float16, Null | Int8 | Int16 | UInt8 | UInt16 | Float16) => { + Some(type_into.clone()) + } ( Float32, Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 - | Float32, + | Float16 | Float32, ) => Some(type_into.clone()), ( Float64, @@ -862,6 +871,7 @@ fn coerced_from<'a>( | UInt16 | UInt32 | UInt64 + | Float16 | Float32 | Float64 | Decimal32(_, _) @@ -873,7 +883,7 @@ fn coerced_from<'a>( Timestamp(TimeUnit::Nanosecond, None), Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8, ) => Some(type_into.clone()), - (Interval(_), Utf8 | LargeUtf8) => Some(type_into.clone()), + (Interval(_), Null | Utf8 | LargeUtf8) => Some(type_into.clone()), // We can go into a Utf8View from a Utf8 or LargeUtf8 (Utf8View, Utf8 | LargeUtf8 | Null) => Some(type_into.clone()), // Any type can be coerced into strings @@ -928,18 +938,21 @@ mod tests { use super::*; use arrow::datatypes::Field; - use datafusion_common::{assert_contains, types::logical_binary}; + use datafusion_common::{ + assert_contains, + types::{logical_binary, logical_int64}, + }; use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; #[test] fn test_string_conversion() { let cases = vec![ - (DataType::Utf8View, DataType::Utf8, true), - (DataType::Utf8View, DataType::LargeUtf8, true), + (DataType::Utf8View, DataType::Utf8), + (DataType::Utf8View, DataType::LargeUtf8), ]; for case in cases { - assert_eq!(can_coerce_from(&case.0, &case.1), case.2); + assert_eq!(coerced_from(&case.0, &case.1), Some(case.0)); } } @@ -1063,7 +1076,7 @@ mod tests { .unwrap_err(); assert_contains!( got.to_string(), - "Function 'test' expects NativeType::Numeric but received NativeType::Timestamp(Second, None)" + "Function 'test' expects NativeType::Numeric but received NativeType::Timestamp(s)" ); Ok(()) @@ -1118,22 +1131,22 @@ mod tests { Ok(()) } - #[test] - fn test_fixed_list_wildcard_coerce() -> Result<()> { - struct MockUdf(Signature); + struct MockUdf(Signature); - impl UDFCoercionExt for MockUdf { - fn name(&self) -> &str { - "test" - } - fn signature(&self) -> &Signature { - &self.0 - } - fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { - unimplemented!() - } + impl UDFCoercionExt for MockUdf { + fn name(&self) -> &str { + "test" + } + fn signature(&self) -> &Signature { + &self.0 } + fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { + unimplemented!() + } + } + #[test] + fn test_fixed_list_wildcard_coerce() -> Result<()> { let inner = Arc::new(Field::new_list_field(DataType::Int32, false)); // able to coerce for any size let current_fields = vec![Arc::new(Field::new( @@ -1340,6 +1353,140 @@ mod tests { Ok(()) } + #[test] + fn test_coercible_nulls() -> Result<()> { + fn null_input(coercion: Coercion) -> Result> { + fields_with_udf( + &[Field::new("field", DataType::Null, true).into()], + &MockUdf(Signature::coercible(vec![coercion], Volatility::Immutable)), + ) + .map(|v| v.into_iter().map(|f| f.data_type().clone()).collect()) + } + + // Casts Null to Int64 if we use TypeSignatureClass::Native + let output = null_input(Coercion::new_exact(TypeSignatureClass::Native( + logical_int64(), + )))?; + assert_eq!(vec![DataType::Int64], output); + + let output = null_input(Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![], + NativeType::Int64, + ))?; + assert_eq!(vec![DataType::Int64], output); + + // Null gets passed through if we use TypeSignatureClass apart from Native + let output = null_input(Coercion::new_exact(TypeSignatureClass::Integer))?; + assert_eq!(vec![DataType::Null], output); + + let output = null_input(Coercion::new_implicit( + TypeSignatureClass::Integer, + vec![], + NativeType::Int64, + ))?; + assert_eq!(vec![DataType::Null], output); + + Ok(()) + } + + #[test] + fn test_coercible_dictionary() -> Result<()> { + let dictionary = + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int64)); + fn dictionary_input(coercion: Coercion) -> Result> { + fields_with_udf( + &[Field::new( + "field", + DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Int64), + ), + true, + ) + .into()], + &MockUdf(Signature::coercible(vec![coercion], Volatility::Immutable)), + ) + .map(|v| v.into_iter().map(|f| f.data_type().clone()).collect()) + } + + // Casts Dictionary to Int64 if we use TypeSignatureClass::Native + let output = dictionary_input(Coercion::new_exact(TypeSignatureClass::Native( + logical_int64(), + )))?; + assert_eq!(vec![DataType::Int64], output); + + let output = dictionary_input(Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![], + NativeType::Int64, + ))?; + assert_eq!(vec![DataType::Int64], output); + + // Dictionary gets passed through if we use TypeSignatureClass apart from Native + let output = dictionary_input(Coercion::new_exact(TypeSignatureClass::Integer))?; + assert_eq!(vec![dictionary.clone()], output); + + let output = dictionary_input(Coercion::new_implicit( + TypeSignatureClass::Integer, + vec![], + NativeType::Int64, + ))?; + assert_eq!(vec![dictionary.clone()], output); + + Ok(()) + } + + #[test] + fn test_coercible_run_end_encoded() -> Result<()> { + let run_end_encoded = DataType::RunEndEncoded( + Field::new("run_ends", DataType::Int16, false).into(), + Field::new("values", DataType::Int64, true).into(), + ); + fn run_end_encoded_input(coercion: Coercion) -> Result> { + fields_with_udf( + &[Field::new( + "field", + DataType::RunEndEncoded( + Field::new("run_ends", DataType::Int16, false).into(), + Field::new("values", DataType::Int64, true).into(), + ), + true, + ) + .into()], + &MockUdf(Signature::coercible(vec![coercion], Volatility::Immutable)), + ) + .map(|v| v.into_iter().map(|f| f.data_type().clone()).collect()) + } + + // Casts REE to Int64 if we use TypeSignatureClass::Native + let output = run_end_encoded_input(Coercion::new_exact( + TypeSignatureClass::Native(logical_int64()), + ))?; + assert_eq!(vec![DataType::Int64], output); + + let output = run_end_encoded_input(Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![], + NativeType::Int64, + ))?; + assert_eq!(vec![DataType::Int64], output); + + // REE gets passed through if we use TypeSignatureClass apart from Native + let output = + run_end_encoded_input(Coercion::new_exact(TypeSignatureClass::Integer))?; + assert_eq!(vec![run_end_encoded.clone()], output); + + let output = run_end_encoded_input(Coercion::new_implicit( + TypeSignatureClass::Integer, + vec![], + NativeType::Int64, + ))?; + assert_eq!(vec![run_end_encoded.clone()], output); + + Ok(()) + } + #[test] fn test_get_valid_types_coercible_binary() -> Result<()> { let signature = Signature::coercible( diff --git a/datafusion/expr/src/type_coercion/mod.rs b/datafusion/expr/src/type_coercion/mod.rs index bd1acd3f3a2e..c92d434e34ab 100644 --- a/datafusion/expr/src/type_coercion/mod.rs +++ b/datafusion/expr/src/type_coercion/mod.rs @@ -58,11 +58,6 @@ pub fn is_signed_numeric(dt: &DataType) -> bool { ) } -/// Determine whether the given data type `dt` is `Null`. -pub fn is_null(dt: &DataType) -> bool { - *dt == DataType::Null -} - /// Determine whether the given data type `dt` is a `Timestamp`. pub fn is_timestamp(dt: &DataType) -> bool { matches!(dt, DataType::Timestamp(_, _)) @@ -80,22 +75,3 @@ pub fn is_datetime(dt: &DataType) -> bool { DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _) ) } - -/// Determine whether the given data type `dt` is a `Utf8` or `Utf8View` or `LargeUtf8`. -pub fn is_utf8_or_utf8view_or_large_utf8(dt: &DataType) -> bool { - matches!( - dt, - DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 - ) -} - -/// Determine whether the given data type `dt` is a `Decimal`. -pub fn is_decimal(dt: &DataType) -> bool { - matches!( - dt, - DataType::Decimal32(_, _) - | DataType::Decimal64(_, _) - | DataType::Decimal128(_, _) - | DataType::Decimal256(_, _) - ) -} diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index a69176e1173a..ee38077dbf30 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -668,7 +668,7 @@ pub trait AggregateUDFImpl: Debug + DynEq + DynHash + Send + Sync { /// /// Or, a closure with two arguments: /// * 'aggregate_function': [AggregateFunction] for which simplified has been invoked - /// * 'info': [crate::simplify::SimplifyInfo] + /// * 'info': [crate::simplify::SimplifyContext] /// /// closure returns simplified [Expr] or an error. /// diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 28a07ad76101..405fb256803b 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -19,7 +19,8 @@ use crate::async_udf::AsyncScalarUDF; use crate::expr::schema_name_from_exprs_comma_separated_without_space; -use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; +use crate::preimage::PreimageResult; +use crate::simplify::{ExprSimplifyResult, SimplifyContext}; use crate::sort_properties::{ExprProperties, SortProperties}; use crate::udf_eq::UdfEq; use crate::{ColumnarValue, Documentation, Expr, Signature}; @@ -30,6 +31,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::{ExprSchema, Result, ScalarValue, not_impl_err}; use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; use datafusion_expr_common::interval_arithmetic::Interval; +use datafusion_expr_common::placement::ExpressionPlacement; use std::any::Any; use std::cmp::Ordering; use std::fmt::Debug; @@ -221,7 +223,7 @@ impl ScalarUDF { pub fn simplify( &self, args: Vec, - info: &dyn SimplifyInfo, + info: &SimplifyContext, ) -> Result { self.inner.simplify(args, info) } @@ -232,6 +234,18 @@ impl ScalarUDF { self.inner.is_nullable(args, schema) } + /// Return a preimage + /// + /// See [`ScalarUDFImpl::preimage`] for more details. + pub fn preimage( + &self, + args: &[Expr], + lit_expr: &Expr, + info: &SimplifyContext, + ) -> Result { + self.inner.preimage(args, lit_expr, info) + } + /// Invoke the function on `args`, returning the appropriate result. /// /// See [`ScalarUDFImpl::invoke_with_args`] for details. @@ -348,6 +362,13 @@ impl ScalarUDF { pub fn as_async(&self) -> Option<&AsyncScalarUDF> { self.inner().as_any().downcast_ref::() } + + /// Returns placement information for this function. + /// + /// See [`ScalarUDFImpl::placement`] for more details. + pub fn placement(&self, args: &[ExpressionPlacement]) -> ExpressionPlacement { + self.inner.placement(args) + } } impl From for ScalarUDF @@ -691,11 +712,116 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync { fn simplify( &self, args: Vec, - _info: &dyn SimplifyInfo, + _info: &SimplifyContext, ) -> Result { Ok(ExprSimplifyResult::Original(args)) } + /// Returns a single contiguous preimage for this function and the specified + /// scalar expression, if any. + /// + /// Currently only applies to `=, !=, >, >=, <, <=, is distinct from, is not distinct from` predicates + /// # Return Value + /// + /// Implementations should return a half-open interval: inclusive lower + /// bound and exclusive upper bound. This is slightly different from normal + /// [`Interval`] semantics where the upper bound is closed (inclusive). + /// Typically this means the upper endpoint must be adjusted to the next + /// value not included in the preimage. See the Half-Open Intervals section + /// below for more details. + /// + /// # Background + /// + /// Inspired by the [ClickHouse Paper], a "preimage rewrite" transforms a + /// predicate containing a function call into a predicate containing an + /// equivalent set of input literal (constant) values. The resulting + /// predicate can often be further optimized by other rewrites (see + /// Examples). + /// + /// From the paper: + /// + /// > some functions can compute the preimage of a given function result. + /// > This is used to replace comparisons of constants with function calls + /// > on the key columns by comparing the key column value with the preimage. + /// > For example, `toYear(k) = 2024` can be replaced by + /// > `k >= 2024-01-01 && k < 2025-01-01` + /// + /// For example, given an expression like + /// ```sql + /// date_part('YEAR', k) = 2024 + /// ``` + /// + /// The interval `[2024-01-01, 2025-12-31`]` contains all possible input + /// values (preimage values) for which the function `date_part(YEAR, k)` + /// produces the output value `2024` (image value). Returning the interval + /// (note upper bound adjusted up) `[2024-01-01, 2025-01-01]` the expression + /// can be rewritten to + /// + /// ```sql + /// k >= '2024-01-01' AND k < '2025-01-01' + /// ``` + /// + /// which is a simpler and a more canonical form, making it easier for other + /// optimizer passes to recognize and apply further transformations. + /// + /// # Examples + /// + /// Case 1: + /// + /// Original: + /// ```sql + /// date_part('YEAR', k) = 2024 AND k >= '2024-06-01' + /// ``` + /// + /// After preimage rewrite: + /// ```sql + /// k >= '2024-01-01' AND k < '2025-01-01' AND k >= '2024-06-01' + /// ``` + /// + /// Since this form is much simpler, the optimizer can combine and simplify + /// sub-expressions further into: + /// ```sql + /// k >= '2024-06-01' AND k < '2025-01-01' + /// ``` + /// + /// Case 2: + /// + /// For min/max pruning, simpler predicates such as: + /// ```sql + /// k >= '2024-01-01' AND k < '2025-01-01' + /// ``` + /// are much easier for the pruner to reason about. See [PruningPredicate] + /// for the backgrounds of predicate pruning. + /// + /// The trade-off with the preimage rewrite is that evaluating the rewritten + /// form might be slightly more expensive than evaluating the original + /// expression. In practice, this cost is usually outweighed by the more + /// aggressive optimization opportunities it enables. + /// + /// # Half-Open Intervals + /// + /// The preimage API uses half-open intervals, which makes the rewrite + /// easier to implement by avoiding calculations to adjust the upper bound. + /// For example, if a function returns its input unchanged and the desired + /// output is the single value `5`, a closed interval could be represented + /// as `[5, 5]`, but then the rewrite would require adjusting the upper + /// bound to `6` to create a proper range predicate. With a half-open + /// interval, the same range is represented as `[5, 6)`, which already + /// forms a valid predicate. + /// + /// [PruningPredicate]: https://docs.rs/datafusion/latest/datafusion/physical_optimizer/pruning/struct.PruningPredicate.html + /// [ClickHouse Paper]: https://www.vldb.org/pvldb/vol17/p3731-schulze.pdf + /// [image]: https://en.wikipedia.org/wiki/Image_(mathematics)#Image_of_an_element + /// [preimage]: https://en.wikipedia.org/wiki/Image_(mathematics)#Inverse_image + fn preimage( + &self, + _args: &[Expr], + _lit_expr: &Expr, + _info: &SimplifyContext, + ) -> Result { + Ok(PreimageResult::None) + } + /// Returns true if some of this `exprs` subexpressions may not be evaluated /// and thus any side effects (like divide by zero) may not be encountered. /// @@ -846,6 +972,20 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync { fn documentation(&self) -> Option<&Documentation> { None } + + /// Returns placement information for this function. + /// + /// This is used by optimizers to make decisions about expression placement, + /// such as whether to push expressions down through projections. + /// + /// The default implementation returns [`ExpressionPlacement::KeepInPlace`], + /// meaning the expression should be kept where it is in the plan. + /// + /// Override this method to indicate that the function can be pushed down + /// closer to the data source. + fn placement(&self, _args: &[ExpressionPlacement]) -> ExpressionPlacement { + ExpressionPlacement::KeepInPlace + } } /// ScalarUDF that adds an alias to the underlying function. It is better to @@ -921,11 +1061,20 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { fn simplify( &self, args: Vec, - info: &dyn SimplifyInfo, + info: &SimplifyContext, ) -> Result { self.inner.simplify(args, info) } + fn preimage( + &self, + args: &[Expr], + lit_expr: &Expr, + info: &SimplifyContext, + ) -> Result { + self.inner.preimage(args, lit_expr, info) + } + fn conditional_arguments<'a>( &self, args: &'a [Expr], @@ -964,6 +1113,10 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { fn documentation(&self) -> Option<&Documentation> { self.inner.documentation() } + + fn placement(&self, args: &[ExpressionPlacement]) -> ExpressionPlacement { + self.inner.placement(args) + } } #[cfg(test)] diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 37055daa1ca4..8f2b8a0d9bfe 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -362,7 +362,7 @@ pub trait WindowUDFImpl: Debug + DynEq + DynHash + Send + Sync { /// /// Or, a closure with two arguments: /// * 'window_function': [crate::expr::WindowFunction] for which simplified has been invoked - /// * 'info': [crate::simplify::SimplifyInfo] + /// * 'info': [crate::simplify::SimplifyContext] /// /// # Notes /// The returned expression must have the same schema as the original diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index de4ebf5fa96e..b19299981cef 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -312,6 +312,7 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::InList { .. } | Expr::Exists { .. } | Expr::InSubquery(_) + | Expr::SetComparison(_) | Expr::ScalarSubquery(_) | Expr::Wildcard { .. } | Expr::Placeholder(_) @@ -937,6 +938,7 @@ pub fn find_valid_equijoin_key_pair( /// round(Float32) /// ``` #[expect(clippy::needless_pass_by_value)] +#[deprecated(since = "53.0.0", note = "Internal function")] pub fn generate_signature_error_msg( func_name: &str, func_signature: Signature, @@ -958,6 +960,26 @@ pub fn generate_signature_error_msg( ) } +/// Creates a detailed error message for a function with wrong signature. +/// +/// For example, a query like `select round(3.14, 1.1);` would yield: +/// ```text +/// Error during planning: No function matches 'round(Float64, Float64)'. You might need to add explicit type casts. +/// Candidate functions: +/// round(Float64, Int64) +/// round(Float32, Int64) +/// round(Float64) +/// round(Float32) +/// ``` +pub(crate) fn generate_signature_error_message( + func_name: &str, + func_signature: &Signature, + input_expr_types: &[DataType], +) -> String { + #[expect(deprecated)] + generate_signature_error_msg(func_name, func_signature.clone(), input_expr_types) +} + /// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` /// /// See [`split_conjunction_owned`] for more details and an example. @@ -1734,7 +1756,8 @@ mod tests { .expect("valid parameter names"); // Generate error message with only 1 argument provided - let error_msg = generate_signature_error_msg("substr", sig, &[DataType::Utf8]); + let error_msg = + generate_signature_error_message("substr", &sig, &[DataType::Utf8]); assert!( error_msg.contains("str: Utf8, start_pos: Int64"), @@ -1753,7 +1776,8 @@ mod tests { Volatility::Immutable, ); - let error_msg = generate_signature_error_msg("my_func", sig, &[DataType::Int32]); + let error_msg = + generate_signature_error_message("my_func", &sig, &[DataType::Int32]); assert!( error_msg.contains("Any, Any"), diff --git a/datafusion/ffi/src/catalog_provider.rs b/datafusion/ffi/src/catalog_provider.rs index 61e26f166353..ff588a89a71b 100644 --- a/datafusion/ffi/src/catalog_provider.rs +++ b/datafusion/ffi/src/catalog_provider.rs @@ -250,6 +250,11 @@ impl FFI_CatalogProvider { runtime: Option, logical_codec: FFI_LogicalExtensionCodec, ) -> Self { + if let Some(provider) = provider.as_any().downcast_ref::() + { + return provider.0.clone(); + } + let private_data = Box::new(ProviderPrivateData { provider, runtime }); Self { diff --git a/datafusion/ffi/src/config/extension_options.rs b/datafusion/ffi/src/config/extension_options.rs new file mode 100644 index 000000000000..48fd4e710921 --- /dev/null +++ b/datafusion/ffi/src/config/extension_options.rs @@ -0,0 +1,288 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::collections::HashMap; +use std::ffi::c_void; + +use abi_stable::StableAbi; +use abi_stable::std_types::{RResult, RStr, RString, RVec, Tuple2}; +use datafusion_common::config::{ConfigEntry, ConfigExtension, ExtensionOptions}; +use datafusion_common::{Result, exec_err}; + +use crate::df_result; + +/// A stable struct for sharing [`ExtensionOptions`] across FFI boundaries. +/// +/// Unlike other FFI structs in this crate, we do not construct a foreign +/// variant of this object. This is due to the typical method for interacting +/// with extension options is by creating a local struct of your concrete type. +/// To support this methodology use the `to_extension` method instead. +/// +/// When using [`FFI_ExtensionOptions`] with multiple extensions, all extension +/// values are stored on a single [`FFI_ExtensionOptions`] object. The keys +/// are stored with the full path prefix to avoid overwriting values when using +/// multiple extensions. +#[repr(C)] +#[derive(Debug, StableAbi)] +pub struct FFI_ExtensionOptions { + /// Return a deep clone of this [`ExtensionOptions`] + pub cloned: unsafe extern "C" fn(&Self) -> FFI_ExtensionOptions, + + /// Set the given `key`, `value` pair + pub set: + unsafe extern "C" fn(&mut Self, key: RStr, value: RStr) -> RResult<(), RString>, + + /// Returns the [`ConfigEntry`] stored in this [`ExtensionOptions`] + pub entries: unsafe extern "C" fn(&Self) -> RVec>, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(&mut Self), + + /// Internal data. This is only to be accessed by the provider of the options. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_ExtensionOptions {} +unsafe impl Sync for FFI_ExtensionOptions {} + +pub struct ExtensionOptionsPrivateData { + pub options: HashMap, +} + +impl FFI_ExtensionOptions { + #[inline] + fn inner_mut(&mut self) -> &mut HashMap { + let private_data = self.private_data as *mut ExtensionOptionsPrivateData; + unsafe { &mut (*private_data).options } + } + + #[inline] + fn inner(&self) -> &HashMap { + let private_data = self.private_data as *const ExtensionOptionsPrivateData; + unsafe { &(*private_data).options } + } +} + +unsafe extern "C" fn cloned_fn_wrapper( + options: &FFI_ExtensionOptions, +) -> FFI_ExtensionOptions { + options + .inner() + .iter() + .map(|(k, v)| (k.to_owned(), v.to_owned())) + .collect::>() + .into() +} + +unsafe extern "C" fn set_fn_wrapper( + options: &mut FFI_ExtensionOptions, + key: RStr, + value: RStr, +) -> RResult<(), RString> { + let _ = options.inner_mut().insert(key.into(), value.into()); + RResult::ROk(()) +} + +unsafe extern "C" fn entries_fn_wrapper( + options: &FFI_ExtensionOptions, +) -> RVec> { + options + .inner() + .iter() + .map(|(key, value)| (key.to_owned().into(), value.to_owned().into()).into()) + .collect() +} + +unsafe extern "C" fn release_fn_wrapper(options: &mut FFI_ExtensionOptions) { + unsafe { + debug_assert!(!options.private_data.is_null()); + let private_data = + Box::from_raw(options.private_data as *mut ExtensionOptionsPrivateData); + drop(private_data); + options.private_data = std::ptr::null_mut(); + } +} + +impl Default for FFI_ExtensionOptions { + fn default() -> Self { + HashMap::new().into() + } +} + +impl From> for FFI_ExtensionOptions { + fn from(options: HashMap) -> Self { + let private_data = ExtensionOptionsPrivateData { options }; + + Self { + cloned: cloned_fn_wrapper, + set: set_fn_wrapper, + entries: entries_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(Box::new(private_data)) as *mut c_void, + } + } +} + +impl Drop for FFI_ExtensionOptions { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +impl Clone for FFI_ExtensionOptions { + fn clone(&self) -> Self { + unsafe { (self.cloned)(self) } + } +} + +impl ConfigExtension for FFI_ExtensionOptions { + const PREFIX: &'static str = + datafusion_common::config::DATAFUSION_FFI_CONFIG_NAMESPACE; +} + +impl ExtensionOptions for FFI_ExtensionOptions { + fn as_any(&self) -> &dyn Any { + self + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + fn cloned(&self) -> Box { + let ffi_options = unsafe { (self.cloned)(self) }; + Box::new(ffi_options) + } + + fn set(&mut self, key: &str, value: &str) -> Result<()> { + if key.split_once('.').is_none() { + return exec_err!("Unable to set FFI config value without namespace set"); + }; + + df_result!(unsafe { (self.set)(self, key.into(), value.into()) }) + } + + fn entries(&self) -> Vec { + unsafe { + (self.entries)(self) + .into_iter() + .map(|entry_tuple| ConfigEntry { + key: entry_tuple.0.into(), + value: Some(entry_tuple.1.into()), + description: "ffi_config_options", + }) + .collect() + } + } +} + +impl FFI_ExtensionOptions { + /// Add all of the values in a concrete configuration extension to the + /// FFI variant. This is safe to call on either side of the FFI + /// boundary. + pub fn add_config(&mut self, config: &C) -> Result<()> { + for entry in config.entries() { + if let Some(value) = entry.value { + let key = format!("{}.{}", C::PREFIX, entry.key); + self.set(key.as_str(), value.as_str())?; + } + } + + Ok(()) + } + + /// Merge another `FFI_ExtensionOptions` configurations into this one. + /// This is safe to call on either side of the FFI boundary. + pub fn merge(&mut self, other: &FFI_ExtensionOptions) -> Result<()> { + for entry in other.entries() { + if let Some(value) = entry.value { + self.set(entry.key.as_str(), value.as_str())?; + } + } + Ok(()) + } + + /// Create a concrete extension type from the FFI variant. + /// This is safe to call on either side of the FFI boundary. + pub fn to_extension(&self) -> Result { + let mut result = C::default(); + + unsafe { + for entry in (self.entries)(self) { + let key = entry.0.as_str(); + let value = entry.1.as_str(); + + if let Some((prefix, inner_key)) = key.split_once('.') + && prefix == C::PREFIX + { + result.set(inner_key, value)?; + } + } + } + + Ok(result) + } +} + +#[cfg(test)] +mod tests { + use datafusion_common::config::{ConfigExtension, ConfigOptions}; + use datafusion_common::extensions_options; + + use crate::config::extension_options::FFI_ExtensionOptions; + + // Define a new configuration struct using the `extensions_options` macro + extensions_options! { + /// My own config options. + pub struct MyConfig { + /// Should "foo" be replaced by "bar"? + pub foo_to_bar: bool, default = true + + /// How many "baz" should be created? + pub baz_count: usize, default = 1337 + } + } + + impl ConfigExtension for MyConfig { + const PREFIX: &'static str = "my_config"; + } + + #[test] + fn round_trip_ffi_extension_options() { + // set up config struct and register extension + let mut config = ConfigOptions::default(); + let mut ffi_options = FFI_ExtensionOptions::default(); + ffi_options.add_config(&MyConfig::default()).unwrap(); + + config.extensions.insert(ffi_options); + + // overwrite config default + config.set("my_config.baz_count", "42").unwrap(); + + // check config state + let returned_ffi_config = + config.extensions.get::().unwrap(); + let my_config: MyConfig = returned_ffi_config.to_extension().unwrap(); + + // check default value + assert!(my_config.foo_to_bar); + + // check overwritten value + assert_eq!(my_config.baz_count, 42); + } +} diff --git a/datafusion/ffi/src/config/mod.rs b/datafusion/ffi/src/config/mod.rs new file mode 100644 index 000000000000..850a4dc33733 --- /dev/null +++ b/datafusion/ffi/src/config/mod.rs @@ -0,0 +1,169 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod extension_options; + +use abi_stable::StableAbi; +use abi_stable::std_types::{RHashMap, RString}; +use datafusion_common::config::{ + ConfigExtension, ConfigOptions, ExtensionOptions, TableOptions, +}; +use datafusion_common::{DataFusionError, Result}; + +use crate::config::extension_options::FFI_ExtensionOptions; + +/// A stable struct for sharing [`ConfigOptions`] across FFI boundaries. +/// +/// Accessing FFI extension options require a slightly different pattern +/// than local extensions. The trait [`ExtensionOptionsFFIProvider`] can +/// be used to simplify accessing FFI extensions. +#[repr(C)] +#[derive(Debug, Clone, StableAbi)] +pub struct FFI_ConfigOptions { + base_options: RHashMap, + + extensions: FFI_ExtensionOptions, +} + +impl From<&ConfigOptions> for FFI_ConfigOptions { + fn from(options: &ConfigOptions) -> Self { + let base_options: RHashMap = options + .entries() + .into_iter() + .filter_map(|entry| entry.value.map(|value| (entry.key, value))) + .map(|(key, value)| (key.into(), value.into())) + .collect(); + + let mut extensions = FFI_ExtensionOptions::default(); + for (extension_name, extension) in options.extensions.iter() { + for entry in extension.entries().iter() { + if let Some(value) = entry.value.as_ref() { + extensions + .set(format!("{extension_name}.{}", entry.key).as_str(), value) + .expect("FFI_ExtensionOptions set should always return Ok"); + } + } + } + + Self { + base_options, + extensions, + } + } +} + +impl TryFrom for ConfigOptions { + type Error = DataFusionError; + fn try_from(ffi_options: FFI_ConfigOptions) -> Result { + let mut options = ConfigOptions::default(); + options.extensions.insert(ffi_options.extensions); + + for kv_tuple in ffi_options.base_options.iter() { + options.set(kv_tuple.0.as_str(), kv_tuple.1.as_str())?; + } + + Ok(options) + } +} + +pub trait ExtensionOptionsFFIProvider { + /// Extract a [`ConfigExtension`]. This method should attempt to first extract + /// the extension from the local options when possible. Should that fail, it + /// should attempt to extract the FFI options and then convert them to the + /// desired [`ConfigExtension`]. + fn local_or_ffi_extension(&self) -> Option; +} + +impl ExtensionOptionsFFIProvider for ConfigOptions { + fn local_or_ffi_extension(&self) -> Option { + self.extensions + .get::() + .map(|v| v.to_owned()) + .or_else(|| { + self.extensions + .get::() + .and_then(|ffi_ext| ffi_ext.to_extension().ok()) + }) + } +} + +impl ExtensionOptionsFFIProvider for TableOptions { + fn local_or_ffi_extension(&self) -> Option { + self.extensions + .get::() + .map(|v| v.to_owned()) + .or_else(|| { + self.extensions + .get::() + .and_then(|ffi_ext| ffi_ext.to_extension().ok()) + }) + } +} + +/// A stable struct for sharing [`TableOptions`] across FFI boundaries. +/// +/// Accessing FFI extension options require a slightly different pattern +/// than local extensions. The trait [`ExtensionOptionsFFIProvider`] can +/// be used to simplify accessing FFI extensions. +#[repr(C)] +#[derive(Debug, Clone, StableAbi)] +pub struct FFI_TableOptions { + base_options: RHashMap, + + extensions: FFI_ExtensionOptions, +} + +impl From<&TableOptions> for FFI_TableOptions { + fn from(options: &TableOptions) -> Self { + let base_options: RHashMap = options + .entries() + .into_iter() + .filter_map(|entry| entry.value.map(|value| (entry.key, value))) + .map(|(key, value)| (key.into(), value.into())) + .collect(); + + let mut extensions = FFI_ExtensionOptions::default(); + for (extension_name, extension) in options.extensions.iter() { + for entry in extension.entries().iter() { + if let Some(value) = entry.value.as_ref() { + extensions + .set(format!("{extension_name}.{}", entry.key).as_str(), value) + .expect("FFI_ExtensionOptions set should always return Ok"); + } + } + } + + Self { + base_options, + extensions, + } + } +} + +impl TryFrom for TableOptions { + type Error = DataFusionError; + fn try_from(ffi_options: FFI_TableOptions) -> Result { + let mut options = TableOptions::default(); + options.extensions.insert(ffi_options.extensions); + + for kv_tuple in ffi_options.base_options.iter() { + options.set(kv_tuple.0.as_str(), kv_tuple.1.as_str())?; + } + + Ok(options) + } +} diff --git a/datafusion/ffi/src/execution_plan.rs b/datafusion/ffi/src/execution_plan.rs index c879b022067c..524d8b4b6b97 100644 --- a/datafusion/ffi/src/execution_plan.rs +++ b/datafusion/ffi/src/execution_plan.rs @@ -90,7 +90,7 @@ impl FFI_ExecutionPlan { unsafe extern "C" fn properties_fn_wrapper( plan: &FFI_ExecutionPlan, ) -> FFI_PlanProperties { - plan.inner().properties().into() + plan.inner().properties().as_ref().into() } unsafe extern "C" fn children_fn_wrapper( @@ -192,7 +192,7 @@ impl Drop for FFI_ExecutionPlan { pub struct ForeignExecutionPlan { name: String, plan: FFI_ExecutionPlan, - properties: PlanProperties, + properties: Arc, children: Vec>, } @@ -244,7 +244,7 @@ impl TryFrom<&FFI_ExecutionPlan> for Arc { let plan = ForeignExecutionPlan { name, plan: plan.clone(), - properties, + properties: Arc::new(properties), children, }; @@ -262,7 +262,7 @@ impl ExecutionPlan for ForeignExecutionPlan { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.properties } @@ -278,7 +278,7 @@ impl ExecutionPlan for ForeignExecutionPlan { plan: self.plan.clone(), name: self.name.clone(), children, - properties: self.properties.clone(), + properties: Arc::clone(&self.properties), })) } @@ -305,19 +305,19 @@ pub(crate) mod tests { #[derive(Debug)] pub struct EmptyExec { - props: PlanProperties, + props: Arc, children: Vec>, } impl EmptyExec { pub fn new(schema: arrow::datatypes::SchemaRef) -> Self { Self { - props: PlanProperties::new( + props: Arc::new(PlanProperties::new( datafusion::physical_expr::EquivalenceProperties::new(schema), Partitioning::UnknownPartitioning(3), EmissionType::Incremental, Boundedness::Bounded, - ), + )), children: Vec::default(), } } @@ -342,7 +342,7 @@ pub(crate) mod tests { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.props } @@ -355,7 +355,7 @@ pub(crate) mod tests { children: Vec>, ) -> Result> { Ok(Arc::new(EmptyExec { - props: self.props.clone(), + props: Arc::clone(&self.props), children, })) } @@ -367,10 +367,6 @@ pub(crate) mod tests { ) -> Result { unimplemented!() } - - fn statistics(&self) -> Result { - unimplemented!() - } } #[test] diff --git a/datafusion/ffi/src/lib.rs b/datafusion/ffi/src/lib.rs index bf0cf9b122c1..d7410e848373 100644 --- a/datafusion/ffi/src/lib.rs +++ b/datafusion/ffi/src/lib.rs @@ -24,11 +24,11 @@ // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -#![deny(clippy::allow_attributes)] pub mod arrow_wrappers; pub mod catalog_provider; pub mod catalog_provider_list; +pub mod config; pub mod execution; pub mod execution_plan; pub mod expr; @@ -40,6 +40,7 @@ pub mod record_batch_stream; pub mod schema_provider; pub mod session; pub mod table_provider; +pub mod table_provider_factory; pub mod table_source; pub mod udaf; pub mod udf; diff --git a/datafusion/ffi/src/schema_provider.rs b/datafusion/ffi/src/schema_provider.rs index b8e44b134f87..5d1348e2328f 100644 --- a/datafusion/ffi/src/schema_provider.rs +++ b/datafusion/ffi/src/schema_provider.rs @@ -259,6 +259,11 @@ impl FFI_SchemaProvider { runtime: Option, logical_codec: FFI_LogicalExtensionCodec, ) -> Self { + if let Some(provider) = provider.as_any().downcast_ref::() + { + return provider.0.clone(); + } + let owner_name = provider.owner_name().map(|s| s.into()).into(); let private_data = Box::new(ProviderPrivateData { provider, runtime }); diff --git a/datafusion/ffi/src/session/config.rs b/datafusion/ffi/src/session/config.rs index eb9c4e2c6986..63f0f20ecc7d 100644 --- a/datafusion/ffi/src/session/config.rs +++ b/datafusion/ffi/src/session/config.rs @@ -15,11 +15,11 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashMap; use std::ffi::c_void; +use crate::config::FFI_ConfigOptions; use abi_stable::StableAbi; -use abi_stable::std_types::{RHashMap, RString}; +use datafusion_common::config::ConfigOptions; use datafusion_common::error::{DataFusionError, Result}; use datafusion_execution::config::SessionConfig; @@ -37,9 +37,8 @@ use datafusion_execution::config::SessionConfig; #[repr(C)] #[derive(Debug, StableAbi)] pub struct FFI_SessionConfig { - /// Return a hash map from key to value of the config options represented - /// by string values. - pub config_options: unsafe extern "C" fn(config: &Self) -> RHashMap, + /// FFI stable configuration options. + pub config_options: FFI_ConfigOptions, /// Used to create a clone on the provider of the execution plan. This should /// only need to be called by the receiver of the plan. @@ -67,21 +66,6 @@ impl FFI_SessionConfig { } } -unsafe extern "C" fn config_options_fn_wrapper( - config: &FFI_SessionConfig, -) -> RHashMap { - let config_options = config.inner().options(); - - let mut options = RHashMap::default(); - for config_entry in config_options.entries() { - if let Some(value) = config_entry.value { - options.insert(config_entry.key.into(), value.into()); - } - } - - options -} - unsafe extern "C" fn release_fn_wrapper(config: &mut FFI_SessionConfig) { unsafe { debug_assert!(!config.private_data.is_null()); @@ -100,7 +84,7 @@ unsafe extern "C" fn clone_fn_wrapper(config: &FFI_SessionConfig) -> FFI_Session let private_data = Box::new(SessionConfigPrivateData { config: old_config }); FFI_SessionConfig { - config_options: config_options_fn_wrapper, + config_options: config.config_options.clone(), private_data: Box::into_raw(private_data) as *mut c_void, clone: clone_fn_wrapper, release: release_fn_wrapper, @@ -119,8 +103,10 @@ impl From<&SessionConfig> for FFI_SessionConfig { config: session.clone(), }); + let config_options = FFI_ConfigOptions::from(session.options().as_ref()); + Self { - config_options: config_options_fn_wrapper, + config_options, private_data: Box::into_raw(private_data) as *mut c_void, clone: clone_fn_wrapper, release: release_fn_wrapper, @@ -149,14 +135,9 @@ impl TryFrom<&FFI_SessionConfig> for SessionConfig { return Ok(config.inner().clone()); } - let config_options = unsafe { (config.config_options)(config) }; - - let mut options_map = HashMap::new(); - config_options.iter().for_each(|kv_pair| { - options_map.insert(kv_pair.0.to_string(), kv_pair.1.to_string()); - }); + let config_options = ConfigOptions::try_from(config.config_options.clone())?; - SessionConfig::from_string_hash_map(&options_map) + Ok(SessionConfig::from(config_options)) } } diff --git a/datafusion/ffi/src/session/mod.rs b/datafusion/ffi/src/session/mod.rs index aa910abb9149..6b8664a43749 100644 --- a/datafusion/ffi/src/session/mod.rs +++ b/datafusion/ffi/src/session/mod.rs @@ -26,7 +26,7 @@ use arrow_schema::SchemaRef; use arrow_schema::ffi::FFI_ArrowSchema; use async_ffi::{FfiFuture, FutureExt}; use async_trait::async_trait; -use datafusion_common::config::{ConfigOptions, TableOptions}; +use datafusion_common::config::{ConfigFileType, ConfigOptions, TableOptions}; use datafusion_common::{DFSchema, DataFusionError}; use datafusion_execution::TaskContext; use datafusion_execution::config::SessionConfig; @@ -240,12 +240,30 @@ unsafe extern "C" fn window_functions_fn_wrapper( .collect() } -fn table_options_to_rhash(options: &TableOptions) -> RHashMap { - options +fn table_options_to_rhash(mut options: TableOptions) -> RHashMap { + // It is important that we mutate options here and set current format + // to None so that when we call `entries()` we get ALL format entries. + // We will pass current_format as a special case and strip it on the + // other side of the boundary. + let current_format = options.current_format.take(); + let mut options: HashMap = options .entries() .into_iter() .filter_map(|entry| entry.value.map(|v| (entry.key.into(), v.into()))) - .collect() + .collect(); + if let Some(current_format) = current_format { + options.insert( + "datafusion_ffi.table_current_format".into(), + match current_format { + ConfigFileType::JSON => "json", + ConfigFileType::PARQUET => "parquet", + ConfigFileType::CSV => "csv", + } + .into(), + ); + } + + options.into() } unsafe extern "C" fn table_options_fn_wrapper( @@ -253,7 +271,7 @@ unsafe extern "C" fn table_options_fn_wrapper( ) -> RHashMap { let session = session.inner(); let table_options = session.table_options(); - table_options_to_rhash(table_options) + table_options_to_rhash(table_options.clone()) } unsafe extern "C" fn default_table_options_fn_wrapper( @@ -262,7 +280,7 @@ unsafe extern "C" fn default_table_options_fn_wrapper( let session = session.inner(); let table_options = session.default_table_options(); - table_options_to_rhash(&table_options) + table_options_to_rhash(table_options) } unsafe extern "C" fn task_ctx_fn_wrapper(session: &FFI_SessionRef) -> FFI_TaskContext { @@ -438,15 +456,70 @@ impl Clone for FFI_SessionRef { } fn table_options_from_rhashmap(options: RHashMap) -> TableOptions { - let options = options + let mut options: HashMap = options .into_iter() .map(|kv_pair| (kv_pair.0.into_string(), kv_pair.1.into_string())) .collect(); + let current_format = options.remove("datafusion_ffi.table_current_format"); + + let mut table_options = TableOptions::default(); + let formats = [ + ConfigFileType::CSV, + ConfigFileType::JSON, + ConfigFileType::PARQUET, + ]; + for format in formats { + // It is imperative that if new enum variants are added below that they be + // included in the formats list above and in the extension check below. + let format_name = match &format { + ConfigFileType::CSV => "csv", + ConfigFileType::PARQUET => "parquet", + ConfigFileType::JSON => "json", + }; + let format_options: HashMap = options + .iter() + .filter_map(|(k, v)| { + let (prefix, key) = k.split_once(".")?; + if prefix == format_name { + Some((format!("format.{key}"), v.to_owned())) + } else { + None + } + }) + .collect(); + if !format_options.is_empty() { + table_options.current_format = Some(format.clone()); + table_options + .alter_with_string_hash_map(&format_options) + .unwrap_or_else(|err| log::warn!("Error parsing table options: {err}")); + } + } + + let extension_options: HashMap = options + .iter() + .filter_map(|(k, v)| { + let (prefix, _) = k.split_once(".")?; + if !["json", "parquet", "csv"].contains(&prefix) { + Some((k.to_owned(), v.to_owned())) + } else { + None + } + }) + .collect(); + if !extension_options.is_empty() { + table_options + .alter_with_string_hash_map(&extension_options) + .unwrap_or_else(|err| log::warn!("Error parsing table options: {err}")); + } - TableOptions::from_string_hash_map(&options).unwrap_or_else(|err| { - log::warn!("Error parsing default table options: {err}"); - TableOptions::default() - }) + table_options.current_format = + current_format.and_then(|format| match format.as_str() { + "csv" => Some(ConfigFileType::CSV), + "parquet" => Some(ConfigFileType::PARQUET), + "json" => Some(ConfigFileType::JSON), + _ => None, + }); + table_options } #[async_trait] @@ -556,6 +629,7 @@ mod tests { use std::sync::Arc; use arrow_schema::{DataType, Field, Schema}; + use datafusion::execution::SessionStateBuilder; use datafusion_common::DataFusionError; use datafusion_expr::col; use datafusion_expr::registry::FunctionRegistry; @@ -566,7 +640,16 @@ mod tests { #[tokio::test] async fn test_ffi_session() -> Result<(), DataFusionError> { let (ctx, task_ctx_provider) = crate::util::tests::test_session_and_ctx(); - let state = ctx.state(); + let mut table_options = TableOptions::default(); + table_options.csv.has_header = Some(true); + table_options.json.schema_infer_max_rec = Some(10); + table_options.parquet.global.coerce_int96 = Some("123456789".into()); + table_options.current_format = Some(ConfigFileType::JSON); + + let state = SessionStateBuilder::new_from_existing(ctx.state()) + .with_table_options(table_options) + .build(); + let logical_codec = FFI_LogicalExtensionCodec::new( Arc::new(DefaultLogicalExtensionCodec {}), None, diff --git a/datafusion/ffi/src/table_provider.rs b/datafusion/ffi/src/table_provider.rs index df8b648026d3..4a89bb025a56 100644 --- a/datafusion/ffi/src/table_provider.rs +++ b/datafusion/ffi/src/table_provider.rs @@ -108,7 +108,7 @@ pub struct FFI_TableProvider { scan: unsafe extern "C" fn( provider: &Self, session: FFI_SessionRef, - projections: RVec, + projections: ROption>, filters_serialized: RVec, limit: ROption, ) -> FfiFuture>, @@ -232,7 +232,7 @@ unsafe extern "C" fn supports_filters_pushdown_fn_wrapper( unsafe extern "C" fn scan_fn_wrapper( provider: &FFI_TableProvider, session: FFI_SessionRef, - projections: RVec, + projections: ROption>, filters_serialized: RVec, limit: ROption, ) -> FfiFuture> { @@ -269,11 +269,12 @@ unsafe extern "C" fn scan_fn_wrapper( } }; - let projections: Vec<_> = projections.into_iter().collect(); + let projections: Option> = + projections.into_option().map(|p| p.into_iter().collect()); let plan = rresult_return!( internal_provider - .scan(session, Some(&projections), &filters, limit.into()) + .scan(session, projections.as_ref(), &filters, limit.into()) .await ); @@ -390,6 +391,9 @@ impl FFI_TableProvider { runtime: Option, logical_codec: FFI_LogicalExtensionCodec, ) -> Self { + if let Some(provider) = provider.as_any().downcast_ref::() { + return provider.0.clone(); + } let private_data = Box::new(ProviderPrivateData { provider, runtime }); Self { @@ -461,8 +465,9 @@ impl TableProvider for ForeignTableProvider { ) -> Result> { let session = FFI_SessionRef::new(session, None, self.0.logical_codec.clone()); - let projections: Option> = - projection.map(|p| p.iter().map(|v| v.to_owned()).collect()); + let projections: ROption> = projection + .map(|p| p.iter().map(|v| v.to_owned()).collect()) + .into(); let codec: Arc = (&self.0.logical_codec).into(); let filter_list = LogicalExprList { @@ -474,7 +479,7 @@ impl TableProvider for ForeignTableProvider { let maybe_plan = (self.0.scan)( &self.0, session, - projections.unwrap_or_default(), + projections, filters_serialized, limit.into(), ) @@ -658,8 +663,9 @@ mod tests { let provider = Arc::new(MemTable::try_new(schema, vec![vec![batch1]])?); - let ffi_provider = + let mut ffi_provider = FFI_TableProvider::new(provider, true, None, task_ctx_provider, None); + ffi_provider.library_marker_id = crate::mock_foreign_marker_id; let foreign_table_provider: Arc = (&ffi_provider).into(); @@ -712,4 +718,62 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_scan_with_none_projection_returns_all_columns() -> Result<()> { + use arrow::datatypes::Field; + use datafusion::arrow::array::Float32Array; + use datafusion::arrow::datatypes::DataType; + use datafusion::arrow::record_batch::RecordBatch; + use datafusion::datasource::MemTable; + use datafusion::physical_plan::collect; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float32, false), + Field::new("b", DataType::Float32, false), + Field::new("c", DataType::Float32, false), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Float32Array::from(vec![1.0, 2.0])), + Arc::new(Float32Array::from(vec![3.0, 4.0])), + Arc::new(Float32Array::from(vec![5.0, 6.0])), + ], + )?; + + let provider = + Arc::new(MemTable::try_new(Arc::clone(&schema), vec![vec![batch]])?); + + let ctx = Arc::new(SessionContext::new()); + let task_ctx_provider = Arc::clone(&ctx) as Arc; + let task_ctx_provider = FFI_TaskContextProvider::from(&task_ctx_provider); + + // Wrap in FFI and force the foreign path (not local bypass) + let mut ffi_provider = + FFI_TableProvider::new(provider, true, None, task_ctx_provider, None); + ffi_provider.library_marker_id = crate::mock_foreign_marker_id; + + let foreign_table_provider: Arc = (&ffi_provider).into(); + + // Call scan with projection=None, meaning "return all columns" + let plan = foreign_table_provider + .scan(&ctx.state(), None, &[], None) + .await?; + assert_eq!( + plan.schema().fields().len(), + 3, + "scan(projection=None) should return all columns; got {}", + plan.schema().fields().len() + ); + + // Also verify we can execute and get correct data + let batches = collect(plan, ctx.task_ctx()).await?; + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].num_columns(), 3); + assert_eq!(batches[0].num_rows(), 2); + + Ok(()) + } } diff --git a/datafusion/ffi/src/table_provider_factory.rs b/datafusion/ffi/src/table_provider_factory.rs new file mode 100644 index 000000000000..15789eeab042 --- /dev/null +++ b/datafusion/ffi/src/table_provider_factory.rs @@ -0,0 +1,429 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, sync::Arc}; + +use abi_stable::{ + StableAbi, + std_types::{RResult, RString, RVec}, +}; +use async_ffi::{FfiFuture, FutureExt}; +use async_trait::async_trait; +use datafusion_catalog::{Session, TableProvider, TableProviderFactory}; +use datafusion_common::error::{DataFusionError, Result}; +use datafusion_execution::TaskContext; +use datafusion_expr::{CreateExternalTable, DdlStatement, LogicalPlan}; +use datafusion_proto::logical_plan::{ + AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec, +}; +use datafusion_proto::protobuf::LogicalPlanNode; +use prost::Message; +use tokio::runtime::Handle; + +use crate::execution::FFI_TaskContextProvider; +use crate::proto::logical_extension_codec::FFI_LogicalExtensionCodec; +use crate::session::{FFI_SessionRef, ForeignSession}; +use crate::table_provider::{FFI_TableProvider, ForeignTableProvider}; +use crate::{df_result, rresult_return}; + +/// A stable struct for sharing [`TableProviderFactory`] across FFI boundaries. +/// +/// Similar to [`FFI_TableProvider`], this struct uses the FFI-safe pattern where: +/// - The `FFI_*` struct exposes stable function pointers +/// - Private data is stored as an opaque pointer +/// - The `Foreign*` wrapper is used by consumers on the other side of the FFI boundary +/// +/// [`FFI_TableProvider`]: crate::table_provider::FFI_TableProvider +#[repr(C)] +#[derive(Debug, StableAbi)] +pub struct FFI_TableProviderFactory { + /// Create a TableProvider with the given command. + /// + /// # Arguments + /// + /// * `factory` - the table provider factory + /// * `session_config` - session configuration + /// * `cmd_serialized` - a ['CreateExternalTable`] encoded as a [`LogicalPlanNode`] protobuf message serialized into bytes + /// to pass across the FFI boundary. + create: unsafe extern "C" fn( + factory: &Self, + session: FFI_SessionRef, + cmd_serialized: RVec, + ) -> FfiFuture>, + + logical_codec: FFI_LogicalExtensionCodec, + + /// Used to create a clone of the factory. This should only need to be called + /// by the receiver of the factory. + clone: unsafe extern "C" fn(factory: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + release: unsafe extern "C" fn(factory: &mut Self), + + /// Return the major DataFusion version number of this factory. + version: unsafe extern "C" fn() -> u64, + + /// Internal data. This is only to be accessed by the provider of the factory. + /// A [`ForeignTableProviderFactory`] should never attempt to access this data. + private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. See [`crate::get_library_marker_id`] and + /// the crate's `README.md` for more information. + library_marker_id: extern "C" fn() -> usize, +} + +unsafe impl Send for FFI_TableProviderFactory {} +unsafe impl Sync for FFI_TableProviderFactory {} + +struct FactoryPrivateData { + factory: Arc, + runtime: Option, +} + +impl FFI_TableProviderFactory { + /// Creates a new [`FFI_TableProvider`]. + pub fn new( + factory: Arc, + runtime: Option, + task_ctx_provider: impl Into, + logical_codec: Option>, + ) -> Self { + let task_ctx_provider = task_ctx_provider.into(); + let logical_codec = + logical_codec.unwrap_or_else(|| Arc::new(DefaultLogicalExtensionCodec {})); + let logical_codec = FFI_LogicalExtensionCodec::new( + logical_codec, + runtime.clone(), + task_ctx_provider.clone(), + ); + Self::new_with_ffi_codec(factory, runtime, logical_codec) + } + + pub fn new_with_ffi_codec( + factory: Arc, + runtime: Option, + logical_codec: FFI_LogicalExtensionCodec, + ) -> Self { + let private_data = Box::new(FactoryPrivateData { factory, runtime }); + + Self { + create: create_fn_wrapper, + logical_codec, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + version: super::version, + private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, + } + } + + fn inner(&self) -> &Arc { + let private_data = self.private_data as *const FactoryPrivateData; + unsafe { &(*private_data).factory } + } + + fn runtime(&self) -> &Option { + let private_data = self.private_data as *const FactoryPrivateData; + unsafe { &(*private_data).runtime } + } + + fn deserialize_cmd( + &self, + cmd_serialized: &RVec, + ) -> Result { + let task_ctx: Arc = + (&self.logical_codec.task_ctx_provider).try_into()?; + let logical_codec: Arc = (&self.logical_codec).into(); + + let plan = LogicalPlanNode::decode(cmd_serialized.as_ref()) + .map_err(|e| DataFusionError::Internal(format!("{e:?}")))?; + match plan.try_into_logical_plan(&task_ctx, logical_codec.as_ref())? { + LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) => Ok(cmd), + _ => Err(DataFusionError::Internal( + "Invalid logical plan in FFI_TableProviderFactory.".to_owned(), + )), + } + } +} + +impl Clone for FFI_TableProviderFactory { + fn clone(&self) -> Self { + unsafe { (self.clone)(self) } + } +} + +impl Drop for FFI_TableProviderFactory { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +impl From<&FFI_TableProviderFactory> for Arc { + fn from(factory: &FFI_TableProviderFactory) -> Self { + if (factory.library_marker_id)() == crate::get_library_marker_id() { + Arc::clone(factory.inner()) as Arc + } else { + Arc::new(ForeignTableProviderFactory(factory.clone())) + } + } +} + +unsafe extern "C" fn create_fn_wrapper( + factory: &FFI_TableProviderFactory, + session: FFI_SessionRef, + cmd_serialized: RVec, +) -> FfiFuture> { + let factory = factory.clone(); + + async move { + let provider = rresult_return!( + create_fn_wrapper_impl(factory, session, cmd_serialized).await + ); + RResult::ROk(provider) + } + .into_ffi() +} + +async fn create_fn_wrapper_impl( + factory: FFI_TableProviderFactory, + session: FFI_SessionRef, + cmd_serialized: RVec, +) -> Result { + let runtime = factory.runtime().clone(); + let ffi_logical_codec = factory.logical_codec.clone(); + let internal_factory = Arc::clone(factory.inner()); + let cmd = factory.deserialize_cmd(&cmd_serialized)?; + + let mut foreign_session = None; + let session = session + .as_local() + .map(Ok::<&(dyn Session + Send + Sync), DataFusionError>) + .unwrap_or_else(|| { + foreign_session = Some(ForeignSession::try_from(&session)?); + Ok(foreign_session.as_ref().unwrap()) + })?; + + let provider = internal_factory.create(session, &cmd).await?; + Ok(FFI_TableProvider::new_with_ffi_codec( + provider, + true, + runtime.clone(), + ffi_logical_codec, + )) +} + +unsafe extern "C" fn clone_fn_wrapper( + factory: &FFI_TableProviderFactory, +) -> FFI_TableProviderFactory { + let runtime = factory.runtime().clone(); + let old_factory = Arc::clone(factory.inner()); + + let private_data = Box::into_raw(Box::new(FactoryPrivateData { + factory: old_factory, + runtime, + })) as *mut c_void; + + FFI_TableProviderFactory { + create: create_fn_wrapper, + logical_codec: factory.logical_codec.clone(), + clone: clone_fn_wrapper, + release: release_fn_wrapper, + version: super::version, + private_data, + library_marker_id: crate::get_library_marker_id, + } +} + +unsafe extern "C" fn release_fn_wrapper(factory: &mut FFI_TableProviderFactory) { + unsafe { + debug_assert!(!factory.private_data.is_null()); + let private_data = Box::from_raw(factory.private_data as *mut FactoryPrivateData); + drop(private_data); + factory.private_data = std::ptr::null_mut(); + } +} + +/// This wrapper struct exists on the receiver side of the FFI interface, so it has +/// no guarantees about being able to access the data in `private_data`. Any functions +/// defined on this struct must only use the stable functions provided in +/// FFI_TableProviderFactory to interact with the foreign table provider factory. +#[derive(Debug)] +pub struct ForeignTableProviderFactory(pub FFI_TableProviderFactory); + +impl ForeignTableProviderFactory { + fn serialize_cmd( + &self, + cmd: CreateExternalTable, + ) -> Result, DataFusionError> { + let logical_codec: Arc = + (&self.0.logical_codec).into(); + + let plan = LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)); + let plan: LogicalPlanNode = + AsLogicalPlan::try_from_logical_plan(&plan, logical_codec.as_ref())?; + + let mut buf: Vec = Vec::new(); + plan.try_encode(&mut buf)?; + + Ok(buf.into()) + } +} + +unsafe impl Send for ForeignTableProviderFactory {} +unsafe impl Sync for ForeignTableProviderFactory {} + +#[async_trait] +impl TableProviderFactory for ForeignTableProviderFactory { + async fn create( + &self, + session: &dyn Session, + cmd: &CreateExternalTable, + ) -> Result> { + let session = FFI_SessionRef::new(session, None, self.0.logical_codec.clone()); + let cmd = self.serialize_cmd(cmd.clone())?; + + let provider = unsafe { + let maybe_provider = (self.0.create)(&self.0, session, cmd).await; + + let ffi_provider = df_result!(maybe_provider)?; + ForeignTableProvider(ffi_provider) + }; + + Ok(Arc::new(provider)) + } +} + +#[cfg(test)] +mod tests { + use arrow::datatypes::Schema; + use datafusion::prelude::SessionContext; + use datafusion_common::{TableReference, ToDFSchema}; + use datafusion_execution::TaskContextProvider; + use std::collections::HashMap; + + use super::*; + + #[derive(Debug)] + struct TestTableProviderFactory {} + + #[async_trait] + impl TableProviderFactory for TestTableProviderFactory { + async fn create( + &self, + _session: &dyn Session, + _cmd: &CreateExternalTable, + ) -> Result> { + use arrow::datatypes::Field; + use datafusion::arrow::array::Float32Array; + use datafusion::arrow::datatypes::DataType; + use datafusion::arrow::record_batch::RecordBatch; + use datafusion::datasource::MemTable; + + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); + + let batch1 = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0]))], + )?; + let batch2 = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Float32Array::from(vec![64.0]))], + )?; + + Ok(Arc::new(MemTable::try_new( + schema, + vec![vec![batch1], vec![batch2]], + )?)) + } + } + + #[tokio::test] + async fn test_round_trip_ffi_table_provider_factory() -> Result<()> { + let ctx = Arc::new(SessionContext::new()); + let task_ctx_provider = Arc::clone(&ctx) as Arc; + let task_ctx_provider = FFI_TaskContextProvider::from(&task_ctx_provider); + + let factory = Arc::new(TestTableProviderFactory {}); + let mut ffi_factory = + FFI_TableProviderFactory::new(factory, None, task_ctx_provider, None); + ffi_factory.library_marker_id = crate::mock_foreign_marker_id; + + let factory: Arc = (&ffi_factory).into(); + + let cmd = CreateExternalTable { + schema: Schema::empty().to_dfschema_ref()?, + name: TableReference::bare("test_table"), + location: "test".to_string(), + file_type: "test".to_string(), + table_partition_cols: vec![], + if_not_exists: false, + or_replace: false, + temporary: false, + definition: None, + order_exprs: vec![], + unbounded: false, + options: HashMap::new(), + constraints: Default::default(), + column_defaults: HashMap::new(), + }; + + let provider = factory.create(&ctx.state(), &cmd).await?; + + assert_eq!(provider.schema().fields().len(), 1); + + Ok(()) + } + + #[tokio::test] + async fn test_ffi_table_provider_factory_clone() -> Result<()> { + let ctx = Arc::new(SessionContext::new()); + let task_ctx_provider = Arc::clone(&ctx) as Arc; + let task_ctx_provider = FFI_TaskContextProvider::from(&task_ctx_provider); + + let factory = Arc::new(TestTableProviderFactory {}); + let ffi_factory = + FFI_TableProviderFactory::new(factory, None, task_ctx_provider, None); + + // Test that we can clone the factory + let cloned_factory = ffi_factory.clone(); + let factory: Arc = (&cloned_factory).into(); + + let cmd = CreateExternalTable { + schema: Schema::empty().to_dfschema_ref()?, + name: TableReference::bare("cloned_test"), + location: "test".to_string(), + file_type: "test".to_string(), + table_partition_cols: vec![], + if_not_exists: false, + or_replace: false, + temporary: false, + definition: None, + order_exprs: vec![], + unbounded: false, + options: HashMap::new(), + constraints: Default::default(), + column_defaults: HashMap::new(), + }; + + let provider = factory.create(&ctx.state(), &cmd).await?; + assert_eq!(provider.schema().fields().len(), 1); + + Ok(()) + } +} diff --git a/datafusion/ffi/src/tests/async_provider.rs b/datafusion/ffi/src/tests/async_provider.rs index 6149736c5855..8370cf19e658 100644 --- a/datafusion/ffi/src/tests/async_provider.rs +++ b/datafusion/ffi/src/tests/async_provider.rs @@ -162,7 +162,7 @@ impl Drop for AsyncTableProvider { #[derive(Debug)] struct AsyncTestExecutionPlan { - properties: datafusion_physical_plan::PlanProperties, + properties: Arc, batch_request: mpsc::Sender, batch_receiver: broadcast::Receiver>, } @@ -173,12 +173,12 @@ impl AsyncTestExecutionPlan { batch_receiver: broadcast::Receiver>, ) -> Self { Self { - properties: datafusion_physical_plan::PlanProperties::new( + properties: Arc::new(datafusion_physical_plan::PlanProperties::new( EquivalenceProperties::new(super::create_test_schema()), Partitioning::UnknownPartitioning(3), datafusion_physical_plan::execution_plan::EmissionType::Incremental, datafusion_physical_plan::execution_plan::Boundedness::Bounded, - ), + )), batch_request, batch_receiver, } @@ -194,7 +194,7 @@ impl ExecutionPlan for AsyncTestExecutionPlan { self } - fn properties(&self) -> &datafusion_physical_plan::PlanProperties { + fn properties(&self) -> &Arc { &self.properties } diff --git a/datafusion/ffi/src/tests/config.rs b/datafusion/ffi/src/tests/config.rs new file mode 100644 index 000000000000..46fc9756203e --- /dev/null +++ b/datafusion/ffi/src/tests/config.rs @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::config::ConfigExtension; +use datafusion_common::extensions_options; + +use crate::config::extension_options::FFI_ExtensionOptions; + +extensions_options! { + pub struct ExternalConfig { + /// Should "foo" be replaced by "bar"? + pub is_enabled: bool, default = true + + /// Some value to be extracted + pub base_number: usize, default = 1000 + } +} + +impl PartialEq for ExternalConfig { + fn eq(&self, other: &Self) -> bool { + self.base_number == other.base_number && self.is_enabled == other.is_enabled + } +} +impl Eq for ExternalConfig {} + +impl ConfigExtension for ExternalConfig { + const PREFIX: &'static str = "external_config"; +} + +pub(crate) extern "C" fn create_extension_options() -> FFI_ExtensionOptions { + let mut extensions = FFI_ExtensionOptions::default(); + extensions + .add_config(&ExternalConfig::default()) + .expect("add_config should be infallible for ExternalConfig"); + + extensions +} diff --git a/datafusion/ffi/src/tests/mod.rs b/datafusion/ffi/src/tests/mod.rs index 9bcd7e003108..cbee5febdb35 100644 --- a/datafusion/ffi/src/tests/mod.rs +++ b/datafusion/ffi/src/tests/mod.rs @@ -34,19 +34,23 @@ use udf_udaf_udwf::{ create_ffi_stddev_func, create_ffi_sum_func, create_ffi_table_func, }; -use super::table_provider::FFI_TableProvider; -use super::udf::FFI_ScalarUDF; use crate::catalog_provider::FFI_CatalogProvider; use crate::catalog_provider_list::FFI_CatalogProviderList; +use crate::config::extension_options::FFI_ExtensionOptions; use crate::proto::logical_extension_codec::FFI_LogicalExtensionCodec; +use crate::table_provider::FFI_TableProvider; +use crate::table_provider_factory::FFI_TableProviderFactory; use crate::tests::catalog::create_catalog_provider_list; use crate::udaf::FFI_AggregateUDF; +use crate::udf::FFI_ScalarUDF; use crate::udtf::FFI_TableFunction; use crate::udwf::FFI_WindowUDF; mod async_provider; pub mod catalog; +pub mod config; mod sync_provider; +mod table_provider_factory; mod udf_udaf_udwf; pub mod utils; @@ -71,6 +75,10 @@ pub struct ForeignLibraryModule { codec: FFI_LogicalExtensionCodec, ) -> FFI_TableProvider, + /// Constructs the table provider factory + pub create_table_factory: + extern "C" fn(codec: FFI_LogicalExtensionCodec) -> FFI_TableProviderFactory, + /// Create a scalar UDF pub create_scalar_udf: extern "C" fn() -> FFI_ScalarUDF, @@ -87,6 +95,9 @@ pub struct ForeignLibraryModule { pub create_rank_udwf: extern "C" fn() -> FFI_WindowUDF, + /// Create extension options, for either ConfigOptions or TableOptions + pub create_extension_options: extern "C" fn() -> FFI_ExtensionOptions, + pub version: extern "C" fn() -> u64, } @@ -128,6 +139,14 @@ extern "C" fn construct_table_provider( } } +/// Here we only wish to create a simple table provider as an example. +/// We create an in-memory table and convert it to it's FFI counterpart. +extern "C" fn construct_table_provider_factory( + codec: FFI_LogicalExtensionCodec, +) -> FFI_TableProviderFactory { + table_provider_factory::create(codec) +} + #[export_root_module] /// This defines the entry point for using the module. pub fn get_foreign_library_module() -> ForeignLibraryModuleRef { @@ -135,12 +154,14 @@ pub fn get_foreign_library_module() -> ForeignLibraryModuleRef { create_catalog: create_catalog_provider, create_catalog_list: create_catalog_provider_list, create_table: construct_table_provider, + create_table_factory: construct_table_provider_factory, create_scalar_udf: create_ffi_abs_func, create_nullary_udf: create_ffi_random_func, create_table_function: create_ffi_table_func, create_sum_udaf: create_ffi_sum_func, create_stddev_udaf: create_ffi_stddev_func, create_rank_udwf: create_ffi_rank_func, + create_extension_options: config::create_extension_options, version: super::version, } .leak_into_prefix() diff --git a/datafusion/ffi/src/tests/table_provider_factory.rs b/datafusion/ffi/src/tests/table_provider_factory.rs new file mode 100644 index 000000000000..29af6aacf648 --- /dev/null +++ b/datafusion/ffi/src/tests/table_provider_factory.rs @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use async_trait::async_trait; +use datafusion_catalog::{MemTable, Session, TableProvider, TableProviderFactory}; +use datafusion_common::Result; +use datafusion_expr::CreateExternalTable; + +use super::{create_record_batch, create_test_schema}; +use crate::proto::logical_extension_codec::FFI_LogicalExtensionCodec; +use crate::table_provider_factory::FFI_TableProviderFactory; + +#[derive(Debug)] +pub struct TestTableProviderFactory {} + +#[async_trait] +impl TableProviderFactory for TestTableProviderFactory { + async fn create( + &self, + _session: &dyn Session, + _cmd: &CreateExternalTable, + ) -> Result> { + let schema = create_test_schema(); + + // It is useful to create these as multiple record batches + // so that we can demonstrate the FFI stream. + let batches = vec![ + create_record_batch(1, 5), + create_record_batch(6, 1), + create_record_batch(7, 5), + ]; + + let table_provider = MemTable::try_new(schema, vec![batches]).unwrap(); + + Ok(Arc::new(table_provider)) + } +} + +pub(crate) fn create(codec: FFI_LogicalExtensionCodec) -> FFI_TableProviderFactory { + let factory = TestTableProviderFactory {}; + FFI_TableProviderFactory::new_with_ffi_codec(Arc::new(factory), None, codec) +} diff --git a/datafusion/ffi/tests/ffi_config.rs b/datafusion/ffi/tests/ffi_config.rs new file mode 100644 index 000000000000..ca0a3e31e8de --- /dev/null +++ b/datafusion/ffi/tests/ffi_config.rs @@ -0,0 +1,124 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Add an additional module here for convenience to scope this to only +/// when the feature integration-tests is built +#[cfg(feature = "integration-tests")] +mod tests { + use datafusion::error::{DataFusionError, Result}; + use datafusion_common::ScalarValue; + use datafusion_common::config::{ConfigOptions, TableOptions}; + use datafusion_execution::config::SessionConfig; + use datafusion_ffi::config::ExtensionOptionsFFIProvider; + use datafusion_ffi::tests::config::ExternalConfig; + use datafusion_ffi::tests::utils::get_module; + + #[test] + fn test_ffi_config_options_extension() -> Result<()> { + let module = get_module()?; + + let extension_options = + module + .create_extension_options() + .ok_or(DataFusionError::NotImplemented( + "External test library failed to implement create_extension_options" + .to_string(), + ))?(); + + let mut config = ConfigOptions::new(); + config.extensions.insert(extension_options); + + // Verify default values are as expected + let returned_config: ExternalConfig = config + .local_or_ffi_extension() + .expect("should have external config extension"); + assert_eq!(returned_config, ExternalConfig::default()); + + config.set("external_config.is_enabled", "false")?; + let returned_config: ExternalConfig = config + .local_or_ffi_extension() + .expect("should have external config extension"); + assert!(!returned_config.is_enabled); + + Ok(()) + } + + #[test] + fn test_ffi_table_options_extension() -> Result<()> { + let module = get_module()?; + + let extension_options = + module + .create_extension_options() + .ok_or(DataFusionError::NotImplemented( + "External test library failed to implement create_extension_options" + .to_string(), + ))?(); + + let mut table_options = TableOptions::new(); + table_options.extensions.insert(extension_options); + + // Verify default values are as expected + let returned_options: ExternalConfig = table_options + .local_or_ffi_extension() + .expect("should have external config extension"); + + assert_eq!(returned_options, ExternalConfig::default()); + + table_options.set("external_config.is_enabled", "false")?; + let returned_options: ExternalConfig = table_options + .local_or_ffi_extension() + .expect("should have external config extension"); + assert!(!returned_options.is_enabled); + + Ok(()) + } + + #[test] + fn test_ffi_session_config_options_extension() -> Result<()> { + let module = get_module()?; + + let extension_options = + module + .create_extension_options() + .ok_or(DataFusionError::NotImplemented( + "External test library failed to implement create_extension_options" + .to_string(), + ))?(); + + let mut config = SessionConfig::new().with_option_extension(extension_options); + + // Verify default values are as expected + let returned_config: ExternalConfig = config + .options() + .local_or_ffi_extension() + .expect("should have external config extension"); + assert_eq!(returned_config, ExternalConfig::default()); + + config = config.set( + "external_config.is_enabled", + &ScalarValue::Boolean(Some(false)), + ); + let returned_config: ExternalConfig = config + .options() + .local_or_ffi_extension() + .expect("should have external config extension"); + assert!(!returned_config.is_enabled); + + Ok(()) + } +} diff --git a/datafusion/ffi/tests/ffi_integration.rs b/datafusion/ffi/tests/ffi_integration.rs index 2d18679cb018..80538d4f92fb 100644 --- a/datafusion/ffi/tests/ffi_integration.rs +++ b/datafusion/ffi/tests/ffi_integration.rs @@ -21,10 +21,15 @@ mod utils; /// when the feature integration-tests is built #[cfg(feature = "integration-tests")] mod tests { + use std::collections::HashMap; use std::sync::Arc; - use datafusion::catalog::TableProvider; + use arrow::datatypes::Schema; + use datafusion::catalog::{TableProvider, TableProviderFactory}; use datafusion::error::{DataFusionError, Result}; + use datafusion_common::TableReference; + use datafusion_common::ToDFSchema; + use datafusion_expr::CreateExternalTable; use datafusion_ffi::tests::create_record_batch; use datafusion_ffi::tests::utils::get_module; @@ -69,4 +74,43 @@ mod tests { async fn sync_test_table_provider() -> Result<()> { test_table_provider(true).await } + + #[tokio::test] + async fn test_table_provider_factory() -> Result<()> { + let table_provider_module = get_module()?; + let (ctx, codec) = super::utils::ctx_and_codec(); + + let ffi_table_provider_factory = table_provider_module + .create_table_factory() + .ok_or(DataFusionError::NotImplemented( + "External table provider factory failed to implement create".to_string(), + ))?(codec); + + let foreign_table_provider_factory: Arc = + (&ffi_table_provider_factory).into(); + + let cmd = CreateExternalTable { + schema: Schema::empty().to_dfschema_ref()?, + name: TableReference::bare("cloned_test"), + location: "test".to_string(), + file_type: "test".to_string(), + table_partition_cols: vec![], + if_not_exists: false, + or_replace: false, + temporary: false, + definition: None, + order_exprs: vec![], + unbounded: false, + options: HashMap::new(), + constraints: Default::default(), + column_defaults: HashMap::new(), + }; + + let provider = foreign_table_provider_factory + .create(&ctx.state(), &cmd) + .await?; + assert_eq!(provider.schema().fields().len(), 2); + + Ok(()) + } } diff --git a/datafusion/functions-aggregate-common/benches/accumulate.rs b/datafusion/functions-aggregate-common/benches/accumulate.rs index f1e4fe23cbb1..aceec57df966 100644 --- a/datafusion/functions-aggregate-common/benches/accumulate.rs +++ b/datafusion/functions-aggregate-common/benches/accumulate.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use std::sync::Arc; use arrow::array::{ArrayRef, BooleanArray, Int64Array}; diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs index 29b8752048c3..25f52df61136 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs @@ -20,10 +20,70 @@ //! [`GroupsAccumulator`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator use arrow::array::{Array, BooleanArray, BooleanBufferBuilder, PrimitiveArray}; -use arrow::buffer::{BooleanBuffer, NullBuffer}; +use arrow::buffer::NullBuffer; use arrow::datatypes::ArrowPrimitiveType; use datafusion_expr_common::groups_accumulator::EmitTo; + +/// If the input has nulls, then the accumulator must potentially +/// handle each input null value specially (e.g. for `SUM` to mark the +/// corresponding sum as null) +/// +/// If there are filters present, `NullState` tracks if it has seen +/// *any* value for that group (as some values may be filtered +/// out). Without a filter, the accumulator is only passed groups that +/// had at least one value to accumulate so they do not need to track +/// if they have seen values for a particular group. +#[derive(Debug)] +pub enum SeenValues { + /// All groups seen so far have seen at least one non-null value + All { + num_values: usize, + }, + // Some groups have not yet seen a non-null value + Some { + values: BooleanBufferBuilder, + }, +} + +impl Default for SeenValues { + fn default() -> Self { + SeenValues::All { num_values: 0 } + } +} + +impl SeenValues { + /// Return a mutable reference to the `BooleanBufferBuilder` in `SeenValues::Some`. + /// + /// If `self` is `SeenValues::All`, it is transitioned to `SeenValues::Some` + /// by creating a new `BooleanBufferBuilder` where the first `num_values` are true. + /// + /// The builder is then ensured to have at least `total_num_groups` length, + /// with any new entries initialized to false. + fn get_builder(&mut self, total_num_groups: usize) -> &mut BooleanBufferBuilder { + match self { + SeenValues::All { num_values } => { + let mut builder = BooleanBufferBuilder::new(total_num_groups); + builder.append_n(*num_values, true); + if total_num_groups > *num_values { + builder.append_n(total_num_groups - *num_values, false); + } + *self = SeenValues::Some { values: builder }; + match self { + SeenValues::Some { values } => values, + _ => unreachable!(), + } + } + SeenValues::Some { values } => { + if values.len() < total_num_groups { + values.append_n(total_num_groups - values.len(), false); + } + values + } + } + } +} + /// Track the accumulator null state per row: if any values for that /// group were null and if any values have been seen at all for that group. /// @@ -53,12 +113,14 @@ use datafusion_expr_common::groups_accumulator::EmitTo; pub struct NullState { /// Have we seen any non-filtered input values for `group_index`? /// - /// If `seen_values[i]` is true, have seen at least one non null + /// If `seen_values` is `SeenValues::Some(buffer)` and buffer\[i\] is true, have seen at least one non null /// value for group `i` /// - /// If `seen_values[i]` is false, have not seen any values that + /// If `seen_values` is `SeenValues::Some(buffer)` and buffer\[i\] is false, have not seen any values that /// pass the filter yet for group `i` - seen_values: BooleanBufferBuilder, + /// + /// If `seen_values` is `SeenValues::All`, all groups have seen at least one non null value + seen_values: SeenValues, } impl Default for NullState { @@ -70,14 +132,16 @@ impl Default for NullState { impl NullState { pub fn new() -> Self { Self { - seen_values: BooleanBufferBuilder::new(0), + seen_values: SeenValues::All { num_values: 0 }, } } /// return the size of all buffers allocated by this null state, not including self pub fn size(&self) -> usize { - // capacity is in bits, so convert to bytes - self.seen_values.capacity() / 8 + match &self.seen_values { + SeenValues::All { .. } => 0, + SeenValues::Some { values } => values.capacity() / 8, + } } /// Invokes `value_fn(group_index, value)` for each non null, non @@ -107,10 +171,17 @@ impl NullState { T: ArrowPrimitiveType + Send, F: FnMut(usize, T::Native) + Send, { - // ensure the seen_values is big enough (start everything at - // "not seen" valid) - let seen_values = - initialize_builder(&mut self.seen_values, total_num_groups, false); + // skip null handling if no nulls in input or accumulator + if let SeenValues::All { num_values } = &mut self.seen_values + && opt_filter.is_none() + && values.null_count() == 0 + { + accumulate(group_indices, values, None, value_fn); + *num_values = total_num_groups; + return; + } + + let seen_values = self.seen_values.get_builder(total_num_groups); accumulate(group_indices, values, opt_filter, |group_index, value| { seen_values.set_bit(group_index, true); value_fn(group_index, value); @@ -140,10 +211,21 @@ impl NullState { let data = values.values(); assert_eq!(data.len(), group_indices.len()); - // ensure the seen_values is big enough (start everything at - // "not seen" valid) - let seen_values = - initialize_builder(&mut self.seen_values, total_num_groups, false); + // skip null handling if no nulls in input or accumulator + if let SeenValues::All { num_values } = &mut self.seen_values + && opt_filter.is_none() + && values.null_count() == 0 + { + group_indices + .iter() + .zip(data.iter()) + .for_each(|(&group_index, new_value)| value_fn(group_index, new_value)); + *num_values = total_num_groups; + + return; + } + + let seen_values = self.seen_values.get_builder(total_num_groups); // These could be made more performant by iterating in chunks of 64 bits at a time match (values.null_count() > 0, opt_filter) { @@ -211,21 +293,39 @@ impl NullState { /// for the `emit_to` rows. /// /// resets the internal state appropriately - pub fn build(&mut self, emit_to: EmitTo) -> NullBuffer { - let nulls: BooleanBuffer = self.seen_values.finish(); - - let nulls = match emit_to { - EmitTo::All => nulls, - EmitTo::First(n) => { - // split off the first N values in seen_values - let first_n_null: BooleanBuffer = nulls.slice(0, n); - // reset the existing seen buffer - self.seen_values - .append_buffer(&nulls.slice(n, nulls.len() - n)); - first_n_null + pub fn build(&mut self, emit_to: EmitTo) -> Option { + match emit_to { + EmitTo::All => { + let old_seen = std::mem::take(&mut self.seen_values); + match old_seen { + SeenValues::All { .. } => None, + SeenValues::Some { mut values } => { + Some(NullBuffer::new(values.finish())) + } + } } - }; - NullBuffer::new(nulls) + EmitTo::First(n) => match &mut self.seen_values { + SeenValues::All { num_values } => { + *num_values = num_values.saturating_sub(n); + None + } + SeenValues::Some { .. } => { + let mut old_values = match std::mem::take(&mut self.seen_values) { + SeenValues::Some { values } => values, + _ => unreachable!(), + }; + let nulls = old_values.finish(); + let first_n_null = nulls.slice(0, n); + let remainder = nulls.slice(n, nulls.len() - n); + let mut new_builder = BooleanBufferBuilder::new(remainder.len()); + new_builder.append_buffer(&remainder); + self.seen_values = SeenValues::Some { + values: new_builder, + }; + Some(NullBuffer::new(first_n_null)) + } + }, + } } } @@ -573,27 +673,14 @@ pub fn accumulate_indices( } } -/// Ensures that `builder` contains a `BooleanBufferBuilder with at -/// least `total_num_groups`. -/// -/// All new entries are initialized to `default_value` -fn initialize_builder( - builder: &mut BooleanBufferBuilder, - total_num_groups: usize, - default_value: bool, -) -> &mut BooleanBufferBuilder { - if builder.len() < total_num_groups { - let new_groups = total_num_groups - builder.len(); - builder.append_n(new_groups, default_value); - } - builder -} - #[cfg(test)] mod test { use super::*; - use arrow::array::{Int32Array, UInt32Array}; + use arrow::{ + array::{Int32Array, UInt32Array}, + buffer::BooleanBuffer, + }; use rand::{Rng, rngs::ThreadRng}; use std::collections::HashSet; @@ -834,15 +921,24 @@ mod test { accumulated_values, expected_values, "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}" ); - let seen_values = null_state.seen_values.finish_cloned(); - mock.validate_seen_values(&seen_values); + + match &null_state.seen_values { + SeenValues::All { num_values } => { + assert_eq!(*num_values, total_num_groups); + } + SeenValues::Some { values } => { + let seen_values = values.finish_cloned(); + mock.validate_seen_values(&seen_values); + } + } // Validate the final buffer (one value per group) let expected_null_buffer = mock.expected_null_buffer(total_num_groups); let null_buffer = null_state.build(EmitTo::All); - - assert_eq!(null_buffer, expected_null_buffer); + if let Some(nulls) = &null_buffer { + assert_eq!(*nulls, expected_null_buffer); + } } // Calls `accumulate_indices` @@ -955,15 +1051,25 @@ mod test { "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}" ); - let seen_values = null_state.seen_values.finish_cloned(); - mock.validate_seen_values(&seen_values); + match &null_state.seen_values { + SeenValues::All { num_values } => { + assert_eq!(*num_values, total_num_groups); + } + SeenValues::Some { values } => { + let seen_values = values.finish_cloned(); + mock.validate_seen_values(&seen_values); + } + } // Validate the final buffer (one value per group) - let expected_null_buffer = mock.expected_null_buffer(total_num_groups); + let expected_null_buffer = Some(mock.expected_null_buffer(total_num_groups)); + let is_all_seen = matches!(null_state.seen_values, SeenValues::All { .. }); let null_buffer = null_state.build(EmitTo::All); - assert_eq!(null_buffer, expected_null_buffer); + if !is_all_seen { + assert_eq!(null_buffer, expected_null_buffer); + } } } diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs index 149312e5a9c0..f716b48f0ccc 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs @@ -120,7 +120,7 @@ where }; let nulls = self.null_state.build(emit_to); - let values = BooleanArray::new(values, Some(nulls)); + let values = BooleanArray::new(values, nulls); Ok(Arc::new(values)) } diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs index 74d361cf257b..435560721cd2 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs @@ -44,7 +44,7 @@ pub fn set_nulls( /// The `NullBuffer` is /// * `true` (representing valid) for values that were `true` in filter /// * `false` (representing null) for values that were `false` or `null` in filter -fn filter_to_nulls(filter: &BooleanArray) -> Option { +pub fn filter_to_nulls(filter: &BooleanArray) -> Option { let (filter_bools, filter_nulls) = filter.clone().into_parts(); let filter_bools = NullBuffer::from(filter_bools); NullBuffer::union(Some(&filter_bools), filter_nulls.as_ref()) diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs index 656b95d140dd..acf875b68613 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs @@ -106,7 +106,8 @@ where opt_filter, total_num_groups, |group_index, new_value| { - let value = &mut self.values[group_index]; + // SAFETY: group_index is guaranteed to be in bounds + let value = unsafe { self.values.get_unchecked_mut(group_index) }; (self.prim_fn)(value, new_value); }, ); @@ -117,7 +118,7 @@ where fn evaluate(&mut self, emit_to: EmitTo) -> Result { let values = emit_to.take_needed(&mut self.values); let nulls = self.null_state.build(emit_to); - let values = PrimitiveArray::::new(values.into(), Some(nulls)) // no copy + let values = PrimitiveArray::::new(values.into(), nulls) // no copy .with_data_type(self.data_type.clone()); Ok(Arc::new(values)) } diff --git a/datafusion/functions-aggregate-common/src/lib.rs b/datafusion/functions-aggregate-common/src/lib.rs index 61b880095047..574d160d4214 100644 --- a/datafusion/functions-aggregate-common/src/lib.rs +++ b/datafusion/functions-aggregate-common/src/lib.rs @@ -31,8 +31,6 @@ // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -// https://github.com/apache/datafusion/issues/18881 -#![deny(clippy::allow_attributes)] pub mod accumulator; pub mod aggregate; diff --git a/datafusion/functions-aggregate-common/src/tdigest.rs b/datafusion/functions-aggregate-common/src/tdigest.rs index 225c61b71939..a7450f0eb52e 100644 --- a/datafusion/functions-aggregate-common/src/tdigest.rs +++ b/datafusion/functions-aggregate-common/src/tdigest.rs @@ -49,17 +49,6 @@ macro_rules! cast_scalar_f64 { }; } -// Cast a non-null [`ScalarValue::UInt64`] to an [`u64`], or -// panic. -macro_rules! cast_scalar_u64 { - ($value:expr ) => { - match &$value { - ScalarValue::UInt64(Some(v)) => *v, - v => panic!("invalid type {}", v), - } - }; -} - /// Centroid implementation to the cluster mentioned in the paper. #[derive(Debug, PartialEq, Clone)] pub struct Centroid { @@ -110,7 +99,7 @@ pub struct TDigest { centroids: Vec, max_size: usize, sum: f64, - count: u64, + count: f64, max: f64, min: f64, } @@ -120,8 +109,8 @@ impl TDigest { TDigest { centroids: Vec::new(), max_size, - sum: 0_f64, - count: 0, + sum: 0.0, + count: 0.0, max: f64::NAN, min: f64::NAN, } @@ -133,14 +122,14 @@ impl TDigest { centroids: vec![centroid.clone()], max_size, sum: centroid.mean * centroid.weight, - count: 1, + count: centroid.weight, max: centroid.mean, min: centroid.mean, } } #[inline] - pub fn count(&self) -> u64 { + pub fn count(&self) -> f64 { self.count } @@ -170,8 +159,8 @@ impl Default for TDigest { TDigest { centroids: Vec::new(), max_size: 100, - sum: 0_f64, - count: 0, + sum: 0.0, + count: 0.0, max: f64::NAN, min: f64::NAN, } @@ -216,12 +205,12 @@ impl TDigest { } let mut result = TDigest::new(self.max_size()); - result.count = self.count() + sorted_values.len() as u64; + result.count = self.count() + sorted_values.len() as f64; let maybe_min = *sorted_values.first().unwrap(); let maybe_max = *sorted_values.last().unwrap(); - if self.count() > 0 { + if self.count() > 0.0 { result.min = self.min.min(maybe_min); result.max = self.max.max(maybe_max); } else { @@ -233,7 +222,7 @@ impl TDigest { let mut k_limit: u64 = 1; let mut q_limit_times_count = - Self::k_to_q(k_limit, self.max_size) * result.count() as f64; + Self::k_to_q(k_limit, self.max_size) * result.count(); k_limit += 1; let mut iter_centroids = self.centroids.iter().peekable(); @@ -281,7 +270,7 @@ impl TDigest { compressed.push(curr.clone()); q_limit_times_count = - Self::k_to_q(k_limit, self.max_size) * result.count() as f64; + Self::k_to_q(k_limit, self.max_size) * result.count(); k_limit += 1; curr = next; } @@ -353,7 +342,7 @@ impl TDigest { let mut centroids: Vec = Vec::with_capacity(n_centroids); let mut starts: Vec = Vec::with_capacity(digests.len()); - let mut count = 0; + let mut count = 0.0; let mut min = f64::INFINITY; let mut max = f64::NEG_INFINITY; @@ -362,7 +351,7 @@ impl TDigest { starts.push(start); let curr_count = digest.count(); - if curr_count > 0 { + if curr_count > 0.0 { min = min.min(digest.min); max = max.max(digest.max); count += curr_count; @@ -373,6 +362,11 @@ impl TDigest { } } + // If no centroids were added (all digests had zero count), return default + if centroids.is_empty() { + return TDigest::default(); + } + let mut digests_per_block: usize = 1; while digests_per_block < starts.len() { for i in (0..starts.len()).step_by(digests_per_block * 2) { @@ -397,7 +391,7 @@ impl TDigest { let mut compressed: Vec = Vec::with_capacity(max_size); let mut k_limit = 1; - let mut q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64; + let mut q_limit_times_count = Self::k_to_q(k_limit, max_size) * count; let mut iter_centroids = centroids.iter_mut(); let mut curr = iter_centroids.next().unwrap(); @@ -416,7 +410,7 @@ impl TDigest { sums_to_merge = 0_f64; weights_to_merge = 0_f64; compressed.push(curr.clone()); - q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64; + q_limit_times_count = Self::k_to_q(k_limit, max_size) * count; k_limit += 1; curr = centroid; } @@ -440,7 +434,7 @@ impl TDigest { return 0.0; } - let rank = q * self.count as f64; + let rank = q * self.count; let mut pos: usize; let mut t; @@ -450,7 +444,7 @@ impl TDigest { } pos = 0; - t = self.count as f64; + t = self.count; for (k, centroid) in self.centroids.iter().enumerate().rev() { t -= centroid.weight(); @@ -563,7 +557,7 @@ impl TDigest { vec![ ScalarValue::UInt64(Some(self.max_size as u64)), ScalarValue::Float64(Some(self.sum)), - ScalarValue::UInt64(Some(self.count)), + ScalarValue::Float64(Some(self.count)), ScalarValue::Float64(Some(self.max)), ScalarValue::Float64(Some(self.min)), ScalarValue::List(arr), @@ -611,7 +605,7 @@ impl TDigest { Self { max_size, sum: cast_scalar_f64!(state[1]), - count: cast_scalar_u64!(&state[2]), + count: cast_scalar_f64!(state[2]), max, min, centroids, diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index 8f8697fef0a1..39337e44bb05 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -53,6 +53,7 @@ datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } half = { workspace = true } log = { workspace = true } +num-traits = { workspace = true } paste = { workspace = true } [dev-dependencies] diff --git a/datafusion/functions-aggregate/benches/array_agg.rs b/datafusion/functions-aggregate/benches/array_agg.rs index d7f687386333..793c2aac9629 100644 --- a/datafusion/functions-aggregate/benches/array_agg.rs +++ b/datafusion/functions-aggregate/benches/array_agg.rs @@ -43,7 +43,7 @@ fn merge_batch_bench(c: &mut Criterion, name: &str, values: ArrayRef) { let list_item_data_type = values.as_list::().values().data_type().clone(); c.bench_function(name, |b| { b.iter(|| { - #[allow(clippy::unit_arg)] + #[expect(clippy::unit_arg)] black_box( ArrayAggAccumulator::try_new(&list_item_data_type, false) .unwrap() diff --git a/datafusion/functions-aggregate/benches/count.rs b/datafusion/functions-aggregate/benches/count.rs index 711bbe5a3c4d..48f71858c120 100644 --- a/datafusion/functions-aggregate/benches/count.rs +++ b/datafusion/functions-aggregate/benches/count.rs @@ -130,7 +130,7 @@ fn count_benchmark(c: &mut Criterion) { let mut accumulator = prepare_accumulator(); c.bench_function("count low cardinality dict 20% nulls, no filter", |b| { b.iter(|| { - #[allow(clippy::unit_arg)] + #[expect(clippy::unit_arg)] black_box( accumulator .update_batch(std::slice::from_ref(&values)) diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index 739e333b5461..2205b009ecb2 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -110,7 +110,7 @@ impl AggregateUDFImpl for ApproxMedian { Ok(vec![ Field::new(format_state_name(args.name, "max_size"), UInt64, false), Field::new(format_state_name(args.name, "sum"), Float64, false), - Field::new(format_state_name(args.name, "count"), UInt64, false), + Field::new(format_state_name(args.name, "count"), Float64, false), Field::new(format_state_name(args.name, "max"), Float64, false), Field::new(format_state_name(args.name, "min"), Float64, false), Field::new_list( diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index b1e649ec029f..392a044d0139 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -259,7 +259,7 @@ impl AggregateUDFImpl for ApproxPercentileCont { ), Field::new( format_state_name(args.name, "count"), - DataType::UInt64, + DataType::Float64, false, ), Field::new( @@ -436,7 +436,7 @@ impl Accumulator for ApproxPercentileAccumulator { } fn evaluate(&mut self) -> Result { - if self.digest.count() == 0 { + if self.digest.count() == 0.0 { return ScalarValue::try_from(self.return_type.clone()); } let q = self.digest.estimate_quantile(self.percentile); @@ -513,8 +513,8 @@ mod tests { ApproxPercentileAccumulator::new_with_max_size(0.5, DataType::Float64, 100); accumulator.merge_digests(&[t1]); - assert_eq!(accumulator.digest.count(), 50_000); + assert_eq!(accumulator.digest.count(), 50_000.0); accumulator.merge_digests(&[t2]); - assert_eq!(accumulator.digest.count(), 100_000); + assert_eq!(accumulator.digest.count(), 100_000.0); } } diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index ff7762e816ad..6fd90130e674 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::fmt::{Debug, Formatter}; +use std::fmt::Debug; use std::hash::Hash; use std::mem::size_of_val; use std::sync::Arc; @@ -111,20 +111,12 @@ An alternative syntax is also supported: description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory." ) )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct ApproxPercentileContWithWeight { signature: Signature, approx_percentile_cont: ApproxPercentileCont, } -impl Debug for ApproxPercentileContWithWeight { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ApproxPercentileContWithWeight") - .field("signature", &self.signature) - .finish() - } -} - impl Default for ApproxPercentileContWithWeight { fn default() -> Self { Self::new() diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 9b2e7429ab3b..cd4cb9b19ff7 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -23,8 +23,10 @@ use std::mem::{size_of, size_of_val, take}; use std::sync::Arc; use arrow::array::{ - Array, ArrayRef, AsArray, BooleanArray, ListArray, StructArray, new_empty_array, + Array, ArrayRef, AsArray, BooleanArray, ListArray, NullBufferBuilder, StructArray, + UInt32Array, new_empty_array, }; +use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; use arrow::compute::{SortOptions, filter}; use arrow::datatypes::{DataType, Field, FieldRef, Fields}; @@ -36,8 +38,10 @@ use datafusion_common::{Result, ScalarValue, assert_eq_or_internal_err, exec_err use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, + Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator, Signature, + Volatility, }; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filter_to_nulls; use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays; use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity; use datafusion_functions_aggregate_common::utils::ordering_fields; @@ -228,6 +232,23 @@ impl AggregateUDFImpl for ArrayAgg { datafusion_expr::ReversedUDAF::Reversed(array_agg_udaf()) } + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + !args.is_distinct && args.order_bys.is_empty() + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + let field = &args.expr_fields[0]; + let data_type = field.data_type().clone(); + let ignore_nulls = args.ignore_nulls && field.is_nullable(); + Ok(Box::new(ArrayAggGroupsAccumulator::new( + data_type, + ignore_nulls, + ))) + } + fn supports_null_handling_clause(&self) -> bool { true } @@ -415,7 +436,332 @@ impl Accumulator for ArrayAggAccumulator { } #[derive(Debug)] -struct DistinctArrayAggAccumulator { +struct ArrayAggGroupsAccumulator { + datatype: DataType, + ignore_nulls: bool, + /// Source arrays — input arrays (from update_batch) or list backing + /// arrays (from merge_batch). + batches: Vec, + /// Per-batch list of (group_idx, row_idx) pairs. + batch_entries: Vec>, + /// Total number of groups tracked. + num_groups: usize, +} + +impl ArrayAggGroupsAccumulator { + fn new(datatype: DataType, ignore_nulls: bool) -> Self { + Self { + datatype, + ignore_nulls, + batches: Vec::new(), + batch_entries: Vec::new(), + num_groups: 0, + } + } + + fn clear_state(&mut self) { + // `size()` measures Vec capacity rather than len, so allocate new + // buffers instead of using `clear()`. + self.batches = Vec::new(); + self.batch_entries = Vec::new(); + self.num_groups = 0; + } + + fn compact_retained_state(&mut self, emit_groups: usize) -> Result<()> { + // EmitTo::First is used to recover from memory pressure. Simply + // removing emitted entries in place is not enough because mixed batches + // would continue to pin their original Array arrays, even if only a few + // retained rows remain. + // + // Rebuild the retained state from scratch so fully emitted batches are + // dropped, mixed batches are compacted to arrays containing only the + // surviving rows, and retained metadata is right-sized. + let emit_groups = emit_groups as u32; + let old_batches = take(&mut self.batches); + let old_batch_entries = take(&mut self.batch_entries); + + let mut batches = Vec::new(); + let mut batch_entries = Vec::new(); + + for (batch, entries) in old_batches.into_iter().zip(old_batch_entries) { + let retained_len = entries.iter().filter(|(g, _)| *g >= emit_groups).count(); + + if retained_len == 0 { + continue; + } + + if retained_len == entries.len() { + // Nothing was emitted from this batch, so we keep the existing + // array and only renumber the remaining group IDs so that they + // start from 0. + let mut retained_entries = entries; + for (g, _) in &mut retained_entries { + *g -= emit_groups; + } + retained_entries.shrink_to_fit(); + batches.push(batch); + batch_entries.push(retained_entries); + continue; + } + + let mut retained_entries = Vec::with_capacity(retained_len); + let mut retained_rows = Vec::with_capacity(retained_len); + + for (g, r) in entries { + if g >= emit_groups { + // Compute the new `(group_idx, row_idx)` pair for a + // retained row. `group_idx` is renumbered to start from + // 0, and `row_idx` points into the new dense batch we are + // building. + retained_entries.push((g - emit_groups, retained_rows.len() as u32)); + retained_rows.push(r); + } + } + + debug_assert_eq!(retained_entries.len(), retained_len); + debug_assert_eq!(retained_rows.len(), retained_len); + + let batch = if retained_len == batch.len() { + batch + } else { + // Compact mixed batches so retained rows no longer pin the + // original array. + let retained_rows = UInt32Array::from(retained_rows); + arrow::compute::take(batch.as_ref(), &retained_rows, None)? + }; + + batches.push(batch); + batch_entries.push(retained_entries); + } + + self.batches = batches; + self.batch_entries = batch_entries; + self.num_groups -= emit_groups as usize; + + Ok(()) + } +} + +impl GroupsAccumulator for ArrayAggGroupsAccumulator { + /// Store a reference to the input batch, plus a `(group_idx, row_idx)` pair + /// for every row. + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let input = &values[0]; + + self.num_groups = self.num_groups.max(total_num_groups); + + let nulls = if self.ignore_nulls { + input.logical_nulls() + } else { + None + }; + + let mut entries = Vec::new(); + + for (row_idx, &group_idx) in group_indices.iter().enumerate() { + // Skip filtered rows + if let Some(filter) = opt_filter + && (filter.is_null(row_idx) || !filter.value(row_idx)) + { + continue; + } + + // Skip null values when ignore_nulls is set + if let Some(ref nulls) = nulls + && nulls.is_null(row_idx) + { + continue; + } + + entries.push((group_idx as u32, row_idx as u32)); + } + + // We only need to record the batch if it was non-empty. + if !entries.is_empty() { + self.batches.push(Arc::clone(input)); + self.batch_entries.push(entries); + } + + Ok(()) + } + + /// Produce a `ListArray` ordered by group index: the list at + /// position N contains the aggregated values for group N. + /// + /// Uses a counting sort to rearrange the stored `(group, row)` + /// entries into group order, then calls `interleave` to gather + /// the values into a flat array that backs the output `ListArray`. + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let emit_groups = match emit_to { + EmitTo::All => self.num_groups, + EmitTo::First(n) => n, + }; + + // Step 1: Count entries per group. For EmitTo::First(n), only groups + // 0..n are counted; the rest are retained to be emitted in the future. + let mut counts = vec![0u32; emit_groups]; + for entries in &self.batch_entries { + for &(g, _) in entries { + let g = g as usize; + if g < emit_groups { + counts[g] += 1; + } + } + } + + // Step 2: Do a prefix sum over the counts and use it to build ListArray + // offsets, null buffer, and write positions for the counting sort. + let mut offsets = Vec::::with_capacity(emit_groups + 1); + offsets.push(0); + let mut nulls_builder = NullBufferBuilder::new(emit_groups); + let mut write_positions = Vec::with_capacity(emit_groups); + let mut cur_offset = 0u32; + for &count in &counts { + if count == 0 { + nulls_builder.append_null(); + } else { + nulls_builder.append_non_null(); + } + write_positions.push(cur_offset); + cur_offset += count; + offsets.push(cur_offset as i32); + } + let total_rows = cur_offset as usize; + + // Step 3: Scatter entries into group order using the counting sort. The + // batch index is implicit from the outer loop position. + let flat_values = if total_rows == 0 { + new_empty_array(&self.datatype) + } else { + let mut interleave_indices = vec![(0usize, 0usize); total_rows]; + for (batch_idx, entries) in self.batch_entries.iter().enumerate() { + for &(g, r) in entries { + let g = g as usize; + if g < emit_groups { + let wp = write_positions[g] as usize; + interleave_indices[wp] = (batch_idx, r as usize); + write_positions[g] += 1; + } + } + } + + let sources: Vec<&dyn Array> = + self.batches.iter().map(|b| b.as_ref()).collect(); + arrow::compute::interleave(&sources, &interleave_indices)? + }; + + // Step 4: Release state for emitted groups. + match emit_to { + EmitTo::All => self.clear_state(), + EmitTo::First(_) => self.compact_retained_state(emit_groups)?, + } + + let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets)); + let field = Arc::new(Field::new_list_field(self.datatype.clone(), true)); + let result = ListArray::new(field, offsets, flat_values, nulls_builder.finish()); + + Ok(Arc::new(result)) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + Ok(vec![self.evaluate(emit_to)?]) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + _opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "one argument to merge_batch"); + let input_list = values[0].as_list::(); + + self.num_groups = self.num_groups.max(total_num_groups); + + // Push the ListArray's backing values array as a single batch. + let list_values = input_list.values(); + let list_offsets = input_list.offsets(); + + let mut entries = Vec::new(); + + for (row_idx, &group_idx) in group_indices.iter().enumerate() { + if input_list.is_null(row_idx) { + continue; + } + let start = list_offsets[row_idx] as u32; + let end = list_offsets[row_idx + 1] as u32; + for pos in start..end { + entries.push((group_idx as u32, pos)); + } + } + + if !entries.is_empty() { + self.batches.push(Arc::clone(list_values)); + self.batch_entries.push(entries); + } + + Ok(()) + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + assert_eq!(values.len(), 1, "one argument to convert_to_state"); + + let input = &values[0]; + + // Each row becomes a 1-element list: offsets are [0, 1, 2, ..., n]. + let offsets = OffsetBuffer::from_repeated_length(1, input.len()); + + // Filtered rows become null list entries, which merge_batch will skip. + let filter_nulls = opt_filter.and_then(filter_to_nulls); + + // With ignore_nulls, null values also become null list entries. Without + // ignore_nulls, null values stay as [NULL] so merge_batch retains them. + let nulls = if self.ignore_nulls { + let logical = input.logical_nulls(); + NullBuffer::union(filter_nulls.as_ref(), logical.as_ref()) + } else { + filter_nulls + }; + + let field = Arc::new(Field::new_list_field(self.datatype.clone(), true)); + let list_array = ListArray::new(field, offsets, Arc::clone(input), nulls); + + Ok(vec![Arc::new(list_array)]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } + + fn size(&self) -> usize { + self.batches + .iter() + .map(|arr| arr.to_data().get_slice_memory_size().unwrap_or_default()) + .sum::() + + self.batches.capacity() * size_of::() + + self + .batch_entries + .iter() + .map(|e| e.capacity() * size_of::<(u32, u32)>()) + .sum::() + + self.batch_entries.capacity() * size_of::>() + } +} + +#[derive(Debug)] +pub struct DistinctArrayAggAccumulator { values: HashSet, datatype: DataType, sort_options: Option, @@ -1227,4 +1573,372 @@ mod tests { acc1.merge_batch(&intermediate_state)?; Ok(acc1) } + + // ---- GroupsAccumulator tests ---- + + use arrow::array::Int32Array; + + fn list_array_to_i32_vecs(list: &ListArray) -> Vec>>> { + (0..list.len()) + .map(|i| { + if list.is_null(i) { + None + } else { + let arr = list.value(i); + let vals: Vec> = arr + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .collect(); + Some(vals) + } + }) + .collect() + } + + fn eval_i32_lists( + acc: &mut ArrayAggGroupsAccumulator, + emit_to: EmitTo, + ) -> Result>>>> { + let result = acc.evaluate(emit_to)?; + Ok(list_array_to_i32_vecs(result.as_list::())) + } + + #[test] + fn groups_accumulator_multiple_batches() -> Result<()> { + let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false); + + // First batch + let values: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); + acc.update_batch(&[values], &[0, 1, 0], None, 2)?; + + // Second batch + let values: ArrayRef = Arc::new(Int32Array::from(vec![4, 5])); + acc.update_batch(&[values], &[1, 0], None, 2)?; + + let vals = eval_i32_lists(&mut acc, EmitTo::All)?; + assert_eq!(vals[0], Some(vec![Some(1), Some(3), Some(5)])); + assert_eq!(vals[1], Some(vec![Some(2), Some(4)])); + + Ok(()) + } + + #[test] + fn groups_accumulator_emit_first() -> Result<()> { + let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false); + + let values: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 30])); + acc.update_batch(&[values], &[0, 1, 2], None, 3)?; + + // Emit first 2 groups + let vals = eval_i32_lists(&mut acc, EmitTo::First(2))?; + assert_eq!(vals.len(), 2); + assert_eq!(vals[0], Some(vec![Some(10)])); + assert_eq!(vals[1], Some(vec![Some(20)])); + + // Remaining group (was index 2, now shifted to 0) + let vals = eval_i32_lists(&mut acc, EmitTo::All)?; + assert_eq!(vals.len(), 1); + assert_eq!(vals[0], Some(vec![Some(30)])); + + Ok(()) + } + + #[test] + fn groups_accumulator_emit_first_frees_batches() -> Result<()> { + // Batch 0 has rows only for group 0; batch 1 has rows for + // both groups. After emitting group 0, batch 0 should be + // dropped entirely and batch 1 should be compacted to the + // retained row(s). + let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false); + + let batch0: ArrayRef = Arc::new(Int32Array::from(vec![10, 20])); + acc.update_batch(&[batch0], &[0, 0], None, 2)?; + + let batch1: ArrayRef = Arc::new(Int32Array::from(vec![30, 40])); + acc.update_batch(&[batch1], &[0, 1], None, 2)?; + + assert_eq!(acc.batches.len(), 2); + assert!(!acc.batches[0].is_empty()); + assert!(!acc.batches[1].is_empty()); + + // Emit group 0. Batch 0 is only referenced by group 0, so it + // should be removed. Batch 1 is mixed, so it should be compacted + // to contain only the retained row for group 1. + let vals = eval_i32_lists(&mut acc, EmitTo::First(1))?; + assert_eq!(vals[0], Some(vec![Some(10), Some(20), Some(30)])); + + assert_eq!(acc.batches.len(), 1); + let retained = acc.batches[0] + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(retained.values(), &[40]); + assert_eq!(acc.batch_entries, vec![vec![(0, 0)]]); + + // Emit remaining group 1 + let vals = eval_i32_lists(&mut acc, EmitTo::All)?; + assert_eq!(vals[0], Some(vec![Some(40)])); + + assert!(acc.batches.is_empty()); + assert_eq!(acc.size(), 0); + + Ok(()) + } + + #[test] + fn groups_accumulator_emit_first_compacts_mixed_batches() -> Result<()> { + let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false); + + let batch: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 30, 40])); + acc.update_batch(&[batch], &[0, 1, 0, 1], None, 2)?; + + let size_before = acc.size(); + let vals = eval_i32_lists(&mut acc, EmitTo::First(1))?; + assert_eq!(vals[0], Some(vec![Some(10), Some(30)])); + + assert_eq!(acc.num_groups, 1); + assert_eq!(acc.batches.len(), 1); + let retained = acc.batches[0] + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(retained.values(), &[20, 40]); + assert_eq!(acc.batch_entries, vec![vec![(0, 0), (0, 1)]]); + assert!(acc.size() < size_before); + + let vals = eval_i32_lists(&mut acc, EmitTo::All)?; + assert_eq!(vals[0], Some(vec![Some(20), Some(40)])); + assert_eq!(acc.size(), 0); + + Ok(()) + } + + #[test] + fn groups_accumulator_emit_all_releases_capacity() -> Result<()> { + let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false); + + let batch: ArrayRef = Arc::new(Int32Array::from_iter_values(0..64)); + acc.update_batch( + &[batch], + &(0..64).map(|i| i % 4).collect::>(), + None, + 4, + )?; + + assert!(acc.size() > 0); + let _ = eval_i32_lists(&mut acc, EmitTo::All)?; + + assert_eq!(acc.size(), 0); + assert_eq!(acc.batches.capacity(), 0); + assert_eq!(acc.batch_entries.capacity(), 0); + + Ok(()) + } + + #[test] + fn groups_accumulator_null_groups() -> Result<()> { + // Groups that never receive values should produce null + let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false); + + let values: ArrayRef = Arc::new(Int32Array::from(vec![1])); + // Only group 0 gets a value, groups 1 and 2 are empty + acc.update_batch(&[values], &[0], None, 3)?; + + let vals = eval_i32_lists(&mut acc, EmitTo::All)?; + assert_eq!(vals, vec![Some(vec![Some(1)]), None, None]); + + Ok(()) + } + + #[test] + fn groups_accumulator_ignore_nulls() -> Result<()> { + let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, true); + + let values: ArrayRef = + Arc::new(Int32Array::from(vec![Some(1), None, Some(3), None])); + acc.update_batch(&[values], &[0, 0, 1, 1], None, 2)?; + + let vals = eval_i32_lists(&mut acc, EmitTo::All)?; + // Group 0: only non-null value is 1 + assert_eq!(vals[0], Some(vec![Some(1)])); + // Group 1: only non-null value is 3 + assert_eq!(vals[1], Some(vec![Some(3)])); + + Ok(()) + } + + #[test] + fn groups_accumulator_opt_filter() -> Result<()> { + let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false); + + let values: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4])); + // Use a mix of false and null to filter out rows — both should + // be skipped. + let filter = BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]); + acc.update_batch(&[values], &[0, 0, 1, 1], Some(&filter), 2)?; + + let vals = eval_i32_lists(&mut acc, EmitTo::All)?; + assert_eq!(vals[0], Some(vec![Some(1)])); // row 1 filtered (null) + assert_eq!(vals[1], Some(vec![Some(3)])); // row 3 filtered (false) + + Ok(()) + } + + #[test] + fn groups_accumulator_state_merge_roundtrip() -> Result<()> { + // Accumulator 1: update_batch, then merge, then update_batch again. + // Verifies that values appear in chronological insertion order. + let mut acc1 = ArrayAggGroupsAccumulator::new(DataType::Int32, false); + let values: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); + acc1.update_batch(&[values], &[0, 1], None, 2)?; + + // Accumulator 2 + let mut acc2 = ArrayAggGroupsAccumulator::new(DataType::Int32, false); + let values: ArrayRef = Arc::new(Int32Array::from(vec![3, 4])); + acc2.update_batch(&[values], &[0, 1], None, 2)?; + + // Merge acc2's state into acc1 + let state = acc2.state(EmitTo::All)?; + acc1.merge_batch(&state, &[0, 1], None, 2)?; + + // Another update_batch on acc1 after the merge + let values: ArrayRef = Arc::new(Int32Array::from(vec![5, 6])); + acc1.update_batch(&[values], &[0, 1], None, 2)?; + + // Each group's values in insertion order: + // group 0: update(1), merge(3), update(5) → [1, 3, 5] + // group 1: update(2), merge(4), update(6) → [2, 4, 6] + let vals = eval_i32_lists(&mut acc1, EmitTo::All)?; + assert_eq!(vals[0], Some(vec![Some(1), Some(3), Some(5)])); + assert_eq!(vals[1], Some(vec![Some(2), Some(4), Some(6)])); + + Ok(()) + } + + #[test] + fn groups_accumulator_convert_to_state() -> Result<()> { + let acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false); + + let values: ArrayRef = Arc::new(Int32Array::from(vec![Some(10), None, Some(30)])); + let state = acc.convert_to_state(&[values], None)?; + + assert_eq!(state.len(), 1); + let vals = list_array_to_i32_vecs(state[0].as_list::()); + assert_eq!( + vals, + vec![ + Some(vec![Some(10)]), + Some(vec![None]), // null preserved inside list, not promoted + Some(vec![Some(30)]), + ] + ); + + Ok(()) + } + + #[test] + fn groups_accumulator_convert_to_state_with_filter() -> Result<()> { + let acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false); + + let values: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 30])); + let filter = BooleanArray::from(vec![true, false, true]); + let state = acc.convert_to_state(&[values], Some(&filter))?; + + let vals = list_array_to_i32_vecs(state[0].as_list::()); + assert_eq!( + vals, + vec![ + Some(vec![Some(10)]), + None, // filtered + Some(vec![Some(30)]), + ] + ); + + Ok(()) + } + + #[test] + fn groups_accumulator_convert_to_state_merge_preserves_nulls() -> Result<()> { + // Verifies that null values survive the convert_to_state -> merge_batch + // round-trip when ignore_nulls is false (default null handling). + let acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false); + + let values: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])); + let state = acc.convert_to_state(&[values], None)?; + + // Feed state into a new accumulator via merge_batch + let mut acc2 = ArrayAggGroupsAccumulator::new(DataType::Int32, false); + acc2.merge_batch(&state, &[0, 0, 1], None, 2)?; + + // Group 0 received rows 0 ([1]) and 1 ([NULL]) → [1, NULL] + let vals = eval_i32_lists(&mut acc2, EmitTo::All)?; + assert_eq!(vals[0], Some(vec![Some(1), None])); + // Group 1 received row 2 ([3]) → [3] + assert_eq!(vals[1], Some(vec![Some(3)])); + + Ok(()) + } + + #[test] + fn groups_accumulator_convert_to_state_merge_ignore_nulls() -> Result<()> { + // Verifies that null values are dropped in the convert_to_state -> + // merge_batch round-trip when ignore_nulls is true. + let acc = ArrayAggGroupsAccumulator::new(DataType::Int32, true); + + let values: ArrayRef = + Arc::new(Int32Array::from(vec![Some(1), None, Some(3), None])); + let state = acc.convert_to_state(&[values], None)?; + + let list = state[0].as_list::(); + // Rows 0 and 2 are valid lists; rows 1 and 3 are null list entries + assert!(!list.is_null(0)); + assert!(list.is_null(1)); + assert!(!list.is_null(2)); + assert!(list.is_null(3)); + + // Feed state into a new accumulator via merge_batch + let mut acc2 = ArrayAggGroupsAccumulator::new(DataType::Int32, true); + acc2.merge_batch(&state, &[0, 0, 1, 1], None, 2)?; + + // Group 0: received [1] and null (skipped) → [1] + let vals = eval_i32_lists(&mut acc2, EmitTo::All)?; + assert_eq!(vals[0], Some(vec![Some(1)])); + // Group 1: received [3] and null (skipped) → [3] + assert_eq!(vals[1], Some(vec![Some(3)])); + + Ok(()) + } + + #[test] + fn groups_accumulator_all_groups_empty() -> Result<()> { + let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false); + + // Create groups but don't add any values (all filtered out) + let values: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); + let filter = BooleanArray::from(vec![false, false]); + acc.update_batch(&[values], &[0, 1], Some(&filter), 2)?; + + let vals = eval_i32_lists(&mut acc, EmitTo::All)?; + assert_eq!(vals, vec![None, None]); + + Ok(()) + } + + #[test] + fn groups_accumulator_ignore_nulls_all_null_group() -> Result<()> { + // When ignore_nulls is true and a group receives only nulls, + // it should produce a null output + let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, true); + + let values: ArrayRef = Arc::new(Int32Array::from(vec![None, Some(1), None])); + acc.update_batch(&[values], &[0, 1, 0], None, 2)?; + + let vals = eval_i32_lists(&mut acc, EmitTo::All)?; + assert_eq!(vals[0], None); // group 0 got only nulls, all filtered + assert_eq!(vals[1], Some(vec![Some(1)])); // group 1 got value 1 + + Ok(()) + } } diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 46a8dbf9540b..543116db1ddb 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -821,7 +821,8 @@ where opt_filter, total_num_groups, |group_index, new_value| { - let sum = &mut self.sums[group_index]; + // SAFETY: group_index is guaranteed to be in bounds + let sum = unsafe { self.sums.get_unchecked_mut(group_index) }; *sum = sum.add_wrapping(new_value); self.counts[group_index] += 1; @@ -836,12 +837,16 @@ where let sums = emit_to.take_needed(&mut self.sums); let nulls = self.null_state.build(emit_to); - assert_eq!(nulls.len(), sums.len()); + if let Some(nulls) = &nulls { + assert_eq!(nulls.len(), sums.len()); + } assert_eq!(counts.len(), sums.len()); // don't evaluate averages with null inputs to avoid errors on null values - let array: PrimitiveArray = if nulls.null_count() > 0 { + let array: PrimitiveArray = if let Some(nulls) = &nulls + && nulls.null_count() > 0 + { let mut builder = PrimitiveBuilder::::with_capacity(nulls.len()) .with_data_type(self.return_data_type.clone()); let iter = sums.into_iter().zip(counts).zip(nulls.iter()); @@ -860,7 +865,7 @@ where .zip(counts.into_iter()) .map(|(sum, count)| (self.avg_fn)(sum, count)) .collect::>>()?; - PrimitiveArray::new(averages.into(), Some(nulls)) // no copy + PrimitiveArray::new(averages.into(), nulls) // no copy .with_data_type(self.return_data_type.clone()) }; @@ -870,7 +875,6 @@ where // return arrays for sums and counts fn state(&mut self, emit_to: EmitTo) -> Result> { let nulls = self.null_state.build(emit_to); - let nulls = Some(nulls); let counts = emit_to.take_needed(&mut self.counts); let counts = UInt64Array::new(counts.into(), nulls.clone()); // zero copy @@ -904,7 +908,9 @@ where opt_filter, total_num_groups, |group_index, partial_count| { - self.counts[group_index] += partial_count; + // SAFETY: group_index is guaranteed to be in bounds + let count = unsafe { self.counts.get_unchecked_mut(group_index) }; + *count += partial_count; }, ); @@ -916,7 +922,8 @@ where opt_filter, total_num_groups, |group_index, new_value: ::Native| { - let sum = &mut self.sums[group_index]; + // SAFETY: group_index is guaranteed to be in bounds + let sum = unsafe { self.sums.get_unchecked_mut(group_index) }; *sum = sum.add_wrapping(new_value); }, ); diff --git a/datafusion/functions-aggregate/src/bool_and_or.rs b/datafusion/functions-aggregate/src/bool_and_or.rs index a107024e2fb4..77b99cd1ae99 100644 --- a/datafusion/functions-aggregate/src/bool_and_or.rs +++ b/datafusion/functions-aggregate/src/bool_and_or.rs @@ -114,11 +114,7 @@ pub struct BoolAnd { impl BoolAnd { fn new() -> Self { Self { - signature: Signature::uniform( - 1, - vec![DataType::Boolean], - Volatility::Immutable, - ), + signature: Signature::exact(vec![DataType::Boolean], Volatility::Immutable), } } } @@ -251,11 +247,7 @@ pub struct BoolOr { impl BoolOr { fn new() -> Self { Self { - signature: Signature::uniform( - 1, - vec![DataType::Boolean], - Volatility::Immutable, - ), + signature: Signature::exact(vec![DataType::Boolean], Volatility::Immutable), } } } diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs index 538311dfa263..6c76c6e94009 100644 --- a/datafusion/functions-aggregate/src/correlation.rs +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -367,7 +367,7 @@ fn accumulate_correlation_states( /// where: /// n = number of observations /// sum_x = sum of x values -/// sum_y = sum of y values +/// sum_y = sum of y values /// sum_xy = sum of (x * y) /// sum_xx = sum of x^2 values /// sum_yy = sum of y^2 values @@ -411,11 +411,15 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator { } fn evaluate(&mut self, emit_to: EmitTo) -> Result { - let n = match emit_to { - EmitTo::All => self.count.len(), - EmitTo::First(n) => n, - }; - + // Drain the state vectors for the groups being emitted + let counts = emit_to.take_needed(&mut self.count); + let sum_xs = emit_to.take_needed(&mut self.sum_x); + let sum_ys = emit_to.take_needed(&mut self.sum_y); + let sum_xys = emit_to.take_needed(&mut self.sum_xy); + let sum_xxs = emit_to.take_needed(&mut self.sum_xx); + let sum_yys = emit_to.take_needed(&mut self.sum_yy); + + let n = counts.len(); let mut values = Vec::with_capacity(n); let mut nulls = NullBufferBuilder::new(n); @@ -427,14 +431,13 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator { // result should be `Null` (according to PostgreSQL's behavior). // - However, if any of the accumulated values contain NaN, the result should // be NaN regardless of the count (even for single-row groups). - // for i in 0..n { - let count = self.count[i]; - let sum_x = self.sum_x[i]; - let sum_y = self.sum_y[i]; - let sum_xy = self.sum_xy[i]; - let sum_xx = self.sum_xx[i]; - let sum_yy = self.sum_yy[i]; + let count = counts[i]; + let sum_x = sum_xs[i]; + let sum_y = sum_ys[i]; + let sum_xy = sum_xys[i]; + let sum_xx = sum_xxs[i]; + let sum_yy = sum_yys[i]; // If BOTH sum_x AND sum_y are NaN, then both input values are NaN → return NaN // If only ONE of them is NaN, then only one input value is NaN → return NULL @@ -470,18 +473,21 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator { } fn state(&mut self, emit_to: EmitTo) -> Result> { - let n = match emit_to { - EmitTo::All => self.count.len(), - EmitTo::First(n) => n, - }; + // Drain the state vectors for the groups being emitted + let count = emit_to.take_needed(&mut self.count); + let sum_x = emit_to.take_needed(&mut self.sum_x); + let sum_y = emit_to.take_needed(&mut self.sum_y); + let sum_xy = emit_to.take_needed(&mut self.sum_xy); + let sum_xx = emit_to.take_needed(&mut self.sum_xx); + let sum_yy = emit_to.take_needed(&mut self.sum_yy); Ok(vec![ - Arc::new(UInt64Array::from(self.count[0..n].to_vec())), - Arc::new(Float64Array::from(self.sum_x[0..n].to_vec())), - Arc::new(Float64Array::from(self.sum_y[0..n].to_vec())), - Arc::new(Float64Array::from(self.sum_xy[0..n].to_vec())), - Arc::new(Float64Array::from(self.sum_xx[0..n].to_vec())), - Arc::new(Float64Array::from(self.sum_yy[0..n].to_vec())), + Arc::new(UInt64Array::from(count)), + Arc::new(Float64Array::from(sum_x)), + Arc::new(Float64Array::from(sum_y)), + Arc::new(Float64Array::from(sum_xy)), + Arc::new(Float64Array::from(sum_xx)), + Arc::new(Float64Array::from(sum_yy)), ]) } @@ -537,12 +543,12 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator { } fn size(&self) -> usize { - size_of_val(&self.count) - + size_of_val(&self.sum_x) - + size_of_val(&self.sum_y) - + size_of_val(&self.sum_xy) - + size_of_val(&self.sum_xx) - + size_of_val(&self.sum_yy) + self.count.capacity() * size_of::() + + self.sum_x.capacity() * size_of::() + + self.sum_y.capacity() * size_of::() + + self.sum_xy.capacity() * size_of::() + + self.sum_xx.capacity() * size_of::() + + self.sum_yy.capacity() * size_of::() } } diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index a7c819acafea..376cf3974590 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -147,20 +147,11 @@ pub fn count_all_window() -> Expr { ```"#, standard_argument(name = "expression",) )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct Count { signature: Signature, } -impl Debug for Count { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("Count") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for Count { fn default() -> Self { Self::new() @@ -598,7 +589,9 @@ impl GroupsAccumulator for CountGroupsAccumulator { values.logical_nulls().as_ref(), opt_filter, |group_index| { - self.counts[group_index] += 1; + // SAFETY: group_index is guaranteed to be in bounds + let count = unsafe { self.counts.get_unchecked_mut(group_index) }; + *count += 1; }, ); diff --git a/datafusion/functions-aggregate/src/covariance.rs b/datafusion/functions-aggregate/src/covariance.rs index e86d742db3d4..8252cf1b19c4 100644 --- a/datafusion/functions-aggregate/src/covariance.rs +++ b/datafusion/functions-aggregate/src/covariance.rs @@ -17,19 +17,13 @@ //! [`CovarianceSample`]: covariance sample aggregations. -use arrow::datatypes::FieldRef; -use arrow::{ - array::{ArrayRef, Float64Array, UInt64Array}, - compute::kernels::cast, - datatypes::{DataType, Field}, -}; -use datafusion_common::{ - Result, ScalarValue, downcast_value, plan_err, unwrap_or_internal_err, -}; +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::cast::{as_float64_array, as_uint64_array}; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, function::{AccumulatorArgs, StateFieldsArgs}, - type_coercion::aggregates::NUMERICS, utils::format_state_name, }; use datafusion_functions_aggregate_common::stats::StatsType; @@ -69,21 +63,12 @@ make_udaf_expr_and_func!( standard_argument(name = "expression1", prefix = "First"), standard_argument(name = "expression2", prefix = "Second") )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct CovarianceSample { signature: Signature, aliases: Vec, } -impl Debug for CovarianceSample { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("CovarianceSample") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for CovarianceSample { fn default() -> Self { Self::new() @@ -94,7 +79,10 @@ impl CovarianceSample { pub fn new() -> Self { Self { aliases: vec![String::from("covar")], - signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + signature: Signature::exact( + vec![DataType::Float64, DataType::Float64], + Volatility::Immutable, + ), } } } @@ -112,11 +100,7 @@ impl AggregateUDFImpl for CovarianceSample { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!("Covariance requires numeric input types"); - } - + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Float64) } @@ -165,20 +149,11 @@ impl AggregateUDFImpl for CovarianceSample { standard_argument(name = "expression1", prefix = "First"), standard_argument(name = "expression2", prefix = "Second") )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct CovariancePopulation { signature: Signature, } -impl Debug for CovariancePopulation { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("CovariancePopulation") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for CovariancePopulation { fn default() -> Self { Self::new() @@ -188,7 +163,10 @@ impl Default for CovariancePopulation { impl CovariancePopulation { pub fn new() -> Self { Self { - signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + signature: Signature::exact( + vec![DataType::Float64, DataType::Float64], + Volatility::Immutable, + ), } } } @@ -206,11 +184,7 @@ impl AggregateUDFImpl for CovariancePopulation { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!("Covariance requires numeric input types"); - } - + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Float64) } @@ -304,30 +278,15 @@ impl Accumulator for CovarianceAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values1 = &cast(&values[0], &DataType::Float64)?; - let values2 = &cast(&values[1], &DataType::Float64)?; - - let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten(); - let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten(); + let values1 = as_float64_array(&values[0])?; + let values2 = as_float64_array(&values[1])?; - for i in 0..values1.len() { - let value1 = if values1.is_valid(i) { - arr1.next() - } else { - None - }; - let value2 = if values2.is_valid(i) { - arr2.next() - } else { - None + for (value1, value2) in values1.iter().zip(values2) { + let (value1, value2) = match (value1, value2) { + (Some(a), Some(b)) => (a, b), + _ => continue, }; - if value1.is_none() || value2.is_none() { - continue; - } - - let value1 = unwrap_or_internal_err!(value1); - let value2 = unwrap_or_internal_err!(value2); let new_count = self.count + 1; let delta1 = value1 - self.mean1; let new_mean1 = delta1 / new_count as f64 + self.mean1; @@ -345,29 +304,14 @@ impl Accumulator for CovarianceAccumulator { } fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values1 = &cast(&values[0], &DataType::Float64)?; - let values2 = &cast(&values[1], &DataType::Float64)?; - let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten(); - let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten(); - - for i in 0..values1.len() { - let value1 = if values1.is_valid(i) { - arr1.next() - } else { - None - }; - let value2 = if values2.is_valid(i) { - arr2.next() - } else { - None - }; - - if value1.is_none() || value2.is_none() { - continue; - } + let values1 = as_float64_array(&values[0])?; + let values2 = as_float64_array(&values[1])?; - let value1 = unwrap_or_internal_err!(value1); - let value2 = unwrap_or_internal_err!(value2); + for (value1, value2) in values1.iter().zip(values2) { + let (value1, value2) = match (value1, value2) { + (Some(a), Some(b)) => (a, b), + _ => continue, + }; let new_count = self.count - 1; let delta1 = self.mean1 - value1; @@ -386,10 +330,10 @@ impl Accumulator for CovarianceAccumulator { } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let counts = downcast_value!(states[0], UInt64Array); - let means1 = downcast_value!(states[1], Float64Array); - let means2 = downcast_value!(states[2], Float64Array); - let cs = downcast_value!(states[3], Float64Array); + let counts = as_uint64_array(&states[0])?; + let means1 = as_float64_array(&states[1])?; + let means2 = as_float64_array(&states[2])?; + let cs = as_float64_array(&states[3])?; for i in 0..counts.len() { let c = counts.value(i); diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 5f3490f535a4..b339479b35e9 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -90,22 +90,12 @@ pub fn last_value(expression: Expr, order_by: Vec) -> Expr { ```"#, standard_argument(name = "expression",) )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct FirstValue { signature: Signature, is_input_pre_ordered: bool, } -impl Debug for FirstValue { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("FirstValue") - .field("name", &self.name()) - .field("signature", &self.signature) - .field("accumulator", &"") - .finish() - } -} - impl Default for FirstValue { fn default() -> Self { Self::new() @@ -1040,22 +1030,12 @@ impl Accumulator for FirstValueAccumulator { ```"#, standard_argument(name = "expression",) )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct LastValue { signature: Signature, is_input_pre_ordered: bool, } -impl Debug for LastValue { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("LastValue") - .field("name", &self.name()) - .field("signature", &self.signature) - .field("accumulator", &"") - .finish() - } -} - impl Default for LastValue { fn default() -> Self { Self::new() diff --git a/datafusion/functions-aggregate/src/grouping.rs b/datafusion/functions-aggregate/src/grouping.rs index 43218b1147d3..c7af2df4b10f 100644 --- a/datafusion/functions-aggregate/src/grouping.rs +++ b/datafusion/functions-aggregate/src/grouping.rs @@ -18,7 +18,6 @@ //! Defines physical expressions that can evaluated at runtime during query execution use std::any::Any; -use std::fmt; use arrow::datatypes::Field; use arrow::datatypes::{DataType, FieldRef}; @@ -60,20 +59,11 @@ make_udaf_expr_and_func!( description = "Expression to evaluate whether data is aggregated across the specified column. Can be a constant, column, or function." ) )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct Grouping { signature: Signature, } -impl fmt::Debug for Grouping { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("Grouping") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for Grouping { fn default() -> Self { Self::new() diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index f364b785ddae..1b9996220d88 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -24,8 +24,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] -// https://github.com/apache/datafusion/issues/18881 -#![deny(clippy::allow_attributes)] //! Aggregate Function packages for [DataFusion]. //! diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index f137ae0801f0..db769918d135 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -85,20 +85,11 @@ make_udaf_expr_and_func!( /// If using the distinct variation, the memory usage will be similarly high if the /// cardinality is high as it stores all distinct values in memory before computing the /// result, but if cardinality is low then memory usage will also be lower. -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct Median { signature: Signature, } -impl Debug for Median { - fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - f.debug_struct("Median") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for Median { fn default() -> Self { Self::new() @@ -566,10 +557,8 @@ impl Accumulator for DistinctMedianAccumulator { } fn evaluate(&mut self) -> Result { - let mut d = std::mem::take(&mut self.distinct_values.values) - .into_iter() - .map(|v| v.0) - .collect::>(); + let mut d: Vec = + self.distinct_values.values.iter().map(|v| v.0).collect(); let median = calculate_median::(&mut d); ScalarValue::new_primitive::(median, &self.data_type) } diff --git a/datafusion/functions-aggregate/src/percentile_cont.rs b/datafusion/functions-aggregate/src/percentile_cont.rs index a4e8332626b0..1aa150b56350 100644 --- a/datafusion/functions-aggregate/src/percentile_cont.rs +++ b/datafusion/functions-aggregate/src/percentile_cont.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; use std::fmt::Debug; use std::mem::{size_of, size_of_val}; use std::sync::Arc; @@ -25,11 +26,11 @@ use arrow::array::{ use arrow::buffer::{OffsetBuffer, ScalarBuffer}; use arrow::{ array::{Array, ArrayRef, AsArray}, - datatypes::{ - ArrowNativeType, DataType, Field, FieldRef, Float16Type, Float32Type, Float64Type, - }, + datatypes::{DataType, Field, FieldRef, Float16Type, Float32Type, Float64Type}, }; +use num_traits::AsPrimitive; + use arrow::array::ArrowNativeTypeOp; use datafusion_common::internal_err; use datafusion_common::types::{NativeType, logical_float64}; @@ -48,11 +49,11 @@ use datafusion_expr::{EmitTo, GroupsAccumulator}; use datafusion_expr::{ expr::{AggregateFunction, Sort}, function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs}, - simplify::SimplifyInfo, + simplify::SimplifyContext, }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask; -use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer; +use datafusion_functions_aggregate_common::utils::{GenericDistinctBuffer, Hashable}; use datafusion_macros::user_doc; use crate::utils::validate_percentile_expr; @@ -67,7 +68,10 @@ use crate::utils::validate_percentile_expr; /// The interpolation formula: `lower + (upper - lower) * fraction` /// is computed as: `lower + ((upper - lower) * (fraction * PRECISION)) / PRECISION` /// to avoid floating-point operations on integer types while maintaining precision. -const INTERPOLATION_PRECISION: usize = 1_000_000; +/// +/// The interpolation arithmetic is performed in f64 and then cast back to the +/// native type to avoid overflowing Float16 intermediates. +const INTERPOLATION_PRECISION: f64 = 1_000_000.0; create_func!(PercentileCont, percentile_cont_udaf); @@ -309,7 +313,7 @@ fn get_percentile(args: &AccumulatorArgs) -> Result { fn simplify_percentile_cont_aggregate( aggregate_function: AggregateFunction, - info: &dyn SimplifyInfo, + info: &SimplifyContext, ) -> Result { enum PercentileRewriteTarget { Min, @@ -388,7 +392,12 @@ impl PercentileContAccumulator { } } -impl Accumulator for PercentileContAccumulator { +impl Accumulator for PercentileContAccumulator +where + T: ArrowNumericType + Debug, + T::Native: Copy + AsPrimitive, + f64: AsPrimitive, +{ fn state(&mut self) -> Result> { // Convert `all_values` to `ListArray` and return a single List ScalarValue @@ -427,14 +436,48 @@ impl Accumulator for PercentileContAccumulator { } fn evaluate(&mut self) -> Result { - let d = std::mem::take(&mut self.all_values); - let value = calculate_percentile::(d, self.percentile); + let value = calculate_percentile::(&mut self.all_values, self.percentile); ScalarValue::new_primitive::(value, &T::DATA_TYPE) } fn size(&self) -> usize { size_of_val(self) + self.all_values.capacity() * size_of::() } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let mut to_remove: HashMap = HashMap::new(); + for i in 0..values[0].len() { + let v = ScalarValue::try_from_array(&values[0], i)?; + if !v.is_null() { + *to_remove.entry(v).or_default() += 1; + } + } + + let mut i = 0; + while i < self.all_values.len() { + let k = + ScalarValue::new_primitive::(Some(self.all_values[i]), &T::DATA_TYPE)?; + if let Some(count) = to_remove.get_mut(&k) + && *count > 0 + { + self.all_values.swap_remove(i); + *count -= 1; + if *count == 0 { + to_remove.remove(&k); + if to_remove.is_empty() { + break; + } + } + } else { + i += 1; + } + } + Ok(()) + } + + fn supports_retract_batch(&self) -> bool { + true + } } /// The percentile_cont groups accumulator accumulates the raw input values @@ -458,8 +501,11 @@ impl PercentileContGroupsAccumulator { } } -impl GroupsAccumulator - for PercentileContGroupsAccumulator +impl GroupsAccumulator for PercentileContGroupsAccumulator +where + T: ArrowNumericType + Send, + T::Native: Copy + AsPrimitive, + f64: AsPrimitive, { fn update_batch( &mut self, @@ -549,13 +595,13 @@ impl GroupsAccumulator fn evaluate(&mut self, emit_to: EmitTo) -> Result { // Emit values - let emit_group_values = emit_to.take_needed(&mut self.group_values); + let mut emit_group_values = emit_to.take_needed(&mut self.group_values); // Calculate percentile for each group let mut evaluate_result_builder = PrimitiveBuilder::::with_capacity(emit_group_values.len()); - for values in emit_group_values { - let value = calculate_percentile::(values, self.percentile); + for values in &mut emit_group_values { + let value = calculate_percentile::(values.as_mut_slice(), self.percentile); evaluate_result_builder.append_option(value); } @@ -638,7 +684,12 @@ impl DistinctPercentileContAccumulator { } } -impl Accumulator for DistinctPercentileContAccumulator { +impl Accumulator for DistinctPercentileContAccumulator +where + T: ArrowNumericType + Debug, + T::Native: Copy + AsPrimitive, + f64: AsPrimitive, +{ fn state(&mut self) -> Result> { self.distinct_values.state() } @@ -652,17 +703,31 @@ impl Accumulator for DistinctPercentileContAccumula } fn evaluate(&mut self) -> Result { - let d = std::mem::take(&mut self.distinct_values.values) - .into_iter() - .map(|v| v.0) - .collect::>(); - let value = calculate_percentile::(d, self.percentile); + let mut values: Vec = + self.distinct_values.values.iter().map(|v| v.0).collect(); + let value = calculate_percentile::(&mut values, self.percentile); ScalarValue::new_primitive::(value, &T::DATA_TYPE) } fn size(&self) -> usize { size_of_val(self) + self.distinct_values.size() } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let arr = values[0].as_primitive::(); + for value in arr.iter().flatten() { + self.distinct_values.values.remove(&Hashable(value)); + } + Ok(()) + } + + fn supports_retract_batch(&self) -> bool { + true + } } /// Calculate the percentile value for a given set of values. @@ -672,10 +737,18 @@ impl Accumulator for DistinctPercentileContAccumula /// For percentile p and n values: /// - If p * (n-1) is an integer, return the value at that position /// - Otherwise, interpolate between the two closest values +/// +/// Note: This function takes a mutable slice and sorts it in place, but does not +/// consume the data. This is important for window frame queries where evaluate() +/// may be called multiple times on the same accumulator state. fn calculate_percentile( - mut values: Vec, + values: &mut [T::Native], percentile: f64, -) -> Option { +) -> Option +where + T::Native: Copy + AsPrimitive, + f64: AsPrimitive, +{ let cmp = |x: &T::Native, y: &T::Native| x.compare(*y); let len = values.len(); @@ -719,22 +792,47 @@ fn calculate_percentile( let (_, upper_value, _) = values.select_nth_unstable_by(upper_index, cmp); let upper_value = *upper_value; - // Linear interpolation using wrapping arithmetic - // We use wrapping operations here (matching the approach in median.rs) because: - // 1. Both values come from the input data, so diff is bounded by the value range - // 2. fraction is between 0 and 1, and INTERPOLATION_PRECISION is small enough - // to prevent overflow when combined with typical numeric ranges - // 3. The result is guaranteed to be between lower_value and upper_value - // 4. For floating-point types, wrapping ops behave the same as standard ops + // Linear interpolation. + // We compute a quantized interpolation weight using `INTERPOLATION_PRECISION` because: + // 1. Both values come from the input data, so (upper - lower) is bounded by the value range + // 2. fraction is between 0 and 1; quantizing it provides stable, predictable results + // 3. The result is guaranteed to be between lower_value and upper_value (modulo cast rounding) + // 4. Arithmetic is performed in f64 and cast back to avoid overflowing Float16 intermediates let fraction = index - (lower_index as f64); - let diff = upper_value.sub_wrapping(lower_value); - let interpolated = lower_value.add_wrapping( - diff.mul_wrapping(T::Native::usize_as( - (fraction * INTERPOLATION_PRECISION as f64) as usize, - )) - .div_wrapping(T::Native::usize_as(INTERPOLATION_PRECISION)), - ); - Some(interpolated) + let scaled = (fraction * INTERPOLATION_PRECISION) as usize; + let weight = scaled as f64 / INTERPOLATION_PRECISION; + + let lower_f: f64 = lower_value.as_(); + let upper_f: f64 = upper_value.as_(); + let interpolated_f = lower_f + (upper_f - lower_f) * weight; + Some(interpolated_f.as_()) } } } + +#[cfg(test)] +mod tests { + use super::calculate_percentile; + use half::f16; + + #[test] + fn f16_interpolation_does_not_overflow_to_nan() { + // Regression test for https://github.com/apache/datafusion/issues/18945 + // Interpolating between 0 and the max finite f16 value previously overflowed + // intermediate f16 computations and produced NaN. + let mut values = vec![f16::from_f32(0.0), f16::from_f32(65504.0)]; + let result = + calculate_percentile::(&mut values, 0.5) + .expect("non-empty input"); + let result_f = result.to_f32(); + assert!( + !result_f.is_nan(), + "expected non-NaN result, got {result_f}" + ); + // 0.5 percentile should be close to midpoint + assert!( + (result_f - 32752.0).abs() < 1.0, + "unexpected result {result_f}" + ); + } +} diff --git a/datafusion/functions-aggregate/src/regr.rs b/datafusion/functions-aggregate/src/regr.rs index bbc5567dab9d..7fef8ac981be 100644 --- a/datafusion/functions-aggregate/src/regr.rs +++ b/datafusion/functions-aggregate/src/regr.rs @@ -17,20 +17,12 @@ //! Defines physical expressions that can evaluated at runtime during query execution -use arrow::array::Float64Array; use arrow::datatypes::FieldRef; -use arrow::{ - array::{ArrayRef, UInt64Array}, - compute::cast, - datatypes::DataType, - datatypes::Field, -}; -use datafusion_common::{ - HashMap, Result, ScalarValue, downcast_value, plan_err, unwrap_or_internal_err, -}; +use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; +use datafusion_common::cast::{as_float64_array, as_uint64_array}; +use datafusion_common::{HashMap, Result, ScalarValue}; use datafusion_doc::aggregate_doc_sections::DOC_SECTION_STATISTICAL; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, @@ -58,26 +50,20 @@ make_regr_udaf_expr_and_func!(regr_sxx, regr_sxx_udaf, RegrType::SXX); make_regr_udaf_expr_and_func!(regr_syy, regr_syy_udaf, RegrType::SYY); make_regr_udaf_expr_and_func!(regr_sxy, regr_sxy_udaf, RegrType::SXY); -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct Regr { signature: Signature, regr_type: RegrType, func_name: &'static str, } -impl Debug for Regr { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("regr") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Regr { pub fn new(regr_type: RegrType, func_name: &'static str) -> Self { Self { - signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + signature: Signature::exact( + vec![DataType::Float64, DataType::Float64], + Volatility::Immutable, + ), regr_type, func_name, } @@ -468,12 +454,8 @@ impl AggregateUDFImpl for Regr { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!("Covariance requires numeric input types"); - } - - if matches!(self.regr_type, RegrType::Count) { + fn return_type(&self, _arg_types: &[DataType]) -> Result { + if self.regr_type == RegrType::Count { Ok(DataType::UInt64) } else { Ok(DataType::Float64) @@ -606,32 +588,18 @@ impl Accumulator for RegrAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { // regr_slope(Y, X) calculates k in y = k*x + b - let values_y = &cast(&values[0], &DataType::Float64)?; - let values_x = &cast(&values[1], &DataType::Float64)?; - - let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten(); - let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten(); + let values_y = as_float64_array(&values[0])?; + let values_x = as_float64_array(&values[1])?; - for i in 0..values_y.len() { + for (value_y, value_x) in values_y.iter().zip(values_x) { // skip either x or y is NULL - let value_y = if values_y.is_valid(i) { - arr_y.next() - } else { - None - }; - let value_x = if values_x.is_valid(i) { - arr_x.next() - } else { - None + let (value_y, value_x) = match (value_y, value_x) { + (Some(y), Some(x)) => (y, x), + // skip either x or y is NULL + _ => continue, }; - if value_y.is_none() || value_x.is_none() { - continue; - } // Update states for regr_slope(y,x) [using cov_pop(x,y)/var_pop(x)] - let value_y = unwrap_or_internal_err!(value_y); - let value_x = unwrap_or_internal_err!(value_x); - self.count += 1; let delta_x = value_x - self.mean_x; let delta_y = value_y - self.mean_y; @@ -652,32 +620,18 @@ impl Accumulator for RegrAccumulator { } fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values_y = &cast(&values[0], &DataType::Float64)?; - let values_x = &cast(&values[1], &DataType::Float64)?; - - let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten(); - let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten(); + let values_y = as_float64_array(&values[0])?; + let values_x = as_float64_array(&values[1])?; - for i in 0..values_y.len() { + for (value_y, value_x) in values_y.iter().zip(values_x) { // skip either x or y is NULL - let value_y = if values_y.is_valid(i) { - arr_y.next() - } else { - None + let (value_y, value_x) = match (value_y, value_x) { + (Some(y), Some(x)) => (y, x), + // skip either x or y is NULL + _ => continue, }; - let value_x = if values_x.is_valid(i) { - arr_x.next() - } else { - None - }; - if value_y.is_none() || value_x.is_none() { - continue; - } // Update states for regr_slope(y,x) [using cov_pop(x,y)/var_pop(x)] - let value_y = unwrap_or_internal_err!(value_y); - let value_x = unwrap_or_internal_err!(value_x); - if self.count > 1 { self.count -= 1; let delta_x = value_x - self.mean_x; @@ -703,12 +657,12 @@ impl Accumulator for RegrAccumulator { } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let count_arr = downcast_value!(states[0], UInt64Array); - let mean_x_arr = downcast_value!(states[1], Float64Array); - let mean_y_arr = downcast_value!(states[2], Float64Array); - let m2_x_arr = downcast_value!(states[3], Float64Array); - let m2_y_arr = downcast_value!(states[4], Float64Array); - let algo_const_arr = downcast_value!(states[5], Float64Array); + let count_arr = as_uint64_array(&states[0])?; + let mean_x_arr = as_float64_array(&states[1])?; + let mean_y_arr = as_float64_array(&states[2])?; + let m2_x_arr = as_float64_array(&states[3])?; + let m2_y_arr = as_float64_array(&states[4])?; + let algo_const_arr = as_float64_array(&states[5])?; for i in 0..count_arr.len() { let count_b = count_arr.value(i); diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 13eb5e1660b5..6f77e7df9254 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -18,7 +18,7 @@ //! Defines physical expressions that can evaluated at runtime during query execution use std::any::Any; -use std::fmt::{Debug, Formatter}; +use std::fmt::Debug; use std::hash::Hash; use std::mem::align_of_val; use std::sync::Arc; @@ -26,8 +26,8 @@ use std::sync::Arc; use arrow::array::Float64Array; use arrow::datatypes::FieldRef; use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; +use datafusion_common::ScalarValue; use datafusion_common::{Result, internal_err, not_impl_err}; -use datafusion_common::{ScalarValue, plan_err}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ @@ -62,21 +62,12 @@ make_udaf_expr_and_func!( standard_argument(name = "expression",) )] /// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct Stddev { signature: Signature, alias: Vec, } -impl Debug for Stddev { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Stddev") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for Stddev { fn default() -> Self { Self::new() @@ -87,7 +78,7 @@ impl Stddev { /// Create a new STDDEV aggregate function pub fn new() -> Self { Self { - signature: Signature::numeric(1, Volatility::Immutable), + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), alias: vec!["stddev_samp".to_string()], } } @@ -180,20 +171,11 @@ make_udaf_expr_and_func!( standard_argument(name = "expression",) )] /// STDDEV_POP population aggregate expression -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct StddevPop { signature: Signature, } -impl Debug for StddevPop { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("StddevPop") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for StddevPop { fn default() -> Self { Self::new() @@ -204,7 +186,7 @@ impl StddevPop { /// Create a new STDDEV_POP aggregate function pub fn new() -> Self { Self { - signature: Signature::numeric(1, Volatility::Immutable), + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), } } } @@ -249,11 +231,7 @@ impl AggregateUDFImpl for StddevPop { Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?)) } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!("StddevPop requires numeric input types"); - } - + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Float64) } @@ -318,13 +296,8 @@ impl Accumulator for StddevAccumulator { fn evaluate(&mut self) -> Result { let variance = self.variance.evaluate()?; match variance { - ScalarValue::Float64(e) => { - if e.is_none() { - Ok(ScalarValue::Float64(None)) - } else { - Ok(ScalarValue::Float64(e.map(|f| f.sqrt()))) - } - } + ScalarValue::Float64(None) => Ok(ScalarValue::Float64(None)), + ScalarValue::Float64(Some(f)) => Ok(ScalarValue::Float64(Some(f.sqrt()))), _ => internal_err!("Variance should be f64"), } } diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 77e9f60afd3c..1c10818c091d 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -384,14 +384,13 @@ impl Accumulator for SimpleStringAggAccumulator { } fn evaluate(&mut self) -> Result { - let result = if self.has_value { - ScalarValue::LargeUtf8(Some(std::mem::take(&mut self.accumulated_string))) + if self.has_value { + Ok(ScalarValue::LargeUtf8(Some( + self.accumulated_string.clone(), + ))) } else { - ScalarValue::LargeUtf8(None) - }; - - self.has_value = false; - Ok(result) + Ok(ScalarValue::LargeUtf8(None)) + } } fn size(&self) -> usize { diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index e6978c15d0bf..fb089ba4f9ce 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -18,20 +18,21 @@ //! [`VarianceSample`]: variance sample aggregations. //! [`VariancePopulation`]: variance population aggregations. -use arrow::datatypes::FieldRef; +use arrow::datatypes::{FieldRef, Float64Type}; use arrow::{ array::{Array, ArrayRef, BooleanArray, Float64Array, UInt64Array}, buffer::NullBuffer, - compute::kernels::cast, datatypes::{DataType, Field}, }; -use datafusion_common::{Result, ScalarValue, downcast_value, not_impl_err, plan_err}; +use datafusion_common::cast::{as_float64_array, as_uint64_array}; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature, Volatility, function::{AccumulatorArgs, StateFieldsArgs}, utils::format_state_name, }; +use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer; use datafusion_functions_aggregate_common::{ aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType, }; @@ -61,21 +62,12 @@ make_udaf_expr_and_func!( syntax_example = "var(expression)", standard_argument(name = "expression", prefix = "Numeric") )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct VarianceSample { signature: Signature, aliases: Vec, } -impl Debug for VarianceSample { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("VarianceSample") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for VarianceSample { fn default() -> Self { Self::new() @@ -86,7 +78,7 @@ impl VarianceSample { pub fn new() -> Self { Self { aliases: vec![String::from("var_sample"), String::from("var_samp")], - signature: Signature::numeric(1, Volatility::Immutable), + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), } } } @@ -110,19 +102,35 @@ impl AggregateUDFImpl for VarianceSample { fn state_fields(&self, args: StateFieldsArgs) -> Result> { let name = args.name; - Ok(vec![ - Field::new(format_state_name(name, "count"), DataType::UInt64, true), - Field::new(format_state_name(name, "mean"), DataType::Float64, true), - Field::new(format_state_name(name, "m2"), DataType::Float64, true), - ] - .into_iter() - .map(Arc::new) - .collect()) + match args.is_distinct { + false => Ok(vec![ + Field::new(format_state_name(name, "count"), DataType::UInt64, true), + Field::new(format_state_name(name, "mean"), DataType::Float64, true), + Field::new(format_state_name(name, "m2"), DataType::Float64, true), + ] + .into_iter() + .map(Arc::new) + .collect()), + true => { + let field = Field::new_list_field(DataType::Float64, true); + let state_name = "distinct_var"; + Ok(vec![ + Field::new( + format_state_name(name, state_name), + DataType::List(Arc::new(field)), + true, + ) + .into(), + ]) + } + } } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { if acc_args.is_distinct { - return not_impl_err!("VAR(DISTINCT) aggregations are not available"); + return Ok(Box::new(DistinctVarianceAccumulator::new( + StatsType::Sample, + ))); } Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?)) @@ -154,21 +162,12 @@ impl AggregateUDFImpl for VarianceSample { syntax_example = "var_pop(expression)", standard_argument(name = "expression", prefix = "Numeric") )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct VariancePopulation { signature: Signature, aliases: Vec, } -impl Debug for VariancePopulation { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("VariancePopulation") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for VariancePopulation { fn default() -> Self { Self::new() @@ -179,7 +178,7 @@ impl VariancePopulation { pub fn new() -> Self { Self { aliases: vec![String::from("var_population")], - signature: Signature::numeric(1, Volatility::Immutable), + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), } } } @@ -197,29 +196,43 @@ impl AggregateUDFImpl for VariancePopulation { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!("Variance requires numeric input types"); - } - + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Float64) } fn state_fields(&self, args: StateFieldsArgs) -> Result> { - let name = args.name; - Ok(vec![ - Field::new(format_state_name(name, "count"), DataType::UInt64, true), - Field::new(format_state_name(name, "mean"), DataType::Float64, true), - Field::new(format_state_name(name, "m2"), DataType::Float64, true), - ] - .into_iter() - .map(Arc::new) - .collect()) + match args.is_distinct { + false => { + let name = args.name; + Ok(vec![ + Field::new(format_state_name(name, "count"), DataType::UInt64, true), + Field::new(format_state_name(name, "mean"), DataType::Float64, true), + Field::new(format_state_name(name, "m2"), DataType::Float64, true), + ] + .into_iter() + .map(Arc::new) + .collect()) + } + true => { + let field = Field::new_list_field(DataType::Float64, true); + let state_name = "distinct_var"; + Ok(vec![ + Field::new( + format_state_name(args.name, state_name), + DataType::List(Arc::new(field)), + true, + ) + .into(), + ]) + } + } } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { if acc_args.is_distinct { - return not_impl_err!("VAR_POP(DISTINCT) aggregations are not available"); + return Ok(Box::new(DistinctVarianceAccumulator::new( + StatsType::Population, + ))); } Ok(Box::new(VarianceAccumulator::try_new( @@ -243,6 +256,7 @@ impl AggregateUDFImpl for VariancePopulation { StatsType::Population, ))) } + fn documentation(&self) -> Option<&Documentation> { self.doc() } @@ -330,10 +344,8 @@ impl Accumulator for VarianceAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &cast(&values[0], &DataType::Float64)?; - let arr = downcast_value!(values, Float64Array).iter().flatten(); - - for value in arr { + let arr = as_float64_array(&values[0])?; + for value in arr.iter().flatten() { (self.count, self.mean, self.m2) = update(self.count, self.mean, self.m2, value) } @@ -342,10 +354,8 @@ impl Accumulator for VarianceAccumulator { } fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &cast(&values[0], &DataType::Float64)?; - let arr = downcast_value!(values, Float64Array).iter().flatten(); - - for value in arr { + let arr = as_float64_array(&values[0])?; + for value in arr.iter().flatten() { let new_count = self.count - 1; let delta1 = self.mean - value; let new_mean = delta1 / new_count as f64 + self.mean; @@ -361,9 +371,9 @@ impl Accumulator for VarianceAccumulator { } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let counts = downcast_value!(states[0], UInt64Array); - let means = downcast_value!(states[1], Float64Array); - let m2s = downcast_value!(states[2], Float64Array); + let counts = as_uint64_array(&states[0])?; + let means = as_float64_array(&states[1])?; + let m2s = as_float64_array(&states[2])?; for i in 0..counts.len() { let c = counts.value(i); @@ -498,8 +508,7 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 1, "single argument to update_batch"); - let values = &cast(&values[0], &DataType::Float64)?; - let values = downcast_value!(values, Float64Array); + let values = as_float64_array(&values[0])?; self.resize(total_num_groups); accumulate(group_indices, values, opt_filter, |group_index, value| { @@ -526,9 +535,9 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { ) -> Result<()> { assert_eq!(values.len(), 3, "two arguments to merge_batch"); // first batch is counts, second is partial means, third is partial m2s - let partial_counts = downcast_value!(values[0], UInt64Array); - let partial_means = downcast_value!(values[1], Float64Array); - let partial_m2s = downcast_value!(values[2], Float64Array); + let partial_counts = as_uint64_array(&values[0])?; + let partial_means = as_float64_array(&values[1])?; + let partial_m2s = as_float64_array(&values[2])?; self.resize(total_num_groups); Self::merge( @@ -581,6 +590,71 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { } } +#[derive(Debug)] +pub struct DistinctVarianceAccumulator { + distinct_values: GenericDistinctBuffer, + stat_type: StatsType, +} + +impl DistinctVarianceAccumulator { + pub fn new(stat_type: StatsType) -> Self { + Self { + distinct_values: GenericDistinctBuffer::::new(DataType::Float64), + stat_type, + } + } +} + +impl Accumulator for DistinctVarianceAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.distinct_values.update_batch(values) + } + + fn evaluate(&mut self) -> Result { + let values = self + .distinct_values + .values + .iter() + .map(|v| v.0) + .collect::>(); + + let count = match self.stat_type { + StatsType::Sample => { + if !values.is_empty() { + values.len() - 1 + } else { + 0 + } + } + StatsType::Population => values.len(), + }; + + let mean = values.iter().sum::() / values.len() as f64; + let m2 = values.iter().map(|x| (x - mean) * (x - mean)).sum::(); + + Ok(ScalarValue::Float64(match values.len() { + 0 => None, + 1 => match self.stat_type { + StatsType::Population => Some(0.0), + StatsType::Sample => None, + }, + _ => Some(m2 / count as f64), + })) + } + + fn size(&self) -> usize { + size_of_val(self) + self.distinct_values.size() + } + + fn state(&mut self) -> Result> { + self.distinct_values.state() + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.distinct_values.merge_batch(states) + } +} + #[cfg(test)] mod tests { use datafusion_expr::EmitTo; diff --git a/datafusion/functions-nested/Cargo.toml b/datafusion/functions-nested/Cargo.toml index 6b0241a10a54..0b26170dbb74 100644 --- a/datafusion/functions-nested/Cargo.toml +++ b/datafusion/functions-nested/Cargo.toml @@ -57,7 +57,9 @@ datafusion-functions-aggregate = { workspace = true } datafusion-functions-aggregate-common = { workspace = true } datafusion-macros = { workspace = true } datafusion-physical-expr-common = { workspace = true } +hashbrown = { workspace = true } itertools = { workspace = true, features = ["use_std"] } +itoa = { workspace = true } log = { workspace = true } paste = { workspace = true } @@ -84,3 +86,23 @@ name = "array_slice" [[bench]] harness = false name = "map" + +[[bench]] +harness = false +name = "array_remove" + +[[bench]] +harness = false +name = "array_repeat" + +[[bench]] +harness = false +name = "array_set_ops" + +[[bench]] +harness = false +name = "array_to_string" + +[[bench]] +harness = false +name = "array_position" diff --git a/datafusion/functions-nested/benches/array_expression.rs b/datafusion/functions-nested/benches/array_expression.rs index 8d72ffa3c1cd..ad9f565f4d64 100644 --- a/datafusion/functions-nested/benches/array_expression.rs +++ b/datafusion/functions-nested/benches/array_expression.rs @@ -15,11 +15,7 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -extern crate arrow; - -use crate::criterion::Criterion; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_expr::lit; use datafusion_functions_nested::expr_fn::{array_replace_all, make_array}; use std::hint::black_box; diff --git a/datafusion/functions-nested/benches/array_has.rs b/datafusion/functions-nested/benches/array_has.rs index a44a80c6ae63..f5e66d56c0ef 100644 --- a/datafusion/functions-nested/benches/array_has.rs +++ b/datafusion/functions-nested/benches/array_has.rs @@ -15,20 +15,31 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; - -use criterion::{BenchmarkId, Criterion}; -use datafusion_expr::lit; -use datafusion_functions_nested::expr_fn::{ - array_has, array_has_all, array_has_any, make_array, +use arrow::array::{ArrayRef, Int64Array, ListArray, StringArray}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::{DataType, Field}; +use criterion::{ + criterion_group, criterion_main, {BenchmarkId, Criterion}, }; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_functions_nested::array_has::{ArrayHas, ArrayHasAll, ArrayHasAny}; +use rand::Rng; +use rand::SeedableRng; +use rand::rngs::StdRng; use std::hint::black_box; +use std::sync::Arc; + +const NUM_ROWS: usize = 10000; +const SEED: u64 = 42; +const NULL_DENSITY: f64 = 0.1; +const NEEDLE_SIZE: usize = 3; // If not explicitly stated, `array` and `array_size` refer to the haystack array. fn criterion_benchmark(c: &mut Criterion) { // Test different array sizes - let array_sizes = vec![1, 10, 100, 1000, 10000]; + let array_sizes = vec![10, 100, 500]; for &size in &array_sizes { bench_array_has(c, size); @@ -41,49 +52,67 @@ fn criterion_benchmark(c: &mut Criterion) { bench_array_has_all_strings(c); bench_array_has_any_strings(c); - // Edge cases - bench_array_has_edge_cases(c); + // Benchmark for array_has_any with one scalar arg + bench_array_has_any_scalar(c); } fn bench_array_has(c: &mut Criterion, array_size: usize) { let mut group = c.benchmark_group("array_has_i64"); - - // Benchmark: element found at beginning - group.bench_with_input( - BenchmarkId::new("found_at_start", array_size), - &array_size, - |b, &size| { - let array = (0..size).map(|i| lit(i as i64)).collect::>(); - let list_array = make_array(array); - let needle = lit(0_i64); - - b.iter(|| black_box(array_has(list_array.clone(), needle.clone()))) - }, - ); - - // Benchmark: element found at end + let list_array = create_int64_list_array(NUM_ROWS, array_size, NULL_DENSITY); + let config_options = Arc::new(ConfigOptions::default()); + let return_field: Arc = Field::new("result", DataType::Boolean, true).into(); + let arg_fields: Vec> = vec![ + Field::new("arr", list_array.data_type().clone(), false).into(), + Field::new("el", DataType::Int64, false).into(), + ]; + + // Benchmark: element found + let args_found = vec![ + ColumnarValue::Array(list_array.clone()), + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ]; group.bench_with_input( - BenchmarkId::new("found_at_end", array_size), + BenchmarkId::new("found", array_size), &array_size, - |b, &size| { - let array = (0..size).map(|i| lit(i as i64)).collect::>(); - let list_array = make_array(array); - let needle = lit((size - 1) as i64); - - b.iter(|| black_box(array_has(list_array.clone(), needle.clone()))) + |b, _| { + let udf = ArrayHas::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_found.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) }, ); // Benchmark: element not found + let args_not_found = vec![ + ColumnarValue::Array(list_array.clone()), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-999))), + ]; group.bench_with_input( BenchmarkId::new("not_found", array_size), &array_size, - |b, &size| { - let array = (0..size).map(|i| lit(i as i64)).collect::>(); - let list_array = make_array(array); - let needle = lit(-1_i64); // Not in array - - b.iter(|| black_box(array_has(list_array.clone(), needle.clone()))) + |b, _| { + let udf = ArrayHas::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_not_found.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) }, ); @@ -92,90 +121,190 @@ fn bench_array_has(c: &mut Criterion, array_size: usize) { fn bench_array_has_all(c: &mut Criterion, array_size: usize) { let mut group = c.benchmark_group("array_has_all"); + let haystack = create_int64_list_array(NUM_ROWS, array_size, NULL_DENSITY); + let list_type = haystack.data_type().clone(); + let config_options = Arc::new(ConfigOptions::default()); + let return_field: Arc = Field::new("result", DataType::Boolean, true).into(); + let arg_fields: Vec> = vec![ + Field::new("haystack", list_type.clone(), false).into(), + Field::new("needle", list_type.clone(), false).into(), + ]; // Benchmark: all elements found (small needle) + let needle_found = create_int64_list_array(NUM_ROWS, NEEDLE_SIZE, 0.0); + let args_found = vec![ + ColumnarValue::Array(haystack.clone()), + ColumnarValue::Array(needle_found), + ]; group.bench_with_input( BenchmarkId::new("all_found_small_needle", array_size), &array_size, - |b, &size| { - let array = (0..size).map(|i| lit(i as i64)).collect::>(); - let list_array = make_array(array); - let needle_array = make_array(vec![lit(0_i64), lit(1_i64), lit(2_i64)]); - - b.iter(|| black_box(array_has_all(list_array.clone(), needle_array.clone()))) + |b, _| { + let udf = ArrayHasAll::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_found.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) }, ); - // Benchmark: all elements found (medium needle - 10% of haystack) + // Benchmark: not all found (needle contains elements outside haystack range) + let needle_missing = + create_int64_list_array_with_offset(NUM_ROWS, NEEDLE_SIZE, array_size as i64); + let args_missing = vec![ + ColumnarValue::Array(haystack.clone()), + ColumnarValue::Array(needle_missing), + ]; group.bench_with_input( - BenchmarkId::new("all_found_medium_needle", array_size), + BenchmarkId::new("not_all_found", array_size), &array_size, - |b, &size| { - let array = (0..size).map(|i| lit(i as i64)).collect::>(); - let list_array = make_array(array); - let needle_size = (size / 10).max(1); - let needle = (0..needle_size).map(|i| lit(i as i64)).collect::>(); - let needle_array = make_array(needle); - - b.iter(|| black_box(array_has_all(list_array.clone(), needle_array.clone()))) - }, - ); - - // Benchmark: not all found (early exit) - group.bench_with_input( - BenchmarkId::new("early_exit", array_size), - &array_size, - |b, &size| { - let array = (0..size).map(|i| lit(i as i64)).collect::>(); - let list_array = make_array(array); - let needle_array = make_array(vec![lit(0_i64), lit(-1_i64)]); // -1 not in array - - b.iter(|| black_box(array_has_all(list_array.clone(), needle_array.clone()))) + |b, _| { + let udf = ArrayHasAll::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_missing.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) }, ); group.finish(); } +const SMALL_ARRAY_SIZE: usize = NEEDLE_SIZE; + fn bench_array_has_any(c: &mut Criterion, array_size: usize) { let mut group = c.benchmark_group("array_has_any"); - - // Benchmark: first element matches (best case) + let first_arr = create_int64_list_array(NUM_ROWS, array_size, NULL_DENSITY); + let list_type = first_arr.data_type().clone(); + let config_options = Arc::new(ConfigOptions::default()); + let return_field: Arc = Field::new("result", DataType::Boolean, true).into(); + let arg_fields: Vec> = vec![ + Field::new("first", list_type.clone(), false).into(), + Field::new("second", list_type.clone(), false).into(), + ]; + + // Benchmark: some elements match + let second_match = create_int64_list_array(NUM_ROWS, SMALL_ARRAY_SIZE, 0.0); + let args_match = vec![ + ColumnarValue::Array(first_arr.clone()), + ColumnarValue::Array(second_match), + ]; group.bench_with_input( - BenchmarkId::new("first_match", array_size), + BenchmarkId::new("some_match", array_size), &array_size, - |b, &size| { - let array = (0..size).map(|i| lit(i as i64)).collect::>(); - let list_array = make_array(array); - let needle_array = make_array(vec![lit(0_i64), lit(-1_i64), lit(-2_i64)]); - - b.iter(|| black_box(array_has_any(list_array.clone(), needle_array.clone()))) + |b, _| { + let udf = ArrayHasAny::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_match.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) }, ); - // Benchmark: last element matches (worst case) + // Benchmark: no match + let second_no_match = create_int64_list_array_with_offset( + NUM_ROWS, + SMALL_ARRAY_SIZE, + array_size as i64, + ); + let args_no_match = vec![ + ColumnarValue::Array(first_arr.clone()), + ColumnarValue::Array(second_no_match), + ]; group.bench_with_input( - BenchmarkId::new("last_match", array_size), + BenchmarkId::new("no_match", array_size), &array_size, - |b, &size| { - let array = (0..size).map(|i| lit(i as i64)).collect::>(); - let list_array = make_array(array); - let needle_array = make_array(vec![lit(-1_i64), lit(-2_i64), lit(0_i64)]); - - b.iter(|| black_box(array_has_any(list_array.clone(), needle_array.clone()))) + |b, _| { + let udf = ArrayHasAny::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_no_match.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) }, ); - // Benchmark: no match + // Benchmark: scalar second arg, some match + let scalar_second_match = create_int64_scalar_list(SMALL_ARRAY_SIZE, 0); + let args_scalar_match = vec![ + ColumnarValue::Array(first_arr.clone()), + ColumnarValue::Scalar(scalar_second_match), + ]; group.bench_with_input( - BenchmarkId::new("no_match", array_size), + BenchmarkId::new("scalar_some_match", array_size), &array_size, - |b, &size| { - let array = (0..size).map(|i| lit(i as i64)).collect::>(); - let list_array = make_array(array); - let needle_array = make_array(vec![lit(-1_i64), lit(-2_i64), lit(-3_i64)]); + |b, _| { + let udf = ArrayHasAny::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_scalar_match.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) + }, + ); - b.iter(|| black_box(array_has_any(list_array.clone(), needle_array.clone()))) + // Benchmark: scalar second arg, no match + let scalar_second_no_match = + create_int64_scalar_list(SMALL_ARRAY_SIZE, array_size as i64); + let args_scalar_no_match = vec![ + ColumnarValue::Array(first_arr.clone()), + ColumnarValue::Scalar(scalar_second_no_match), + ]; + group.bench_with_input( + BenchmarkId::new("scalar_no_match", array_size), + &array_size, + |b, _| { + let udf = ArrayHasAny::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_scalar_no_match.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) }, ); @@ -184,29 +313,56 @@ fn bench_array_has_any(c: &mut Criterion, array_size: usize) { fn bench_array_has_strings(c: &mut Criterion) { let mut group = c.benchmark_group("array_has_strings"); + let config_options = Arc::new(ConfigOptions::default()); + let return_field: Arc = Field::new("result", DataType::Boolean, true).into(); - // Benchmark with string arrays (common use case for tickers, tags, etc.) - let sizes = vec![10, 100, 1000]; + let sizes = vec![10, 100, 500]; for &size in &sizes { - group.bench_with_input(BenchmarkId::new("found", size), &size, |b, &size| { - let array = (0..size) - .map(|i| lit(format!("TICKER{i:04}"))) - .collect::>(); - let list_array = make_array(array); - let needle = lit("TICKER0005"); - - b.iter(|| black_box(array_has(list_array.clone(), needle.clone()))) + let list_array = create_string_list_array(NUM_ROWS, size, NULL_DENSITY); + let arg_fields: Vec> = vec![ + Field::new("arr", list_array.data_type().clone(), false).into(), + Field::new("el", DataType::Utf8, false).into(), + ]; + + let args_found = vec![ + ColumnarValue::Array(list_array.clone()), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("value_1".to_string()))), + ]; + group.bench_with_input(BenchmarkId::new("found", size), &size, |b, _| { + let udf = ArrayHas::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_found.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) }); - group.bench_with_input(BenchmarkId::new("not_found", size), &size, |b, &size| { - let array = (0..size) - .map(|i| lit(format!("TICKER{i:04}"))) - .collect::>(); - let list_array = make_array(array); - let needle = lit("NOTFOUND"); - - b.iter(|| black_box(array_has(list_array.clone(), needle.clone()))) + let args_not_found = vec![ + ColumnarValue::Array(list_array.clone()), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("NOTFOUND".to_string()))), + ]; + group.bench_with_input(BenchmarkId::new("not_found", size), &size, |b, _| { + let udf = ArrayHas::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_not_found.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) }); } @@ -215,49 +371,173 @@ fn bench_array_has_strings(c: &mut Criterion) { fn bench_array_has_all_strings(c: &mut Criterion) { let mut group = c.benchmark_group("array_has_all_strings"); + let config_options = Arc::new(ConfigOptions::default()); + let return_field: Arc = Field::new("result", DataType::Boolean, true).into(); - // Realistic scenario: checking if a portfolio contains certain tickers - let portfolio_size = 100; - let check_sizes = vec![1, 3, 5, 10]; + let sizes = vec![10, 100, 500]; - for &check_size in &check_sizes { - group.bench_with_input( - BenchmarkId::new("all_found", check_size), - &check_size, - |b, &check_size| { - let portfolio = (0..portfolio_size) - .map(|i| lit(format!("TICKER{i:04}"))) - .collect::>(); - let list_array = make_array(portfolio); - - let checking = (0..check_size) - .map(|i| lit(format!("TICKER{i:04}"))) - .collect::>(); - let needle_array = make_array(checking); + for &size in &sizes { + let haystack = create_string_list_array(NUM_ROWS, size, NULL_DENSITY); + let list_type = haystack.data_type().clone(); + let arg_fields: Vec> = vec![ + Field::new("haystack", list_type.clone(), false).into(), + Field::new("needle", list_type.clone(), false).into(), + ]; + + let needle_found = create_string_list_array(NUM_ROWS, NEEDLE_SIZE, 0.0); + let args_found = vec![ + ColumnarValue::Array(haystack.clone()), + ColumnarValue::Array(needle_found), + ]; + group.bench_with_input(BenchmarkId::new("all_found", size), &size, |b, _| { + let udf = ArrayHasAll::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_found.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) + }); + let needle_missing = + create_string_list_array_with_prefix(NUM_ROWS, NEEDLE_SIZE, "missing_"); + let args_missing = vec![ + ColumnarValue::Array(haystack.clone()), + ColumnarValue::Array(needle_missing), + ]; + group.bench_with_input(BenchmarkId::new("not_all_found", size), &size, |b, _| { + let udf = ArrayHasAll::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_missing.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) + }); + } + + group.finish(); +} + +fn bench_array_has_any_strings(c: &mut Criterion) { + let mut group = c.benchmark_group("array_has_any_strings"); + let config_options = Arc::new(ConfigOptions::default()); + let return_field: Arc = Field::new("result", DataType::Boolean, true).into(); + + let sizes = vec![10, 100, 500]; + + for &size in &sizes { + let first_arr = create_string_list_array(NUM_ROWS, size, NULL_DENSITY); + let list_type = first_arr.data_type().clone(); + let arg_fields: Vec> = vec![ + Field::new("first", list_type.clone(), false).into(), + Field::new("second", list_type.clone(), false).into(), + ]; + + let second_match = create_string_list_array(NUM_ROWS, SMALL_ARRAY_SIZE, 0.0); + let args_match = vec![ + ColumnarValue::Array(first_arr.clone()), + ColumnarValue::Array(second_match), + ]; + group.bench_with_input(BenchmarkId::new("some_match", size), &size, |b, _| { + let udf = ArrayHasAny::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_match.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) + }); + + let second_no_match = + create_string_list_array_with_prefix(NUM_ROWS, SMALL_ARRAY_SIZE, "missing_"); + let args_no_match = vec![ + ColumnarValue::Array(first_arr.clone()), + ColumnarValue::Array(second_no_match), + ]; + group.bench_with_input(BenchmarkId::new("no_match", size), &size, |b, _| { + let udf = ArrayHasAny::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_no_match.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) + }); + + // Benchmark: scalar second arg, some match + let scalar_second_match = create_string_scalar_list(SMALL_ARRAY_SIZE, "value_"); + let args_scalar_match = vec![ + ColumnarValue::Array(first_arr.clone()), + ColumnarValue::Scalar(scalar_second_match), + ]; + group.bench_with_input( + BenchmarkId::new("scalar_some_match", size), + &size, + |b, _| { + let udf = ArrayHasAny::new(); b.iter(|| { - black_box(array_has_all(list_array.clone(), needle_array.clone())) + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_scalar_match.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) }) }, ); + // Benchmark: scalar second arg, no match + let scalar_second_no_match = + create_string_scalar_list(SMALL_ARRAY_SIZE, "missing_"); + let args_scalar_no_match = vec![ + ColumnarValue::Array(first_arr.clone()), + ColumnarValue::Scalar(scalar_second_no_match), + ]; group.bench_with_input( - BenchmarkId::new("some_missing", check_size), - &check_size, - |b, &check_size| { - let portfolio = (0..portfolio_size) - .map(|i| lit(format!("TICKER{i:04}"))) - .collect::>(); - let list_array = make_array(portfolio); - - let mut checking = (0..check_size - 1) - .map(|i| lit(format!("TICKER{i:04}"))) - .collect::>(); - checking.push(lit("NOTFOUND".to_string())); - let needle_array = make_array(checking); - + BenchmarkId::new("scalar_no_match", size), + &size, + |b, _| { + let udf = ArrayHasAny::new(); b.iter(|| { - black_box(array_has_all(list_array.clone(), needle_array.clone())) + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_scalar_no_match.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) }) }, ); @@ -266,48 +546,81 @@ fn bench_array_has_all_strings(c: &mut Criterion) { group.finish(); } -fn bench_array_has_any_strings(c: &mut Criterion) { - let mut group = c.benchmark_group("array_has_any_strings"); - - let portfolio_size = 100; - let check_sizes = vec![1, 3, 5, 10]; - - for &check_size in &check_sizes { +/// Benchmarks array_has_any with one scalar arg. Varies the scalar argument +/// size while keeping the columnar array small (3 elements per row). +fn bench_array_has_any_scalar(c: &mut Criterion) { + let mut group = c.benchmark_group("array_has_any_scalar"); + let config_options = Arc::new(ConfigOptions::default()); + let return_field: Arc = Field::new("result", DataType::Boolean, true).into(); + + let array_size = 3; + let scalar_sizes = vec![1, 10, 100, 1000]; + + // i64 benchmarks + let first_arr_i64 = create_int64_list_array(NUM_ROWS, array_size, NULL_DENSITY); + let list_type_i64 = first_arr_i64.data_type().clone(); + let arg_fields_i64: Vec> = vec![ + Field::new("first", list_type_i64.clone(), false).into(), + Field::new("second", list_type_i64.clone(), false).into(), + ]; + + for &scalar_size in &scalar_sizes { + let scalar_arg = create_int64_scalar_list(scalar_size, array_size as i64); + let args = vec![ + ColumnarValue::Array(first_arr_i64.clone()), + ColumnarValue::Scalar(scalar_arg), + ]; group.bench_with_input( - BenchmarkId::new("first_matches", check_size), - &check_size, - |b, &check_size| { - let portfolio = (0..portfolio_size) - .map(|i| lit(format!("TICKER{i:04}"))) - .collect::>(); - let list_array = make_array(portfolio); - - let mut checking = vec![lit("TICKER0000".to_string())]; - checking.extend((1..check_size).map(|_| lit("NOTFOUND".to_string()))); - let needle_array = make_array(checking); - + BenchmarkId::new("i64_no_match", scalar_size), + &scalar_size, + |b, _| { + let udf = ArrayHasAny::new(); b.iter(|| { - black_box(array_has_any(list_array.clone(), needle_array.clone())) + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields_i64.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) }) }, ); + } + // String benchmarks + let first_arr_str = create_string_list_array(NUM_ROWS, array_size, NULL_DENSITY); + let list_type_str = first_arr_str.data_type().clone(); + let arg_fields_str: Vec> = vec![ + Field::new("first", list_type_str.clone(), false).into(), + Field::new("second", list_type_str.clone(), false).into(), + ]; + + for &scalar_size in &scalar_sizes { + let scalar_arg = create_string_scalar_list(scalar_size, "missing_"); + let args = vec![ + ColumnarValue::Array(first_arr_str.clone()), + ColumnarValue::Scalar(scalar_arg), + ]; group.bench_with_input( - BenchmarkId::new("none_match", check_size), - &check_size, - |b, &check_size| { - let portfolio = (0..portfolio_size) - .map(|i| lit(format!("TICKER{i:04}"))) - .collect::>(); - let list_array = make_array(portfolio); - - let checking = (0..check_size) - .map(|i| lit(format!("NOTFOUND{i}"))) - .collect::>(); - let needle_array = make_array(checking); - + BenchmarkId::new("string_no_match", scalar_size), + &scalar_size, + |b, _| { + let udf = ArrayHasAny::new(); b.iter(|| { - black_box(array_has_any(list_array.clone(), needle_array.clone())) + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields_str.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) }) }, ); @@ -316,61 +629,152 @@ fn bench_array_has_any_strings(c: &mut Criterion) { group.finish(); } -fn bench_array_has_edge_cases(c: &mut Criterion) { - let mut group = c.benchmark_group("array_has_edge_cases"); - - // Empty array - group.bench_function("empty_array", |b| { - let list_array = make_array(vec![]); - let needle = lit(1_i64); - - b.iter(|| black_box(array_has(list_array.clone(), needle.clone()))) - }); - - // Single element array - found - group.bench_function("single_element_found", |b| { - let list_array = make_array(vec![lit(1_i64)]); - let needle = lit(1_i64); - - b.iter(|| black_box(array_has(list_array.clone(), needle.clone()))) - }); - - // Single element array - not found - group.bench_function("single_element_not_found", |b| { - let list_array = make_array(vec![lit(1_i64)]); - let needle = lit(2_i64); - - b.iter(|| black_box(array_has(list_array.clone(), needle.clone()))) - }); - - // Array with duplicates - group.bench_function("array_with_duplicates", |b| { - let array = vec![lit(1_i64); 1000]; - let list_array = make_array(array); - let needle = lit(1_i64); - - b.iter(|| black_box(array_has(list_array.clone(), needle.clone()))) - }); +fn create_int64_list_array( + num_rows: usize, + array_size: usize, + null_density: f64, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..num_rows * array_size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(rng.random_range(0..array_size as i64)) + } + }) + .collect::(); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Int64, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} - // array_has_all: empty needle - group.bench_function("array_has_all_empty_needle", |b| { - let array = (0..1000).map(|i| lit(i as i64)).collect::>(); - let list_array = make_array(array); - let needle_array = make_array(vec![]); +/// Like `create_int64_list_array` but values are offset so they won't +/// appear in a standard list array (useful for "not found" benchmarks). +fn create_int64_list_array_with_offset( + num_rows: usize, + array_size: usize, + offset: i64, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED + 1); + let values = (0..num_rows * array_size) + .map(|_| Some(rng.random_range(0..array_size as i64) + offset)) + .collect::(); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Int64, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} - b.iter(|| black_box(array_has_all(list_array.clone(), needle_array.clone()))) - }); +fn create_string_list_array( + num_rows: usize, + array_size: usize, + null_density: f64, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..num_rows * array_size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + let idx = rng.random_range(0..array_size); + Some(format!("value_{idx}")) + } + }) + .collect::(); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Utf8, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} - // array_has_any: empty needle - group.bench_function("array_has_any_empty_needle", |b| { - let array = (0..1000).map(|i| lit(i as i64)).collect::>(); - let list_array = make_array(array); - let needle_array = make_array(vec![]); +/// Like `create_string_list_array` but values use a different prefix so +/// they won't appear in a standard string list array. +fn create_string_list_array_with_prefix( + num_rows: usize, + array_size: usize, + prefix: &str, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED + 1); + let values = (0..num_rows * array_size) + .map(|_| { + let idx = rng.random_range(0..array_size); + Some(format!("{prefix}{idx}")) + }) + .collect::(); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Utf8, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} - b.iter(|| black_box(array_has_any(list_array.clone(), needle_array.clone()))) - }); +/// Create a `ScalarValue::List` containing a single list of `size` i64 elements, +/// with values starting at `offset`. +fn create_int64_scalar_list(size: usize, offset: i64) -> ScalarValue { + let values = (0..size as i64) + .map(|i| Some(i + offset)) + .collect::(); + let list = ListArray::try_new( + Arc::new(Field::new("item", DataType::Int64, true)), + OffsetBuffer::new(vec![0, size as i32].into()), + Arc::new(values), + None, + ) + .unwrap(); + ScalarValue::List(Arc::new(list)) +} - group.finish(); +/// Create a `ScalarValue::List` containing a single list of `size` string elements, +/// with values like "{prefix}0", "{prefix}1", etc. +fn create_string_scalar_list(size: usize, prefix: &str) -> ScalarValue { + let values = (0..size) + .map(|i| Some(format!("{prefix}{i}"))) + .collect::(); + let list = ListArray::try_new( + Arc::new(Field::new("item", DataType::Utf8, true)), + OffsetBuffer::new(vec![0, size as i32].into()), + Arc::new(values), + None, + ) + .unwrap(); + ScalarValue::List(Arc::new(list)) } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/functions-nested/benches/array_position.rs b/datafusion/functions-nested/benches/array_position.rs new file mode 100644 index 000000000000..08367648449d --- /dev/null +++ b/datafusion/functions-nested/benches/array_position.rs @@ -0,0 +1,237 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, Int64Array, ListArray}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::{DataType, Field}; +use criterion::{ + criterion_group, criterion_main, {BenchmarkId, Criterion}, +}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_functions_nested::position::ArrayPosition; +use rand::Rng; +use rand::SeedableRng; +use rand::rngs::StdRng; +use std::hint::black_box; +use std::sync::Arc; + +const NUM_ROWS: usize = 10000; +const SEED: u64 = 42; +const NULL_DENSITY: f64 = 0.1; +const SENTINEL_NEEDLE: i64 = -1; + +fn criterion_benchmark(c: &mut Criterion) { + for size in [10, 100, 500] { + bench_array_position(c, size); + } +} + +fn bench_array_position(c: &mut Criterion, array_size: usize) { + let mut group = c.benchmark_group("array_position_i64"); + let haystack_found_once = create_haystack_with_sentinel( + NUM_ROWS, + array_size, + NULL_DENSITY, + SENTINEL_NEEDLE, + 0, + ); + let haystack_found_many = create_haystack_with_sentinels( + NUM_ROWS, + array_size, + NULL_DENSITY, + SENTINEL_NEEDLE, + ); + let haystack_not_found = + create_haystack_without_sentinel(NUM_ROWS, array_size, NULL_DENSITY); + let num_rows = haystack_not_found.len(); + let arg_fields: Vec> = vec![ + Field::new("haystack", haystack_not_found.data_type().clone(), false).into(), + Field::new("needle", DataType::Int64, false).into(), + ]; + let return_field: Arc = Field::new("result", DataType::UInt64, true).into(); + let config_options = Arc::new(ConfigOptions::default()); + let needle = ScalarValue::Int64(Some(SENTINEL_NEEDLE)); + + // Benchmark: one match per row. + let args_found_once = vec![ + ColumnarValue::Array(haystack_found_once.clone()), + ColumnarValue::Scalar(needle.clone()), + ]; + group.bench_with_input( + BenchmarkId::new("found_once", array_size), + &array_size, + |b, _| { + let udf = ArrayPosition::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_found_once.clone(), + arg_fields: arg_fields.clone(), + number_rows: num_rows, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) + }, + ); + + // Benchmark: many matches per row. + let args_found_many = vec![ + ColumnarValue::Array(haystack_found_many.clone()), + ColumnarValue::Scalar(needle.clone()), + ]; + group.bench_with_input( + BenchmarkId::new("found_many", array_size), + &array_size, + |b, _| { + let udf = ArrayPosition::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_found_many.clone(), + arg_fields: arg_fields.clone(), + number_rows: num_rows, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) + }, + ); + + // Benchmark: needle is not found in any row. + let args_not_found = vec![ + ColumnarValue::Array(haystack_not_found.clone()), + ColumnarValue::Scalar(needle.clone()), + ]; + group.bench_with_input( + BenchmarkId::new("not_found", array_size), + &array_size, + |b, _| { + let udf = ArrayPosition::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_not_found.clone(), + arg_fields: arg_fields.clone(), + number_rows: num_rows, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) + }, + ); + + group.finish(); +} + +fn create_haystack_without_sentinel( + num_rows: usize, + array_size: usize, + null_density: f64, +) -> ArrayRef { + create_haystack_from_fn(num_rows, array_size, |_, _, rng| { + random_haystack_value(rng, array_size, null_density) + }) +} + +fn create_haystack_with_sentinel( + num_rows: usize, + array_size: usize, + null_density: f64, + sentinel: i64, + sentinel_index: usize, +) -> ArrayRef { + assert!(sentinel_index < array_size); + + create_haystack_from_fn(num_rows, array_size, |_, col, rng| { + if col == sentinel_index { + Some(sentinel) + } else { + random_haystack_value(rng, array_size, null_density) + } + }) +} + +fn create_haystack_with_sentinels( + num_rows: usize, + array_size: usize, + null_density: f64, + sentinel: i64, +) -> ArrayRef { + create_haystack_from_fn(num_rows, array_size, |_, col, rng| { + // Place the sentinel in half the positions to create many matches per row. + if col % 2 == 0 { + Some(sentinel) + } else { + random_haystack_value(rng, array_size, null_density) + } + }) +} + +fn create_haystack_from_fn( + num_rows: usize, + array_size: usize, + mut value_at: F, +) -> ArrayRef +where + F: FnMut(usize, usize, &mut StdRng) -> Option, +{ + let mut rng = StdRng::seed_from_u64(SEED); + let mut values = Vec::with_capacity(num_rows * array_size); + for row in 0..num_rows { + for col in 0..array_size { + values.push(value_at(row, col, &mut rng)); + } + } + let values = values.into_iter().collect::(); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Int64, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +fn random_haystack_value( + rng: &mut StdRng, + array_size: usize, + null_density: f64, +) -> Option { + if rng.random::() < null_density { + None + } else { + Some(rng.random_range(0..array_size as i64)) + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-nested/benches/array_remove.rs b/datafusion/functions-nested/benches/array_remove.rs new file mode 100644 index 000000000000..a494d322392a --- /dev/null +++ b/datafusion/functions-nested/benches/array_remove.rs @@ -0,0 +1,572 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayRef, BinaryArray, BooleanArray, Decimal128Array, FixedSizeBinaryArray, + Float64Array, Int64Array, ListArray, StringArray, +}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::{DataType, Field}; +use criterion::{ + criterion_group, criterion_main, {BenchmarkId, Criterion}, +}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_functions_nested::remove::ArrayRemove; +use rand::Rng; +use rand::SeedableRng; +use rand::rngs::StdRng; +use std::hint::black_box; +use std::sync::Arc; + +const NUM_ROWS: usize = 10000; +const ARRAY_SIZES: &[usize] = &[10, 100, 500]; +const SEED: u64 = 42; +const NULL_DENSITY: f64 = 0.1; + +fn criterion_benchmark(c: &mut Criterion) { + // Test array_remove with different data types and array sizes + // TODO: Add performance tests for nested datatypes + bench_array_remove_int64(c); + bench_array_remove_f64(c); + bench_array_remove_strings(c); + bench_array_remove_binary(c); + bench_array_remove_boolean(c); + bench_array_remove_decimal64(c); + bench_array_remove_fixed_size_binary(c); +} + +fn bench_array_remove_int64(c: &mut Criterion) { + let mut group = c.benchmark_group("array_remove_int64"); + + for &array_size in ARRAY_SIZES { + let list_array = create_int64_list_array(NUM_ROWS, array_size, NULL_DENSITY); + let element_to_remove = ScalarValue::Int64(Some(1)); + let args = create_args(list_array.clone(), element_to_remove.clone()); + + group.bench_with_input( + BenchmarkId::new("remove", array_size), + &array_size, + |b, _| { + let udf = ArrayRemove::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new("arr", list_array.data_type().clone(), false) + .into(), + Field::new("el", DataType::Int64, false).into(), + ], + number_rows: NUM_ROWS, + return_field: Field::new( + "result", + list_array.data_type().clone(), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + + group.finish(); +} + +fn bench_array_remove_f64(c: &mut Criterion) { + let mut group = c.benchmark_group("array_remove_f64"); + + for &array_size in ARRAY_SIZES { + let list_array = create_f64_list_array(NUM_ROWS, array_size, NULL_DENSITY); + let element_to_remove = ScalarValue::Float64(Some(1.0)); + let args = create_args(list_array.clone(), element_to_remove.clone()); + + group.bench_with_input( + BenchmarkId::new("remove", array_size), + &array_size, + |b, _| { + let udf = ArrayRemove::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new("arr", list_array.data_type().clone(), false) + .into(), + Field::new("el", DataType::Float64, false).into(), + ], + number_rows: NUM_ROWS, + return_field: Field::new( + "result", + list_array.data_type().clone(), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + + group.finish(); +} + +fn bench_array_remove_strings(c: &mut Criterion) { + let mut group = c.benchmark_group("array_remove_strings"); + + for &array_size in ARRAY_SIZES { + let list_array = create_string_list_array(NUM_ROWS, array_size, NULL_DENSITY); + let element_to_remove = ScalarValue::Utf8(Some("value_1".to_string())); + let args = create_args(list_array.clone(), element_to_remove.clone()); + + group.bench_with_input( + BenchmarkId::new("remove", array_size), + &array_size, + |b, _| { + let udf = ArrayRemove::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new("arr", list_array.data_type().clone(), false) + .into(), + Field::new("el", DataType::Utf8, false).into(), + ], + number_rows: NUM_ROWS, + return_field: Field::new( + "result", + list_array.data_type().clone(), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + + group.finish(); +} + +fn bench_array_remove_binary(c: &mut Criterion) { + let mut group = c.benchmark_group("array_remove_binary"); + + for &array_size in ARRAY_SIZES { + let list_array = create_binary_list_array(NUM_ROWS, array_size, NULL_DENSITY); + let element_to_remove = ScalarValue::Binary(Some(b"value_1".to_vec())); + let args = create_args(list_array.clone(), element_to_remove.clone()); + + group.bench_with_input( + BenchmarkId::new("remove", array_size), + &array_size, + |b, _| { + let udf = ArrayRemove::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new("arr", list_array.data_type().clone(), false) + .into(), + Field::new("el", DataType::Binary, false).into(), + ], + number_rows: NUM_ROWS, + return_field: Field::new( + "result", + list_array.data_type().clone(), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + + group.finish(); +} + +fn bench_array_remove_boolean(c: &mut Criterion) { + let mut group = c.benchmark_group("array_remove_boolean"); + + for &array_size in ARRAY_SIZES { + let list_array = create_boolean_list_array(NUM_ROWS, array_size, NULL_DENSITY); + let element_to_remove = ScalarValue::Boolean(Some(true)); + let args = create_args(list_array.clone(), element_to_remove.clone()); + + group.bench_with_input( + BenchmarkId::new("remove", array_size), + &array_size, + |b, _| { + let udf = ArrayRemove::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new("arr", list_array.data_type().clone(), false) + .into(), + Field::new("el", DataType::Boolean, false).into(), + ], + number_rows: NUM_ROWS, + return_field: Field::new( + "result", + list_array.data_type().clone(), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + + group.finish(); +} + +fn bench_array_remove_decimal64(c: &mut Criterion) { + let mut group = c.benchmark_group("array_remove_decimal64"); + + for &array_size in ARRAY_SIZES { + let list_array = create_decimal64_list_array(NUM_ROWS, array_size, NULL_DENSITY); + let element_to_remove = ScalarValue::Decimal128(Some(100_i128), 10, 2); + let args = create_args(list_array.clone(), element_to_remove.clone()); + + group.bench_with_input( + BenchmarkId::new("remove", array_size), + &array_size, + |b, _| { + let udf = ArrayRemove::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new("arr", list_array.data_type().clone(), false) + .into(), + Field::new("el", DataType::Decimal128(10, 2), false) + .into(), + ], + number_rows: NUM_ROWS, + return_field: Field::new( + "result", + list_array.data_type().clone(), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + + group.finish(); +} + +fn bench_array_remove_fixed_size_binary(c: &mut Criterion) { + let mut group = c.benchmark_group("array_remove_fixed_size_binary"); + + for &array_size in ARRAY_SIZES { + let list_array = + create_fixed_size_binary_list_array(NUM_ROWS, array_size, NULL_DENSITY); + let element_to_remove = ScalarValue::FixedSizeBinary(16, Some(vec![1u8; 16])); + let args = create_args(list_array.clone(), element_to_remove.clone()); + + group.bench_with_input( + BenchmarkId::new("remove", array_size), + &array_size, + |b, _| { + let udf = ArrayRemove::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new("arr", list_array.data_type().clone(), false) + .into(), + Field::new("el", DataType::FixedSizeBinary(16), false) + .into(), + ], + number_rows: NUM_ROWS, + return_field: Field::new( + "result", + list_array.data_type().clone(), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + + group.finish(); +} + +fn create_args(list_array: ArrayRef, element: ScalarValue) -> Vec { + vec![ + ColumnarValue::Array(list_array), + ColumnarValue::Scalar(element), + ] +} + +fn create_int64_list_array( + num_rows: usize, + array_size: usize, + null_density: f64, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..num_rows * array_size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(rng.random_range(0..array_size as i64)) + } + }) + .collect::(); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Int64, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +fn create_f64_list_array( + num_rows: usize, + array_size: usize, + null_density: f64, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..num_rows * array_size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(rng.random_range(0.0..array_size as f64)) + } + }) + .collect::(); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Float64, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +fn create_string_list_array( + num_rows: usize, + array_size: usize, + null_density: f64, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..num_rows * array_size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + let idx = rng.random_range(0..array_size); + Some(format!("value_{idx}")) + } + }) + .collect::(); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Utf8, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +fn create_binary_list_array( + num_rows: usize, + array_size: usize, + null_density: f64, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..num_rows * array_size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + let idx = rng.random_range(0..array_size); + Some(format!("value_{idx}").into_bytes()) + } + }) + .collect::(); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Binary, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +fn create_boolean_list_array( + num_rows: usize, + array_size: usize, + null_density: f64, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..num_rows * array_size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(rng.random::()) + } + }) + .collect::(); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Boolean, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +fn create_decimal64_list_array( + num_rows: usize, + array_size: usize, + null_density: f64, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..num_rows * array_size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(rng.random_range(0..array_size) as i128 * 100) + } + }) + .collect::() + .with_precision_and_scale(10, 2) + .unwrap(); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Decimal128(10, 2), true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +fn create_fixed_size_binary_list_array( + num_rows: usize, + array_size: usize, + null_density: f64, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let mut buffer = Vec::with_capacity(num_rows * array_size * 16); + let mut null_buffer = Vec::with_capacity(num_rows * array_size); + for _ in 0..num_rows * array_size { + if rng.random::() < null_density { + null_buffer.push(false); + buffer.extend_from_slice(&[0u8; 16]); + } else { + null_buffer.push(true); + let mut bytes = [0u8; 16]; + rng.fill(&mut bytes); + buffer.extend_from_slice(&bytes); + } + } + let nulls = arrow::buffer::NullBuffer::from_iter(null_buffer.iter().copied()); + let values = FixedSizeBinaryArray::new(16, buffer.into(), Some(nulls)); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::FixedSizeBinary(16), true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-nested/benches/array_repeat.rs b/datafusion/functions-nested/benches/array_repeat.rs new file mode 100644 index 000000000000..0ce8db00ceb8 --- /dev/null +++ b/datafusion/functions-nested/benches/array_repeat.rs @@ -0,0 +1,476 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, BooleanArray, Float64Array, Int64Array, ListArray}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::{DataType, Field}; +use criterion::{ + criterion_group, criterion_main, {BenchmarkId, Criterion}, +}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_functions_nested::repeat::ArrayRepeat; +use rand::Rng; +use rand::SeedableRng; +use rand::rngs::StdRng; +use std::hint::black_box; +use std::sync::Arc; + +const NUM_ROWS: &[usize] = &[100, 1000, 10000]; +const REPEAT_COUNTS: &[u64] = &[5, 50]; +const SEED: u64 = 42; +const NULL_DENSITY: f64 = 0.1; + +fn criterion_benchmark(c: &mut Criterion) { + // Test array_repeat with different element types + bench_array_repeat_int64(c); + bench_array_repeat_string(c); + bench_array_repeat_float64(c); + bench_array_repeat_boolean(c); + + // Test array_repeat with list element (nested arrays) + bench_array_repeat_nested_int64_list(c); + bench_array_repeat_nested_string_list(c); +} + +fn bench_array_repeat_int64(c: &mut Criterion) { + let mut group = c.benchmark_group("array_repeat_int64"); + + for &num_rows in NUM_ROWS { + let element_array = create_int64_array(num_rows, NULL_DENSITY); + + for &repeat_count in REPEAT_COUNTS { + let args = vec![ + ColumnarValue::Array(element_array.clone()), + ColumnarValue::Scalar(ScalarValue::from(repeat_count)), + ]; + + group.bench_with_input( + BenchmarkId::new(format!("repeat_{repeat_count}_count"), num_rows), + &num_rows, + |b, _| { + let udf = ArrayRepeat::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new("element", DataType::Int64, false).into(), + Field::new("count", DataType::UInt64, false).into(), + ], + number_rows: num_rows, + return_field: Field::new( + "result", + DataType::List(Arc::new(Field::new_list_field( + DataType::Int64, + true, + ))), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + } + + group.finish(); +} + +fn bench_array_repeat_string(c: &mut Criterion) { + let mut group = c.benchmark_group("array_repeat_string"); + + for &num_rows in NUM_ROWS { + let element_array = create_string_array(num_rows, NULL_DENSITY); + + for &repeat_count in REPEAT_COUNTS { + let args = vec![ + ColumnarValue::Array(element_array.clone()), + ColumnarValue::Scalar(ScalarValue::from(repeat_count)), + ]; + + group.bench_with_input( + BenchmarkId::new(format!("repeat_{repeat_count}_count"), num_rows), + &num_rows, + |b, _| { + let udf = ArrayRepeat::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new("element", DataType::Utf8, false).into(), + Field::new("count", DataType::UInt64, false).into(), + ], + number_rows: num_rows, + return_field: Field::new( + "result", + DataType::List(Arc::new(Field::new_list_field( + DataType::Utf8, + true, + ))), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + } + + group.finish(); +} + +fn bench_array_repeat_nested_int64_list(c: &mut Criterion) { + let mut group = c.benchmark_group("array_repeat_nested_int64"); + + for &num_rows in NUM_ROWS { + let list_array = create_int64_list_array(num_rows, 5, NULL_DENSITY); + + for &repeat_count in REPEAT_COUNTS { + let args = vec![ + ColumnarValue::Array(list_array.clone()), + ColumnarValue::Scalar(ScalarValue::from(repeat_count)), + ]; + + group.bench_with_input( + BenchmarkId::new(format!("repeat_{repeat_count}_count"), num_rows), + &num_rows, + |b, _| { + let udf = ArrayRepeat::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new( + "element", + list_array.data_type().clone(), + false, + ) + .into(), + Field::new("count", DataType::UInt64, false).into(), + ], + number_rows: num_rows, + return_field: Field::new( + "result", + DataType::List(Arc::new(Field::new_list_field( + list_array.data_type().clone(), + true, + ))), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + } + + group.finish(); +} + +fn bench_array_repeat_float64(c: &mut Criterion) { + let mut group = c.benchmark_group("array_repeat_float64"); + + for &num_rows in NUM_ROWS { + let element_array = create_float64_array(num_rows, NULL_DENSITY); + + for &repeat_count in REPEAT_COUNTS { + let args = vec![ + ColumnarValue::Array(element_array.clone()), + ColumnarValue::Scalar(ScalarValue::from(repeat_count)), + ]; + + group.bench_with_input( + BenchmarkId::new(format!("repeat_{repeat_count}_count"), num_rows), + &num_rows, + |b, _| { + let udf = ArrayRepeat::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new("element", DataType::Float64, false) + .into(), + Field::new("count", DataType::UInt64, false).into(), + ], + number_rows: num_rows, + return_field: Field::new( + "result", + DataType::List(Arc::new(Field::new_list_field( + DataType::Float64, + true, + ))), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + } + + group.finish(); +} + +fn bench_array_repeat_boolean(c: &mut Criterion) { + let mut group = c.benchmark_group("array_repeat_boolean"); + + for &num_rows in NUM_ROWS { + let element_array = create_boolean_array(num_rows, NULL_DENSITY); + + for &repeat_count in REPEAT_COUNTS { + let args = vec![ + ColumnarValue::Array(element_array.clone()), + ColumnarValue::Scalar(ScalarValue::from(repeat_count)), + ]; + + group.bench_with_input( + BenchmarkId::new(format!("repeat_{repeat_count}_count"), num_rows), + &num_rows, + |b, _| { + let udf = ArrayRepeat::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new("element", DataType::Boolean, false) + .into(), + Field::new("count", DataType::UInt64, false).into(), + ], + number_rows: num_rows, + return_field: Field::new( + "result", + DataType::List(Arc::new(Field::new_list_field( + DataType::Boolean, + true, + ))), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + } + + group.finish(); +} + +fn bench_array_repeat_nested_string_list(c: &mut Criterion) { + let mut group = c.benchmark_group("array_repeat_nested_string"); + + for &num_rows in NUM_ROWS { + let list_array = create_string_list_array(num_rows, 5, NULL_DENSITY); + + for &repeat_count in REPEAT_COUNTS { + let args = vec![ + ColumnarValue::Array(list_array.clone()), + ColumnarValue::Scalar(ScalarValue::from(repeat_count)), + ]; + + group.bench_with_input( + BenchmarkId::new(format!("repeat_{repeat_count}_count"), num_rows), + &num_rows, + |b, _| { + let udf = ArrayRepeat::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new( + "element", + list_array.data_type().clone(), + false, + ) + .into(), + Field::new("count", DataType::UInt64, false).into(), + ], + number_rows: num_rows, + return_field: Field::new( + "result", + DataType::List(Arc::new(Field::new_list_field( + list_array.data_type().clone(), + true, + ))), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + } + + group.finish(); +} + +fn create_int64_array(num_rows: usize, null_density: f64) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..num_rows) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(rng.random_range(0..1000)) + } + }) + .collect::(); + + Arc::new(values) +} + +fn create_string_array(num_rows: usize, null_density: f64) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + use arrow::array::StringArray; + + let values = (0..num_rows) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(format!("value_{}", rng.random_range(0..100))) + } + }) + .collect::(); + + Arc::new(values) +} + +fn create_int64_list_array( + num_rows: usize, + array_size: usize, + null_density: f64, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..num_rows * array_size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(rng.random_range(0..1000)) + } + }) + .collect::(); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Int64, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +fn create_float64_array(num_rows: usize, null_density: f64) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..num_rows) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(rng.random_range(0.0..1000.0)) + } + }) + .collect::(); + + Arc::new(values) +} + +fn create_boolean_array(num_rows: usize, null_density: f64) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..num_rows) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(rng.random()) + } + }) + .collect::(); + + Arc::new(values) +} + +fn create_string_list_array( + num_rows: usize, + array_size: usize, + null_density: f64, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + use arrow::array::StringArray; + + let values = (0..num_rows * array_size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(format!("value_{}", rng.random_range(0..100))) + } + }) + .collect::(); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Utf8, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-nested/benches/array_reverse.rs b/datafusion/functions-nested/benches/array_reverse.rs index 92a65128fe6b..0c3729618831 100644 --- a/datafusion/functions-nested/benches/array_reverse.rs +++ b/datafusion/functions-nested/benches/array_reverse.rs @@ -15,18 +15,14 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -extern crate arrow; - use std::{hint::black_box, sync::Arc}; -use crate::criterion::Criterion; use arrow::{ array::{ArrayRef, FixedSizeListArray, Int32Array, ListArray, ListViewArray}, buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}, datatypes::{DataType, Field}, }; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_functions_nested::reverse::array_reverse_inner; fn array_reverse(array: &ArrayRef) -> ArrayRef { diff --git a/datafusion/functions-nested/benches/array_set_ops.rs b/datafusion/functions-nested/benches/array_set_ops.rs new file mode 100644 index 000000000000..e3146921d7fe --- /dev/null +++ b/datafusion/functions-nested/benches/array_set_ops.rs @@ -0,0 +1,259 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, Int64Array, ListArray}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::{DataType, Field}; +use criterion::{ + criterion_group, criterion_main, {BenchmarkId, Criterion}, +}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_functions_nested::set_ops::{ArrayDistinct, ArrayIntersect, ArrayUnion}; +use rand::SeedableRng; +use rand::prelude::SliceRandom; +use rand::rngs::StdRng; +use std::collections::HashSet; +use std::hint::black_box; +use std::sync::Arc; + +const NUM_ROWS: usize = 1000; +const ARRAY_SIZES: &[usize] = &[10, 50, 100]; +const SEED: u64 = 42; + +fn criterion_benchmark(c: &mut Criterion) { + bench_array_union(c); + bench_array_intersect(c); + bench_array_distinct(c); +} + +fn invoke_udf(udf: &impl ScalarUDFImpl, array1: &ArrayRef, array2: &ArrayRef) { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(array1.clone()), + ColumnarValue::Array(array2.clone()), + ], + arg_fields: vec![ + Field::new("arr1", array1.data_type().clone(), false).into(), + Field::new("arr2", array2.data_type().clone(), false).into(), + ], + number_rows: NUM_ROWS, + return_field: Field::new("result", array1.data_type().clone(), false).into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ); +} + +fn bench_array_union(c: &mut Criterion) { + let mut group = c.benchmark_group("array_union"); + let udf = ArrayUnion::new(); + + for (overlap_label, overlap_ratio) in &[("high_overlap", 0.8), ("low_overlap", 0.2)] { + for &array_size in ARRAY_SIZES { + let (array1, array2) = + create_arrays_with_overlap(NUM_ROWS, array_size, *overlap_ratio); + group.bench_with_input( + BenchmarkId::new(*overlap_label, array_size), + &array_size, + |b, _| b.iter(|| invoke_udf(&udf, &array1, &array2)), + ); + } + } + + group.finish(); +} + +fn bench_array_intersect(c: &mut Criterion) { + let mut group = c.benchmark_group("array_intersect"); + let udf = ArrayIntersect::new(); + + for (overlap_label, overlap_ratio) in &[("high_overlap", 0.8), ("low_overlap", 0.2)] { + for &array_size in ARRAY_SIZES { + let (array1, array2) = + create_arrays_with_overlap(NUM_ROWS, array_size, *overlap_ratio); + group.bench_with_input( + BenchmarkId::new(*overlap_label, array_size), + &array_size, + |b, _| b.iter(|| invoke_udf(&udf, &array1, &array2)), + ); + } + } + + group.finish(); +} + +fn bench_array_distinct(c: &mut Criterion) { + let mut group = c.benchmark_group("array_distinct"); + let udf = ArrayDistinct::new(); + + for (duplicate_label, duplicate_ratio) in + &[("high_duplicate", 0.8), ("low_duplicate", 0.2)] + { + for &array_size in ARRAY_SIZES { + let array = + create_array_with_duplicates(NUM_ROWS, array_size, *duplicate_ratio); + group.bench_with_input( + BenchmarkId::new(*duplicate_label, array_size), + &array_size, + |b, _| { + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Array(array.clone())], + arg_fields: vec![ + Field::new("arr", array.data_type().clone(), false) + .into(), + ], + number_rows: NUM_ROWS, + return_field: Field::new( + "result", + array.data_type().clone(), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + } + + group.finish(); +} + +fn create_arrays_with_overlap( + num_rows: usize, + array_size: usize, + overlap_ratio: f64, +) -> (ArrayRef, ArrayRef) { + assert!((0.0..=1.0).contains(&overlap_ratio)); + let overlap_count = ((array_size as f64) * overlap_ratio).round() as usize; + + let mut rng = StdRng::seed_from_u64(SEED); + + let mut values1 = Vec::with_capacity(num_rows * array_size); + let mut values2 = Vec::with_capacity(num_rows * array_size); + + for row in 0..num_rows { + let base = (row as i64) * (array_size as i64) * 2; + + for i in 0..array_size { + values1.push(base + i as i64); + } + + let mut positions: Vec = (0..array_size).collect(); + positions.shuffle(&mut rng); + + let overlap_positions: HashSet<_> = + positions[..overlap_count].iter().copied().collect(); + + for i in 0..array_size { + if overlap_positions.contains(&i) { + values2.push(base + i as i64); + } else { + values2.push(base + array_size as i64 + i as i64); + } + } + } + + let values1 = Int64Array::from(values1); + let values2 = Int64Array::from(values2); + + let field = Arc::new(Field::new("item", DataType::Int64, true)); + + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + let array1 = Arc::new( + ListArray::try_new( + field.clone(), + OffsetBuffer::new(offsets.clone().into()), + Arc::new(values1), + None, + ) + .unwrap(), + ); + + let array2 = Arc::new( + ListArray::try_new( + field, + OffsetBuffer::new(offsets.into()), + Arc::new(values2), + None, + ) + .unwrap(), + ); + + (array1, array2) +} + +fn create_array_with_duplicates( + num_rows: usize, + array_size: usize, + duplicate_ratio: f64, +) -> ArrayRef { + assert!((0.0..=1.0).contains(&duplicate_ratio)); + let unique_count = ((array_size as f64) * (1.0 - duplicate_ratio)).round() as usize; + let duplicate_count = array_size - unique_count; + + let mut rng = StdRng::seed_from_u64(SEED); + let mut values = Vec::with_capacity(num_rows * array_size); + + for row in 0..num_rows { + let base = (row as i64) * (array_size as i64) * 2; + + // Add unique values first + for i in 0..unique_count { + values.push(base + i as i64); + } + + // Fill the rest with duplicates randomly picked from the unique values + let mut unique_indices: Vec = + (0..unique_count).map(|i| base + i as i64).collect(); + unique_indices.shuffle(&mut rng); + + for i in 0..duplicate_count { + values.push(unique_indices[i % unique_count]); + } + } + + let values = Int64Array::from(values); + let field = Arc::new(Field::new("item", DataType::Int64, true)); + + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + field, + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-nested/benches/array_slice.rs b/datafusion/functions-nested/benches/array_slice.rs index 858e43899619..b95fe47575e5 100644 --- a/datafusion/functions-nested/benches/array_slice.rs +++ b/datafusion/functions-nested/benches/array_slice.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::{ Int64Array, ListArray, ListViewArray, NullBufferBuilder, PrimitiveArray, }; diff --git a/datafusion/functions-nested/benches/array_to_string.rs b/datafusion/functions-nested/benches/array_to_string.rs new file mode 100644 index 000000000000..286ed4eeb000 --- /dev/null +++ b/datafusion/functions-nested/benches/array_to_string.rs @@ -0,0 +1,188 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, Float64Array, Int64Array, ListArray, StringArray}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::{DataType, Field}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_functions_nested::string::ArrayToString; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +const NUM_ROWS: usize = 1000; +const ARRAY_SIZES: &[usize] = &[5, 20, 100]; +const NESTED_ARRAY_SIZE: usize = 3; +const SEED: u64 = 42; +const NULL_DENSITY: f64 = 0.1; + +fn criterion_benchmark(c: &mut Criterion) { + bench_array_to_string(c, "array_to_string_int64", create_int64_list_array); + bench_array_to_string(c, "array_to_string_float64", create_float64_list_array); + bench_array_to_string(c, "array_to_string_string", create_string_list_array); + bench_array_to_string( + c, + "array_to_string_nested_int64", + create_nested_int64_list_array, + ); +} + +fn bench_array_to_string( + c: &mut Criterion, + group_name: &str, + make_array: impl Fn(usize) -> ArrayRef, +) { + let mut group = c.benchmark_group(group_name); + + for &array_size in ARRAY_SIZES { + let list_array = make_array(array_size); + let args = vec![ + ColumnarValue::Array(list_array.clone()), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string()))), + ]; + let arg_fields = vec![ + Field::new("array", list_array.data_type().clone(), true).into(), + Field::new("delimiter", DataType::Utf8, false).into(), + ]; + + group.bench_with_input( + BenchmarkId::from_parameter(array_size), + &array_size, + |b, _| { + let udf = ArrayToString::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: Field::new("result", DataType::Utf8, true) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + + group.finish(); +} + +fn create_int64_list_array(array_size: usize) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..NUM_ROWS * array_size) + .map(|_| { + if rng.random::() < NULL_DENSITY { + None + } else { + Some(rng.random_range(0..1000)) + } + }) + .collect::(); + let offsets = (0..=NUM_ROWS) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Int64, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +fn create_nested_int64_list_array(array_size: usize) -> ArrayRef { + let inner = create_int64_list_array(array_size); + let inner_rows = NUM_ROWS; + let outer_rows = inner_rows / NESTED_ARRAY_SIZE; + let offsets = (0..=outer_rows) + .map(|i| (i * NESTED_ARRAY_SIZE) as i32) + .collect::>(); + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", inner.data_type().clone(), true)), + OffsetBuffer::new(offsets.into()), + inner, + None, + ) + .unwrap(), + ) +} + +fn create_float64_list_array(array_size: usize) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..NUM_ROWS * array_size) + .map(|_| { + if rng.random::() < NULL_DENSITY { + None + } else { + Some(rng.random_range(-1000.0..1000.0)) + } + }) + .collect::(); + let offsets = (0..=NUM_ROWS) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Float64, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +fn create_string_list_array(array_size: usize) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..NUM_ROWS * array_size) + .map(|_| { + if rng.random::() < NULL_DENSITY { + None + } else { + Some(format!("value_{}", rng.random_range(0..100))) + } + }) + .collect::(); + let offsets = (0..=NUM_ROWS) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Utf8, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-nested/benches/map.rs b/datafusion/functions-nested/benches/map.rs index 75b4045a193d..e50c4659b17c 100644 --- a/datafusion/functions-nested/benches/map.rs +++ b/datafusion/functions-nested/benches/map.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::{Int32Array, ListArray, StringArray}; use arrow::buffer::{OffsetBuffer, ScalarBuffer}; use arrow::datatypes::{DataType, Field}; diff --git a/datafusion/functions-nested/src/array_has.rs b/datafusion/functions-nested/src/array_has.rs index 54b94abafb99..76cf786c954d 100644 --- a/datafusion/functions-nested/src/array_has.rs +++ b/datafusion/functions-nested/src/array_has.rs @@ -17,7 +17,10 @@ //! [`ScalarUDFImpl`] definitions for array_has, array_has_all and array_has_any functions. -use arrow::array::{Array, ArrayRef, BooleanArray, Datum, Scalar}; +use arrow::array::{ + Array, ArrayRef, AsArray, BooleanArray, BooleanBufferBuilder, Datum, Scalar, + StringArrayType, +}; use arrow::buffer::BooleanBuffer; use arrow::datatypes::DataType; use arrow::row::{RowConverter, Rows, SortField}; @@ -37,6 +40,7 @@ use itertools::Itertools; use crate::make_array::make_array_udf; use crate::utils::make_scalar_function; +use hashbrown::HashSet; use std::any::Any; use std::sync::Arc; @@ -55,7 +59,7 @@ make_udf_expr_and_func!(ArrayHasAll, ); make_udf_expr_and_func!(ArrayHasAny, array_has_any, - haystack_array needle_array, // arg names + first_array second_array, // arg names "returns true if at least one element of the second array appears in the first array; otherwise, it returns false.", // doc array_has_any_udf // internal function name ); @@ -125,7 +129,7 @@ impl ScalarUDFImpl for ArrayHas { fn simplify( &self, mut args: Vec, - _info: &dyn datafusion_expr::simplify::SimplifyInfo, + _info: &datafusion_expr::simplify::SimplifyContext, ) -> Result { let [haystack, needle] = take_function_args(self.name(), &mut args)?; @@ -262,7 +266,7 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayWrapper<'a> { DataType::FixedSizeList(_, _) => Ok(ArrayWrapper::FixedSizeList( as_fixed_size_list_array(value)?, )), - _ => exec_err!("array_has does not support type '{:?}'.", value.data_type()), + _ => exec_err!("array_has does not support type '{}'.", value.data_type()), } } } @@ -303,10 +307,8 @@ impl<'a> ArrayWrapper<'a> { fn offsets(&self) -> Box + 'a> { match self { ArrayWrapper::FixedSizeList(arr) => { - let offsets = (0..=arr.len()) - .step_by(arr.value_length() as usize) - .collect::>(); - Box::new(offsets.into_iter()) + let value_length = arr.value_length() as usize; + Box::new((0..=arr.len()).map(move |i| i * value_length)) } ArrayWrapper::List(arr) => { Box::new(arr.offsets().iter().map(|o| (*o) as usize)) @@ -316,6 +318,14 @@ impl<'a> ArrayWrapper<'a> { } } } + + fn nulls(&self) -> Option<&arrow::buffer::NullBuffer> { + match self { + ArrayWrapper::FixedSizeList(arr) => arr.nulls(), + ArrayWrapper::List(arr) => arr.nulls(), + ArrayWrapper::LargeList(arr) => arr.nulls(), + } + } } fn array_has_dispatch_for_array<'a>( @@ -342,8 +352,6 @@ fn array_has_dispatch_for_scalar( haystack: ArrayWrapper<'_>, needle: &dyn Datum, ) -> Result { - let values = haystack.values(); - let is_nested = values.data_type().is_nested(); // If first argument is empty list (second argument is non-null), return false // i.e. array_has([], non-null element) -> false if haystack.len() == 0 { @@ -352,37 +360,62 @@ fn array_has_dispatch_for_scalar( None, ))); } - let eq_array = compare_with_eq(values, needle, is_nested)?; - let mut final_contained = vec![None; haystack.len()]; - // Check validity buffer to distinguish between null and empty arrays + // For sliced ListArrays, values() returns the full underlying array but + // only elements between the first and last offset are visible. + let offsets: Vec = haystack.offsets().collect(); + let first_offset = offsets[0]; + let visible_values = haystack + .values() + .slice(first_offset, offsets[offsets.len() - 1] - first_offset); + + let is_nested = visible_values.data_type().is_nested(); + let eq_array = compare_with_eq(&visible_values, needle, is_nested)?; + + // When a haystack element is null, `eq()` returns null (not false). + // In Arrow, a null BooleanArray entry has validity=0 but an + // undefined value bit that may happen to be 1. Since set_indices() + // operates on the raw value buffer and ignores validity, we AND the + // values with the validity bitmap to clear any undefined bits at + // null positions. This ensures set_indices() only yields positions + // where the comparison genuinely returned true. + let eq_bits = match eq_array.nulls() { + Some(nulls) => eq_array.values() & nulls.inner(), + None => eq_array.values().clone(), + }; + let validity = match &haystack { ArrayWrapper::FixedSizeList(arr) => arr.nulls(), ArrayWrapper::List(arr) => arr.nulls(), ArrayWrapper::LargeList(arr) => arr.nulls(), }; + let mut matches = eq_bits.set_indices().peekable(); + let mut result = BooleanBufferBuilder::new(haystack.len()); + result.append_n(haystack.len(), false); - for (i, (start, end)) in haystack.offsets().tuple_windows().enumerate() { - let length = end - start; + // Match positions are relative to visible_values (0-based), so + // subtract first_offset from each offset when comparing. + for (i, window) in offsets.windows(2).enumerate() { + let end = window[1] - first_offset; - // Check if the array at this position is null - if let Some(validity_buffer) = validity - && !validity_buffer.is_valid(i) - { - final_contained[i] = None; // null array -> null result - continue; + let has_match = matches.peek().is_some_and(|&p| p < end); + + // Advance past all match positions in this row's range. + while matches.peek().is_some_and(|&p| p < end) { + matches.next(); } - // For non-null arrays: length is 0 for empty arrays - if length == 0 { - final_contained[i] = Some(false); // empty array -> false - } else { - let sliced_array = eq_array.slice(start, length); - final_contained[i] = Some(sliced_array.true_count() > 0); + if has_match && validity.is_none_or(|v| v.is_valid(i)) { + result.set_bit(i, true); } } - Ok(Arc::new(BooleanArray::from(final_contained))) + // A null haystack row always produces a null output, so we can + // reuse the haystack's null buffer directly. + Ok(Arc::new(BooleanArray::new( + result.finish(), + validity.cloned(), + ))) } fn array_has_all_inner(args: &[ArrayRef]) -> Result { @@ -476,6 +509,218 @@ fn array_has_any_inner(args: &[ArrayRef]) -> Result { array_has_all_and_any_inner(args, ComparisonType::Any) } +/// Fast path for `array_has_any` when exactly one argument is a scalar. +fn array_has_any_with_scalar( + columnar_arg: &ColumnarValue, + scalar_arg: &ScalarValue, +) -> Result { + if scalar_arg.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))); + } + + // Convert the scalar to a 1-element ListArray, then extract the inner values + let scalar_array = scalar_arg.to_array_of_size(1)?; + let scalar_list: ArrayWrapper = scalar_array.as_ref().try_into()?; + let offsets: Vec = scalar_list.offsets().collect(); + let scalar_values = scalar_list + .values() + .slice(offsets[0], offsets[1] - offsets[0]); + + // If scalar list is empty, result is always false + if scalar_values.is_empty() { + return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(false)))); + } + + match scalar_values.data_type() { + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => { + array_has_any_with_scalar_string(columnar_arg, &scalar_values) + } + _ => array_has_any_with_scalar_general(columnar_arg, &scalar_values), + } +} + +/// When the scalar argument has more elements than this, the scalar fast path +/// builds a HashSet for O(1) lookups. At or below this threshold, it falls +/// back to a linear scan, since hashing every columnar element is more +/// expensive than a linear scan over a short array. +const SCALAR_SMALL_THRESHOLD: usize = 8; + +/// String-specialized scalar fast path for `array_has_any`. +fn array_has_any_with_scalar_string( + columnar_arg: &ColumnarValue, + scalar_values: &ArrayRef, +) -> Result { + let (col_arr, is_scalar_output) = match columnar_arg { + ColumnarValue::Array(arr) => (Arc::clone(arr), false), + ColumnarValue::Scalar(s) => (s.to_array_of_size(1)?, true), + }; + + let col_list: ArrayWrapper = col_arr.as_ref().try_into()?; + let col_values = col_list.values(); + let col_offsets: Vec = col_list.offsets().collect(); + let col_nulls = col_list.nulls(); + + let scalar_lookup = ScalarStringLookup::new(scalar_values); + let has_null_scalar = scalar_values.null_count() > 0; + + let result = match col_values.data_type() { + DataType::Utf8 => array_has_any_string_inner( + col_values.as_string::(), + &col_offsets, + col_nulls, + has_null_scalar, + &scalar_lookup, + ), + DataType::LargeUtf8 => array_has_any_string_inner( + col_values.as_string::(), + &col_offsets, + col_nulls, + has_null_scalar, + &scalar_lookup, + ), + DataType::Utf8View => array_has_any_string_inner( + col_values.as_string_view(), + &col_offsets, + col_nulls, + has_null_scalar, + &scalar_lookup, + ), + _ => unreachable!("array_has_any_with_scalar_string called with non-string type"), + }; + + if is_scalar_output { + Ok(ColumnarValue::Scalar(ScalarValue::try_from_array( + &result, 0, + )?)) + } else { + Ok(ColumnarValue::Array(result)) + } +} + +/// Pre-computed lookup structure for the scalar string fastpath. +enum ScalarStringLookup<'a> { + /// Large scalar: HashSet for O(1) lookups. + Set(HashSet<&'a str>), + /// Small scalar: Vec for linear scan. + List(Vec>), +} + +impl<'a> ScalarStringLookup<'a> { + fn new(scalar_values: &'a ArrayRef) -> Self { + let strings = string_array_to_vec(scalar_values.as_ref()); + if strings.len() > SCALAR_SMALL_THRESHOLD { + ScalarStringLookup::Set(strings.into_iter().flatten().collect()) + } else { + ScalarStringLookup::List(strings) + } + } + + fn contains(&self, value: &str) -> bool { + match self { + ScalarStringLookup::Set(set) => set.contains(value), + ScalarStringLookup::List(list) => list.contains(&Some(value)), + } + } +} + +/// Inner implementation of the string scalar fast path, generic over string +/// array type to allow direct element access by index. +fn array_has_any_string_inner<'a, C: StringArrayType<'a> + Copy>( + col_strings: C, + col_offsets: &[usize], + col_nulls: Option<&arrow::buffer::NullBuffer>, + has_null_scalar: bool, + scalar_lookup: &ScalarStringLookup<'_>, +) -> ArrayRef { + let num_rows = col_offsets.len() - 1; + let mut builder = BooleanArray::builder(num_rows); + + for i in 0..num_rows { + if col_nulls.is_some_and(|v| v.is_null(i)) { + builder.append_null(); + continue; + } + let start = col_offsets[i]; + let end = col_offsets[i + 1]; + let found = (start..end).any(|j| { + if col_strings.is_null(j) { + has_null_scalar + } else { + scalar_lookup.contains(col_strings.value(j)) + } + }); + builder.append_value(found); + } + + Arc::new(builder.finish()) +} + +/// General scalar fast path for `array_has_any`, using RowConverter for +/// type-erased comparison. +fn array_has_any_with_scalar_general( + columnar_arg: &ColumnarValue, + scalar_values: &ArrayRef, +) -> Result { + let converter = + RowConverter::new(vec![SortField::new(scalar_values.data_type().clone())])?; + let scalar_rows = converter.convert_columns(&[Arc::clone(scalar_values)])?; + + let (col_arr, is_scalar_output) = match columnar_arg { + ColumnarValue::Array(arr) => (Arc::clone(arr), false), + ColumnarValue::Scalar(s) => (s.to_array_of_size(1)?, true), + }; + + let col_list: ArrayWrapper = col_arr.as_ref().try_into()?; + let col_rows = converter.convert_columns(&[Arc::clone(col_list.values())])?; + let col_offsets: Vec = col_list.offsets().collect(); + let col_nulls = col_list.nulls(); + + let mut builder = BooleanArray::builder(col_list.len()); + let num_scalar = scalar_rows.num_rows(); + + if num_scalar > SCALAR_SMALL_THRESHOLD { + // Large scalar: build HashSet for O(1) lookups + let scalar_set: HashSet> = (0..num_scalar) + .map(|i| Box::from(scalar_rows.row(i).as_ref())) + .collect(); + + for i in 0..col_list.len() { + if col_nulls.is_some_and(|v| v.is_null(i)) { + builder.append_null(); + continue; + } + let start = col_offsets[i]; + let end = col_offsets[i + 1]; + let found = + (start..end).any(|j| scalar_set.contains(col_rows.row(j).as_ref())); + builder.append_value(found); + } + } else { + // Small scalar: linear scan avoids HashSet hashing overhead + for i in 0..col_list.len() { + if col_nulls.is_some_and(|v| v.is_null(i)) { + builder.append_null(); + continue; + } + let start = col_offsets[i]; + let end = col_offsets[i + 1]; + let found = (start..end) + .any(|j| (0..num_scalar).any(|k| col_rows.row(j) == scalar_rows.row(k))); + builder.append_value(found); + } + } + + let result: ArrayRef = Arc::new(builder.finish()); + + if is_scalar_output { + Ok(ColumnarValue::Scalar(ScalarValue::try_from_array( + &result, 0, + )?)) + } else { + Ok(ColumnarValue::Array(result)) + } +} + #[user_doc( doc_section(label = "Array Functions"), description = "Returns true if all elements of sub-array exist in array.", @@ -552,8 +797,8 @@ impl ScalarUDFImpl for ArrayHasAll { #[user_doc( doc_section(label = "Array Functions"), - description = "Returns true if any elements exist in both arrays.", - syntax_example = "array_has_any(array, sub-array)", + description = "Returns true if the arrays have any elements in common.", + syntax_example = "array_has_any(array1, array2)", sql_example = r#"```sql > select array_has_any([1, 2, 3], [3, 4]); +------------------------------------------+ @@ -563,11 +808,11 @@ impl ScalarUDFImpl for ArrayHasAll { +------------------------------------------+ ```"#, argument( - name = "array", + name = "array1", description = "Array expression. Can be a constant, column, or function, and any combination of array operators." ), argument( - name = "sub-array", + name = "array2", description = "Array expression. Can be a constant, column, or function, and any combination of array operators." ) )] @@ -612,7 +857,15 @@ impl ScalarUDFImpl for ArrayHasAny { &self, args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_has_any_inner)(&args.args) + let [first_arg, second_arg] = take_function_args(self.name(), &args.args)?; + + // If either argument is scalar, use the fast path. + match (&first_arg, &second_arg) { + (cv, ColumnarValue::Scalar(scalar)) | (ColumnarValue::Scalar(scalar), cv) => { + array_has_any_with_scalar(cv, scalar) + } + _ => make_scalar_function(array_has_any_inner)(&args.args), + } } fn aliases(&self) -> &[String] { @@ -684,8 +937,8 @@ mod tests { utils::SingleRowListArrayBuilder, }; use datafusion_expr::{ - ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDFImpl, col, - execution_props::ExecutionProps, lit, simplify::ExprSimplifyResult, + ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDFImpl, col, lit, + simplify::ExprSimplifyResult, }; use crate::expr_fn::make_array; @@ -701,8 +954,7 @@ mod tests { .build_list_scalar()); let needle = col("c"); - let props = ExecutionProps::new(); - let context = datafusion_expr::simplify::SimplifyContext::new(&props); + let context = datafusion_expr::simplify::SimplifyContext::default(); let Ok(ExprSimplifyResult::Simplified(Expr::InList(in_list))) = ArrayHas::new().simplify(vec![haystack, needle.clone()], &context) @@ -725,8 +977,7 @@ mod tests { let haystack = make_array(vec![lit(1), lit(2), lit(3)]); let needle = col("c"); - let props = ExecutionProps::new(); - let context = datafusion_expr::simplify::SimplifyContext::new(&props); + let context = datafusion_expr::simplify::SimplifyContext::default(); let Ok(ExprSimplifyResult::Simplified(Expr::InList(in_list))) = ArrayHas::new().simplify(vec![haystack, needle.clone()], &context) @@ -749,8 +1000,7 @@ mod tests { let haystack = Expr::Literal(ScalarValue::Null, None); let needle = col("c"); - let props = ExecutionProps::new(); - let context = datafusion_expr::simplify::SimplifyContext::new(&props); + let context = datafusion_expr::simplify::SimplifyContext::default(); let Ok(ExprSimplifyResult::Simplified(simplified)) = ArrayHas::new().simplify(vec![haystack, needle], &context) else { @@ -767,8 +1017,7 @@ mod tests { let haystack = Expr::Literal(ScalarValue::List(Arc::new(haystack)), None); let needle = col("c"); - let props = ExecutionProps::new(); - let context = datafusion_expr::simplify::SimplifyContext::new(&props); + let context = datafusion_expr::simplify::SimplifyContext::default(); let Ok(ExprSimplifyResult::Simplified(simplified)) = ArrayHas::new().simplify(vec![haystack, needle], &context) else { @@ -783,8 +1032,7 @@ mod tests { let haystack = col("c1"); let needle = col("c2"); - let props = ExecutionProps::new(); - let context = datafusion_expr::simplify::SimplifyContext::new(&props); + let context = datafusion_expr::simplify::SimplifyContext::default(); let Ok(ExprSimplifyResult::Original(args)) = ArrayHas::new().simplify(vec![haystack, needle.clone()], &context) @@ -830,6 +1078,52 @@ mod tests { Ok(()) } + #[test] + fn test_array_has_sliced_list() -> Result<(), DataFusionError> { + // [[10, 20], [30, 40], [50, 60], [70, 80]] → slice(1,2) → [[30, 40], [50, 60]] + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(10), Some(20)]), + Some(vec![Some(30), Some(40)]), + Some(vec![Some(50), Some(60)]), + Some(vec![Some(70), Some(80)]), + ]); + let sliced = list.slice(1, 2); + let haystack_field = + Arc::new(Field::new("haystack", sliced.data_type().clone(), true)); + let needle_field = Arc::new(Field::new("needle", DataType::Int32, true)); + let return_field = Arc::new(Field::new("return", DataType::Boolean, true)); + + // Search for elements that exist only in sliced-away rows: + // 10 is in the prefix row, 70 is in the suffix row. + let invoke = |needle: i32| -> Result { + ArrayHas::new() + .invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(sliced.clone())), + ColumnarValue::Scalar(ScalarValue::Int32(Some(needle))), + ], + arg_fields: vec![ + Arc::clone(&haystack_field), + Arc::clone(&needle_field), + ], + number_rows: 2, + return_field: Arc::clone(&return_field), + config_options: Arc::new(ConfigOptions::default()), + })? + .into_array(2) + }; + + let output = invoke(10)?.as_boolean().clone(); + assert!(!output.value(0)); + assert!(!output.value(1)); + + let output = invoke(70)?.as_boolean().clone(); + assert!(!output.value(0)); + assert!(!output.value(1)); + + Ok(()) + } + #[test] fn test_array_has_list_null_haystack() -> Result<(), DataFusionError> { let haystack_field = Arc::new(Field::new("haystack", DataType::Null, true)); diff --git a/datafusion/functions-nested/src/arrays_zip.rs b/datafusion/functions-nested/src/arrays_zip.rs new file mode 100644 index 000000000000..2ac30d07046e --- /dev/null +++ b/datafusion/functions-nested/src/arrays_zip.rs @@ -0,0 +1,336 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for arrays_zip function. + +use crate::utils::make_scalar_function; +use arrow::array::{ + Array, ArrayRef, Capacities, ListArray, MutableArrayData, StructArray, new_null_array, +}; +use arrow::buffer::{NullBuffer, OffsetBuffer}; +use arrow::datatypes::DataType::{FixedSizeList, LargeList, List, Null}; +use arrow::datatypes::{DataType, Field, Fields}; +use datafusion_common::cast::{ + as_fixed_size_list_array, as_large_list_array, as_list_array, +}; +use datafusion_common::{Result, exec_err}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; +use std::any::Any; +use std::sync::Arc; + +/// Type-erased view of a list column (works for both List and LargeList). +/// Stores the information needed to iterate rows without re-downcasting. +struct ListColumnView { + /// The flat values array backing this list column. + values: ArrayRef, + /// Pre-computed per-row start offsets (length = num_rows + 1). + offsets: Vec, + /// Pre-computed null bitmap: true means the row is null. + is_null: Vec, +} + +make_udf_expr_and_func!( + ArraysZip, + arrays_zip, + "combines multiple arrays into a single array of structs.", + arrays_zip_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns an array of structs created by combining the elements of each input array at the same index. If the arrays have different lengths, shorter arrays are padded with NULLs.", + syntax_example = "arrays_zip(array1, array2[, ..., array_n])", + sql_example = r#"```sql +> select arrays_zip([1, 2, 3], ['a', 'b', 'c']); ++---------------------------------------------------+ +| arrays_zip([1, 2, 3], ['a', 'b', 'c']) | ++---------------------------------------------------+ +| [{c0: 1, c1: a}, {c0: 2, c1: b}, {c0: 3, c1: c}] | ++---------------------------------------------------+ +> select arrays_zip([1, 2], [3, 4, 5]); ++---------------------------------------------------+ +| arrays_zip([1, 2], [3, 4, 5]) | ++---------------------------------------------------+ +| [{c0: 1, c1: 3}, {c0: 2, c1: 4}, {c0: , c1: 5}] | ++---------------------------------------------------+ +```"#, + argument(name = "array1", description = "First array expression."), + argument(name = "array2", description = "Second array expression."), + argument(name = "array_n", description = "Subsequent array expressions.") +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArraysZip { + signature: Signature, + aliases: Vec, +} + +impl Default for ArraysZip { + fn default() -> Self { + Self::new() + } +} + +impl ArraysZip { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + aliases: vec![String::from("list_zip")], + } + } +} + +impl ScalarUDFImpl for ArraysZip { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "arrays_zip" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.is_empty() { + return exec_err!("arrays_zip requires at least two arguments"); + } + + let mut fields = Vec::with_capacity(arg_types.len()); + for (i, arg_type) in arg_types.iter().enumerate() { + let element_type = match arg_type { + List(field) | LargeList(field) | FixedSizeList(field, _) => { + field.data_type().clone() + } + Null => Null, + dt => { + return exec_err!("arrays_zip expects array arguments, got {dt}"); + } + }; + fields.push(Field::new(format!("c{i}"), element_type, true)); + } + + Ok(List(Arc::new(Field::new_list_field( + DataType::Struct(Fields::from(fields)), + true, + )))) + } + + fn invoke_with_args( + &self, + args: datafusion_expr::ScalarFunctionArgs, + ) -> Result { + make_scalar_function(arrays_zip_inner)(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +/// Core implementation for arrays_zip. +/// +/// Takes N list arrays and produces a list of structs where each struct +/// has one field per input array. If arrays within a row have different +/// lengths, shorter arrays are padded with NULLs. +/// Supports List, LargeList, and Null input types. +fn arrays_zip_inner(args: &[ArrayRef]) -> Result { + if args.len() < 2 { + return exec_err!("arrays_zip requires at least two arguments"); + } + + let num_rows = args[0].len(); + + // Build a type-erased ListColumnView for each argument. + // None means the argument is Null-typed (all nulls, no backing data). + let mut views: Vec> = Vec::with_capacity(args.len()); + let mut element_types: Vec = Vec::with_capacity(args.len()); + + for (i, arg) in args.iter().enumerate() { + match arg.data_type() { + List(field) => { + let arr = as_list_array(arg)?; + let raw_offsets = arr.value_offsets(); + let offsets: Vec = + raw_offsets.iter().map(|&o| o as usize).collect(); + let is_null = (0..num_rows).map(|row| arr.is_null(row)).collect(); + element_types.push(field.data_type().clone()); + views.push(Some(ListColumnView { + values: Arc::clone(arr.values()), + offsets, + is_null, + })); + } + LargeList(field) => { + let arr = as_large_list_array(arg)?; + let raw_offsets = arr.value_offsets(); + let offsets: Vec = + raw_offsets.iter().map(|&o| o as usize).collect(); + let is_null = (0..num_rows).map(|row| arr.is_null(row)).collect(); + element_types.push(field.data_type().clone()); + views.push(Some(ListColumnView { + values: Arc::clone(arr.values()), + offsets, + is_null, + })); + } + FixedSizeList(field, size) => { + let arr = as_fixed_size_list_array(arg)?; + let size = *size as usize; + let offsets: Vec = (0..=num_rows).map(|row| row * size).collect(); + let is_null = (0..num_rows).map(|row| arr.is_null(row)).collect(); + element_types.push(field.data_type().clone()); + views.push(Some(ListColumnView { + values: Arc::clone(arr.values()), + offsets, + is_null, + })); + } + Null => { + element_types.push(Null); + views.push(None); + } + dt => { + return exec_err!("arrays_zip argument {i} expected list type, got {dt}"); + } + } + } + + // Collect per-column values data for MutableArrayData builders. + let values_data: Vec<_> = views + .iter() + .map(|v| v.as_ref().map(|view| view.values.to_data())) + .collect(); + + let struct_fields: Fields = element_types + .iter() + .enumerate() + .map(|(i, dt)| Field::new(format!("c{i}"), dt.clone(), true)) + .collect::>() + .into(); + + // Create a MutableArrayData builder per column. For None (Null-typed) + // args we only need extend_nulls, so we track them separately. + let mut builders: Vec> = values_data + .iter() + .map(|vd| { + vd.as_ref().map(|data| { + MutableArrayData::with_capacities(vec![data], true, Capacities::Array(0)) + }) + }) + .collect(); + + let mut offsets: Vec = Vec::with_capacity(num_rows + 1); + offsets.push(0); + let mut null_mask: Vec = Vec::with_capacity(num_rows); + let mut total_values: usize = 0; + + // Process each row: compute per-array lengths, then copy values + // and pad shorter arrays with NULLs. + for row_idx in 0..num_rows { + let mut max_len: usize = 0; + let mut all_null = true; + + for view in views.iter().flatten() { + if !view.is_null[row_idx] { + all_null = false; + let len = view.offsets[row_idx + 1] - view.offsets[row_idx]; + max_len = max_len.max(len); + } + } + + if all_null { + null_mask.push(true); + offsets.push(*offsets.last().unwrap()); + continue; + } + null_mask.push(false); + + // Extend each column builder for this row. + for (col_idx, view) in views.iter().enumerate() { + match view { + Some(v) if !v.is_null[row_idx] => { + let start = v.offsets[row_idx]; + let end = v.offsets[row_idx + 1]; + let len = end - start; + let builder = builders[col_idx].as_mut().unwrap(); + builder.extend(0, start, end); + if len < max_len { + builder.extend_nulls(max_len - len); + } + } + _ => { + // Null list entry or None (Null-typed) arg — all nulls. + if let Some(builder) = builders[col_idx].as_mut() { + builder.extend_nulls(max_len); + } + } + } + } + + total_values += max_len; + let last = *offsets.last().unwrap(); + offsets.push(last + max_len as i32); + } + + // Assemble struct columns from builders. + let struct_columns: Vec = builders + .into_iter() + .zip(element_types.iter()) + .map(|(builder, elem_type)| match builder { + Some(b) => arrow::array::make_array(b.freeze()), + None => new_null_array( + if elem_type.is_null() { + &Null + } else { + elem_type + }, + total_values, + ), + }) + .collect(); + + let struct_array = StructArray::try_new(struct_fields, struct_columns, None)?; + + let null_buffer = if null_mask.iter().any(|&v| v) { + Some(NullBuffer::from( + null_mask.iter().map(|v| !v).collect::>(), + )) + } else { + None + }; + + let result = ListArray::try_new( + Arc::new(Field::new_list_field( + struct_array.data_type().clone(), + true, + )), + OffsetBuffer::new(offsets.into()), + Arc::new(struct_array), + null_buffer, + )?; + + Ok(Arc::new(result)) +} diff --git a/datafusion/functions-nested/src/cardinality.rs b/datafusion/functions-nested/src/cardinality.rs index c467686b865c..8953a8568f4a 100644 --- a/datafusion/functions-nested/src/cardinality.rs +++ b/datafusion/functions-nested/src/cardinality.rs @@ -120,7 +120,7 @@ impl ScalarUDFImpl for Cardinality { fn cardinality_inner(args: &[ArrayRef]) -> Result { let [array] = take_function_args("cardinality", args)?; match array.data_type() { - Null => Ok(Arc::new(UInt64Array::from_value(0, array.len()))), + Null => Ok(Arc::new(UInt64Array::new_null(array.len()))), List(_) => { let list_array = as_list_array(array)?; generic_list_cardinality::(list_array) @@ -152,9 +152,14 @@ fn generic_list_cardinality( ) -> Result { let result = array .iter() - .map(|arr| match crate::utils::compute_array_dims(arr)? { - Some(vector) => Ok(Some(vector.iter().map(|x| x.unwrap()).product::())), - None => Ok(None), + .map(|arr| match arr { + Some(arr) if arr.is_empty() => Ok(Some(0u64)), + arr => match crate::utils::compute_array_dims(arr)? { + Some(vector) => { + Ok(Some(vector.iter().map(|x| x.unwrap()).product::())) + } + None => Ok(None), + }, }) .collect::>()?; Ok(Arc::new(result) as ArrayRef) diff --git a/datafusion/functions-nested/src/except.rs b/datafusion/functions-nested/src/except.rs index a8ac997ce33e..19a4e9573e35 100644 --- a/datafusion/functions-nested/src/except.rs +++ b/datafusion/functions-nested/src/except.rs @@ -15,11 +15,12 @@ // specific language governing permissions and limitations // under the License. -//! [`ScalarUDFImpl`] definitions for array_except function. +//! [`ScalarUDFImpl`] definition for array_except function. use crate::utils::{check_datatypes, make_scalar_function}; +use arrow::array::new_null_array; use arrow::array::{Array, ArrayRef, GenericListArray, OffsetSizeTrait, cast::AsArray}; -use arrow::buffer::OffsetBuffer; +use arrow::buffer::{NullBuffer, OffsetBuffer}; use arrow::datatypes::{DataType, FieldRef}; use arrow::row::{RowConverter, SortField}; use datafusion_common::utils::{ListCoercion, take_function_args}; @@ -28,6 +29,7 @@ use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; +use itertools::Itertools; use std::any::Any; use std::sync::Arc; @@ -104,8 +106,11 @@ impl ScalarUDFImpl for ArrayExcept { } fn return_type(&self, arg_types: &[DataType]) -> Result { - match (&arg_types[0].clone(), &arg_types[1].clone()) { - (DataType::Null, _) | (_, DataType::Null) => Ok(arg_types[0].clone()), + match (&arg_types[0], &arg_types[1]) { + (DataType::Null, DataType::Null) => { + Ok(DataType::new_list(DataType::Null, true)) + } + (DataType::Null, dt) | (dt, DataType::Null) => Ok(dt.clone()), (dt, _) => Ok(dt.clone()), } } @@ -129,8 +134,16 @@ impl ScalarUDFImpl for ArrayExcept { fn array_except_inner(args: &[ArrayRef]) -> Result { let [array1, array2] = take_function_args("array_except", args)?; + let len = array1.len(); match (array1.data_type(), array2.data_type()) { - (DataType::Null, _) | (_, DataType::Null) => Ok(array1.to_owned()), + (DataType::Null, DataType::Null) => Ok(new_null_array( + &DataType::new_list(DataType::Null, true), + len, + )), + (DataType::Null, dt @ DataType::List(_)) + | (DataType::Null, dt @ DataType::LargeList(_)) + | (dt @ DataType::List(_), DataType::Null) + | (dt @ DataType::LargeList(_), DataType::Null) => Ok(new_null_array(dt, len)), (DataType::List(field), DataType::List(_)) => { check_datatypes("array_except", &[array1, array2])?; let list1 = array1.as_list::(); @@ -169,15 +182,27 @@ fn general_except( let mut rows = Vec::with_capacity(l_values.num_rows()); let mut dedup = HashSet::new(); - for (l_w, r_w) in l.offsets().windows(2).zip(r.offsets().windows(2)) { - let l_slice = l_w[0].as_usize()..l_w[1].as_usize(); - let r_slice = r_w[0].as_usize()..r_w[1].as_usize(); - for i in r_slice { - let right_row = r_values.row(i); + let nulls = NullBuffer::union(l.nulls(), r.nulls()); + + let l_offsets_iter = l.offsets().iter().tuple_windows(); + let r_offsets_iter = r.offsets().iter().tuple_windows(); + for (list_index, ((l_start, l_end), (r_start, r_end))) in + l_offsets_iter.zip(r_offsets_iter).enumerate() + { + if nulls + .as_ref() + .is_some_and(|nulls| nulls.is_null(list_index)) + { + offsets.push(OffsetSize::usize_as(rows.len())); + continue; + } + + for element_index in r_start.as_usize()..r_end.as_usize() { + let right_row = r_values.row(element_index); dedup.insert(right_row); } - for i in l_slice { - let left_row = l_values.row(i); + for element_index in l_start.as_usize()..l_end.as_usize() { + let left_row = l_values.row(element_index); if dedup.insert(left_row) { rows.push(left_row); } @@ -192,7 +217,7 @@ fn general_except( field.to_owned(), OffsetBuffer::new(offsets.into()), values.to_owned(), - l.nulls().cloned(), + nulls, )) } else { internal_err!("array_except failed to convert rows") diff --git a/datafusion/functions-nested/src/flatten.rs b/datafusion/functions-nested/src/flatten.rs index 33b3e102ae0b..8c21348507d2 100644 --- a/datafusion/functions-nested/src/flatten.rs +++ b/datafusion/functions-nested/src/flatten.rs @@ -208,7 +208,7 @@ fn flatten_inner(args: &[ArrayRef]) -> Result { } Null => Ok(Arc::clone(array)), _ => { - exec_err!("flatten does not support type '{:?}'", array.data_type()) + exec_err!("flatten does not support type '{}'", array.data_type()) } } } diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index ed9e1af4eaa8..99b25ec96454 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -24,8 +24,6 @@ // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -// https://github.com/apache/datafusion/issues/18881 -#![deny(clippy::allow_attributes)] //! Nested type Functions for [DataFusion]. //! @@ -40,6 +38,7 @@ pub mod macros; pub mod array_has; +pub mod arrays_zip; pub mod cardinality; pub mod concat; pub mod dimension; @@ -81,6 +80,7 @@ pub mod expr_fn { pub use super::array_has::array_has; pub use super::array_has::array_has_all; pub use super::array_has::array_has_any; + pub use super::arrays_zip::arrays_zip; pub use super::cardinality::cardinality; pub use super::concat::array_append; pub use super::concat::array_concat; @@ -161,6 +161,7 @@ pub fn all_default_nested_functions() -> Vec> { set_ops::array_distinct_udf(), set_ops::array_intersect_udf(), set_ops::array_union_udf(), + arrays_zip::arrays_zip_udf(), position::array_position_udf(), position::array_positions_udf(), remove::array_remove_udf(), diff --git a/datafusion/functions-nested/src/make_array.rs b/datafusion/functions-nested/src/make_array.rs index 410a545853ac..bc899126fb64 100644 --- a/datafusion/functions-nested/src/make_array.rs +++ b/datafusion/functions-nested/src/make_array.rs @@ -31,7 +31,6 @@ use arrow::datatypes::DataType; use arrow::datatypes::{DataType::Null, Field}; use datafusion_common::utils::SingleRowListArrayBuilder; use datafusion_common::{Result, plan_err}; -use datafusion_expr::TypeSignature; use datafusion_expr::binary::{ try_type_union_resolution_with_struct, type_union_resolution, }; @@ -80,10 +79,7 @@ impl Default for MakeArray { impl MakeArray { pub fn new() -> Self { Self { - signature: Signature::one_of( - vec![TypeSignature::Nullary, TypeSignature::UserDefined], - Volatility::Immutable, - ), + signature: Signature::user_defined(Volatility::Immutable), aliases: vec![String::from("make_list")], } } @@ -125,7 +121,11 @@ impl ScalarUDFImpl for MakeArray { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - coerce_types_inner(arg_types, self.name()) + if arg_types.is_empty() { + Ok(vec![]) + } else { + coerce_types_inner(arg_types, self.name()) + } } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions-nested/src/map.rs b/datafusion/functions-nested/src/map.rs index a96bbc0589e3..7df131cf5e27 100644 --- a/datafusion/functions-nested/src/map.rs +++ b/datafusion/functions-nested/src/map.rs @@ -119,7 +119,7 @@ fn get_first_array_ref(columnar_value: &ColumnarValue) -> Result { ScalarValue::List(array) => Ok(array.value(0)), ScalarValue::LargeList(array) => Ok(array.value(0)), ScalarValue::FixedSizeList(array) => Ok(array.value(0)), - _ => exec_err!("Expected array, got {:?}", value), + _ => exec_err!("Expected array, got {}", value), }, ColumnarValue::Array(array) => Ok(array.to_owned()), } diff --git a/datafusion/functions-nested/src/planner.rs b/datafusion/functions-nested/src/planner.rs index afb18a44f48a..e96fdb7d4bac 100644 --- a/datafusion/functions-nested/src/planner.rs +++ b/datafusion/functions-nested/src/planner.rs @@ -37,7 +37,7 @@ use std::sync::Arc; use crate::map::map_udf; use crate::{ - array_has::{array_has_all, array_has_udf}, + array_has::array_has_all, expr_fn::{array_append, array_concat, array_prepend}, extract::{array_element, array_slice}, make_array::make_array, @@ -120,20 +120,6 @@ impl ExprPlanner for NestedFunctionPlanner { ScalarFunction::new_udf(map_udf(), vec![keys, values]), ))) } - - fn plan_any(&self, expr: RawBinaryExpr) -> Result> { - if expr.op == BinaryOperator::Eq { - Ok(PlannerResult::Planned(Expr::ScalarFunction( - ScalarFunction::new_udf( - array_has_udf(), - // left and right are reversed here so `needle=any(haystack)` -> `array_has(haystack, needle)` - vec![expr.right, expr.left], - ), - ))) - } else { - plan_err!("Unsupported AnyOp: '{}', only '=' is supported", expr.op) - } - } } #[derive(Debug)] diff --git a/datafusion/functions-nested/src/position.rs b/datafusion/functions-nested/src/position.rs index d085fa29cc7e..0214b1552bc9 100644 --- a/datafusion/functions-nested/src/position.rs +++ b/datafusion/functions-nested/src/position.rs @@ -17,11 +17,13 @@ //! [`ScalarUDFImpl`] definitions for array_position and array_positions functions. +use arrow::array::Scalar; use arrow::datatypes::DataType; use arrow::datatypes::{ DataType::{LargeList, List, UInt64}, Field, }; +use datafusion_common::ScalarValue; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -37,9 +39,7 @@ use arrow::array::{ use datafusion_common::cast::{ as_generic_list_array, as_int64_array, as_large_list_array, as_list_array, }; -use datafusion_common::{ - Result, assert_or_internal_err, exec_err, utils::take_function_args, -}; +use datafusion_common::{Result, exec_err, utils::take_function_args}; use itertools::Itertools; use crate::utils::{compare_element_to_list, make_scalar_function}; @@ -54,7 +54,7 @@ make_udf_expr_and_func!( #[user_doc( doc_section(label = "Array Functions"), - description = "Returns the position of the first occurrence of the specified element in the array, or NULL if not found.", + description = "Returns the position of the first occurrence of the specified element in the array, or NULL if not found. Comparisons are done using `IS DISTINCT FROM` semantics, so NULL is considered to match NULL.", syntax_example = "array_position(array, element)\narray_position(array, element, index)", sql_example = r#"```sql > select array_position([1, 2, 2, 3, 1, 4], 2); @@ -74,10 +74,7 @@ make_udf_expr_and_func!( name = "array", description = "Array expression. Can be a constant, column, or function, and any combination of array operators." ), - argument( - name = "element", - description = "Element to search for position in the array." - ), + argument(name = "element", description = "Element to search for in the array."), argument( name = "index", description = "Index at which to start searching (1-indexed)." @@ -129,7 +126,54 @@ impl ScalarUDFImpl for ArrayPosition { &self, args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_position_inner)(&args.args) + let [first_arg, second_arg, third_arg @ ..] = args.args.as_slice() else { + return exec_err!("array_position expects two or three arguments"); + }; + + match second_arg { + ColumnarValue::Scalar(scalar_element) => { + // Nested element types (List, Struct) can't use the fast path + // (because Arrow's `non_distinct` does not support them). + if scalar_element.data_type().is_nested() { + return make_scalar_function(array_position_inner)(&args.args); + } + + // Determine batch length from whichever argument is columnar; + // if all inputs are scalar, batch length is 1. + let (num_rows, all_inputs_scalar) = match (first_arg, third_arg.first()) { + (ColumnarValue::Array(a), _) => (a.len(), false), + (_, Some(ColumnarValue::Array(a))) => (a.len(), false), + _ => (1, true), + }; + + let element_arr = scalar_element.to_array_of_size(1)?; + let haystack = first_arg.to_array(num_rows)?; + let arr_from = resolve_start_from(third_arg.first(), num_rows)?; + + let result = match haystack.data_type() { + List(_) => { + let list = as_generic_list_array::(&haystack)?; + array_position_scalar::(list, &element_arr, &arr_from) + } + LargeList(_) => { + let list = as_generic_list_array::(&haystack)?; + array_position_scalar::(list, &element_arr, &arr_from) + } + t => exec_err!("array_position does not support type '{t}'."), + }?; + + if all_inputs_scalar { + Ok(ColumnarValue::Scalar(ScalarValue::try_from_array( + &result, 0, + )?)) + } else { + Ok(ColumnarValue::Array(result)) + } + } + ColumnarValue::Array(_) => { + make_scalar_function(array_position_inner)(&args.args) + } + } } fn aliases(&self) -> &[String] { @@ -152,6 +196,109 @@ fn array_position_inner(args: &[ArrayRef]) -> Result { } } +/// Resolves the optional `start_from` argument into a `Vec` of +/// 0-indexed starting positions. +fn resolve_start_from( + third_arg: Option<&ColumnarValue>, + num_rows: usize, +) -> Result> { + match third_arg { + None => Ok(vec![0i64; num_rows]), + Some(ColumnarValue::Scalar(ScalarValue::Int64(Some(v)))) => { + Ok(vec![v - 1; num_rows]) + } + Some(ColumnarValue::Scalar(s)) => { + exec_err!("array_position expected Int64 for start_from, got {s}") + } + Some(ColumnarValue::Array(a)) => { + Ok(as_int64_array(a)?.values().iter().map(|&x| x - 1).collect()) + } + } +} + +/// Fast path for `array_position` when the element is a scalar. +/// +/// Performs a single bulk `not_distinct` comparison of the scalar element +/// against the entire flattened values buffer, then walks the result bitmap +/// using offsets to find per-row first-match positions. +fn array_position_scalar( + list_array: &GenericListArray, + element_array: &ArrayRef, + arr_from: &[i64], // 0-indexed +) -> Result { + crate::utils::check_datatypes( + "array_position", + &[list_array.values(), element_array], + )?; + + if list_array.len() == 0 { + return Ok(Arc::new(UInt64Array::new_null(0))); + } + + let element_datum = Scalar::new(Arc::clone(element_array)); + let validity = list_array.nulls(); + + // Only compare the visible portion of the values buffer, which avoids + // wasted work for sliced ListArrays. + let offsets = list_array.offsets(); + let first_offset = offsets[0].as_usize(); + let last_offset = offsets[list_array.len()].as_usize(); + let visible_values = list_array + .values() + .slice(first_offset, last_offset - first_offset); + + // `not_distinct` treats NULL=NULL as true, matching the semantics of + // `array_position` + let eq_array = arrow_ord::cmp::not_distinct(&visible_values, &element_datum)?; + let eq_bits = eq_array.values(); + + let mut result: Vec> = Vec::with_capacity(list_array.len()); + let mut matches = eq_bits.set_indices().peekable(); + + // Match positions are relative to visible_values (0-based), so + // subtract first_offset from each offset when comparing. + for i in 0..list_array.len() { + let start = offsets[i].as_usize() - first_offset; + let end = offsets[i + 1].as_usize() - first_offset; + + if validity.is_some_and(|v| v.is_null(i)) { + // Null row -> null output; advance past matches in range + while matches.peek().is_some_and(|&p| p < end) { + matches.next(); + } + result.push(None); + continue; + } + + let from = arr_from[i]; + let row_len = end - start; + if !(from >= 0 && (from as usize) <= row_len) { + return exec_err!("start_from out of bounds: {}", from + 1); + } + let search_start = start + from as usize; + + // Advance past matches before search_start + while matches.peek().is_some_and(|&p| p < search_start) { + matches.next(); + } + + // First match in [search_start, end)? + if matches.peek().is_some_and(|&p| p < end) { + let pos = *matches.peek().unwrap(); + result.push(Some((pos - start + 1) as u64)); + // Advance past remaining matches in this row + while matches.peek().is_some_and(|&p| p < end) { + matches.next(); + } + } else { + result.push(None); + } + } + + debug_assert_eq!(result.len(), list_array.len()); + Ok(Arc::new(UInt64Array::from(result))) +} + fn general_position_dispatch(args: &[ArrayRef]) -> Result { let list_array = as_generic_list_array::(&args[0])?; let element_array = &args[1]; @@ -164,7 +311,6 @@ fn general_position_dispatch(args: &[ArrayRef]) -> Result>() @@ -172,13 +318,11 @@ fn general_position_dispatch(args: &[ArrayRef]) -> Result= 0 && (from as usize) <= arr.len()), - "start_from index out of bounds" - ); + if !arr.is_none_or(|arr| from >= 0 && (from as usize) <= arr.len()) { + return exec_err!("start_from out of bounds: {}", from + 1); + } } generic_position::(list_array, element_array, &arr_from) @@ -340,3 +484,60 @@ fn general_positions( ListArray::from_iter_primitive::(data), )) } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::AsArray; + use arrow::datatypes::Int32Type; + use datafusion_common::config::ConfigOptions; + use datafusion_expr::ScalarFunctionArgs; + + #[test] + fn test_array_position_sliced_list() -> Result<()> { + // [[10, 20], [30, 40], [50, 60], [70, 80]] → slice(1,2) → [[30, 40], [50, 60]] + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(10), Some(20)]), + Some(vec![Some(30), Some(40)]), + Some(vec![Some(50), Some(60)]), + Some(vec![Some(70), Some(80)]), + ]); + let sliced = list.slice(1, 2); + let haystack_field = + Arc::new(Field::new("haystack", sliced.data_type().clone(), true)); + let needle_field = Arc::new(Field::new("needle", DataType::Int32, true)); + let return_field = Arc::new(Field::new("return", UInt64, true)); + + // Search for elements that exist only in sliced-away rows: + // 10 is in the prefix row, 70 is in the suffix row. + let invoke = |needle: i32| -> Result { + ArrayPosition::new() + .invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(sliced.clone())), + ColumnarValue::Scalar(ScalarValue::Int32(Some(needle))), + ], + arg_fields: vec![ + Arc::clone(&haystack_field), + Arc::clone(&needle_field), + ], + number_rows: 2, + return_field: Arc::clone(&return_field), + config_options: Arc::new(ConfigOptions::default()), + })? + .into_array(2) + }; + + let output = invoke(10)?; + let output = output.as_primitive::(); + assert!(output.is_null(0)); + assert!(output.is_null(1)); + + let output = invoke(70)?; + let output = output.as_primitive::(); + assert!(output.is_null(0)); + assert!(output.is_null(1)); + + Ok(()) + } +} diff --git a/datafusion/functions-nested/src/range.rs b/datafusion/functions-nested/src/range.rs index aae641ceeb35..307067b9c997 100644 --- a/datafusion/functions-nested/src/range.rs +++ b/datafusion/functions-nested/src/range.rs @@ -392,20 +392,27 @@ impl Range { } let stop = if !self.include_upper_bound { - Date32Type::subtract_month_day_nano(stop, step) + Date32Type::subtract_month_day_nano_opt(stop, step).ok_or_else(|| { + exec_datafusion_err!( + "Cannot generate date range where stop {} - {step:?}) overflows", + date32_to_string(stop) + ) + })? } else { stop }; let neg = months < 0 || days < 0; - let mut new_date = start; + let mut new_date = Some(start); let values = from_fn(|| { - if (neg && new_date < stop) || (!neg && new_date > stop) { + let Some(current_date) = new_date else { + return None; // previous overflow + }; + if (neg && current_date < stop) || (!neg && current_date > stop) { None } else { - let current_date = new_date; - new_date = Date32Type::add_month_day_nano(new_date, step); + new_date = Date32Type::add_month_day_nano_opt(current_date, step); Some(Some(current_date)) } }); @@ -578,3 +585,11 @@ fn parse_tz(tz: &Option<&str>) -> Result { Tz::from_str(tz) .map_err(|op| exec_datafusion_err!("failed to parse timezone {tz}: {:?}", op)) } + +fn date32_to_string(value: i32) -> String { + if let Some(d) = Date32Type::to_naive_date_opt(value) { + format!("{value} ({d})") + } else { + format!("{value} (unknown date)") + } +} diff --git a/datafusion/functions-nested/src/remove.rs b/datafusion/functions-nested/src/remove.rs index 41c06cb9c4cb..3d4076800e1e 100644 --- a/datafusion/functions-nested/src/remove.rs +++ b/datafusion/functions-nested/src/remove.rs @@ -20,8 +20,8 @@ use crate::utils; use crate::utils::make_scalar_function; use arrow::array::{ - Array, ArrayRef, BooleanArray, GenericListArray, OffsetSizeTrait, cast::AsArray, - new_empty_array, + Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, NullBufferBuilder, + OffsetSizeTrait, cast::AsArray, make_array, }; use arrow::buffer::OffsetBuffer; use arrow::datatypes::{DataType, FieldRef}; @@ -377,73 +377,84 @@ fn general_remove( ); } }; - let data_type = list_field.data_type(); - let mut new_values = vec![]; + let original_data = list_array.values().to_data(); // Build up the offsets for the final output array let mut offsets = Vec::::with_capacity(arr_n.len() + 1); offsets.push(OffsetSize::zero()); - // n is the number of elements to remove in this row - for (row_index, (list_array_row, n)) in - list_array.iter().zip(arr_n.iter()).enumerate() - { - match list_array_row { - Some(list_array_row) => { - let eq_array = utils::compare_element_to_list( - &list_array_row, - element_array, - row_index, - false, - )?; - - // We need to keep at most first n elements as `false`, which represent the elements to remove. - let eq_array = if eq_array.false_count() < *n as usize { - eq_array - } else { - let mut count = 0; - eq_array - .iter() - .map(|e| { - // Keep first n `false` elements, and reverse other elements to `true`. - if let Some(false) = e { - if count < *n { - count += 1; - e - } else { - Some(true) - } - } else { - e - } - }) - .collect::() - }; - - let filtered_array = arrow::compute::filter(&list_array_row, &eq_array)?; - offsets.push( - offsets[row_index] + OffsetSize::usize_as(filtered_array.len()), - ); - new_values.push(filtered_array); - } - None => { - // Null element results in a null row (no new offsets) - offsets.push(offsets[row_index]); + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data], + false, + Capacities::Array(original_data.len()), + ); + let mut valid = NullBufferBuilder::new(list_array.len()); + + for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { + if list_array.is_null(row_index) || element_array.is_null(row_index) { + offsets.push(offsets[row_index]); + valid.append_null(); + continue; + } + + let start = offset_window[0].to_usize().unwrap(); + let end = offset_window[1].to_usize().unwrap(); + // n is the number of elements to remove in this row + let n = arr_n[row_index]; + + // compare each element in the list, `false` means the element matches and should be removed + let eq_array = utils::compare_element_to_list( + &list_array.value(row_index), + element_array, + row_index, + false, + )?; + + let num_to_remove = eq_array.false_count(); + + // Fast path: no elements to remove, copy entire row + if num_to_remove == 0 { + mutable.extend(0, start, end); + offsets.push(offsets[row_index] + OffsetSize::usize_as(end - start)); + valid.append_non_null(); + continue; + } + + // Remove at most `n` matching elements + let max_removals = n.min(num_to_remove as i64); + let mut removed = 0i64; + let mut copied = 0usize; + // marks the beginning of a range of elements pending to be copied. + let mut pending_batch_to_retain: Option = None; + for (i, keep) in eq_array.iter().enumerate() { + if keep == Some(false) && removed < max_removals { + // Flush pending batch before skipping this element + if let Some(bs) = pending_batch_to_retain { + mutable.extend(0, start + bs, start + i); + copied += i - bs; + pending_batch_to_retain = None; + } + removed += 1; + } else if pending_batch_to_retain.is_none() { + pending_batch_to_retain = Some(i); } } - } - let values = if new_values.is_empty() { - new_empty_array(data_type) - } else { - let new_values = new_values.iter().map(|x| x.as_ref()).collect::>(); - arrow::compute::concat(&new_values)? - }; + // Flush remaining batch + if let Some(bs) = pending_batch_to_retain { + mutable.extend(0, start + bs, start + eq_array.len()); + copied += eq_array.len() - bs; + } + + offsets.push(offsets[row_index] + OffsetSize::usize_as(copied)); + valid.append_non_null(); + } + let new_values = make_array(mutable.freeze()); Ok(Arc::new(GenericListArray::::try_new( Arc::clone(list_field), OffsetBuffer::new(offsets.into()), - values, - list_array.nulls().cloned(), + new_values, + valid.finish(), )?)) } diff --git a/datafusion/functions-nested/src/repeat.rs b/datafusion/functions-nested/src/repeat.rs index a121b5f03162..5e78a4d0f601 100644 --- a/datafusion/functions-nested/src/repeat.rs +++ b/datafusion/functions-nested/src/repeat.rs @@ -19,22 +19,23 @@ use crate::utils::make_scalar_function; use arrow::array::{ - Array, ArrayRef, Capacities, GenericListArray, ListArray, MutableArrayData, - OffsetSizeTrait, UInt64Array, new_null_array, + Array, ArrayRef, BooleanBufferBuilder, GenericListArray, Int64Array, OffsetSizeTrait, + UInt64Array, }; -use arrow::buffer::OffsetBuffer; +use arrow::buffer::{NullBuffer, OffsetBuffer}; use arrow::compute; -use arrow::compute::cast; use arrow::datatypes::DataType; use arrow::datatypes::{ DataType::{LargeList, List}, Field, }; -use datafusion_common::cast::{as_large_list_array, as_list_array, as_uint64_array}; -use datafusion_common::{Result, exec_err, utils::take_function_args}; +use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; +use datafusion_common::types::{NativeType, logical_int64}; +use datafusion_common::{DataFusionError, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; +use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; @@ -89,7 +90,17 @@ impl Default for ArrayRepeat { impl ArrayRepeat { pub fn new() -> Self { Self { - signature: Signature::user_defined(Volatility::Immutable), + signature: Signature::coercible( + vec![ + Coercion::new_exact(TypeSignatureClass::Any), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![TypeSignatureClass::Integer], + NativeType::Int64, + ), + ], + Volatility::Immutable, + ), aliases: vec![String::from("list_repeat")], } } @@ -109,10 +120,17 @@ impl ScalarUDFImpl for ArrayRepeat { } fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(List(Arc::new(Field::new_list_field( - arg_types[0].clone(), - true, - )))) + let element_type = &arg_types[0]; + match element_type { + LargeList(_) => Ok(LargeList(Arc::new(Field::new_list_field( + element_type.clone(), + true, + )))), + _ => Ok(List(Arc::new(Field::new_list_field( + element_type.clone(), + true, + )))), + } } fn invoke_with_args( @@ -126,23 +144,6 @@ impl ScalarUDFImpl for ArrayRepeat { &self.aliases } - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - let [first_type, second_type] = take_function_args(self.name(), arg_types)?; - - // Coerce the second argument to Int64/UInt64 if it's a numeric type - let second = match second_type { - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { - DataType::Int64 - } - DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { - DataType::UInt64 - } - _ => return exec_err!("count must be an integer type"), - }; - - Ok(vec![first_type.clone(), second]) - } - fn documentation(&self) -> Option<&Documentation> { self.doc() } @@ -150,15 +151,7 @@ impl ScalarUDFImpl for ArrayRepeat { fn array_repeat_inner(args: &[ArrayRef]) -> Result { let element = &args[0]; - let count_array = &args[1]; - - let count_array = match count_array.data_type() { - DataType::Int64 => &cast(count_array, &DataType::UInt64)?, - DataType::UInt64 => count_array, - _ => return exec_err!("count must be an integer type"), - }; - - let count_array = as_uint64_array(count_array)?; + let count_array = as_int64_array(&args[1])?; match element.data_type() { List(_) => { @@ -187,45 +180,46 @@ fn array_repeat_inner(args: &[ArrayRef]) -> Result { /// ``` fn general_repeat( array: &ArrayRef, - count_array: &UInt64Array, + count_array: &Int64Array, ) -> Result { - let data_type = array.data_type(); - let mut new_values = vec![]; - - let count_vec = count_array - .values() - .to_vec() - .iter() - .map(|x| *x as usize) - .collect::>(); - - for (row_index, &count) in count_vec.iter().enumerate() { - let repeated_array = if array.is_null(row_index) { - new_null_array(data_type, count) - } else { - let original_data = array.to_data(); - let capacity = Capacities::Array(count); - let mut mutable = - MutableArrayData::with_capacities(vec![&original_data], false, capacity); - - for _ in 0..count { - mutable.extend(0, row_index, row_index + 1); - } - - let data = mutable.freeze(); - arrow::array::make_array(data) - }; - new_values.push(repeated_array); + let total_repeated_values: usize = (0..count_array.len()) + .map(|i| get_count_with_validity(count_array, i)) + .sum(); + + let mut take_indices = Vec::with_capacity(total_repeated_values); + let mut offsets = Vec::with_capacity(count_array.len() + 1); + offsets.push(O::zero()); + let mut running_offset = 0usize; + + for idx in 0..count_array.len() { + let count = get_count_with_validity(count_array, idx); + running_offset = running_offset.checked_add(count).ok_or_else(|| { + DataFusionError::Execution( + "array_repeat: running_offset overflowed usize".to_string(), + ) + })?; + let offset = O::from_usize(running_offset).ok_or_else(|| { + DataFusionError::Execution(format!( + "array_repeat: offset {running_offset} exceeds the maximum value for offset type" + )) + })?; + offsets.push(offset); + take_indices.extend(std::iter::repeat_n(idx as u64, count)); } - let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); - let values = compute::concat(&new_values)?; + // Build the flattened values + let repeated_values = compute::take( + array.as_ref(), + &UInt64Array::from_iter_values(take_indices), + None, + )?; + // Construct final ListArray Ok(Arc::new(GenericListArray::::try_new( - Arc::new(Field::new_list_field(data_type.to_owned(), true)), - OffsetBuffer::from_lengths(count_vec), - values, - None, + Arc::new(Field::new_list_field(array.data_type().to_owned(), true)), + OffsetBuffer::new(offsets.into()), + repeated_values, + count_array.nulls().cloned(), )?)) } @@ -241,58 +235,95 @@ fn general_repeat( /// ``` fn general_list_repeat( list_array: &GenericListArray, - count_array: &UInt64Array, + count_array: &Int64Array, ) -> Result { - let data_type = list_array.data_type(); - let value_type = list_array.value_type(); - let mut new_values = vec![]; - - let count_vec = count_array - .values() - .to_vec() - .iter() - .map(|x| *x as usize) - .collect::>(); - - for (list_array_row, &count) in list_array.iter().zip(count_vec.iter()) { - let list_arr = match list_array_row { - Some(list_array_row) => { - let original_data = list_array_row.to_data(); - let capacity = Capacities::Array(original_data.len() * count); - let mut mutable = MutableArrayData::with_capacities( - vec![&original_data], - false, - capacity, - ); - - for _ in 0..count { - mutable.extend(0, 0, original_data.len()); - } - - let data = mutable.freeze(); - let repeated_array = arrow::array::make_array(data); - - let list_arr = GenericListArray::::try_new( - Arc::new(Field::new_list_field(value_type.clone(), true)), - OffsetBuffer::::from_lengths(vec![original_data.len(); count]), - repeated_array, - None, - )?; - Arc::new(list_arr) as ArrayRef + let list_offsets = list_array.value_offsets(); + + // calculate capacities for pre-allocation + let mut outer_total = 0usize; + let mut inner_total = 0usize; + for i in 0..count_array.len() { + let count = get_count_with_validity(count_array, i); + if count > 0 { + outer_total += count; + if list_array.is_valid(i) { + let len = list_offsets[i + 1].to_usize().unwrap() + - list_offsets[i].to_usize().unwrap(); + inner_total += len * count; } - None => new_null_array(data_type, count), - }; - new_values.push(list_arr); + } } - let lengths = new_values.iter().map(|a| a.len()).collect::>(); - let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); - let values = compute::concat(&new_values)?; + // Build inner structures + let mut inner_offsets = Vec::with_capacity(outer_total + 1); + let mut take_indices = Vec::with_capacity(inner_total); + let mut inner_nulls = BooleanBufferBuilder::new(outer_total); + let mut inner_running = 0usize; + inner_offsets.push(O::zero()); + + for row_idx in 0..count_array.len() { + let count = get_count_with_validity(count_array, row_idx); + let list_is_valid = list_array.is_valid(row_idx); + let start = list_offsets[row_idx].to_usize().unwrap(); + let end = list_offsets[row_idx + 1].to_usize().unwrap(); + let row_len = end - start; + + for _ in 0..count { + inner_running = inner_running.checked_add(row_len).ok_or_else(|| { + DataFusionError::Execution( + "array_repeat: inner offset overflowed usize".to_string(), + ) + })?; + let offset = O::from_usize(inner_running).ok_or_else(|| { + DataFusionError::Execution(format!( + "array_repeat: offset {inner_running} exceeds the maximum value for offset type" + )) + })?; + inner_offsets.push(offset); + inner_nulls.append(list_is_valid); + if list_is_valid { + take_indices.extend(start as u64..end as u64); + } + } + } - Ok(Arc::new(ListArray::try_new( - Arc::new(Field::new_list_field(data_type.to_owned(), true)), - OffsetBuffer::::from_lengths(lengths), - values, + // Build inner ListArray + let inner_values = compute::take( + list_array.values().as_ref(), + &UInt64Array::from_iter_values(take_indices), None, + )?; + let inner_list = GenericListArray::::try_new( + Arc::new(Field::new_list_field(list_array.value_type().clone(), true)), + OffsetBuffer::new(inner_offsets.into()), + inner_values, + Some(NullBuffer::new(inner_nulls.finish())), + )?; + + // Build outer ListArray + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new_list_field( + list_array.data_type().to_owned(), + true, + )), + OffsetBuffer::::from_lengths( + count_array + .iter() + .map(|c| c.map(|v| if v > 0 { v as usize } else { 0 }).unwrap_or(0)), + ), + Arc::new(inner_list), + count_array.nulls().cloned(), )?)) } + +/// Helper function to get count from count_array at given index +/// Return 0 for null values or non-positive count. +#[inline] +fn get_count_with_validity(count_array: &Int64Array, idx: usize) -> usize { + if count_array.is_null(idx) { + 0 + } else { + let c = count_array.value(idx); + if c > 0 { c as usize } else { 0 } + } +} diff --git a/datafusion/functions-nested/src/set_ops.rs b/datafusion/functions-nested/src/set_ops.rs index 69a220e125c0..150559111fef 100644 --- a/datafusion/functions-nested/src/set_ops.rs +++ b/datafusion/functions-nested/src/set_ops.rs @@ -19,11 +19,9 @@ use crate::utils::make_scalar_function; use arrow::array::{ - Array, ArrayRef, GenericListArray, LargeListArray, ListArray, OffsetSizeTrait, - new_null_array, + Array, ArrayRef, GenericListArray, OffsetSizeTrait, new_empty_array, new_null_array, }; use arrow::buffer::{NullBuffer, OffsetBuffer}; -use arrow::compute; use arrow::datatypes::DataType::{LargeList, List, Null}; use arrow::datatypes::{DataType, Field, FieldRef}; use arrow::row::{RowConverter, SortField}; @@ -36,9 +34,8 @@ use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; -use itertools::Itertools; +use hashbrown::HashSet; use std::any::Any; -use std::collections::HashSet; use std::fmt::{Display, Formatter}; use std::sync::Arc; @@ -69,7 +66,7 @@ make_udf_expr_and_func!( #[user_doc( doc_section(label = "Array Functions"), - description = "Returns an array of elements that are present in both arrays (all elements from both arrays) with out duplicates.", + description = "Returns an array of elements that are present in both arrays (all elements from both arrays) without duplicates.", syntax_example = "array_union(array1, array2)", sql_example = r#"```sql > select array_union([1, 2, 3, 4], [5, 6, 3, 4]); @@ -136,8 +133,7 @@ impl ScalarUDFImpl for ArrayUnion { let [array1, array2] = take_function_args(self.name(), arg_types)?; match (array1, array2) { (Null, Null) => Ok(DataType::new_list(Null, true)), - (Null, dt) => Ok(dt.clone()), - (dt, Null) => Ok(dt.clone()), + (Null, dt) | (dt, Null) => Ok(dt.clone()), (dt, _) => Ok(dt.clone()), } } @@ -186,11 +182,17 @@ impl ScalarUDFImpl for ArrayUnion { ) )] #[derive(Debug, PartialEq, Eq, Hash)] -pub(super) struct ArrayIntersect { +pub struct ArrayIntersect { signature: Signature, aliases: Vec, } +impl Default for ArrayIntersect { + fn default() -> Self { + Self::new() + } +} + impl ArrayIntersect { pub fn new() -> Self { Self { @@ -221,8 +223,7 @@ impl ScalarUDFImpl for ArrayIntersect { let [array1, array2] = take_function_args(self.name(), arg_types)?; match (array1, array2) { (Null, Null) => Ok(DataType::new_list(Null, true)), - (Null, dt) => Ok(dt.clone()), - (dt, Null) => Ok(dt.clone()), + (Null, dt) | (dt, Null) => Ok(dt.clone()), (dt, _) => Ok(dt.clone()), } } @@ -261,7 +262,7 @@ impl ScalarUDFImpl for ArrayIntersect { ) )] #[derive(Debug, PartialEq, Eq, Hash)] -pub(super) struct ArrayDistinct { +pub struct ArrayDistinct { signature: Signature, aliases: Vec, } @@ -275,6 +276,12 @@ impl ArrayDistinct { } } +impl Default for ArrayDistinct { + fn default() -> Self { + Self::new() + } +} + impl ScalarUDFImpl for ArrayDistinct { fn as_any(&self) -> &dyn Any { self @@ -361,76 +368,118 @@ fn generic_set_lists( "{set_op:?} is not implemented for '{l:?}' and '{r:?}'" ); - let mut offsets = vec![OffsetSize::usize_as(0)]; - let mut new_arrays = vec![]; - let mut new_null_buf = vec![]; + // Convert all values to rows in batch for performance. let converter = RowConverter::new(vec![SortField::new(l.value_type())])?; - for (first_arr, second_arr) in l.iter().zip(r.iter()) { - let mut ele_should_be_null = false; + let rows_l = converter.convert_columns(&[Arc::clone(l.values())])?; + let rows_r = converter.convert_columns(&[Arc::clone(r.values())])?; - let l_values = if let Some(first_arr) = first_arr { - converter.convert_columns(&[first_arr])? - } else { - ele_should_be_null = true; - converter.empty_rows(0, 0) - }; + match set_op { + SetOp::Union => generic_set_loop::( + l, r, &rows_l, &rows_r, field, &converter, + ), + SetOp::Intersect => generic_set_loop::( + l, r, &rows_l, &rows_r, field, &converter, + ), + } +} - let r_values = if let Some(second_arr) = second_arr { - converter.convert_columns(&[second_arr])? - } else { - ele_should_be_null = true; - converter.empty_rows(0, 0) - }; - - let l_iter = l_values.iter().sorted().dedup(); - let values_set: HashSet<_> = l_iter.clone().collect(); - let mut rows = if set_op == SetOp::Union { - l_iter.collect() - } else { - vec![] - }; - - for r_val in r_values.iter().sorted().dedup() { - match set_op { - SetOp::Union => { - if !values_set.contains(&r_val) { - rows.push(r_val); - } +/// Inner loop for set operations, parameterized by const generic to +/// avoid branching inside the hot loop. +fn generic_set_loop( + l: &GenericListArray, + r: &GenericListArray, + rows_l: &arrow::row::Rows, + rows_r: &arrow::row::Rows, + field: Arc, + converter: &RowConverter, +) -> Result { + let l_offsets = l.value_offsets(); + let r_offsets = r.value_offsets(); + + let mut result_offsets = Vec::with_capacity(l.len() + 1); + result_offsets.push(OffsetSize::usize_as(0)); + let initial_capacity = if IS_UNION { + // Union can include all elements from both sides + rows_l.num_rows() + } else { + // Intersect result is bounded by the smaller side + rows_l.num_rows().min(rows_r.num_rows()) + }; + + let mut final_rows = Vec::with_capacity(initial_capacity); + + // Reuse hash sets across iterations + let mut seen = HashSet::new(); + let mut lookup_set = HashSet::new(); + for i in 0..l.len() { + let last_offset = *result_offsets.last().unwrap(); + + if l.is_null(i) || r.is_null(i) { + result_offsets.push(last_offset); + continue; + } + + let l_start = l_offsets[i].as_usize(); + let l_end = l_offsets[i + 1].as_usize(); + let r_start = r_offsets[i].as_usize(); + let r_end = r_offsets[i + 1].as_usize(); + + seen.clear(); + + if IS_UNION { + for idx in l_start..l_end { + let row = rows_l.row(idx); + if seen.insert(row) { + final_rows.push(row); } - SetOp::Intersect => { - if values_set.contains(&r_val) { - rows.push(r_val); - } + } + for idx in r_start..r_end { + let row = rows_r.row(idx); + if seen.insert(row) { + final_rows.push(row); } } - } - - let last_offset = match offsets.last() { - Some(offset) => *offset, - None => return internal_err!("offsets should not be empty"), - }; - - offsets.push(last_offset + OffsetSize::usize_as(rows.len())); - let arrays = converter.convert_rows(rows)?; - let array = match arrays.first() { - Some(array) => Arc::clone(array), - None => { - return internal_err!("{set_op}: failed to get array from rows"); + } else { + let l_len = l_end - l_start; + let r_len = r_end - r_start; + + // Select shorter side for lookup, longer side for probing + let (lookup_rows, lookup_range, probe_rows, probe_range) = if l_len < r_len { + (rows_l, l_start..l_end, rows_r, r_start..r_end) + } else { + (rows_r, r_start..r_end, rows_l, l_start..l_end) + }; + lookup_set.clear(); + lookup_set.reserve(lookup_range.len()); + + // Build lookup table + for idx in lookup_range { + lookup_set.insert(lookup_rows.row(idx)); } - }; - new_null_buf.push(!ele_should_be_null); - new_arrays.push(array); + // Probe and emit distinct intersected rows + for idx in probe_range { + let row = probe_rows.row(idx); + if lookup_set.contains(&row) && seen.insert(row) { + final_rows.push(row); + } + } + } + result_offsets.push(last_offset + OffsetSize::usize_as(seen.len())); } - let offsets = OffsetBuffer::new(offsets.into()); - let new_arrays_ref: Vec<_> = new_arrays.iter().map(|v| v.as_ref()).collect(); - let values = compute::concat(&new_arrays_ref)?; + let final_values = if final_rows.is_empty() { + new_empty_array(&l.value_type()) + } else { + let arrays = converter.convert_rows(final_rows)?; + Arc::clone(&arrays[0]) + }; + let arr = GenericListArray::::try_new( field, - offsets, - values, - Some(NullBuffer::new(new_null_buf.into())), + OffsetBuffer::new(result_offsets.into()), + final_values, + NullBuffer::union(l.nulls(), r.nulls()), )?; Ok(Arc::new(arr)) } @@ -440,59 +489,13 @@ fn general_set_op( array2: &ArrayRef, set_op: SetOp, ) -> Result { - fn empty_array(data_type: &DataType, len: usize, large: bool) -> Result { - let field = Arc::new(Field::new_list_field(data_type.clone(), true)); - let values = new_null_array(data_type, len); - if large { - Ok(Arc::new(LargeListArray::try_new( - field, - OffsetBuffer::new_zeroed(len), - values, - None, - )?)) - } else { - Ok(Arc::new(ListArray::try_new( - field, - OffsetBuffer::new_zeroed(len), - values, - None, - )?)) - } - } - + let len = array1.len(); match (array1.data_type(), array2.data_type()) { - (Null, Null) => Ok(Arc::new(ListArray::new_null( - Arc::new(Field::new_list_field(Null, true)), - array1.len(), - ))), - (Null, List(field)) => { - if set_op == SetOp::Intersect { - return empty_array(field.data_type(), array1.len(), false); - } - let array = as_list_array(&array2)?; - general_array_distinct::(array, field) - } - (List(field), Null) => { - if set_op == SetOp::Intersect { - return empty_array(field.data_type(), array1.len(), false); - } - let array = as_list_array(&array1)?; - general_array_distinct::(array, field) - } - (Null, LargeList(field)) => { - if set_op == SetOp::Intersect { - return empty_array(field.data_type(), array1.len(), true); - } - let array = as_large_list_array(&array2)?; - general_array_distinct::(array, field) - } - (LargeList(field), Null) => { - if set_op == SetOp::Intersect { - return empty_array(field.data_type(), array1.len(), true); - } - let array = as_large_list_array(&array1)?; - general_array_distinct::(array, field) - } + (Null, Null) => Ok(new_null_array(&DataType::new_list(Null, true), len)), + (Null, dt @ List(_)) + | (Null, dt @ LargeList(_)) + | (dt @ List(_), Null) + | (dt @ LargeList(_), Null) => Ok(new_null_array(dt, len)), (List(field), List(_)) => { let array1 = as_list_array(&array1)?; let array2 = as_list_array(&array2)?; @@ -528,42 +531,52 @@ fn general_array_distinct( if array.is_empty() { return Ok(Arc::new(array.clone()) as ArrayRef); } + let value_offsets = array.value_offsets(); let dt = array.value_type(); - let mut offsets = Vec::with_capacity(array.len()); + let mut offsets = Vec::with_capacity(array.len() + 1); offsets.push(OffsetSize::usize_as(0)); - let mut new_arrays = Vec::with_capacity(array.len()); - let converter = RowConverter::new(vec![SortField::new(dt)])?; - // distinct for each list in ListArray - for arr in array.iter() { - let last_offset: OffsetSize = offsets.last().copied().unwrap(); - let Some(arr) = arr else { - // Add same offset for null + + // Convert all values to row format in a single batch for performance + let converter = RowConverter::new(vec![SortField::new(dt.clone())])?; + let rows = converter.convert_columns(&[Arc::clone(array.values())])?; + let mut final_rows = Vec::with_capacity(rows.num_rows()); + let mut seen = HashSet::new(); + for i in 0..array.len() { + let last_offset = *offsets.last().unwrap(); + + // Null list entries produce no output; just carry forward the offset. + if array.is_null(i) { offsets.push(last_offset); continue; - }; - let values = converter.convert_columns(&[arr])?; - // sort elements in list and remove duplicates - let rows = values.iter().sorted().dedup().collect::>(); - offsets.push(last_offset + OffsetSize::usize_as(rows.len())); - let arrays = converter.convert_rows(rows)?; - let array = match arrays.first() { - Some(array) => Arc::clone(array), - None => { - return internal_err!("array_distinct: failed to get array from rows"); + } + + let start = value_offsets[i].as_usize(); + let end = value_offsets[i + 1].as_usize(); + seen.clear(); + seen.reserve(end - start); + + // Walk the sub-array and keep only the first occurrence of each value. + for idx in start..end { + let row = rows.row(idx); + if seen.insert(row) { + final_rows.push(row); } - }; - new_arrays.push(array); - } - if new_arrays.is_empty() { - return Ok(Arc::new(array.clone()) as ArrayRef); + } + offsets.push(last_offset + OffsetSize::usize_as(seen.len())); } - let offsets = OffsetBuffer::new(offsets.into()); - let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::>(); - let values = compute::concat(&new_arrays_ref)?; + + // Convert all collected distinct rows back + let final_values = if final_rows.is_empty() { + new_empty_array(&dt) + } else { + let arrays = converter.convert_rows(final_rows)?; + Arc::clone(&arrays[0]) + }; + Ok(Arc::new(GenericListArray::::try_new( Arc::clone(field), - offsets, - values, + OffsetBuffer::new(offsets.into()), + final_values, // Keep the list nulls array.nulls().cloned(), )?)) diff --git a/datafusion/functions-nested/src/sort.rs b/datafusion/functions-nested/src/sort.rs index ba2da0f760ee..cbe101f111b2 100644 --- a/datafusion/functions-nested/src/sort.rs +++ b/datafusion/functions-nested/src/sort.rs @@ -18,16 +18,14 @@ //! [`ScalarUDFImpl`] definitions for array_sort function. use crate::utils::make_scalar_function; -use arrow::array::{ - Array, ArrayRef, GenericListArray, NullBufferBuilder, OffsetSizeTrait, new_null_array, -}; +use arrow::array::{Array, ArrayRef, GenericListArray, OffsetSizeTrait, new_null_array}; use arrow::buffer::OffsetBuffer; use arrow::compute::SortColumn; use arrow::datatypes::{DataType, FieldRef}; use arrow::{compute, compute::SortOptions}; use datafusion_common::cast::{as_large_list_array, as_list_array, as_string_array}; use datafusion_common::utils::ListCoercion; -use datafusion_common::{Result, exec_err, plan_err}; +use datafusion_common::{Result, exec_err}; use datafusion_expr::{ ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, Volatility, @@ -134,18 +132,7 @@ impl ScalarUDFImpl for ArraySort { } fn return_type(&self, arg_types: &[DataType]) -> Result { - match &arg_types[0] { - DataType::Null => Ok(DataType::Null), - DataType::List(field) => { - Ok(DataType::new_list(field.data_type().clone(), true)) - } - DataType::LargeList(field) => { - Ok(DataType::new_large_list(field.data_type().clone(), true)) - } - arg_type => { - plan_err!("{} does not support type {arg_type}", self.name()) - } - } + Ok(arg_types[0].clone()) } fn invoke_with_args( @@ -206,11 +193,11 @@ fn array_sort_inner(args: &[ArrayRef]) -> Result { } DataType::List(field) => { let array = as_list_array(&args[0])?; - array_sort_generic(array, field, sort_options) + array_sort_generic(array, Arc::clone(field), sort_options) } DataType::LargeList(field) => { let array = as_large_list_array(&args[0])?; - array_sort_generic(array, field, sort_options) + array_sort_generic(array, Arc::clone(field), sort_options) } // Signature should prevent this arm ever occurring _ => exec_err!("array_sort expects list for first argument"), @@ -219,18 +206,16 @@ fn array_sort_inner(args: &[ArrayRef]) -> Result { fn array_sort_generic( list_array: &GenericListArray, - field: &FieldRef, + field: FieldRef, sort_options: Option, ) -> Result { let row_count = list_array.len(); let mut array_lengths = vec![]; let mut arrays = vec![]; - let mut valid = NullBufferBuilder::new(row_count); for i in 0..row_count { if list_array.is_null(i) { array_lengths.push(0); - valid.append_null(); } else { let arr_ref = list_array.value(i); @@ -253,25 +238,22 @@ fn array_sort_generic( }; array_lengths.push(sorted_array.len()); arrays.push(sorted_array); - valid.append_non_null(); } } - let buffer = valid.finish(); - let elements = arrays .iter() .map(|a| a.as_ref()) .collect::>(); let list_arr = if elements.is_empty() { - GenericListArray::::new_null(Arc::clone(field), row_count) + GenericListArray::::new_null(field, row_count) } else { GenericListArray::::new( - Arc::clone(field), + field, OffsetBuffer::from_lengths(array_lengths), Arc::new(compute::concat(elements.as_slice())?), - buffer, + list_array.nulls().cloned(), ) }; Ok(Arc::new(list_arr)) diff --git a/datafusion/functions-nested/src/string.rs b/datafusion/functions-nested/src/string.rs index 1c8d58fca80d..8aabc4930956 100644 --- a/datafusion/functions-nested/src/string.rs +++ b/datafusion/functions-nested/src/string.rs @@ -29,6 +29,7 @@ use datafusion_common::utils::ListCoercion; use datafusion_common::{DataFusionError, Result, not_impl_err}; use std::any::Any; +use std::fmt::Write; use crate::utils::make_scalar_function; use arrow::array::{ @@ -36,7 +37,7 @@ use arrow::array::{ builder::{ArrayBuilder, LargeStringBuilder, StringViewBuilder}, cast::AsArray, }; -use arrow::compute::cast; +use arrow::compute::{can_cast_types, cast}; use arrow::datatypes::DataType::{ Dictionary, FixedSizeList, LargeList, LargeUtf8, List, Null, Utf8, Utf8View, }; @@ -54,69 +55,6 @@ use datafusion_functions::downcast_arg; use datafusion_macros::user_doc; use std::sync::Arc; -macro_rules! call_array_function { - ($DATATYPE:expr, false) => { - match $DATATYPE { - DataType::Utf8 => array_function!(StringArray), - DataType::Utf8View => array_function!(StringViewArray), - DataType::LargeUtf8 => array_function!(LargeStringArray), - DataType::Boolean => array_function!(BooleanArray), - DataType::Float32 => array_function!(Float32Array), - DataType::Float64 => array_function!(Float64Array), - DataType::Int8 => array_function!(Int8Array), - DataType::Int16 => array_function!(Int16Array), - DataType::Int32 => array_function!(Int32Array), - DataType::Int64 => array_function!(Int64Array), - DataType::UInt8 => array_function!(UInt8Array), - DataType::UInt16 => array_function!(UInt16Array), - DataType::UInt32 => array_function!(UInt32Array), - DataType::UInt64 => array_function!(UInt64Array), - dt => not_impl_err!("Unsupported data type in array_to_string: {dt}"), - } - }; - ($DATATYPE:expr, $INCLUDE_LIST:expr) => {{ - match $DATATYPE { - DataType::List(_) => array_function!(ListArray), - DataType::Utf8 => array_function!(StringArray), - DataType::Utf8View => array_function!(StringViewArray), - DataType::LargeUtf8 => array_function!(LargeStringArray), - DataType::Boolean => array_function!(BooleanArray), - DataType::Float32 => array_function!(Float32Array), - DataType::Float64 => array_function!(Float64Array), - DataType::Int8 => array_function!(Int8Array), - DataType::Int16 => array_function!(Int16Array), - DataType::Int32 => array_function!(Int32Array), - DataType::Int64 => array_function!(Int64Array), - DataType::UInt8 => array_function!(UInt8Array), - DataType::UInt16 => array_function!(UInt16Array), - DataType::UInt32 => array_function!(UInt32Array), - DataType::UInt64 => array_function!(UInt64Array), - dt => not_impl_err!("Unsupported data type in array_to_string: {dt}"), - } - }}; -} - -macro_rules! to_string { - ($ARG:expr, $ARRAY:expr, $DELIMITER:expr, $NULL_STRING:expr, $WITH_NULL_STRING:expr, $ARRAY_TYPE:ident) => {{ - let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); - for x in arr { - match x { - Some(x) => { - $ARG.push_str(&x.to_string()); - $ARG.push_str($DELIMITER); - } - None => { - if $WITH_NULL_STRING { - $ARG.push_str($NULL_STRING); - $ARG.push_str($DELIMITER); - } - } - } - } - Ok($ARG) - }}; -} - // Create static instances of ScalarUDFs for each function make_udf_expr_and_func!( ArrayToString, @@ -145,7 +83,7 @@ make_udf_expr_and_func!( argument(name = "delimiter", description = "Array element separator."), argument( name = "null_string", - description = "Optional. String to replace null values in the array. If not provided, nulls will be handled by default behavior." + description = "Optional. String to use for null values in the output. If not provided, nulls will be omitted." ) )] #[derive(Debug, PartialEq, Eq, Hash)] @@ -347,181 +285,257 @@ fn array_to_string_inner(args: &[ArrayRef]) -> Result { } }; - let mut null_string = String::from(""); - let mut with_null_string = false; - if args.len() == 3 { - null_string = match args[2].data_type() { - Utf8 => args[2].as_string::().value(0).to_string(), - Utf8View => args[2].as_string_view().value(0).to_string(), - LargeUtf8 => args[2].as_string::().value(0).to_string(), + let null_strings: Vec> = if args.len() == 3 { + match args[2].data_type() { + Utf8 => args[2].as_string::().iter().collect(), + Utf8View => args[2].as_string_view().iter().collect(), + LargeUtf8 => args[2].as_string::().iter().collect(), other => { return exec_err!( - "unsupported type for second argument to array_to_string function as {other:?}" + "unsupported type for third argument to array_to_string function as {other:?}" ); } + } + } else { + // If `null_strings` is not specified, we treat it as equivalent to + // explicitly passing a NULL value for `null_strings` in every row. + vec![None; args[0].len()] + }; + + let string_arr = match arr.data_type() { + List(_) => { + let list_array = as_list_array(&arr)?; + generate_string_array::(list_array, &delimiters, &null_strings)? + } + LargeList(_) => { + let list_array = as_large_list_array(&arr)?; + generate_string_array::(list_array, &delimiters, &null_strings)? + } + // Signature guards against this arm + _ => return exec_err!("array_to_string expects list as first argument"), + }; + + Ok(Arc::new(string_arr)) +} + +fn generate_string_array( + list_arr: &GenericListArray, + delimiters: &[Option<&str>], + null_strings: &[Option<&str>], +) -> Result { + let mut builder = StringBuilder::with_capacity(list_arr.len(), 0); + let mut buf = String::new(); + + for ((arr, &delimiter), &null_string) in list_arr + .iter() + .zip(delimiters.iter()) + .zip(null_strings.iter()) + { + let (Some(arr), Some(delimiter)) = (arr, delimiter) else { + builder.append_null(); + continue; }; - with_null_string = true; - } - - /// Creates a single string from single element of a ListArray (which is - /// itself another Array) - fn compute_array_to_string<'a>( - arg: &'a mut String, - arr: &ArrayRef, - delimiter: String, - null_string: String, - with_null_string: bool, - ) -> Result<&'a mut String> { - match arr.data_type() { - List(..) => { - let list_array = as_list_array(&arr)?; - for i in 0..list_array.len() { - if !list_array.is_null(i) { - compute_array_to_string( - arg, - &list_array.value(i), - delimiter.clone(), - null_string.clone(), - with_null_string, - )?; - } else if with_null_string { - arg.push_str(&null_string); - arg.push_str(&delimiter); - } - } - Ok(arg) - } - FixedSizeList(..) => { - let list_array = as_fixed_size_list_array(&arr)?; - - for i in 0..list_array.len() { - if !list_array.is_null(i) { - compute_array_to_string( - arg, - &list_array.value(i), - delimiter.clone(), - null_string.clone(), - with_null_string, - )?; - } else if with_null_string { - arg.push_str(&null_string); - arg.push_str(&delimiter); - } - } + buf.clear(); + let mut first = true; + compute_array_to_string(&mut buf, &arr, delimiter, null_string, &mut first)?; + builder.append_value(&buf); + } - Ok(arg) - } - LargeList(..) => { - let list_array = as_large_list_array(&arr)?; - for i in 0..list_array.len() { - if !list_array.is_null(i) { - compute_array_to_string( - arg, - &list_array.value(i), - delimiter.clone(), - null_string.clone(), - with_null_string, - )?; - } else if with_null_string { - arg.push_str(&null_string); - arg.push_str(&delimiter); + Ok(builder.finish()) +} + +fn compute_array_to_string( + buf: &mut String, + arr: &ArrayRef, + delimiter: &str, + null_string: Option<&str>, + first: &mut bool, +) -> Result<()> { + // Handle lists by recursing on each list element. + macro_rules! handle_list { + ($list_array:expr) => { + for i in 0..$list_array.len() { + if !$list_array.is_null(i) { + compute_array_to_string( + buf, + &$list_array.value(i), + delimiter, + null_string, + first, + )?; + } else if let Some(ns) = null_string { + if *first { + *first = false; + } else { + buf.push_str(delimiter); } + buf.push_str(ns); } + } + }; + } - Ok(arg) + match arr.data_type() { + List(..) => { + let list_array = as_list_array(arr)?; + handle_list!(list_array); + Ok(()) + } + FixedSizeList(..) => { + let list_array = as_fixed_size_list_array(arr)?; + handle_list!(list_array); + Ok(()) + } + LargeList(..) => { + let list_array = as_large_list_array(arr)?; + handle_list!(list_array); + Ok(()) + } + Dictionary(_key_type, value_type) => { + // Call cast to unwrap the dictionary. This could be optimized if we wanted + // to accept the overhead of extra code + let values = cast(arr, value_type.as_ref()).map_err(|e| { + DataFusionError::from(e) + .context("Casting dictionary to values in compute_array_to_string") + })?; + compute_array_to_string(buf, &values, delimiter, null_string, first) + } + Null => Ok(()), + data_type => { + macro_rules! str_leaf { + ($ARRAY_TYPE:ident) => { + write_leaf_to_string( + buf, + downcast_arg!(arr, $ARRAY_TYPE), + delimiter, + null_string, + first, + |buf, x: &str| buf.push_str(x), + ) + }; } - Dictionary(_key_type, value_type) => { - // Call cast to unwrap the dictionary. This could be optimized if we wanted - // to accept the overhead of extra code - let values = cast(&arr, value_type.as_ref()).map_err(|e| { - DataFusionError::from(e).context( - "Casting dictionary to values in compute_array_to_string", + macro_rules! bool_leaf { + ($ARRAY_TYPE:ident) => { + write_leaf_to_string( + buf, + downcast_arg!(arr, $ARRAY_TYPE), + delimiter, + null_string, + first, + |buf, x: bool| { + if x { + buf.push_str("true"); + } else { + buf.push_str("false"); + } + }, ) - })?; - compute_array_to_string( - arg, - &values, - delimiter, - null_string, - with_null_string, - ) + }; } - Null => Ok(arg), - data_type => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - to_string!( - arg, - arr, - &delimiter, - &null_string, - with_null_string, - $ARRAY_TYPE - ) - }; - } - call_array_function!(data_type, false) + macro_rules! int_leaf { + ($ARRAY_TYPE:ident) => { + write_leaf_to_string( + buf, + downcast_arg!(arr, $ARRAY_TYPE), + delimiter, + null_string, + first, + |buf, x| { + let mut itoa_buf = itoa::Buffer::new(); + buf.push_str(itoa_buf.format(x)); + }, + ) + }; } - } - } - - fn generate_string_array( - list_arr: &GenericListArray, - delimiters: &[Option<&str>], - null_string: &str, - with_null_string: bool, - ) -> Result { - let mut res: Vec> = Vec::new(); - for (arr, &delimiter) in list_arr.iter().zip(delimiters.iter()) { - if let (Some(arr), Some(delimiter)) = (arr, delimiter) { - let mut arg = String::from(""); - let s = compute_array_to_string( - &mut arg, - &arr, - delimiter.to_string(), - null_string.to_string(), - with_null_string, - )? - .clone(); - - if let Some(s) = s.strip_suffix(delimiter) { - res.push(Some(s.to_string())); - } else { - res.push(Some(s)); + macro_rules! float_leaf { + ($ARRAY_TYPE:ident) => { + write_leaf_to_string( + buf, + downcast_arg!(arr, $ARRAY_TYPE), + delimiter, + null_string, + first, + |buf, x| { + // TODO: Consider switching to a more efficient + // floating point display library (e.g., ryu). This + // might result in some differences in the output + // format, however. + write!(buf, "{}", x).unwrap(); + }, + ) + }; + } + match data_type { + Utf8 => str_leaf!(StringArray), + Utf8View => str_leaf!(StringViewArray), + LargeUtf8 => str_leaf!(LargeStringArray), + DataType::Boolean => bool_leaf!(BooleanArray), + DataType::Float32 => float_leaf!(Float32Array), + DataType::Float64 => float_leaf!(Float64Array), + DataType::Int8 => int_leaf!(Int8Array), + DataType::Int16 => int_leaf!(Int16Array), + DataType::Int32 => int_leaf!(Int32Array), + DataType::Int64 => int_leaf!(Int64Array), + DataType::UInt8 => int_leaf!(UInt8Array), + DataType::UInt16 => int_leaf!(UInt16Array), + DataType::UInt32 => int_leaf!(UInt32Array), + DataType::UInt64 => int_leaf!(UInt64Array), + data_type if can_cast_types(data_type, &Utf8) => { + let str_arr = cast(arr, &Utf8).map_err(|e| { + DataFusionError::from(e) + .context("Casting to string in array_to_string") + })?; + return compute_array_to_string( + buf, + &str_arr, + delimiter, + null_string, + first, + ); + } + data_type => { + return not_impl_err!( + "Unsupported data type in array_to_string: {data_type}" + ); } - } else { - res.push(None); } + Ok(()) } - - Ok(StringArray::from(res)) } +} - let string_arr = match arr.data_type() { - List(_) => { - let list_array = as_list_array(&arr)?; - generate_string_array::( - list_array, - &delimiters, - &null_string, - with_null_string, - )? +/// Appends the string representation of each element in a leaf (non-list) +/// array to `buf`, separated by `delimiter`. Null elements are rendered +/// using `null_string` if provided, or skipped otherwise. The `append` +/// closure controls how each non-null element is written to the buffer. +fn write_leaf_to_string<'a, A, T>( + buf: &mut String, + arr: &'a A, + delimiter: &str, + null_string: Option<&str>, + first: &mut bool, + append: impl Fn(&mut String, T), +) where + &'a A: IntoIterator>, +{ + for x in arr { + // Skip nulls when no null_string is provided + if x.is_none() && null_string.is_none() { + continue; } - LargeList(_) => { - let list_array = as_large_list_array(&arr)?; - generate_string_array::( - list_array, - &delimiters, - &null_string, - with_null_string, - )? + + if *first { + *first = false; + } else { + buf.push_str(delimiter); } - // Signature guards against this arm - _ => return exec_err!("array_to_string expects list as first argument"), - }; - Ok(Arc::new(string_arr)) + match x { + Some(x) => append(buf, x), + None => buf.push_str(null_string.unwrap()), + } + } } /// String_to_array SQL function diff --git a/datafusion/functions-nested/src/utils.rs b/datafusion/functions-nested/src/utils.rs index d2a69c010e8e..9f46917a87eb 100644 --- a/datafusion/functions-nested/src/utils.rs +++ b/datafusion/functions-nested/src/utils.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Field, Fields}; use arrow::array::{ - Array, ArrayRef, BooleanArray, GenericListArray, OffsetSizeTrait, Scalar, UInt32Array, + Array, ArrayRef, BooleanArray, GenericListArray, OffsetSizeTrait, Scalar, }; use arrow::buffer::OffsetBuffer; use datafusion_common::cast::{ @@ -161,8 +161,7 @@ pub(crate) fn compare_element_to_list( ); } - let indices = UInt32Array::from(vec![row_index as u32]); - let element_array_row = arrow::compute::take(element_array, &indices, None)?; + let element_array_row = element_array.slice(row_index, 1); // Compute all positions in list_row_array (that is itself an // array) that are equal to `from_array_row` @@ -260,7 +259,7 @@ pub(crate) fn get_map_entry_field(data_type: &DataType) -> Result<&Fields> { match field_data_type { DataType::Struct(fields) => Ok(fields), _ => { - internal_err!("Expected a Struct type, got {:?}", field_data_type) + internal_err!("Expected a Struct type, got {}", field_data_type) } } } diff --git a/datafusion/functions-table/src/generate_series.rs b/datafusion/functions-table/src/generate_series.rs index b806798bcecc..342269fbc299 100644 --- a/datafusion/functions-table/src/generate_series.rs +++ b/datafusion/functions-table/src/generate_series.rs @@ -433,30 +433,11 @@ fn reach_end_int64(val: i64, end: i64, step: i64, include_end: bool) -> bool { } } -fn validate_interval_step( - step: IntervalMonthDayNano, - start: i64, - end: i64, -) -> Result<()> { +fn validate_interval_step(step: IntervalMonthDayNano) -> Result<()> { if step.months == 0 && step.days == 0 && step.nanoseconds == 0 { return plan_err!("Step interval cannot be zero"); } - let step_is_positive = step.months > 0 || step.days > 0 || step.nanoseconds > 0; - let step_is_negative = step.months < 0 || step.days < 0 || step.nanoseconds < 0; - - if start > end && step_is_positive { - return plan_err!( - "Start is bigger than end, but increment is positive: Cannot generate infinite series" - ); - } - - if start < end && step_is_negative { - return plan_err!( - "Start is smaller than end, but increment is negative: Cannot generate infinite series" - ); - } - Ok(()) } @@ -567,18 +548,6 @@ impl GenerateSeriesFuncImpl { } }; - if start > end && step > 0 { - return plan_err!( - "Start is bigger than end, but increment is positive: Cannot generate infinite series" - ); - } - - if start < end && step < 0 { - return plan_err!( - "Start is smaller than end, but increment is negative: Cannot generate infinite series" - ); - } - if step == 0 { return plan_err!("Step cannot be zero"); } @@ -656,7 +625,7 @@ impl GenerateSeriesFuncImpl { }; // Validate step interval - validate_interval_step(step, start, end)?; + validate_interval_step(step)?; Ok(Arc::new(GenerateSeriesTable { schema, @@ -749,7 +718,7 @@ impl GenerateSeriesFuncImpl { let end_ts = end_date as i64 * NANOS_PER_DAY; // Validate step interval - validate_interval_step(step_interval, start_ts, end_ts)?; + validate_interval_step(step_interval)?; Ok(Arc::new(GenerateSeriesTable { schema, diff --git a/datafusion/functions-table/src/lib.rs b/datafusion/functions-table/src/lib.rs index 1783c15b14b5..cd9ade041acb 100644 --- a/datafusion/functions-table/src/lib.rs +++ b/datafusion/functions-table/src/lib.rs @@ -24,8 +24,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] -// https://github.com/apache/datafusion/issues/18881 -#![deny(clippy::allow_attributes)] pub mod generate_series; diff --git a/datafusion/functions-window-common/src/lib.rs b/datafusion/functions-window-common/src/lib.rs index 210e54d67289..301f2c34a6c9 100644 --- a/datafusion/functions-window-common/src/lib.rs +++ b/datafusion/functions-window-common/src/lib.rs @@ -24,8 +24,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] -// https://github.com/apache/datafusion/issues/18881 -#![deny(clippy::allow_attributes)] //! Common user-defined window functionality for [DataFusion] //! diff --git a/datafusion/functions-window/Cargo.toml b/datafusion/functions-window/Cargo.toml index 42690907ae26..fae71e180e34 100644 --- a/datafusion/functions-window/Cargo.toml +++ b/datafusion/functions-window/Cargo.toml @@ -51,3 +51,11 @@ datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } log = { workspace = true } paste = { workspace = true } + +[dev-dependencies] +arrow = { workspace = true, features = ["test_utils"] } +criterion = { workspace = true } + +[[bench]] +name = "nth_value" +harness = false diff --git a/datafusion/functions-window/benches/nth_value.rs b/datafusion/functions-window/benches/nth_value.rs new file mode 100644 index 000000000000..00daf9fa4f9b --- /dev/null +++ b/datafusion/functions-window/benches/nth_value.rs @@ -0,0 +1,263 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::hint::black_box; +use std::ops::Range; +use std::slice; +use std::sync::Arc; + +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field, FieldRef, Int64Type}; +use arrow::util::bench_util::create_primitive_array; + +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_expr::{PartitionEvaluator, WindowUDFImpl}; +use datafusion_functions_window::nth_value::{NthValue, NthValueKind}; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use datafusion_physical_expr::expressions::{Column, Literal}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + +const ARRAY_SIZE: usize = 8192; + +/// Creates a partition evaluator for FIRST_VALUE, LAST_VALUE, or NTH_VALUE +fn create_evaluator( + kind: NthValueKind, + ignore_nulls: bool, + n: Option, +) -> Box { + let expr = Arc::new(Column::new("c", 0)) as Arc; + let input_field: FieldRef = Field::new("c", DataType::Int64, true).into(); + let input_fields = vec![input_field]; + + let (nth_value, exprs): (NthValue, Vec>) = match kind { + NthValueKind::First => (NthValue::first(), vec![expr]), + NthValueKind::Last => (NthValue::last(), vec![expr]), + NthValueKind::Nth => { + let n_value = + Arc::new(Literal::new(ScalarValue::Int64(n))) as Arc; + (NthValue::nth(), vec![expr, n_value]) + } + }; + + let args = PartitionEvaluatorArgs::new(&exprs, &input_fields, false, ignore_nulls); + nth_value.partition_evaluator(args).unwrap() +} + +fn bench_nth_value_ignore_nulls(c: &mut Criterion) { + let mut group = c.benchmark_group("nth_value_ignore_nulls"); + + // Test different null densities + let null_densities = [0.0, 0.3, 0.5, 0.8]; + + for null_density in null_densities { + let values = Arc::new(create_primitive_array::( + ARRAY_SIZE, + null_density, + )) as ArrayRef; + let null_pct = (null_density * 100.0) as u32; + + // FIRST_VALUE with ignore_nulls - expanding window + group.bench_function( + BenchmarkId::new("first_value_expanding", format!("{null_pct}%_nulls")), + |b| { + b.iter(|| { + let mut evaluator = create_evaluator(NthValueKind::First, true, None); + let values_slice = slice::from_ref(&values); + for i in 0..values.len() { + let range = Range { + start: 0, + end: i + 1, + }; + black_box(evaluator.evaluate(values_slice, &range).unwrap()); + } + }) + }, + ); + + // LAST_VALUE with ignore_nulls - expanding window + group.bench_function( + BenchmarkId::new("last_value_expanding", format!("{null_pct}%_nulls")), + |b| { + b.iter(|| { + let mut evaluator = create_evaluator(NthValueKind::Last, true, None); + let values_slice = slice::from_ref(&values); + for i in 0..values.len() { + let range = Range { + start: 0, + end: i + 1, + }; + black_box(evaluator.evaluate(values_slice, &range).unwrap()); + } + }) + }, + ); + + // NTH_VALUE(col, 10) with ignore_nulls - get 10th non-null value + group.bench_function( + BenchmarkId::new("nth_value_10_expanding", format!("{null_pct}%_nulls")), + |b| { + b.iter(|| { + let mut evaluator = + create_evaluator(NthValueKind::Nth, true, Some(10)); + let values_slice = slice::from_ref(&values); + for i in 0..values.len() { + let range = Range { + start: 0, + end: i + 1, + }; + black_box(evaluator.evaluate(values_slice, &range).unwrap()); + } + }) + }, + ); + + // NTH_VALUE(col, -10) with ignore_nulls - get 10th from last non-null value + group.bench_function( + BenchmarkId::new("nth_value_neg10_expanding", format!("{null_pct}%_nulls")), + |b| { + b.iter(|| { + let mut evaluator = + create_evaluator(NthValueKind::Nth, true, Some(-10)); + let values_slice = slice::from_ref(&values); + for i in 0..values.len() { + let range = Range { + start: 0, + end: i + 1, + }; + black_box(evaluator.evaluate(values_slice, &range).unwrap()); + } + }) + }, + ); + + // Sliding window benchmarks with 100-row window + let window_size: usize = 100; + + group.bench_function( + BenchmarkId::new("first_value_sliding_100", format!("{null_pct}%_nulls")), + |b| { + b.iter(|| { + let mut evaluator = create_evaluator(NthValueKind::First, true, None); + let values_slice = slice::from_ref(&values); + for i in 0..values.len() { + let start = i.saturating_sub(window_size - 1); + let range = Range { start, end: i + 1 }; + black_box(evaluator.evaluate(values_slice, &range).unwrap()); + } + }) + }, + ); + + group.bench_function( + BenchmarkId::new("last_value_sliding_100", format!("{null_pct}%_nulls")), + |b| { + b.iter(|| { + let mut evaluator = create_evaluator(NthValueKind::Last, true, None); + let values_slice = slice::from_ref(&values); + for i in 0..values.len() { + let start = i.saturating_sub(window_size - 1); + let range = Range { start, end: i + 1 }; + black_box(evaluator.evaluate(values_slice, &range).unwrap()); + } + }) + }, + ); + } + + group.finish(); + + // Comparison benchmarks: ignore_nulls vs respect_nulls + let mut comparison_group = c.benchmark_group("nth_value_nulls_comparison"); + let values_with_nulls = + Arc::new(create_primitive_array::(ARRAY_SIZE, 0.5)) as ArrayRef; + + // FIRST_VALUE comparison + comparison_group.bench_function( + BenchmarkId::new("first_value", "ignore_nulls"), + |b| { + b.iter(|| { + let mut evaluator = create_evaluator(NthValueKind::First, true, None); + let values_slice = slice::from_ref(&values_with_nulls); + for i in 0..values_with_nulls.len() { + let range = Range { + start: 0, + end: i + 1, + }; + black_box(evaluator.evaluate(values_slice, &range).unwrap()); + } + }) + }, + ); + + comparison_group.bench_function( + BenchmarkId::new("first_value", "respect_nulls"), + |b| { + b.iter(|| { + let mut evaluator = create_evaluator(NthValueKind::First, false, None); + let values_slice = slice::from_ref(&values_with_nulls); + for i in 0..values_with_nulls.len() { + let range = Range { + start: 0, + end: i + 1, + }; + black_box(evaluator.evaluate(values_slice, &range).unwrap()); + } + }) + }, + ); + + // NTH_VALUE comparison + comparison_group.bench_function( + BenchmarkId::new("nth_value_10", "ignore_nulls"), + |b| { + b.iter(|| { + let mut evaluator = create_evaluator(NthValueKind::Nth, true, Some(10)); + let values_slice = slice::from_ref(&values_with_nulls); + for i in 0..values_with_nulls.len() { + let range = Range { + start: 0, + end: i + 1, + }; + black_box(evaluator.evaluate(values_slice, &range).unwrap()); + } + }) + }, + ); + + comparison_group.bench_function( + BenchmarkId::new("nth_value_10", "respect_nulls"), + |b| { + b.iter(|| { + let mut evaluator = create_evaluator(NthValueKind::Nth, false, Some(10)); + let values_slice = slice::from_ref(&values_with_nulls); + for i in 0..values_with_nulls.len() { + let range = Range { + start: 0, + end: i + 1, + }; + black_box(evaluator.evaluate(values_slice, &range).unwrap()); + } + }) + }, + ); + + comparison_group.finish(); +} + +criterion_group!(benches, bench_nth_value_ignore_nulls); +criterion_main!(benches); diff --git a/datafusion/functions-window/src/lib.rs b/datafusion/functions-window/src/lib.rs index 300313387388..6edfb92744f5 100644 --- a/datafusion/functions-window/src/lib.rs +++ b/datafusion/functions-window/src/lib.rs @@ -25,7 +25,6 @@ // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] // https://github.com/apache/datafusion/issues/18881 -#![deny(clippy::allow_attributes)] //! Window Function packages for [DataFusion]. //! diff --git a/datafusion/functions-window/src/nth_value.rs b/datafusion/functions-window/src/nth_value.rs index be08f25ec404..8d37cf7e604a 100644 --- a/datafusion/functions-window/src/nth_value.rs +++ b/datafusion/functions-window/src/nth_value.rs @@ -19,6 +19,7 @@ use crate::utils::{get_scalar_value_from_args, get_signed_integer}; +use arrow::buffer::NullBuffer; use arrow::datatypes::FieldRef; use datafusion_common::arrow::array::ArrayRef; use datafusion_common::arrow::datatypes::{DataType, Field}; @@ -96,7 +97,7 @@ impl NthValue { Self { signature: Signature::one_of( vec![ - TypeSignature::Any(0), + TypeSignature::Nullary, TypeSignature::Any(1), TypeSignature::Any(2), ], @@ -268,7 +269,7 @@ impl WindowUDFImpl for NthValue { kind: self.kind, }; - if !matches!(self.kind, NthValueKind::Nth) { + if self.kind != NthValueKind::Nth { return Ok(Box::new(NthValueEvaluator { state, ignore_nulls: partition_evaluator_args.ignore_nulls(), @@ -370,6 +371,33 @@ impl PartitionEvaluator for NthValueEvaluator { fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> { let out = &state.out_col; let size = out.len(); + if self.ignore_nulls { + match self.state.kind { + // Prune on first non-null output in case of FIRST_VALUE + NthValueKind::First => { + if let Some(nulls) = out.nulls() { + if self.state.finalized_result.is_none() { + if let Some(valid_index) = nulls.valid_indices().next() { + let result = + ScalarValue::try_from_array(out, valid_index)?; + self.state.finalized_result = Some(result); + } else { + // The output is empty or all nulls, ignore + } + } + if state.window_frame_range.start < state.window_frame_range.end { + state.window_frame_range.start = + state.window_frame_range.end - 1; + } + return Ok(()); + } else { + // Fall through to the main case because there are no nulls + } + } + // Do not memoize for other kinds when nulls are ignored + NthValueKind::Last | NthValueKind::Nth => return Ok(()), + } + } let mut buffer_size = 1; // Decide if we arrived at a final result yet: let (is_prunable, is_reverse_direction) = match self.state.kind { @@ -397,8 +425,7 @@ impl PartitionEvaluator for NthValueEvaluator { } } }; - // Do not memoize results when nulls are ignored. - if is_prunable && !self.ignore_nulls { + if is_prunable { if self.state.finalized_result.is_none() && !is_reverse_direction { let result = ScalarValue::try_from_array(out, size - 1)?; self.state.finalized_result = Some(result); @@ -424,99 +451,90 @@ impl PartitionEvaluator for NthValueEvaluator { // We produce None if the window is empty. return ScalarValue::try_from(arr.data_type()); } + match self.valid_index(arr, range) { + Some(index) => ScalarValue::try_from_array(arr, index), + None => ScalarValue::try_from(arr.data_type()), + } + } + } - // If null values exist and need to be ignored, extract the valid indices. - let valid_indices = if self.ignore_nulls { - // Calculate valid indices, inside the window frame boundaries. - let slice = arr.slice(range.start, n_range); - match slice.nulls() { - Some(nulls) => { - let valid_indices = nulls - .valid_indices() - .map(|idx| { - // Add offset `range.start` to valid indices, to point correct index in the original arr. - idx + range.start - }) - .collect::>(); - if valid_indices.is_empty() { - // If all values are null, return directly. - return ScalarValue::try_from(arr.data_type()); - } - Some(valid_indices) - } - None => None, - } - } else { - None - }; - match self.state.kind { - NthValueKind::First => { - if let Some(valid_indices) = &valid_indices { - ScalarValue::try_from_array(arr, valid_indices[0]) + fn supports_bounded_execution(&self) -> bool { + true + } + + fn uses_window_frame(&self) -> bool { + true + } +} + +impl NthValueEvaluator { + fn valid_index(&self, array: &ArrayRef, range: &Range) -> Option { + let n_range = range.end - range.start; + if self.ignore_nulls { + // Calculate valid indices, inside the window frame boundaries. + let slice = array.slice(range.start, n_range); + if let Some(nulls) = slice.nulls() + && nulls.null_count() > 0 + { + return self.valid_index_with_nulls(nulls, range.start); + } + } + // Either no nulls, or nulls are regarded as valid rows + match self.state.kind { + NthValueKind::First => Some(range.start), + NthValueKind::Last => Some(range.end - 1), + NthValueKind::Nth => match self.n.cmp(&0) { + Ordering::Greater => { + // SQL indices are not 0-based. + let index = (self.n as usize) - 1; + if index >= n_range { + // Outside the range, return NULL: + None } else { - ScalarValue::try_from_array(arr, range.start) + Some(range.start + index) } } - NthValueKind::Last => { - if let Some(valid_indices) = &valid_indices { - ScalarValue::try_from_array( - arr, - valid_indices[valid_indices.len() - 1], - ) + Ordering::Less => { + let reverse_index = (-self.n) as usize; + if n_range < reverse_index { + // Outside the range, return NULL: + None } else { - ScalarValue::try_from_array(arr, range.end - 1) + Some(range.end - reverse_index) } } - NthValueKind::Nth => { - match self.n.cmp(&0) { - Ordering::Greater => { - // SQL indices are not 0-based. - let index = (self.n as usize) - 1; - if index >= n_range { - // Outside the range, return NULL: - ScalarValue::try_from(arr.data_type()) - } else if let Some(valid_indices) = valid_indices { - if index >= valid_indices.len() { - return ScalarValue::try_from(arr.data_type()); - } - ScalarValue::try_from_array(&arr, valid_indices[index]) - } else { - ScalarValue::try_from_array(arr, range.start + index) - } - } - Ordering::Less => { - let reverse_index = (-self.n) as usize; - if n_range < reverse_index { - // Outside the range, return NULL: - ScalarValue::try_from(arr.data_type()) - } else if let Some(valid_indices) = valid_indices { - if reverse_index > valid_indices.len() { - return ScalarValue::try_from(arr.data_type()); - } - let new_index = - valid_indices[valid_indices.len() - reverse_index]; - ScalarValue::try_from_array(&arr, new_index) - } else { - ScalarValue::try_from_array( - arr, - range.start + n_range - reverse_index, - ) - } + Ordering::Equal => None, + }, + } + } + + fn valid_index_with_nulls(&self, nulls: &NullBuffer, offset: usize) -> Option { + match self.state.kind { + NthValueKind::First => nulls.valid_indices().next().map(|idx| idx + offset), + NthValueKind::Last => nulls.valid_indices().last().map(|idx| idx + offset), + NthValueKind::Nth => { + match self.n.cmp(&0) { + Ordering::Greater => { + // SQL indices are not 0-based. + let index = (self.n as usize) - 1; + nulls.valid_indices().nth(index).map(|idx| idx + offset) + } + Ordering::Less => { + let reverse_index = (-self.n) as usize; + let valid_indices_len = nulls.len() - nulls.null_count(); + if reverse_index > valid_indices_len { + return None; } - Ordering::Equal => ScalarValue::try_from(arr.data_type()), + nulls + .valid_indices() + .nth(valid_indices_len - reverse_index) + .map(|idx| idx + offset) } + Ordering::Equal => None, } } } } - - fn supports_bounded_execution(&self) -> bool { - true - } - - fn uses_window_frame(&self) -> bool { - true - } } #[cfg(test)] diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 2bdc05abe380..1940f1378b63 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -82,12 +82,13 @@ hex = { workspace = true, optional = true } itertools = { workspace = true } log = { workspace = true } md-5 = { version = "^0.10.0", optional = true } +memchr = { workspace = true } num-traits = { workspace = true } rand = { workspace = true } regex = { workspace = true, optional = true } -sha2 = { version = "^0.10.9", optional = true } +sha2 = { workspace = true, optional = true } unicode-segmentation = { version = "^1.7.1", optional = true } -uuid = { version = "1.19", features = ["v4"], optional = true } +uuid = { workspace = true, features = ["v4"], optional = true } [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } @@ -107,6 +108,11 @@ harness = false name = "concat" required-features = ["string_expressions"] +[[bench]] +harness = false +name = "concat_ws" +required-features = ["string_expressions"] + [[bench]] harness = false name = "to_timestamp" @@ -127,6 +133,11 @@ harness = false name = "gcd" required-features = ["math_expressions"] +[[bench]] +harness = false +name = "nanvl" +required-features = ["math_expressions"] + [[bench]] harness = false name = "uuid" @@ -181,6 +192,11 @@ harness = false name = "signum" required-features = ["math_expressions"] +[[bench]] +harness = false +name = "atan2" +required-features = ["math_expressions"] + [[bench]] harness = false name = "substr_index" @@ -298,10 +314,25 @@ required-features = ["unicode_expressions"] [[bench]] harness = false -name = "left" +name = "split_part" +required-features = ["string_expressions"] + +[[bench]] +harness = false +name = "left_right" required-features = ["unicode_expressions"] [[bench]] harness = false name = "factorial" required-features = ["math_expressions"] + +[[bench]] +harness = false +name = "floor_ceil" +required-features = ["math_expressions"] + +[[bench]] +harness = false +name = "round" +required-features = ["math_expressions"] diff --git a/datafusion/functions/benches/ascii.rs b/datafusion/functions/benches/ascii.rs index 66d81261bfe8..a2424ed352af 100644 --- a/datafusion/functions/benches/ascii.rs +++ b/datafusion/functions/benches/ascii.rs @@ -15,19 +15,47 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; mod helper; use arrow::datatypes::{DataType, Field}; use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; -use datafusion_expr::ScalarFunctionArgs; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use helper::gen_string_array; use std::hint::black_box; use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { let ascii = datafusion_functions::string::ascii(); + let config_options = Arc::new(ConfigOptions::default()); + + // Scalar benchmarks (outside loop) + c.bench_function("ascii/scalar_utf8", |b| { + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( + "hello".to_string(), + )))], + arg_fields: vec![Field::new("a", DataType::Utf8, false).into()], + number_rows: 1, + return_field: Field::new("f", DataType::Int32, true).into(), + config_options: Arc::clone(&config_options), + }; + b.iter(|| black_box(ascii.invoke_with_args(args.clone()).unwrap())) + }); + + c.bench_function("ascii/scalar_utf8view", |b| { + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + "hello".to_string(), + )))], + arg_fields: vec![Field::new("a", DataType::Utf8View, false).into()], + number_rows: 1, + return_field: Field::new("f", DataType::Int32, true).into(), + config_options: Arc::clone(&config_options), + }; + b.iter(|| black_box(ascii.invoke_with_args(args.clone()).unwrap())) + }); // All benches are single batch run with 8192 rows const N_ROWS: usize = 8192; diff --git a/datafusion/functions/benches/atan2.rs b/datafusion/functions/benches/atan2.rs new file mode 100644 index 000000000000..f1c9756a0cc0 --- /dev/null +++ b/datafusion/functions/benches/atan2.rs @@ -0,0 +1,146 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::datatypes::{DataType, Field, Float32Type, Float64Type}; +use arrow::util::bench_util::create_primitive_array; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::math::atan2; +use std::hint::black_box; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let atan2_fn = atan2(); + let config_options = Arc::new(ConfigOptions::default()); + + for size in [1024, 4096, 8192] { + let y_f32 = Arc::new(create_primitive_array::(size, 0.2)); + let x_f32 = Arc::new(create_primitive_array::(size, 0.2)); + let f32_args = vec![ColumnarValue::Array(y_f32), ColumnarValue::Array(x_f32)]; + let f32_arg_fields = f32_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let return_field_f32 = Field::new("f", DataType::Float32, true).into(); + + c.bench_function(&format!("atan2 f32 array: {size}"), |b| { + b.iter(|| { + black_box( + atan2_fn + .invoke_with_args(ScalarFunctionArgs { + args: f32_args.clone(), + arg_fields: f32_arg_fields.clone(), + number_rows: size, + return_field: Arc::clone(&return_field_f32), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + let y_f64 = Arc::new(create_primitive_array::(size, 0.2)); + let x_f64 = Arc::new(create_primitive_array::(size, 0.2)); + let f64_args = vec![ColumnarValue::Array(y_f64), ColumnarValue::Array(x_f64)]; + let f64_arg_fields = f64_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let return_field_f64 = Field::new("f", DataType::Float64, true).into(); + + c.bench_function(&format!("atan2 f64 array: {size}"), |b| { + b.iter(|| { + black_box( + atan2_fn + .invoke_with_args(ScalarFunctionArgs { + args: f64_args.clone(), + arg_fields: f64_arg_fields.clone(), + number_rows: size, + return_field: Arc::clone(&return_field_f64), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + } + + let scalar_f32_args = vec![ + ColumnarValue::Scalar(ScalarValue::Float32(Some(1.0))), + ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), + ]; + let scalar_f32_arg_fields = vec![ + Field::new("a", DataType::Float32, false).into(), + Field::new("b", DataType::Float32, false).into(), + ]; + let return_field_f32 = Field::new("f", DataType::Float32, false).into(); + + c.bench_function("atan2 f32 scalar", |b| { + b.iter(|| { + black_box( + atan2_fn + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f32_args.clone(), + arg_fields: scalar_f32_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field_f32), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + let scalar_f64_args = vec![ + ColumnarValue::Scalar(ScalarValue::Float64(Some(1.0))), + ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), + ]; + let scalar_f64_arg_fields = vec![ + Field::new("a", DataType::Float64, false).into(), + Field::new("b", DataType::Float64, false).into(), + ]; + let return_field_f64 = Field::new("f", DataType::Float64, false).into(); + + c.bench_function("atan2 f64 scalar", |b| { + b.iter(|| { + black_box( + atan2_fn + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f64_args.clone(), + arg_fields: scalar_f64_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field_f64), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/character_length.rs b/datafusion/functions/benches/character_length.rs index 35a0cf886b7f..4927627ec2f0 100644 --- a/datafusion/functions/benches/character_length.rs +++ b/datafusion/functions/benches/character_length.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::{DataType, Field}; use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::config::ConfigOptions; diff --git a/datafusion/functions/benches/chr.rs b/datafusion/functions/benches/chr.rs index 9a6342ca40bb..a702dc161ae0 100644 --- a/datafusion/functions/benches/chr.rs +++ b/datafusion/functions/benches/chr.rs @@ -15,10 +15,9 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::{array::PrimitiveArray, datatypes::Int64Type}; use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string::chr; use rand::{Rng, SeedableRng}; @@ -35,11 +34,32 @@ pub fn seedable_rng() -> StdRng { } fn criterion_benchmark(c: &mut Criterion) { - let cot_fn = chr(); + let chr_fn = chr(); + let config_options = Arc::new(ConfigOptions::default()); + + // Scalar benchmarks + c.bench_function("chr/scalar", |b| { + let args = vec![ColumnarValue::Scalar(ScalarValue::Int64(Some(65)))]; + let arg_fields = vec![Field::new("arg_0", DataType::Int64, true).into()]; + b.iter(|| { + black_box( + chr_fn + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + let size = 1024; let input: PrimitiveArray = { let null_density = 0.2; - let mut rng = StdRng::seed_from_u64(42); + let mut rng = seedable_rng(); (0..size) .map(|_| { if rng.random::() < null_density { @@ -57,12 +77,11 @@ fn criterion_benchmark(c: &mut Criterion) { .enumerate() .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) .collect::>(); - let config_options = Arc::new(ConfigOptions::default()); - c.bench_function("chr", |b| { + c.bench_function("chr/array", |b| { b.iter(|| { black_box( - cot_fn + chr_fn .invoke_with_args(ScalarFunctionArgs { args: args.clone(), arg_fields: arg_fields.clone(), diff --git a/datafusion/functions/benches/concat.rs b/datafusion/functions/benches/concat.rs index f7ef97892090..0fb910800e3b 100644 --- a/datafusion/functions/benches/concat.rs +++ b/datafusion/functions/benches/concat.rs @@ -17,16 +17,18 @@ use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; -use arrow::util::bench_util::create_string_array_with_len; +use arrow::util::bench_util::{create_string_array_with_len, create_string_view_array}; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string::concat; +use rand::Rng; +use rand::distr::Alphanumeric; use std::hint::black_box; use std::sync::Arc; -fn create_args(size: usize, str_len: usize) -> Vec { +fn create_array_args(size: usize, str_len: usize) -> Vec { let array = Arc::new(create_string_array_with_len::(size, 0.2, str_len)); let scalar = ScalarValue::Utf8(Some(", ".to_string())); vec![ @@ -36,9 +38,37 @@ fn create_args(size: usize, str_len: usize) -> Vec { ] } +fn create_array_args_view(size: usize) -> Vec { + let array = Arc::new(create_string_view_array(size, 0.2)); + let scalar = ScalarValue::Utf8(Some(", ".to_string())); + vec![ + ColumnarValue::Array(Arc::clone(&array) as ArrayRef), + ColumnarValue::Scalar(scalar), + ColumnarValue::Array(array), + ] +} + +fn generate_random_string(str_len: usize) -> String { + rand::rng() + .sample_iter(&Alphanumeric) + .take(str_len) + .map(char::from) + .collect() +} + +fn create_scalar_args(count: usize, str_len: usize) -> Vec { + std::iter::repeat_with(|| { + let s = generate_random_string(str_len); + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) + }) + .take(count) + .collect() +} + fn criterion_benchmark(c: &mut Criterion) { + // Benchmark for array concat for size in [1024, 4096, 8192] { - let args = create_args(size, 32); + let args = create_array_args(size, 32); let arg_fields = args .iter() .enumerate() @@ -67,6 +97,70 @@ fn criterion_benchmark(c: &mut Criterion) { }); group.finish(); } + + // Benchmark for StringViewArray concat + for size in [1024, 4096, 8192] { + let args = create_array_args_view(size); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + // Use Utf8View for array args + let dt = if matches!(arg, ColumnarValue::Array(_)) { + DataType::Utf8View + } else { + DataType::Utf8 // scalar remains Utf8 + }; + Field::new(format!("arg_{idx}"), dt, true).into() + }) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + let mut group = c.benchmark_group("concat function"); + group.bench_function(BenchmarkId::new("concat_view", size), |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box( + concat() + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8View, true) + .into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + group.finish(); + } + + // Benchmark for scalar concat + let scalar_args = create_scalar_args(10, 100); + let scalar_arg_fields = scalar_args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + let mut group = c.benchmark_group("concat function"); + group.bench_function(BenchmarkId::new("concat", "scalar"), |b| { + b.iter(|| { + let args_cloned = scalar_args.clone(); + black_box( + concat() + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: scalar_arg_fields.clone(), + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/functions/benches/concat_ws.rs b/datafusion/functions/benches/concat_ws.rs new file mode 100644 index 000000000000..97d6d96411d7 --- /dev/null +++ b/datafusion/functions/benches/concat_ws.rs @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field}; +use arrow::util::bench_util::create_string_array_with_len; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::string::concat_ws; +use rand::Rng; +use rand::distr::Alphanumeric; +use std::hint::black_box; +use std::sync::Arc; + +fn create_array_args(size: usize, str_len: usize) -> Vec { + let array = Arc::new(create_string_array_with_len::(size, 0.2, str_len)); + let scalar = ScalarValue::Utf8(Some(", ".to_string())); + vec![ + ColumnarValue::Scalar(scalar), + ColumnarValue::Array(Arc::clone(&array) as ArrayRef), + ColumnarValue::Array(array), + ] +} + +fn generate_random_string(str_len: usize) -> String { + rand::rng() + .sample_iter(&Alphanumeric) + .take(str_len) + .map(char::from) + .collect() +} + +fn create_scalar_args(count: usize, str_len: usize) -> Vec { + let mut args = Vec::with_capacity(count + 1); + + args.push(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + ",".to_string(), + )))); + + for _ in 0..count { + let s = generate_random_string(str_len); + args.push(ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))); + } + args +} + +fn criterion_benchmark(c: &mut Criterion) { + // Benchmark for array concat_ws + for size in [1024, 4096, 8192] { + let args = create_array_args(size, 32); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + let mut group = c.benchmark_group("concat_ws function"); + group.bench_function(BenchmarkId::new("concat_ws", size), |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box( + concat_ws() + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + group.finish(); + } + + // Benchmark for scalar concat_ws + let scalar_args = create_scalar_args(10, 100); + let scalar_arg_fields = scalar_args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + let mut group = c.benchmark_group("concat_ws function"); + group.bench_function(BenchmarkId::new("concat_ws", "scalar"), |b| { + b.iter(|| { + let args_cloned = scalar_args.clone(); + black_box( + concat_ws() + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: scalar_arg_fields.clone(), + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/contains.rs b/datafusion/functions/benches/contains.rs index 052eff38869d..6c39f45e14fa 100644 --- a/datafusion/functions/benches/contains.rs +++ b/datafusion/functions/benches/contains.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::{StringArray, StringViewArray}; use arrow::datatypes::{DataType, Field}; use criterion::{Criterion, criterion_group, criterion_main}; diff --git a/datafusion/functions/benches/cot.rs b/datafusion/functions/benches/cot.rs index c47198d4a620..16c3fba2175f 100644 --- a/datafusion/functions/benches/cot.rs +++ b/datafusion/functions/benches/cot.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::{ datatypes::{Float32Type, Float64Type}, util::bench_util::create_primitive_array, @@ -27,11 +25,15 @@ use datafusion_functions::math::cot; use std::hint::black_box; use arrow::datatypes::{DataType, Field}; +use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { let cot_fn = cot(); + let config_options = Arc::new(ConfigOptions::default()); + + // Array benchmarks - run for different sizes for size in [1024, 4096, 8192] { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let f32_args = vec![ColumnarValue::Array(f32_array)]; @@ -42,7 +44,6 @@ fn criterion_benchmark(c: &mut Criterion) { Field::new(format!("arg_{idx}"), arg.data_type(), true).into() }) .collect::>(); - let config_options = Arc::new(ConfigOptions::default()); c.bench_function(&format!("cot f32 array: {size}"), |b| { b.iter(|| { @@ -59,6 +60,7 @@ fn criterion_benchmark(c: &mut Criterion) { ) }) }); + let f64_array = Arc::new(create_primitive_array::(size, 0.2)); let f64_args = vec![ColumnarValue::Array(f64_array)]; let arg_fields = f64_args @@ -86,6 +88,47 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); } + + // Scalar benchmarks - run only once since size doesn't affect scalar performance + let scalar_f32_args = vec![ColumnarValue::Scalar(ScalarValue::Float32(Some(1.0)))]; + let scalar_f32_arg_fields = vec![Field::new("a", DataType::Float32, false).into()]; + let return_field_f32 = Field::new("f", DataType::Float32, false).into(); + + c.bench_function("cot f32 scalar", |b| { + b.iter(|| { + black_box( + cot_fn + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f32_args.clone(), + arg_fields: scalar_f32_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field_f32), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + let scalar_f64_args = vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(1.0)))]; + let scalar_f64_arg_fields = vec![Field::new("a", DataType::Float64, false).into()]; + let return_field_f64 = Field::new("f", DataType::Float64, false).into(); + + c.bench_function("cot f64 scalar", |b| { + b.iter(|| { + black_box( + cot_fn + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f64_args.clone(), + arg_fields: scalar_f64_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field_f64), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/functions/benches/crypto.rs b/datafusion/functions/benches/crypto.rs index bf30cc9a0c44..9a86efbff9ed 100644 --- a/datafusion/functions/benches/crypto.rs +++ b/datafusion/functions/benches/crypto.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{Criterion, criterion_group, criterion_main}; diff --git a/datafusion/functions/benches/date_bin.rs b/datafusion/functions/benches/date_bin.rs index eb4e960d8312..28dee9698726 100644 --- a/datafusion/functions/benches/date_bin.rs +++ b/datafusion/functions/benches/date_bin.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use std::hint::black_box; use std::sync::Arc; diff --git a/datafusion/functions/benches/date_trunc.rs b/datafusion/functions/benches/date_trunc.rs index f5c8ceb5fe9d..0668a1cc5085 100644 --- a/datafusion/functions/benches/date_trunc.rs +++ b/datafusion/functions/benches/date_trunc.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use std::hint::black_box; use std::sync::Arc; @@ -25,7 +23,7 @@ use arrow::datatypes::Field; use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; -use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_expr::{ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs}; use datafusion_functions::datetime::date_trunc; use rand::Rng; use rand::rngs::ThreadRng; @@ -57,10 +55,13 @@ fn criterion_benchmark(c: &mut Criterion) { }) .collect::>(); - let return_type = udf - .return_type(&args.iter().map(|arg| arg.data_type()).collect::>()) + let scalar_arguments = vec![None; arg_fields.len()]; + let return_field = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &scalar_arguments, + }) .unwrap(); - let return_field = Arc::new(Field::new("f", return_type, true)); let config_options = Arc::new(ConfigOptions::default()); b.iter(|| { diff --git a/datafusion/functions/benches/encoding.rs b/datafusion/functions/benches/encoding.rs index 8a7c2b7b664b..0b8f0c5c51a5 100644 --- a/datafusion/functions/benches/encoding.rs +++ b/datafusion/functions/benches/encoding.rs @@ -15,11 +15,9 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::Array; use arrow::datatypes::{DataType, Field}; -use arrow::util::bench_util::create_string_array_with_len; +use arrow::util::bench_util::create_binary_array; use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; @@ -32,20 +30,22 @@ fn criterion_benchmark(c: &mut Criterion) { let config_options = Arc::new(ConfigOptions::default()); for size in [1024, 4096, 8192] { - let str_array = Arc::new(create_string_array_with_len::(size, 0.2, 32)); + let bin_array = Arc::new(create_binary_array::(size, 0.2)); c.bench_function(&format!("base64_decode/{size}"), |b| { let method = ColumnarValue::Scalar("base64".into()); let encoded = encoding::encode() .invoke_with_args(ScalarFunctionArgs { - args: vec![ColumnarValue::Array(str_array.clone()), method.clone()], + args: vec![ColumnarValue::Array(bin_array.clone()), method.clone()], arg_fields: vec![ - Field::new("a", str_array.data_type().to_owned(), true).into(), + Field::new("a", bin_array.data_type().to_owned(), true).into(), Field::new("b", method.data_type().to_owned(), true).into(), ], number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), }) + .unwrap() + .cast_to(&DataType::Binary, None) .unwrap(); let arg_fields = vec![ @@ -61,7 +61,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: arg_fields.clone(), number_rows: size, - return_field: Field::new("f", DataType::Utf8, true).into(), + return_field: Field::new("f", DataType::Binary, true).into(), config_options: Arc::clone(&config_options), }) .unwrap(), @@ -72,24 +72,26 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function(&format!("hex_decode/{size}"), |b| { let method = ColumnarValue::Scalar("hex".into()); let arg_fields = vec![ - Field::new("a", str_array.data_type().to_owned(), true).into(), + Field::new("a", bin_array.data_type().to_owned(), true).into(), Field::new("b", method.data_type().to_owned(), true).into(), ]; let encoded = encoding::encode() .invoke_with_args(ScalarFunctionArgs { - args: vec![ColumnarValue::Array(str_array.clone()), method.clone()], + args: vec![ColumnarValue::Array(bin_array.clone()), method.clone()], arg_fields, number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), }) + .unwrap() + .cast_to(&DataType::Binary, None) .unwrap(); let arg_fields = vec![ Field::new("a", encoded.data_type().to_owned(), true).into(), Field::new("b", method.data_type().to_owned(), true).into(), ]; - let return_field = Field::new("f", DataType::Utf8, true).into(); + let return_field = Field::new("f", DataType::Binary, true).into(); let args = vec![encoded, method]; b.iter(|| { diff --git a/datafusion/functions/benches/ends_with.rs b/datafusion/functions/benches/ends_with.rs index 926fd9ff72a5..474e8a1555cf 100644 --- a/datafusion/functions/benches/ends_with.rs +++ b/datafusion/functions/benches/ends_with.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::{StringArray, StringViewArray}; use arrow::datatypes::{DataType, Field}; use criterion::{Criterion, criterion_group, criterion_main}; diff --git a/datafusion/functions/benches/factorial.rs b/datafusion/functions/benches/factorial.rs index 5c5ff991d745..c441b50c288c 100644 --- a/datafusion/functions/benches/factorial.rs +++ b/datafusion/functions/benches/factorial.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::Int64Array; use arrow::datatypes::{DataType, Field}; use criterion::{Criterion, criterion_group, criterion_main}; diff --git a/datafusion/functions/benches/find_in_set.rs b/datafusion/functions/benches/find_in_set.rs index e207c1fa48ab..9ee20ecd14fd 100644 --- a/datafusion/functions/benches/find_in_set.rs +++ b/datafusion/functions/benches/find_in_set.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::{StringArray, StringViewArray}; use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ diff --git a/datafusion/functions/benches/floor_ceil.rs b/datafusion/functions/benches/floor_ceil.rs new file mode 100644 index 000000000000..dc095e0152c4 --- /dev/null +++ b/datafusion/functions/benches/floor_ceil.rs @@ -0,0 +1,133 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, Float64Type}; +use arrow::util::bench_util::create_primitive_array; +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::math::{ceil, floor}; +use std::hint::black_box; +use std::sync::Arc; +use std::time::Duration; + +fn criterion_benchmark(c: &mut Criterion) { + let floor_fn = floor(); + let ceil_fn = ceil(); + let config_options = Arc::new(ConfigOptions::default()); + + for size in [1024, 4096, 8192] { + let mut group = c.benchmark_group(format!("floor_ceil size={size}")); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + // Float64 array benchmark + let f64_array = Arc::new(create_primitive_array::(size, 0.1)); + let batch_len = f64_array.len(); + let f64_args = vec![ColumnarValue::Array(f64_array)]; + + group.bench_function("floor_f64_array", |b| { + b.iter(|| { + let args_cloned = f64_args.clone(); + black_box( + floor_fn + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![ + Field::new("a", DataType::Float64, true).into(), + ], + number_rows: batch_len, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + group.bench_function("ceil_f64_array", |b| { + b.iter(|| { + let args_cloned = f64_args.clone(); + black_box( + ceil_fn + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![ + Field::new("a", DataType::Float64, true).into(), + ], + number_rows: batch_len, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + // Scalar benchmark (the optimization we added) + let scalar_args = vec![ColumnarValue::Scalar(ScalarValue::Float64(Some( + std::f64::consts::PI, + )))]; + + group.bench_function("floor_f64_scalar", |b| { + b.iter(|| { + let args_cloned = scalar_args.clone(); + black_box( + floor_fn + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![ + Field::new("a", DataType::Float64, false).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Float64, false) + .into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + group.bench_function("ceil_f64_scalar", |b| { + b.iter(|| { + let args_cloned = scalar_args.clone(); + black_box( + ceil_fn + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![ + Field::new("a", DataType::Float64, false).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Float64, false) + .into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + group.finish(); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/gcd.rs b/datafusion/functions/benches/gcd.rs index 9705af8a2fcd..3c72a46e6643 100644 --- a/datafusion/functions/benches/gcd.rs +++ b/datafusion/functions/benches/gcd.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::Field; use arrow::{ array::{ArrayRef, Int64Array}, diff --git a/datafusion/functions/benches/initcap.rs b/datafusion/functions/benches/initcap.rs index ba055d58f566..b5e653e4136a 100644 --- a/datafusion/functions/benches/initcap.rs +++ b/datafusion/functions/benches/initcap.rs @@ -15,19 +15,19 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - -use arrow::array::OffsetSizeTrait; +use arrow::array::{ArrayRef, OffsetSizeTrait, StringArray, StringViewBuilder}; use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; -use criterion::{Criterion, criterion_group, criterion_main}; +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::unicode; use std::hint::black_box; use std::sync::Arc; +use std::time::Duration; fn create_args( size: usize, @@ -47,62 +47,161 @@ fn create_args( } } +/// Create a Utf8 array where every value contains non-ASCII Unicode text. +fn create_unicode_utf8_args(size: usize) -> Vec { + let array = Arc::new(StringArray::from_iter_values(std::iter::repeat_n( + "ñAnDÚ ÁrBOL ОлЕГ ÍslENsku", + size, + ))) as ArrayRef; + vec![ColumnarValue::Array(array)] +} + +/// Create a Utf8View array where every value contains non-ASCII Unicode text. +fn create_unicode_utf8view_args(size: usize) -> Vec { + let mut builder = StringViewBuilder::with_capacity(size); + for _ in 0..size { + builder.append_value("ñAnDÚ ÁrBOL ОлЕГ ÍslENsku"); + } + let array = Arc::new(builder.finish()) as ArrayRef; + vec![ColumnarValue::Array(array)] +} + fn criterion_benchmark(c: &mut Criterion) { let initcap = unicode::initcap(); - for size in [1024, 4096] { - let args = create_args::(size, 8, true); - let arg_fields = args - .iter() - .enumerate() - .map(|(idx, arg)| { - Field::new(format!("arg_{idx}"), arg.data_type(), true).into() - }) - .collect::>(); - let config_options = Arc::new(ConfigOptions::default()); + let config_options = Arc::new(ConfigOptions::default()); + + // Array benchmarks: vary both row count and string length + for size in [1024, 4096, 8192] { + for str_len in [16, 128] { + let mut group = + c.benchmark_group(format!("initcap size={size} str_len={str_len}")); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + // Utf8 + let array_args = create_args::(size, str_len, false); + let array_arg_fields = vec![Field::new("arg_0", DataType::Utf8, true).into()]; - c.bench_function( - format!("initcap string view shorter than 12 [size={size}]").as_str(), - |b| { + group.bench_function("array_utf8", |b| { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: arg_fields.clone(), + args: array_args.clone(), + arg_fields: array_arg_fields.clone(), number_rows: size, - return_field: Field::new("f", DataType::Utf8View, true).into(), + return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), })) }) - }, - ); + }); + + // Utf8View + let array_view_args = create_args::(size, str_len, true); + let array_view_arg_fields = + vec![Field::new("arg_0", DataType::Utf8View, true).into()]; - let args = create_args::(size, 16, true); - c.bench_function( - format!("initcap string view longer than 12 [size={size}]").as_str(), - |b| { + group.bench_function("array_utf8view", |b| { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: arg_fields.clone(), + args: array_view_args.clone(), + arg_fields: array_view_arg_fields.clone(), number_rows: size, return_field: Field::new("f", DataType::Utf8View, true).into(), config_options: Arc::clone(&config_options), })) }) - }, - ); + }); - let args = create_args::(size, 16, false); - c.bench_function(format!("initcap string [size={size}]").as_str(), |b| { + group.finish(); + } + } + + // Unicode array benchmarks + for size in [1024, 4096, 8192] { + let mut group = c.benchmark_group(format!("initcap unicode size={size}")); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + let unicode_args = create_unicode_utf8_args(size); + let unicode_arg_fields = vec![Field::new("arg_0", DataType::Utf8, true).into()]; + + group.bench_function("array_utf8", |b| { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: arg_fields.clone(), + args: unicode_args.clone(), + arg_fields: unicode_arg_fields.clone(), number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), })) }) }); + + let unicode_view_args = create_unicode_utf8view_args(size); + let unicode_view_arg_fields = + vec![Field::new("arg_0", DataType::Utf8View, true).into()]; + + group.bench_function("array_utf8view", |b| { + b.iter(|| { + black_box(initcap.invoke_with_args(ScalarFunctionArgs { + args: unicode_view_args.clone(), + arg_fields: unicode_view_arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8View, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }); + + group.finish(); + } + + // Scalar benchmarks: independent of array size, run once + { + let mut group = c.benchmark_group("initcap scalar"); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + // Utf8 + let scalar_args = vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( + "hello world test string".to_string(), + )))]; + let scalar_arg_fields = vec![Field::new("arg_0", DataType::Utf8, false).into()]; + + group.bench_function("scalar_utf8", |b| { + b.iter(|| { + black_box(initcap.invoke_with_args(ScalarFunctionArgs { + args: scalar_args.clone(), + arg_fields: scalar_arg_fields.clone(), + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, false).into(), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Utf8View + let scalar_view_args = vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + "hello world test string".to_string(), + )))]; + let scalar_view_arg_fields = + vec![Field::new("arg_0", DataType::Utf8View, false).into()]; + + group.bench_function("scalar_utf8view", |b| { + b.iter(|| { + black_box(initcap.invoke_with_args(ScalarFunctionArgs { + args: scalar_view_args.clone(), + arg_fields: scalar_view_arg_fields.clone(), + number_rows: 1, + return_field: Field::new("f", DataType::Utf8View, false).into(), + config_options: Arc::clone(&config_options), + })) + }) + }); + + group.finish(); } } diff --git a/datafusion/functions/benches/isnan.rs b/datafusion/functions/benches/isnan.rs index d4e41e882fe2..e353b9d27a0a 100644 --- a/datafusion/functions/benches/isnan.rs +++ b/datafusion/functions/benches/isnan.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::{DataType, Field}; use arrow::{ datatypes::{Float32Type, Float64Type}, diff --git a/datafusion/functions/benches/iszero.rs b/datafusion/functions/benches/iszero.rs index 53e38745afa9..c6d0aed4c615 100644 --- a/datafusion/functions/benches/iszero.rs +++ b/datafusion/functions/benches/iszero.rs @@ -15,14 +15,13 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::{DataType, Field}; use arrow::{ datatypes::{Float32Type, Float64Type}, util::bench_util::create_primitive_array, }; use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::math::iszero; @@ -31,6 +30,8 @@ use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { let iszero = iszero(); + let config_options = Arc::new(ConfigOptions::default()); + for size in [1024, 4096, 8192] { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = f32_array.len(); @@ -43,7 +44,6 @@ fn criterion_benchmark(c: &mut Criterion) { }) .collect::>(); let return_field = Arc::new(Field::new("f", DataType::Boolean, true)); - let config_options = Arc::new(ConfigOptions::default()); c.bench_function(&format!("iszero f32 array: {size}"), |b| { b.iter(|| { @@ -60,6 +60,7 @@ fn criterion_benchmark(c: &mut Criterion) { ) }) }); + let f64_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = f64_array.len(); let f64_args = vec![ColumnarValue::Array(f64_array)]; @@ -88,6 +89,46 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); } + + // Scalar benchmarks - run once since size doesn't affect scalar performance + let scalar_f32_args = vec![ColumnarValue::Scalar(ScalarValue::Float32(Some(1.0)))]; + let scalar_f32_arg_fields = vec![Field::new("a", DataType::Float32, false).into()]; + let return_field_scalar = Arc::new(Field::new("f", DataType::Boolean, false)); + + c.bench_function("iszero f32 scalar", |b| { + b.iter(|| { + black_box( + iszero + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f32_args.clone(), + arg_fields: scalar_f32_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field_scalar), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + let scalar_f64_args = vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(1.0)))]; + let scalar_f64_arg_fields = vec![Field::new("a", DataType::Float64, false).into()]; + + c.bench_function("iszero f64 scalar", |b| { + b.iter(|| { + black_box( + iszero + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f64_args.clone(), + arg_fields: scalar_f64_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field_scalar), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/functions/benches/left.rs b/datafusion/functions/benches/left.rs deleted file mode 100644 index 3ea628fe2987..000000000000 --- a/datafusion/functions/benches/left.rs +++ /dev/null @@ -1,111 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -extern crate criterion; - -use std::hint::black_box; -use std::sync::Arc; - -use arrow::array::{ArrayRef, Int64Array}; -use arrow::datatypes::{DataType, Field}; -use arrow::util::bench_util::create_string_array_with_len; -use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; -use datafusion_common::config::ConfigOptions; -use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; -use datafusion_functions::unicode::left; - -fn create_args(size: usize, str_len: usize, use_negative: bool) -> Vec { - let string_array = Arc::new(create_string_array_with_len::(size, 0.1, str_len)); - - // For negative n, we want to trigger the double-iteration code path - let n_values: Vec = if use_negative { - (0..size).map(|i| -((i % 10 + 1) as i64)).collect() - } else { - (0..size).map(|i| (i % 10 + 1) as i64).collect() - }; - let n_array = Arc::new(Int64Array::from(n_values)); - - vec![ - ColumnarValue::Array(string_array), - ColumnarValue::Array(Arc::clone(&n_array) as ArrayRef), - ] -} - -fn criterion_benchmark(c: &mut Criterion) { - for size in [1024, 4096] { - let mut group = c.benchmark_group(format!("left size={size}")); - - // Benchmark with positive n (no optimization needed) - let args = create_args(size, 32, false); - group.bench_function(BenchmarkId::new("positive n", size), |b| { - let arg_fields = args - .iter() - .enumerate() - .map(|(idx, arg)| { - Field::new(format!("arg_{idx}"), arg.data_type(), true).into() - }) - .collect::>(); - let config_options = Arc::new(ConfigOptions::default()); - - b.iter(|| { - black_box( - left() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: arg_fields.clone(), - number_rows: size, - return_field: Field::new("f", DataType::Utf8, true).into(), - config_options: Arc::clone(&config_options), - }) - .expect("left should work"), - ) - }) - }); - - // Benchmark with negative n (triggers optimization) - let args = create_args(size, 32, true); - group.bench_function(BenchmarkId::new("negative n", size), |b| { - let arg_fields = args - .iter() - .enumerate() - .map(|(idx, arg)| { - Field::new(format!("arg_{idx}"), arg.data_type(), true).into() - }) - .collect::>(); - let config_options = Arc::new(ConfigOptions::default()); - - b.iter(|| { - black_box( - left() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: arg_fields.clone(), - number_rows: size, - return_field: Field::new("f", DataType::Utf8, true).into(), - config_options: Arc::clone(&config_options), - }) - .expect("left should work"), - ) - }) - }); - - group.finish(); - } -} - -criterion_group!(benches, criterion_benchmark); -criterion_main!(benches); diff --git a/datafusion/functions/benches/left_right.rs b/datafusion/functions/benches/left_right.rs new file mode 100644 index 000000000000..59f8d8a75f74 --- /dev/null +++ b/datafusion/functions/benches/left_right.rs @@ -0,0 +1,128 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::hint::black_box; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int64Array}; +use arrow::datatypes::{DataType, Field}; +use arrow::util::bench_util::{ + create_string_array_with_len, create_string_view_array_with_len, +}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::unicode::{left, right}; + +fn create_args( + size: usize, + str_len: usize, + use_negative: bool, + is_string_view: bool, +) -> Vec { + let string_arg = if is_string_view { + ColumnarValue::Array(Arc::new(create_string_view_array_with_len( + size, 0.1, str_len, true, + ))) + } else { + ColumnarValue::Array(Arc::new(create_string_array_with_len::( + size, 0.1, str_len, + ))) + }; + + // For negative n, we want to trigger the double-iteration code path + let n_values: Vec = if use_negative { + (0..size).map(|i| -((i % 10 + 1) as i64)).collect() + } else { + (0..size).map(|i| (i % 10 + 1) as i64).collect() + }; + let n_array = Arc::new(Int64Array::from(n_values)); + + vec![ + string_arg, + ColumnarValue::Array(Arc::clone(&n_array) as ArrayRef), + ] +} + +fn criterion_benchmark(c: &mut Criterion) { + let left_function = left(); + let right_function = right(); + + for function in [left_function, right_function] { + for is_string_view in [false, true] { + for is_negative in [false, true] { + for size in [1024, 4096] { + let function_name = function.name(); + let mut group = + c.benchmark_group(format!("{function_name} size={size}")); + + let bench_name = format!( + "{} {} n", + if is_string_view { + "string_view_array" + } else { + "string_array" + }, + if is_negative { "negative" } else { "positive" }, + ); + let return_type = if is_string_view { + DataType::Utf8View + } else { + DataType::Utf8 + }; + + let args = create_args(size, 32, is_negative, is_string_view); + group.bench_function(BenchmarkId::new(bench_name, size), |b| { + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true) + .into() + }) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + b.iter(|| { + black_box( + function + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new( + "f", + return_type.clone(), + true, + ) + .into(), + config_options: Arc::clone(&config_options), + }) + .expect("should work"), + ) + }) + }); + + group.finish(); + } + } + } + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/levenshtein.rs b/datafusion/functions/benches/levenshtein.rs index 19f81b6cafcb..08733b245ffb 100644 --- a/datafusion/functions/benches/levenshtein.rs +++ b/datafusion/functions/benches/levenshtein.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::OffsetSizeTrait; use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; diff --git a/datafusion/functions/benches/lower.rs b/datafusion/functions/benches/lower.rs index 333dca390054..6dbc8dcb7d14 100644 --- a/datafusion/functions/benches/lower.rs +++ b/datafusion/functions/benches/lower.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::{ArrayRef, StringArray, StringViewBuilder}; use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ diff --git a/datafusion/functions/benches/make_date.rs b/datafusion/functions/benches/make_date.rs index 8b1b32edfc9c..42b5b1019538 100644 --- a/datafusion/functions/benches/make_date.rs +++ b/datafusion/functions/benches/make_date.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use std::hint::black_box; use std::sync::Arc; diff --git a/datafusion/functions/benches/nanvl.rs b/datafusion/functions/benches/nanvl.rs new file mode 100644 index 000000000000..206eebd81eb8 --- /dev/null +++ b/datafusion/functions/benches/nanvl.rs @@ -0,0 +1,114 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::array::{ArrayRef, Float32Array, Float64Array}; +use arrow::datatypes::{DataType, Field}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::math::nanvl; +use std::hint::black_box; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let nanvl_fn = nanvl(); + let config_options = Arc::new(ConfigOptions::default()); + + // Scalar benchmarks + c.bench_function("nanvl/scalar_f64", |b| { + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Float64(Some(f64::NAN))), + ColumnarValue::Scalar(ScalarValue::Float64(Some(1.0))), + ], + arg_fields: vec![ + Field::new("a", DataType::Float64, true).into(), + Field::new("b", DataType::Float64, true).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::clone(&config_options), + }; + + b.iter(|| black_box(nanvl_fn.invoke_with_args(args.clone()).unwrap())) + }); + + c.bench_function("nanvl/scalar_f32", |b| { + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Float32(Some(f32::NAN))), + ColumnarValue::Scalar(ScalarValue::Float32(Some(1.0))), + ], + arg_fields: vec![ + Field::new("a", DataType::Float32, true).into(), + Field::new("b", DataType::Float32, true).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Float32, true).into(), + config_options: Arc::clone(&config_options), + }; + + b.iter(|| black_box(nanvl_fn.invoke_with_args(args.clone()).unwrap())) + }); + + // Array benchmarks + for size in [1024, 4096, 8192] { + let a64: ArrayRef = Arc::new(Float64Array::from(vec![f64::NAN; size])); + let b64: ArrayRef = Arc::new(Float64Array::from(vec![1.0; size])); + c.bench_function(&format!("nanvl/array_f64/{size}"), |bench| { + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::clone(&a64)), + ColumnarValue::Array(Arc::clone(&b64)), + ], + arg_fields: vec![ + Field::new("a", DataType::Float64, true).into(), + Field::new("b", DataType::Float64, true).into(), + ], + number_rows: size, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::clone(&config_options), + }; + bench.iter(|| black_box(nanvl_fn.invoke_with_args(args.clone()).unwrap())) + }); + + let a32: ArrayRef = Arc::new(Float32Array::from(vec![f32::NAN; size])); + let b32: ArrayRef = Arc::new(Float32Array::from(vec![1.0; size])); + c.bench_function(&format!("nanvl/array_f32/{size}"), |bench| { + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::clone(&a32)), + ColumnarValue::Array(Arc::clone(&b32)), + ], + arg_fields: vec![ + Field::new("a", DataType::Float32, true).into(), + Field::new("b", DataType::Float32, true).into(), + ], + number_rows: size, + return_field: Field::new("f", DataType::Float32, true).into(), + config_options: Arc::clone(&config_options), + }; + bench.iter(|| black_box(nanvl_fn.invoke_with_args(args.clone()).unwrap())) + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/nullif.rs b/datafusion/functions/benches/nullif.rs index f937d19421e8..f9f063c52d0d 100644 --- a/datafusion/functions/benches/nullif.rs +++ b/datafusion/functions/benches/nullif.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{Criterion, criterion_group, criterion_main}; diff --git a/datafusion/functions/benches/pad.rs b/datafusion/functions/benches/pad.rs index f6b2ed7636bf..0f856f0fef38 100644 --- a/datafusion/functions/benches/pad.rs +++ b/datafusion/functions/benches/pad.rs @@ -15,9 +15,10 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - -use arrow::array::{ArrowPrimitiveType, OffsetSizeTrait, PrimitiveArray}; +use arrow::array::{ + ArrowPrimitiveType, GenericStringBuilder, OffsetSizeTrait, PrimitiveArray, + StringViewBuilder, +}; use arrow::datatypes::{DataType, Field, Int64Type}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, @@ -32,6 +33,51 @@ use std::hint::black_box; use std::sync::Arc; use std::time::Duration; +const UNICODE_STRINGS: &[&str] = &[ + "Ñandú", + "Íslensku", + "Þjóðarinnar", + "Ελληνική", + "Иванович", + "データフュージョン", + "José García", + "Ölçü bïrïmï", + "Ÿéšṱëṟḏàÿ", + "Ährenstraße", +]; + +fn create_unicode_string_array( + size: usize, + null_density: f32, +) -> arrow::array::GenericStringArray { + let mut rng = rand::rng(); + let mut builder = GenericStringBuilder::::new(); + for i in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + builder.append_value(UNICODE_STRINGS[i % UNICODE_STRINGS.len()]); + } + } + builder.finish() +} + +fn create_unicode_string_view_array( + size: usize, + null_density: f32, +) -> arrow::array::StringViewArray { + let mut rng = rand::rng(); + let mut builder = StringViewBuilder::with_capacity(size); + for i in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + builder.append_value(UNICODE_STRINGS[i % UNICODE_STRINGS.len()]); + } + } + builder.finish() +} + struct Filter { dist: Dist, } @@ -69,6 +115,34 @@ where .collect() } +/// Create args for pad benchmark with Unicode strings +fn create_unicode_pad_args( + size: usize, + target_len: usize, + use_string_view: bool, +) -> Vec { + let length_array = + Arc::new(create_primitive_array::(size, 0.0, target_len)); + + if use_string_view { + let string_array = create_unicode_string_view_array(size, 0.1); + let fill_array = create_unicode_string_view_array(size, 0.1); + vec![ + ColumnarValue::Array(Arc::new(string_array)), + ColumnarValue::Array(length_array), + ColumnarValue::Array(Arc::new(fill_array)), + ] + } else { + let string_array = create_unicode_string_array::(size, 0.1); + let fill_array = create_unicode_string_array::(size, 0.1); + vec![ + ColumnarValue::Array(Arc::new(string_array)), + ColumnarValue::Array(length_array), + ColumnarValue::Array(Arc::new(fill_array)), + ] + } +} + /// Create args for pad benchmark fn create_pad_args( size: usize, @@ -210,6 +284,58 @@ fn criterion_benchmark(c: &mut Criterion) { }, ); + // Utf8 type with Unicode strings + let args = create_unicode_pad_args(size, 20, false); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + group.bench_function( + format!("lpad utf8 unicode [size={size}, target=20]"), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(unicode::lpad().invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + + // StringView type with Unicode strings + let args = create_unicode_pad_args(size, 20, true); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + group.bench_function( + format!("lpad stringview unicode [size={size}, target=20]"), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(unicode::lpad().invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8View, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + group.finish(); } @@ -324,6 +450,58 @@ fn criterion_benchmark(c: &mut Criterion) { }, ); + // Utf8 type with Unicode strings + let args = create_unicode_pad_args(size, 20, false); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + group.bench_function( + format!("rpad utf8 unicode [size={size}, target=20]"), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(unicode::rpad().invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + + // StringView type with Unicode strings + let args = create_unicode_pad_args(size, 20, true); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + group.bench_function( + format!("rpad stringview unicode [size={size}, target=20]"), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(unicode::rpad().invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8View, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + group.finish(); } } diff --git a/datafusion/functions/benches/random.rs b/datafusion/functions/benches/random.rs index 3d8631140c05..71ded120eb51 100644 --- a/datafusion/functions/benches/random.rs +++ b/datafusion/functions/benches/random.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::{DataType, Field}; use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::config::ConfigOptions; diff --git a/datafusion/functions/benches/regexp_count.rs b/datafusion/functions/benches/regexp_count.rs index eae7ef00f16b..bce76c05585b 100644 --- a/datafusion/functions/benches/regexp_count.rs +++ b/datafusion/functions/benches/regexp_count.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::Int64Array; use arrow::array::OffsetSizeTrait; use arrow::datatypes::{DataType, Field}; diff --git a/datafusion/functions/benches/regx.rs b/datafusion/functions/benches/regx.rs index 32378ccd126e..a46b548236d0 100644 --- a/datafusion/functions/benches/regx.rs +++ b/datafusion/functions/benches/regx.rs @@ -15,25 +15,27 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; +use std::hint::black_box; +use std::iter; +use std::sync::Arc; use arrow::array::builder::StringBuilder; use arrow::array::{ArrayRef, AsArray, Int64Array, StringArray, StringViewArray}; use arrow::compute::cast; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; use datafusion_functions::regex::regexpcount::regexp_count_func; use datafusion_functions::regex::regexpinstr::regexp_instr_func; -use datafusion_functions::regex::regexplike::regexp_like; +use datafusion_functions::regex::regexplike::{RegexpLikeFunc, regexp_like}; use datafusion_functions::regex::regexpmatch::regexp_match; use datafusion_functions::regex::regexpreplace::regexp_replace; use rand::Rng; use rand::distr::Alphanumeric; use rand::prelude::IndexedRandom; use rand::rngs::ThreadRng; -use std::hint::black_box; -use std::iter; -use std::sync::Arc; fn data(rng: &mut ThreadRng) -> StringArray { let mut data: Vec = vec![]; for _ in 0..1000 { @@ -107,6 +109,8 @@ fn subexp(rng: &mut ThreadRng) -> Int64Array { } fn criterion_benchmark(c: &mut Criterion) { + let regexp_like_func = RegexpLikeFunc::new(); + let config_options = Arc::new(ConfigOptions::default()); c.bench_function("regexp_count_1000 string", |b| { let mut rng = rand::rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; @@ -221,6 +225,32 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); + let scalar_args = vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("foobarbequebaz".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("(bar)(beque)".to_string()))), + ]; + let scalar_arg_fields = vec![ + Field::new("arg_0", DataType::Utf8, false).into(), + Field::new("arg_1", DataType::Utf8, false).into(), + ]; + let return_field = Field::new("f", DataType::Boolean, true).into(); + + c.bench_function("regexp_like scalar utf8", |b| { + b.iter(|| { + black_box( + regexp_like_func + .invoke_with_args(ScalarFunctionArgs { + args: scalar_args.clone(), + arg_fields: scalar_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + }) + .expect("regexp_like scalar should work on valid values"), + ) + }) + }); + c.bench_function("regexp_match_1000", |b| { let mut rng = rand::rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; diff --git a/datafusion/functions/benches/repeat.rs b/datafusion/functions/benches/repeat.rs index 304739b42f5f..354812c0d2ea 100644 --- a/datafusion/functions/benches/repeat.rs +++ b/datafusion/functions/benches/repeat.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::{ArrayRef, Int64Array, OffsetSizeTrait}; use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ @@ -24,6 +22,7 @@ use arrow::util::bench_util::{ }; use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; use datafusion_common::DataFusionError; +use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string; @@ -80,6 +79,44 @@ fn invoke_repeat_with_args( } fn criterion_benchmark(c: &mut Criterion) { + let repeat_fn = string::repeat(); + let config_options = Arc::new(ConfigOptions::default()); + + // Scalar benchmarks (outside loop) + c.bench_function("repeat/scalar_utf8", |b| { + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("hello".to_string()))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ], + arg_fields: vec![ + Field::new("a", DataType::Utf8, false).into(), + Field::new("b", DataType::Int64, false).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }; + b.iter(|| black_box(repeat_fn.invoke_with_args(args.clone()).unwrap())) + }); + + c.bench_function("repeat/scalar_utf8view", |b| { + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("hello".to_string()))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ], + arg_fields: vec![ + Field::new("a", DataType::Utf8View, false).into(), + Field::new("b", DataType::Int64, false).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }; + b.iter(|| black_box(repeat_fn.invoke_with_args(args.clone()).unwrap())) + }); + for size in [1024, 4096] { // REPEAT 3 TIMES let repeat_times = 3; diff --git a/datafusion/functions/benches/replace.rs b/datafusion/functions/benches/replace.rs index deadbfeb99a8..55fbd6ae57af 100644 --- a/datafusion/functions/benches/replace.rs +++ b/datafusion/functions/benches/replace.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::OffsetSizeTrait; use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ diff --git a/datafusion/functions/benches/reverse.rs b/datafusion/functions/benches/reverse.rs index 73f5be5b45df..f2e2898bbfe4 100644 --- a/datafusion/functions/benches/reverse.rs +++ b/datafusion/functions/benches/reverse.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; mod helper; use arrow::datatypes::{DataType, Field}; diff --git a/datafusion/functions/benches/round.rs b/datafusion/functions/benches/round.rs new file mode 100644 index 000000000000..7010aa3507db --- /dev/null +++ b/datafusion/functions/benches/round.rs @@ -0,0 +1,152 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, Float32Type, Float64Type}; +use arrow::util::bench_util::create_primitive_array; +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::math::round; +use std::hint::black_box; +use std::sync::Arc; +use std::time::Duration; + +fn criterion_benchmark(c: &mut Criterion) { + let round_fn = round(); + let config_options = Arc::new(ConfigOptions::default()); + + for size in [1024, 4096, 8192] { + let mut group = c.benchmark_group(format!("round size={size}")); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + // Float64 array benchmark + let f64_array = Arc::new(create_primitive_array::(size, 0.1)); + let batch_len = f64_array.len(); + let f64_args = vec![ + ColumnarValue::Array(f64_array), + ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), + ]; + + group.bench_function("round_f64_array", |b| { + b.iter(|| { + let args_cloned = f64_args.clone(); + black_box( + round_fn + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![ + Field::new("a", DataType::Float64, true).into(), + Field::new("b", DataType::Int32, false).into(), + ], + number_rows: batch_len, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + // Float32 array benchmark + let f32_array = Arc::new(create_primitive_array::(size, 0.1)); + let f32_args = vec![ + ColumnarValue::Array(f32_array), + ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), + ]; + + group.bench_function("round_f32_array", |b| { + b.iter(|| { + let args_cloned = f32_args.clone(); + black_box( + round_fn + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![ + Field::new("a", DataType::Float32, true).into(), + Field::new("b", DataType::Int32, false).into(), + ], + number_rows: batch_len, + return_field: Field::new("f", DataType::Float32, true).into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + // Scalar benchmark (the optimization we added) + let scalar_f64_args = vec![ + ColumnarValue::Scalar(ScalarValue::Float64(Some(std::f64::consts::PI))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), + ]; + + group.bench_function("round_f64_scalar", |b| { + b.iter(|| { + let args_cloned = scalar_f64_args.clone(); + black_box( + round_fn + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![ + Field::new("a", DataType::Float64, false).into(), + Field::new("b", DataType::Int32, false).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Float64, false) + .into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + let scalar_f32_args = vec![ + ColumnarValue::Scalar(ScalarValue::Float32(Some(std::f32::consts::PI))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), + ]; + + group.bench_function("round_f32_scalar", |b| { + b.iter(|| { + let args_cloned = scalar_f32_args.clone(); + black_box( + round_fn + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![ + Field::new("a", DataType::Float32, false).into(), + Field::new("b", DataType::Int32, false).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Float32, false) + .into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + group.finish(); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/signum.rs b/datafusion/functions/benches/signum.rs index 08a197a60eb7..e98d1b2c22ea 100644 --- a/datafusion/functions/benches/signum.rs +++ b/datafusion/functions/benches/signum.rs @@ -15,14 +15,13 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::DataType; use arrow::{ datatypes::{Field, Float32Type, Float64Type}, util::bench_util::create_primitive_array, }; use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::math::signum; @@ -88,6 +87,51 @@ fn criterion_benchmark(c: &mut Criterion) { ) }) }); + + // Scalar benchmarks (the optimization we added) + let scalar_f32_args = + vec![ColumnarValue::Scalar(ScalarValue::Float32(Some(-42.5)))]; + let scalar_f32_arg_fields = + vec![Field::new("a", DataType::Float32, false).into()]; + let return_field_f32 = Field::new("f", DataType::Float32, false).into(); + + c.bench_function(&format!("signum f32 scalar: {size}"), |b| { + b.iter(|| { + black_box( + signum + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f32_args.clone(), + arg_fields: scalar_f32_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field_f32), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + let scalar_f64_args = + vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(-42.5)))]; + let scalar_f64_arg_fields = + vec![Field::new("a", DataType::Float64, false).into()]; + let return_field_f64 = Field::new("f", DataType::Float64, false).into(); + + c.bench_function(&format!("signum f64 scalar: {size}"), |b| { + b.iter(|| { + black_box( + signum + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f64_args.clone(), + arg_fields: scalar_f64_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field_f64), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); } } diff --git a/datafusion/functions/benches/split_part.rs b/datafusion/functions/benches/split_part.rs new file mode 100644 index 000000000000..7ef84a058920 --- /dev/null +++ b/datafusion/functions/benches/split_part.rs @@ -0,0 +1,380 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, Int64Array, StringArray, StringViewArray}; +use arrow::datatypes::{DataType, Field}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::string::split_part; +use rand::distr::Alphanumeric; +use rand::prelude::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +const N_ROWS: usize = 8192; + +/// Generate test data for split_part benchmarks +/// Creates strings with multiple parts separated by the delimiter +fn gen_split_part_data( + n_rows: usize, + num_parts: usize, // number of parts in each string (separated by delimiter) + part_len: usize, // length of each part + delimiter: &str, // the delimiter to use + use_string_view: bool, // false -> StringArray, true -> StringViewArray +) -> (ColumnarValue, ColumnarValue) { + let mut rng = StdRng::seed_from_u64(42); + + let mut strings: Vec = Vec::with_capacity(n_rows); + for _ in 0..n_rows { + let mut parts: Vec = Vec::with_capacity(num_parts); + for _ in 0..num_parts { + let part: String = (&mut rng) + .sample_iter(&Alphanumeric) + .take(part_len) + .map(char::from) + .collect(); + parts.push(part); + } + strings.push(parts.join(delimiter)); + } + + let delimiters: Vec = vec![delimiter.to_string(); n_rows]; + + if use_string_view { + let string_array: StringViewArray = strings.into_iter().map(Some).collect(); + let delimiter_array: StringViewArray = delimiters.into_iter().map(Some).collect(); + ( + ColumnarValue::Array(Arc::new(string_array) as ArrayRef), + ColumnarValue::Array(Arc::new(delimiter_array) as ArrayRef), + ) + } else { + let string_array: StringArray = strings.into_iter().map(Some).collect(); + let delimiter_array: StringArray = delimiters.into_iter().map(Some).collect(); + ( + ColumnarValue::Array(Arc::new(string_array) as ArrayRef), + ColumnarValue::Array(Arc::new(delimiter_array) as ArrayRef), + ) + } +} + +fn gen_positions(n_rows: usize, position: i64) -> ColumnarValue { + let positions: Vec = vec![position; n_rows]; + ColumnarValue::Array(Arc::new(Int64Array::from(positions)) as ArrayRef) +} + +fn criterion_benchmark(c: &mut Criterion) { + let split_part_func = split_part(); + let config_options = Arc::new(ConfigOptions::default()); + + let mut group = c.benchmark_group("split_part"); + + // Test different scenarios + // Scenario 1: Single-char delimiter, first position (should be fastest with optimization) + { + let (strings, delimiters) = gen_split_part_data(N_ROWS, 10, 8, ".", false); + let positions = gen_positions(N_ROWS, 1); + let args = vec![strings, delimiters, positions]; + let arg_fields: Vec<_> = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect(); + let return_field = Field::new("f", DataType::Utf8, true).into(); + + group.bench_function(BenchmarkId::new("single_char_delim", "pos_first"), |b| { + b.iter(|| { + black_box( + split_part_func + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: N_ROWS, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + }) + .expect("split_part should work"), + ) + }) + }); + } + + // Scenario 2: Single-char delimiter, middle position + { + let (strings, delimiters) = gen_split_part_data(N_ROWS, 10, 8, ".", false); + let positions = gen_positions(N_ROWS, 5); + let args = vec![strings, delimiters, positions]; + let arg_fields: Vec<_> = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect(); + let return_field = Field::new("f", DataType::Utf8, true).into(); + + group.bench_function(BenchmarkId::new("single_char_delim", "pos_middle"), |b| { + b.iter(|| { + black_box( + split_part_func + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: N_ROWS, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + }) + .expect("split_part should work"), + ) + }) + }); + } + + // Scenario 3: Single-char delimiter, last position + { + let (strings, delimiters) = gen_split_part_data(N_ROWS, 10, 8, ".", false); + let positions = gen_positions(N_ROWS, 10); + let args = vec![strings, delimiters, positions]; + let arg_fields: Vec<_> = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect(); + let return_field = Field::new("f", DataType::Utf8, true).into(); + + group.bench_function(BenchmarkId::new("single_char_delim", "pos_last"), |b| { + b.iter(|| { + black_box( + split_part_func + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: N_ROWS, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + }) + .expect("split_part should work"), + ) + }) + }); + } + + // Scenario 4: Single-char delimiter, negative position (last element) + { + let (strings, delimiters) = gen_split_part_data(N_ROWS, 10, 8, ".", false); + let positions = gen_positions(N_ROWS, -1); + let args = vec![strings, delimiters, positions]; + let arg_fields: Vec<_> = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect(); + let return_field = Field::new("f", DataType::Utf8, true).into(); + + group.bench_function( + BenchmarkId::new("single_char_delim", "pos_negative"), + |b| { + b.iter(|| { + black_box( + split_part_func + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: N_ROWS, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + }) + .expect("split_part should work"), + ) + }) + }, + ); + } + + // Scenario 5: Multi-char delimiter, first position + { + let (strings, delimiters) = gen_split_part_data(N_ROWS, 10, 8, "~@~", false); + let positions = gen_positions(N_ROWS, 1); + let args = vec![strings, delimiters, positions]; + let arg_fields: Vec<_> = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect(); + let return_field = Field::new("f", DataType::Utf8, true).into(); + + group.bench_function(BenchmarkId::new("multi_char_delim", "pos_first"), |b| { + b.iter(|| { + black_box( + split_part_func + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: N_ROWS, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + }) + .expect("split_part should work"), + ) + }) + }); + } + + // Scenario 6: Multi-char delimiter, middle position + { + let (strings, delimiters) = gen_split_part_data(N_ROWS, 10, 8, "~@~", false); + let positions = gen_positions(N_ROWS, 5); + let args = vec![strings, delimiters, positions]; + let arg_fields: Vec<_> = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect(); + let return_field = Field::new("f", DataType::Utf8, true).into(); + + group.bench_function(BenchmarkId::new("multi_char_delim", "pos_middle"), |b| { + b.iter(|| { + black_box( + split_part_func + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: N_ROWS, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + }) + .expect("split_part should work"), + ) + }) + }); + } + + // Scenario 7: StringViewArray, single-char delimiter, first position + { + let (strings, delimiters) = gen_split_part_data(N_ROWS, 10, 8, ".", true); + let positions = gen_positions(N_ROWS, 1); + let args = vec![strings, delimiters, positions]; + let arg_fields: Vec<_> = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect(); + let return_field = Field::new("f", DataType::Utf8, true).into(); + + group.bench_function( + BenchmarkId::new("string_view_single_char", "pos_first"), + |b| { + b.iter(|| { + black_box( + split_part_func + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: N_ROWS, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + }) + .expect("split_part should work"), + ) + }) + }, + ); + } + + // Scenario 8: Many parts (20), position near end - shows benefit of early termination + { + let (strings, delimiters) = gen_split_part_data(N_ROWS, 20, 8, ".", false); + let positions = gen_positions(N_ROWS, 2); + let args = vec![strings, delimiters, positions]; + let arg_fields: Vec<_> = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect(); + let return_field = Field::new("f", DataType::Utf8, true).into(); + + group.bench_function(BenchmarkId::new("many_parts_20", "pos_second"), |b| { + b.iter(|| { + black_box( + split_part_func + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: N_ROWS, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + }) + .expect("split_part should work"), + ) + }) + }); + } + + // Scenario 9: Long strings with many parts - worst case for old implementation + { + let (strings, delimiters) = gen_split_part_data(N_ROWS, 50, 16, "/", false); + let positions = gen_positions(N_ROWS, 1); + let args = vec![strings, delimiters, positions]; + let arg_fields: Vec<_> = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect(); + let return_field = Field::new("f", DataType::Utf8, true).into(); + + group.bench_function( + BenchmarkId::new("long_strings_50_parts", "pos_first"), + |b| { + b.iter(|| { + black_box( + split_part_func + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: N_ROWS, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + }) + .expect("split_part should work"), + ) + }) + }, + ); + } + + group.finish(); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/starts_with.rs b/datafusion/functions/benches/starts_with.rs index 9ee39b694539..17483f0da7a0 100644 --- a/datafusion/functions/benches/starts_with.rs +++ b/datafusion/functions/benches/starts_with.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::{StringArray, StringViewArray}; use arrow::datatypes::{DataType, Field}; use criterion::{Criterion, criterion_group, criterion_main}; diff --git a/datafusion/functions/benches/strpos.rs b/datafusion/functions/benches/strpos.rs index 9babf1d05c05..94ce919c3d80 100644 --- a/datafusion/functions/benches/strpos.rs +++ b/datafusion/functions/benches/strpos.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::{StringArray, StringViewArray}; use arrow::datatypes::{DataType, Field}; use criterion::{Criterion, criterion_group, criterion_main}; @@ -29,9 +27,12 @@ use std::hint::black_box; use std::str::Chars; use std::sync::Arc; -/// gen_arr(4096, 128, 0.1, 0.1, true) will generate a StringViewArray with -/// 4096 rows, each row containing a string with 128 random characters. -/// around 10% of the rows are null, around 10% of the rows are non-ASCII. +/// Returns a `Vec` with two elements: a haystack array and a +/// needle array. Each haystack is a random string of `str_len_chars` +/// characters. Each needle is a random contiguous substring of its +/// corresponding haystack (i.e., the needle is always present in the haystack). +/// Around `null_density` fraction of rows are null and `utf8_density` fraction +/// contain non-ASCII characters; the remaining rows are ASCII-only. fn gen_string_array( n_rows: usize, str_len_chars: usize, diff --git a/datafusion/functions/benches/substr.rs b/datafusion/functions/benches/substr.rs index a6989c1bca45..37a1e178f561 100644 --- a/datafusion/functions/benches/substr.rs +++ b/datafusion/functions/benches/substr.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::{ArrayRef, Int64Array, OffsetSizeTrait}; use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ diff --git a/datafusion/functions/benches/substr_index.rs b/datafusion/functions/benches/substr_index.rs index 88600317c996..663e7928bfd9 100644 --- a/datafusion/functions/benches/substr_index.rs +++ b/datafusion/functions/benches/substr_index.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use std::hint::black_box; use std::sync::Arc; @@ -50,7 +48,10 @@ where } } -fn data() -> (StringArray, StringArray, Int64Array) { +fn data( + batch_size: usize, + single_char_delimiter: bool, +) -> (StringArray, StringArray, Int64Array) { let dist = Filter { dist: Uniform::new(-4, 5), test: |x: &i64| x != &0, @@ -60,19 +61,39 @@ fn data() -> (StringArray, StringArray, Int64Array) { let mut delimiters: Vec = vec![]; let mut counts: Vec = vec![]; - for _ in 0..1000 { + for _ in 0..batch_size { let length = rng.random_range(20..50); - let text: String = (&mut rng) + let base: String = (&mut rng) .sample_iter(&Alphanumeric) .take(length) .map(char::from) .collect(); - let char = rng.random_range(0..text.len()); - let delimiter = &text.chars().nth(char).unwrap(); + + let (string_value, delimiter): (String, String) = if single_char_delimiter { + let char_idx = rng.random_range(0..base.chars().count()); + let delimiter = base.chars().nth(char_idx).unwrap().to_string(); + (base, delimiter) + } else { + let long_delimiters = ["|||", "***", "&&&", "###", "@@@", "$$$"]; + let delimiter = + long_delimiters[rng.random_range(0..long_delimiters.len())].to_string(); + + let delimiter_count = rng.random_range(1..4); + let mut result = String::new(); + + for i in 0..delimiter_count { + result.push_str(&base); + if i < delimiter_count - 1 { + result.push_str(&delimiter); + } + } + (result, delimiter) + }; + let count = rng.sample(dist.dist.unwrap()); - strings.push(text); - delimiters.push(delimiter.to_string()); + strings.push(string_value); + delimiters.push(delimiter); counts.push(count); } @@ -83,38 +104,63 @@ fn data() -> (StringArray, StringArray, Int64Array) { ) } -fn criterion_benchmark(c: &mut Criterion) { - c.bench_function("substr_index_array_array_1000", |b| { - let (strings, delimiters, counts) = data(); - let batch_len = counts.len(); - let strings = ColumnarValue::Array(Arc::new(strings) as ArrayRef); - let delimiters = ColumnarValue::Array(Arc::new(delimiters) as ArrayRef); - let counts = ColumnarValue::Array(Arc::new(counts) as ArrayRef); - - let args = vec![strings, delimiters, counts]; - let arg_fields = args - .iter() - .enumerate() - .map(|(idx, arg)| { - Field::new(format!("arg_{idx}"), arg.data_type(), true).into() - }) - .collect::>(); - let config_options = Arc::new(ConfigOptions::default()); - - b.iter(|| { - black_box( - substr_index() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: arg_fields.clone(), - number_rows: batch_len, - return_field: Field::new("f", DataType::Utf8, true).into(), - config_options: Arc::clone(&config_options), - }) - .expect("substr_index should work on valid values"), - ) +fn run_benchmark( + b: &mut criterion::Bencher, + strings: StringArray, + delimiters: StringArray, + counts: Int64Array, + batch_size: usize, +) { + let strings = ColumnarValue::Array(Arc::new(strings) as ArrayRef); + let delimiters = ColumnarValue::Array(Arc::new(delimiters) as ArrayRef); + let counts = ColumnarValue::Array(Arc::new(counts) as ArrayRef); + + let args = vec![strings, delimiters, counts]; + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type().clone(), true).into() }) - }); + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + b.iter(|| { + black_box( + substr_index() + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: batch_size, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }) + .expect("substr_index should work on valid values"), + ) + }) +} + +fn criterion_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("substr_index"); + + let batch_sizes = [100, 1000, 10_000]; + + for batch_size in batch_sizes { + group.bench_function( + format!("substr_index_{batch_size}_single_delimiter"), + |b| { + let (strings, delimiters, counts) = data(batch_size, true); + run_benchmark(b, strings, delimiters, counts, batch_size); + }, + ); + + group.bench_function(format!("substr_index_{batch_size}_long_delimiter"), |b| { + let (strings, delimiters, counts) = data(batch_size, false); + run_benchmark(b, strings, delimiters, counts, batch_size); + }); + } + + group.finish(); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/functions/benches/to_char.rs b/datafusion/functions/benches/to_char.rs index ac5b5dc7e03a..4d866570b7dd 100644 --- a/datafusion/functions/benches/to_char.rs +++ b/datafusion/functions/benches/to_char.rs @@ -15,18 +15,15 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use std::hint::black_box; use std::sync::Arc; -use arrow::array::{ArrayRef, Date32Array, StringArray}; +use arrow::array::{ArrayRef, Date32Array, Date64Array, StringArray}; use arrow::datatypes::{DataType, Field}; use chrono::TimeDelta; use chrono::prelude::*; use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::ScalarValue; -use datafusion_common::ScalarValue::TimestampNanosecond; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::datetime::to_char; @@ -65,6 +62,26 @@ fn generate_date32_array(rng: &mut ThreadRng) -> Date32Array { Date32Array::from(data) } +fn generate_date64_array(rng: &mut ThreadRng) -> Date64Array { + let start_date = "1970-01-01" + .parse::() + .expect("Date should parse"); + let end_date = "2050-12-31" + .parse::() + .expect("Date should parse"); + let mut data: Vec = Vec::with_capacity(1000); + for _ in 0..1000 { + let date = pick_date_in_range(rng, start_date, end_date); + let millis = date + .and_hms_opt(0, 0, 0) + .unwrap() + .and_utc() + .timestamp_millis(); + data.push(millis); + } + Date64Array::from(data) +} + const DATE_PATTERNS: [&str; 5] = ["%Y:%m:%d", "%d-%m-%Y", "%d%m%Y", "%Y%m%d", "%Y...%m...%d"]; @@ -157,7 +174,7 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("to_char_array_datetime_patterns_1000", |b| { let mut rng = rand::rng(); - let data_arr = generate_date32_array(&mut rng); + let data_arr = generate_date64_array(&mut rng); let batch_len = data_arr.len(); let data = ColumnarValue::Array(Arc::new(data_arr) as ArrayRef); let patterns = ColumnarValue::Array(Arc::new(generate_datetime_pattern_array( @@ -184,7 +201,7 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("to_char_array_mixed_patterns_1000", |b| { let mut rng = rand::rng(); - let data_arr = generate_date32_array(&mut rng); + let data_arr = generate_date64_array(&mut rng); let batch_len = data_arr.len(); let data = ColumnarValue::Array(Arc::new(data_arr) as ArrayRef); let patterns = ColumnarValue::Array(Arc::new(generate_mixed_pattern_array( @@ -237,7 +254,7 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("to_char_scalar_datetime_pattern_1000", |b| { let mut rng = rand::rng(); - let data_arr = generate_date32_array(&mut rng); + let data_arr = generate_date64_array(&mut rng); let batch_len = data_arr.len(); let data = ColumnarValue::Array(Arc::new(data_arr) as ArrayRef); let patterns = ColumnarValue::Scalar(ScalarValue::Utf8(Some( @@ -261,38 +278,6 @@ fn criterion_benchmark(c: &mut Criterion) { ) }) }); - - c.bench_function("to_char_scalar_1000", |b| { - let mut rng = rand::rng(); - let timestamp = "2026-07-08T09:10:11" - .parse::() - .unwrap() - .with_nanosecond(56789) - .unwrap() - .and_utc() - .timestamp_nanos_opt() - .unwrap(); - let data = ColumnarValue::Scalar(TimestampNanosecond(Some(timestamp), None)); - let pattern = - ColumnarValue::Scalar(ScalarValue::Utf8(Some(pick_date_pattern(&mut rng)))); - - b.iter(|| { - black_box( - to_char() - .invoke_with_args(ScalarFunctionArgs { - args: vec![data.clone(), pattern.clone()], - arg_fields: vec![ - Field::new("a", data.data_type(), true).into(), - Field::new("b", pattern.data_type(), true).into(), - ], - number_rows: 1, - return_field: Field::new("f", DataType::Utf8, true).into(), - config_options: Arc::clone(&config_options), - }) - .expect("to_char should work on valid values"), - ) - }) - }); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/functions/benches/to_hex.rs b/datafusion/functions/benches/to_hex.rs index 1c6757a291b2..33f8d9c49e8e 100644 --- a/datafusion/functions/benches/to_hex.rs +++ b/datafusion/functions/benches/to_hex.rs @@ -15,12 +15,11 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::Int64Array; use arrow::datatypes::{DataType, Field, Int32Type, Int64Type}; use arrow::util::bench_util::create_primitive_array; use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string; @@ -32,6 +31,42 @@ fn criterion_benchmark(c: &mut Criterion) { let hex = string::to_hex(); let config_options = Arc::new(ConfigOptions::default()); + c.bench_function("to_hex/scalar_i32", |b| { + let args = vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(2147483647)))]; + let arg_fields = vec![Field::new("a", DataType::Int32, true).into()]; + b.iter(|| { + black_box( + hex.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + c.bench_function("to_hex/scalar_i64", |b| { + let args = vec![ColumnarValue::Scalar(ScalarValue::Int64(Some( + 9223372036854775807, + )))]; + let arg_fields = vec![Field::new("a", DataType::Int64, true).into()]; + b.iter(|| { + black_box( + hex.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + for size in [1024, 4096, 8192] { let mut group = c.benchmark_group(format!("to_hex size={size}")); group.sampling_mode(SamplingMode::Flat); diff --git a/datafusion/functions/benches/to_timestamp.rs b/datafusion/functions/benches/to_timestamp.rs index ed865fa6e8d5..90ea145d5d2c 100644 --- a/datafusion/functions/benches/to_timestamp.rs +++ b/datafusion/functions/benches/to_timestamp.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use std::hint::black_box; use std::sync::Arc; diff --git a/datafusion/functions/benches/translate.rs b/datafusion/functions/benches/translate.rs index 601bdec7cd36..d0568ba0f535 100644 --- a/datafusion/functions/benches/translate.rs +++ b/datafusion/functions/benches/translate.rs @@ -15,23 +15,23 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::OffsetSizeTrait; use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; -use datafusion_common::DataFusionError; use datafusion_common::config::ConfigOptions; +use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::unicode; use std::hint::black_box; use std::sync::Arc; use std::time::Duration; -fn create_args(size: usize, str_len: usize) -> Vec { +fn create_args_array_from_to( + size: usize, + str_len: usize, +) -> Vec { let string_array = Arc::new(create_string_array_with_len::(size, 0.1, str_len)); - // Create simple from/to strings for translation let from_array = Arc::new(create_string_array_with_len::(size, 0.1, 3)); let to_array = Arc::new(create_string_array_with_len::(size, 0.1, 2)); @@ -42,6 +42,19 @@ fn create_args(size: usize, str_len: usize) -> Vec( + size: usize, + str_len: usize, +) -> Vec { + let string_array = Arc::new(create_string_array_with_len::(size, 0.1, str_len)); + + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Scalar(ScalarValue::from("aeiou")), + ColumnarValue::Scalar(ScalarValue::from("AEIOU")), + ] +} + fn invoke_translate_with_args( args: Vec, number_rows: usize, @@ -69,17 +82,22 @@ fn criterion_benchmark(c: &mut Criterion) { group.sample_size(10); group.measurement_time(Duration::from_secs(10)); - for str_len in [8, 32] { - let args = create_args::(size, str_len); - group.bench_function( - format!("translate_string [size={size}, str_len={str_len}]"), - |b| { - b.iter(|| { - let args_cloned = args.clone(); - black_box(invoke_translate_with_args(args_cloned, size)) - }) - }, - ); + for str_len in [8, 32, 128, 1024] { + let args = create_args_array_from_to::(size, str_len); + group.bench_function(format!("array_from_to [str_len={str_len}]"), |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(invoke_translate_with_args(args_cloned, size)) + }) + }); + + let args = create_args_scalar_from_to::(size, str_len); + group.bench_function(format!("scalar_from_to [str_len={str_len}]"), |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(invoke_translate_with_args(args_cloned, size)) + }) + }); } group.finish(); diff --git a/datafusion/functions/benches/trim.rs b/datafusion/functions/benches/trim.rs index 29bbc3f7dcb4..21d99592d182 100644 --- a/datafusion/functions/benches/trim.rs +++ b/datafusion/functions/benches/trim.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::{ArrayRef, LargeStringArray, StringArray, StringViewArray}; use arrow::datatypes::{DataType, Field}; use criterion::{ @@ -143,7 +141,46 @@ fn create_args( ] } -#[allow(clippy::too_many_arguments)] +/// Create args for trim benchmark where space characters are being trimmed +fn create_space_trim_args( + size: usize, + pad_len: usize, + remaining_len: usize, + string_array_type: StringArrayType, + trim_type: TrimType, +) -> Vec { + let rng = &mut StdRng::seed_from_u64(42); + let spaces = " ".repeat(pad_len); + + let string_iter = (0..size).map(|_| { + if rng.random::() < 0.1 { + None + } else { + let content: String = rng + .sample_iter(&Alphanumeric) + .take(remaining_len) + .map(char::from) + .collect(); + + let value = match trim_type { + TrimType::Ltrim => format!("{spaces}{content}"), + TrimType::Rtrim => format!("{content}{spaces}"), + TrimType::Btrim => format!("{spaces}{content}{spaces}"), + }; + Some(value) + } + }); + + let string_array: ArrayRef = match string_array_type { + StringArrayType::Utf8View => Arc::new(string_iter.collect::()), + StringArrayType::Utf8 => Arc::new(string_iter.collect::()), + StringArrayType::LargeUtf8 => Arc::new(string_iter.collect::()), + }; + + vec![ColumnarValue::Array(string_array)] +} + +#[expect(clippy::too_many_arguments)] fn run_with_string_type( group: &mut BenchmarkGroup<'_, M>, trim_func: &ScalarUDF, @@ -189,7 +226,7 @@ fn run_with_string_type( ); } -#[allow(clippy::too_many_arguments)] +#[expect(clippy::too_many_arguments)] fn run_trim_benchmark( c: &mut Criterion, group_name: &str, @@ -223,6 +260,60 @@ fn run_trim_benchmark( group.finish(); } +#[expect(clippy::too_many_arguments)] +fn run_space_trim_benchmark( + c: &mut Criterion, + group_name: &str, + trim_func: &ScalarUDF, + trim_type: TrimType, + string_types: &[StringArrayType], + size: usize, + pad_len: usize, + remaining_len: usize, +) { + let mut group = c.benchmark_group(group_name); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + + let total_len = match trim_type { + TrimType::Btrim => 2 * pad_len + remaining_len, + _ => pad_len + remaining_len, + }; + + for string_type in string_types { + let args = + create_space_trim_args(size, pad_len, remaining_len, *string_type, trim_type); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + group.bench_function( + format!( + "{trim_type} {string_type} [size={size}, len={total_len}, pad={pad_len}]", + ), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(trim_func.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + } + + group.finish(); +} + fn criterion_benchmark(c: &mut Criterion) { let ltrim = string::ltrim(); let rtrim = string::rtrim(); @@ -297,6 +388,45 @@ fn criterion_benchmark(c: &mut Criterion) { &trimmed, remaining_len, ); + + // Scenario 4: Trim spaces, short strings (len <= 12) + // pad_len=4, remaining_len=8 + run_space_trim_benchmark( + c, + "trim spaces, short strings (len <= 12)", + trim_func, + *trim_type, + &string_types, + size, + 4, + 8, + ); + + // Scenario 5: Trim spaces, long strings (len > 12) + // pad_len=4, remaining_len=60 + run_space_trim_benchmark( + c, + "trim spaces, long strings", + trim_func, + *trim_type, + &string_types, + size, + 4, + 60, + ); + + // Scenario 6: Trim spaces, long strings, heavy padding + // pad_len=56, remaining_len=8 + run_space_trim_benchmark( + c, + "trim spaces, heavy padding", + trim_func, + *trim_type, + &string_types, + size, + 56, + 8, + ); } } } diff --git a/datafusion/functions/benches/trunc.rs b/datafusion/functions/benches/trunc.rs index d0a6e2be75e0..ffbedcb142c7 100644 --- a/datafusion/functions/benches/trunc.rs +++ b/datafusion/functions/benches/trunc.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::{ datatypes::{Field, Float32Type, Float64Type}, util::bench_util::create_primitive_array, @@ -32,12 +30,13 @@ use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { let trunc = trunc(); + let config_options = Arc::new(ConfigOptions::default()); + for size in [1024, 4096, 8192] { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let f32_args = vec![ColumnarValue::Array(f32_array)]; let arg_fields = vec![Field::new("a", DataType::Float32, false).into()]; let return_field = Field::new("f", DataType::Float32, true).into(); - let config_options = Arc::new(ConfigOptions::default()); c.bench_function(&format!("trunc f32 array: {size}"), |b| { b.iter(|| { @@ -74,6 +73,51 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); } + + // Scalar benchmarks - to measure optimized performance + let scalar_f64_args = vec![ColumnarValue::Scalar( + datafusion_common::ScalarValue::Float64(Some(std::f64::consts::PI)), + )]; + let scalar_arg_fields = vec![Field::new("a", DataType::Float64, false).into()]; + let scalar_return_field = Field::new("f", DataType::Float64, false).into(); + + c.bench_function("trunc f64 scalar", |b| { + b.iter(|| { + black_box( + trunc + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f64_args.clone(), + arg_fields: scalar_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&scalar_return_field), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + let scalar_f32_args = vec![ColumnarValue::Scalar( + datafusion_common::ScalarValue::Float32(Some(std::f32::consts::PI)), + )]; + let scalar_f32_arg_fields = vec![Field::new("a", DataType::Float32, false).into()]; + let scalar_f32_return_field = Field::new("f", DataType::Float32, false).into(); + + c.bench_function("trunc f32 scalar", |b| { + b.iter(|| { + black_box( + trunc + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f32_args.clone(), + arg_fields: scalar_f32_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&scalar_f32_return_field), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/functions/benches/upper.rs b/datafusion/functions/benches/upper.rs index 51ce1da0fa1f..3f6fa36b18c1 100644 --- a/datafusion/functions/benches/upper.rs +++ b/datafusion/functions/benches/upper.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{Criterion, criterion_group, criterion_main}; diff --git a/datafusion/functions/benches/uuid.rs b/datafusion/functions/benches/uuid.rs index df9b2bed4be2..629fb950dd9f 100644 --- a/datafusion/functions/benches/uuid.rs +++ b/datafusion/functions/benches/uuid.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::{DataType, Field}; use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::config::ConfigOptions; diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 04189c0c6f36..7c24450adf18 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -26,7 +26,7 @@ use datafusion_common::{ use datafusion_common::{exec_datafusion_err, utils::take_function_args}; use std::any::Any; -use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr::{ Coercion, ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, @@ -160,7 +160,7 @@ impl ScalarUDFImpl for ArrowCastFunc { fn simplify( &self, mut args: Vec, - info: &dyn SimplifyInfo, + info: &SimplifyContext, ) -> Result { // convert this into a real cast let target_type = data_type_from_args(&args)?; diff --git a/datafusion/functions/src/core/coalesce.rs b/datafusion/functions/src/core/coalesce.rs index 1404f6857097..359a6f6c9c84 100644 --- a/datafusion/functions/src/core/coalesce.rs +++ b/datafusion/functions/src/core/coalesce.rs @@ -19,7 +19,7 @@ use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::{Result, exec_err, internal_err, plan_err}; use datafusion_expr::binary::try_type_union_resolution; use datafusion_expr::conditional_expressions::CaseBuilder; -use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr::{ ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, }; @@ -97,7 +97,7 @@ impl ScalarUDFImpl for CoalesceFunc { fn simplify( &self, args: Vec, - _info: &dyn SimplifyInfo, + _info: &SimplifyContext, ) -> Result { if args.is_empty() { return plan_err!("coalesce must have at least one argument"); diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 3e961e4da4e7..d57ba46fb56a 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -33,8 +33,8 @@ use datafusion_common::{ use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, - ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, Expr, ExpressionPlacement, ReturnFieldArgs, + ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; @@ -421,7 +421,7 @@ impl ScalarUDFImpl for GetFieldFunc { fn simplify( &self, args: Vec, - _info: &dyn datafusion_expr::simplify::SimplifyInfo, + _info: &datafusion_expr::simplify::SimplifyContext, ) -> Result { // Need at least 2 args (base + field) if args.len() < 2 { @@ -499,6 +499,32 @@ impl ScalarUDFImpl for GetFieldFunc { fn documentation(&self) -> Option<&Documentation> { self.doc() } + + fn placement(&self, args: &[ExpressionPlacement]) -> ExpressionPlacement { + // get_field can be pushed to leaves if: + // 1. The base (first arg) is a column or already placeable at leaves + // 2. All field keys (remaining args) are literals + if args.is_empty() { + return ExpressionPlacement::KeepInPlace; + } + + let base_placement = args[0]; + let base_is_pushable = matches!( + base_placement, + ExpressionPlacement::Column | ExpressionPlacement::MoveTowardsLeafNodes + ); + + let all_keys_are_literals = args + .iter() + .skip(1) + .all(|p| *p == ExpressionPlacement::Literal); + + if base_is_pushable && all_keys_are_literals { + ExpressionPlacement::MoveTowardsLeafNodes + } else { + ExpressionPlacement::KeepInPlace + } + } } #[cfg(test)] @@ -542,4 +568,92 @@ mod tests { Ok(()) } + + #[test] + fn test_placement_literal_key() { + let func = GetFieldFunc::new(); + + // get_field(col, 'literal') -> leaf-pushable (static field access) + let args = vec![ExpressionPlacement::Column, ExpressionPlacement::Literal]; + assert_eq!( + func.placement(&args), + ExpressionPlacement::MoveTowardsLeafNodes + ); + + // get_field(col, 'a', 'b') -> leaf-pushable (nested static field access) + let args = vec![ + ExpressionPlacement::Column, + ExpressionPlacement::Literal, + ExpressionPlacement::Literal, + ]; + assert_eq!( + func.placement(&args), + ExpressionPlacement::MoveTowardsLeafNodes + ); + + // get_field(get_field(col, 'a'), 'b') represented as MoveTowardsLeafNodes for base + let args = vec![ + ExpressionPlacement::MoveTowardsLeafNodes, + ExpressionPlacement::Literal, + ]; + assert_eq!( + func.placement(&args), + ExpressionPlacement::MoveTowardsLeafNodes + ); + } + + #[test] + fn test_placement_column_key() { + let func = GetFieldFunc::new(); + + // get_field(col, other_col) -> NOT leaf-pushable (dynamic per-row lookup) + let args = vec![ExpressionPlacement::Column, ExpressionPlacement::Column]; + assert_eq!(func.placement(&args), ExpressionPlacement::KeepInPlace); + + // get_field(col, 'a', other_col) -> NOT leaf-pushable (dynamic nested lookup) + let args = vec![ + ExpressionPlacement::Column, + ExpressionPlacement::Literal, + ExpressionPlacement::Column, + ]; + assert_eq!(func.placement(&args), ExpressionPlacement::KeepInPlace); + } + + #[test] + fn test_placement_root() { + let func = GetFieldFunc::new(); + + // get_field(root_expr, 'literal') -> NOT leaf-pushable + let args = vec![ + ExpressionPlacement::KeepInPlace, + ExpressionPlacement::Literal, + ]; + assert_eq!(func.placement(&args), ExpressionPlacement::KeepInPlace); + + // get_field(col, root_expr) -> NOT leaf-pushable + let args = vec![ + ExpressionPlacement::Column, + ExpressionPlacement::KeepInPlace, + ]; + assert_eq!(func.placement(&args), ExpressionPlacement::KeepInPlace); + } + + #[test] + fn test_placement_edge_cases() { + let func = GetFieldFunc::new(); + + // Empty args -> NOT leaf-pushable + assert_eq!(func.placement(&[]), ExpressionPlacement::KeepInPlace); + + // Just base, no key -> MoveTowardsLeafNodes (not a valid call but should handle gracefully) + let args = vec![ExpressionPlacement::Column]; + assert_eq!( + func.placement(&args), + ExpressionPlacement::MoveTowardsLeafNodes + ); + + // Literal base with literal key -> NOT leaf-pushable (would be constant-folded) + let args = vec![ExpressionPlacement::Literal, ExpressionPlacement::Literal]; + assert_eq!(func.placement(&args), ExpressionPlacement::KeepInPlace); + } } diff --git a/datafusion/functions/src/core/nvl.rs b/datafusion/functions/src/core/nvl.rs index 0b9968a88fc9..0b4966d4fbdc 100644 --- a/datafusion/functions/src/core/nvl.rs +++ b/datafusion/functions/src/core/nvl.rs @@ -18,7 +18,7 @@ use crate::core::coalesce::CoalesceFunc; use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::Result; -use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr::{ ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -124,7 +124,7 @@ impl ScalarUDFImpl for NVLFunc { fn simplify( &self, args: Vec, - info: &dyn SimplifyInfo, + info: &SimplifyContext, ) -> Result { self.coalesce.simplify(args, info) } diff --git a/datafusion/functions/src/core/nvl2.rs b/datafusion/functions/src/core/nvl2.rs index eda59fe07f57..0b092c44d502 100644 --- a/datafusion/functions/src/core/nvl2.rs +++ b/datafusion/functions/src/core/nvl2.rs @@ -21,7 +21,7 @@ use datafusion_expr::{ ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, conditional_expressions::CaseBuilder, - simplify::{ExprSimplifyResult, SimplifyInfo}, + simplify::{ExprSimplifyResult, SimplifyContext}, type_coercion::binary::comparison_coercion, }; use datafusion_macros::user_doc; @@ -108,7 +108,7 @@ impl ScalarUDFImpl for NVL2Func { fn simplify( &self, args: Vec, - _info: &dyn SimplifyInfo, + _info: &SimplifyContext, ) -> Result { let [test, if_non_null, if_null] = take_function_args(self.name(), args)?; diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs index 56d4f23cc4e2..8d915fb2e2c0 100644 --- a/datafusion/functions/src/core/union_extract.rs +++ b/datafusion/functions/src/core/union_extract.rs @@ -189,13 +189,14 @@ mod tests { fn test_scalar_value() -> Result<()> { let fun = UnionExtractFun::new(); - let fields = UnionFields::new( + let fields = UnionFields::try_new( vec![1, 3], vec![ Field::new("str", DataType::Utf8, false), Field::new("int", DataType::Int32, false), ], - ); + ) + .unwrap(); let args = vec![ ColumnarValue::Scalar(ScalarValue::Union( diff --git a/datafusion/functions/src/core/union_tag.rs b/datafusion/functions/src/core/union_tag.rs index 809679dea646..fac5c82691ad 100644 --- a/datafusion/functions/src/core/union_tag.rs +++ b/datafusion/functions/src/core/union_tag.rs @@ -143,7 +143,7 @@ impl ScalarUDFImpl for UnionTagFunc { args.return_field.data_type(), )?)), }, - v => exec_err!("union_tag only support unions, got {:?}", v.data_type()), + v => exec_err!("union_tag only support unions, got {}", v.data_type()), } } diff --git a/datafusion/functions/src/crypto/basic.rs b/datafusion/functions/src/crypto/basic.rs index bda16684c8b6..abb86b8246fc 100644 --- a/datafusion/functions/src/crypto/basic.rs +++ b/datafusion/functions/src/crypto/basic.rs @@ -17,19 +17,13 @@ //! "crypto" DataFusion functions -use arrow::array::{ - Array, ArrayRef, AsArray, BinaryArray, BinaryArrayType, StringViewArray, -}; +use arrow::array::{Array, ArrayRef, AsArray, BinaryArray, BinaryArrayType}; use arrow::datatypes::DataType; use blake2::{Blake2b512, Blake2s256, Digest}; use blake3::Hasher as Blake3; -use datafusion_common::cast::as_binary_array; use arrow::compute::StringArrayType; -use datafusion_common::{ - DataFusionError, Result, ScalarValue, exec_err, internal_err, plan_err, - utils::take_function_args, -}; +use datafusion_common::{DataFusionError, Result, ScalarValue, exec_err, plan_err}; use datafusion_expr::ColumnarValue; use md5::Md5; use sha2::{Sha224, Sha256, Sha384, Sha512}; @@ -37,53 +31,8 @@ use std::fmt; use std::str::FromStr; use std::sync::Arc; -macro_rules! define_digest_function { - ($NAME: ident, $METHOD: ident, $DOC: expr) => { - #[doc = $DOC] - pub fn $NAME(args: &[ColumnarValue]) -> Result { - let [data] = take_function_args(&DigestAlgorithm::$METHOD.to_string(), args)?; - digest_process(data, DigestAlgorithm::$METHOD) - } - }; -} -define_digest_function!( - sha224, - Sha224, - "computes sha224 hash digest of the given input" -); -define_digest_function!( - sha256, - Sha256, - "computes sha256 hash digest of the given input" -); -define_digest_function!( - sha384, - Sha384, - "computes sha384 hash digest of the given input" -); -define_digest_function!( - sha512, - Sha512, - "computes sha512 hash digest of the given input" -); -define_digest_function!( - blake2b, - Blake2b, - "computes blake2b hash digest of the given input" -); -define_digest_function!( - blake2s, - Blake2s, - "computes blake2s hash digest of the given input" -); -define_digest_function!( - blake3, - Blake3, - "computes blake3 hash digest of the given input" -); - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] -pub enum DigestAlgorithm { +pub(crate) enum DigestAlgorithm { Md5, Sha224, Sha256, @@ -135,44 +84,6 @@ impl fmt::Display for DigestAlgorithm { } } -/// computes md5 hash digest of the given input -pub fn md5(args: &[ColumnarValue]) -> Result { - let [data] = take_function_args("md5", args)?; - let value = digest_process(data, DigestAlgorithm::Md5)?; - - // md5 requires special handling because of its unique utf8view return type - Ok(match value { - ColumnarValue::Array(array) => { - let binary_array = as_binary_array(&array)?; - let string_array: StringViewArray = binary_array - .iter() - .map(|opt| opt.map(hex_encode::<_>)) - .collect(); - ColumnarValue::Array(Arc::new(string_array)) - } - ColumnarValue::Scalar(ScalarValue::Binary(opt)) => { - ColumnarValue::Scalar(ScalarValue::Utf8View(opt.map(hex_encode::<_>))) - } - _ => return internal_err!("Impossibly got invalid results from digest"), - }) -} - -/// Hex encoding lookup table for fast byte-to-hex conversion -const HEX_CHARS_LOWER: &[u8; 16] = b"0123456789abcdef"; - -/// Fast hex encoding using a lookup table instead of format strings. -/// This is significantly faster than using `write!("{:02x}")` for each byte. -#[inline] -fn hex_encode>(data: T) -> String { - let bytes = data.as_ref(); - let mut s = String::with_capacity(bytes.len() * 2); - for &b in bytes { - s.push(HEX_CHARS_LOWER[(b >> 4) as usize] as char); - s.push(HEX_CHARS_LOWER[(b & 0x0f) as usize] as char); - } - s -} - macro_rules! digest_to_array { ($METHOD:ident, $INPUT:expr) => {{ let binary_array: BinaryArray = $INPUT @@ -269,7 +180,7 @@ impl DigestAlgorithm { } } -pub fn digest_process( +pub(crate) fn digest_process( value: &ColumnarValue, digest_algorithm: DigestAlgorithm, ) -> Result { diff --git a/datafusion/functions/src/crypto/md5.rs b/datafusion/functions/src/crypto/md5.rs index 728e0d4a3309..355e3e287ad2 100644 --- a/datafusion/functions/src/crypto/md5.rs +++ b/datafusion/functions/src/crypto/md5.rs @@ -15,11 +15,13 @@ // specific language governing permissions and limitations // under the License. -use crate::crypto::basic::md5; -use arrow::datatypes::DataType; +use arrow::{array::StringViewArray, datatypes::DataType}; use datafusion_common::{ - Result, + Result, ScalarValue, + cast::as_binary_array, + internal_err, types::{logical_binary, logical_string}, + utils::take_function_args, }; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, @@ -27,7 +29,9 @@ use datafusion_expr::{ }; use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; use datafusion_macros::user_doc; -use std::any::Any; +use std::{any::Any, sync::Arc}; + +use crate::crypto::basic::{DigestAlgorithm, digest_process}; #[user_doc( doc_section(label = "Hashing Functions"), @@ -97,3 +101,38 @@ impl ScalarUDFImpl for Md5Func { self.doc() } } + +/// Hex encoding lookup table for fast byte-to-hex conversion +const HEX_CHARS_LOWER: &[u8; 16] = b"0123456789abcdef"; + +/// Fast hex encoding using a lookup table instead of format strings. +/// This is significantly faster than using `write!("{:02x}")` for each byte. +#[inline] +fn hex_encode(data: impl AsRef<[u8]>) -> String { + let bytes = data.as_ref(); + let mut s = String::with_capacity(bytes.len() * 2); + for &b in bytes { + s.push(HEX_CHARS_LOWER[(b >> 4) as usize] as char); + s.push(HEX_CHARS_LOWER[(b & 0x0f) as usize] as char); + } + s +} + +fn md5(args: &[ColumnarValue]) -> Result { + let [data] = take_function_args("md5", args)?; + let value = digest_process(data, DigestAlgorithm::Md5)?; + + // md5 requires special handling because of its unique utf8view return type + Ok(match value { + ColumnarValue::Array(array) => { + let binary_array = as_binary_array(&array)?; + let string_array: StringViewArray = + binary_array.iter().map(|opt| opt.map(hex_encode)).collect(); + ColumnarValue::Array(Arc::new(string_array)) + } + ColumnarValue::Scalar(ScalarValue::Binary(opt)) => { + ColumnarValue::Scalar(ScalarValue::Utf8View(opt.map(hex_encode))) + } + _ => return internal_err!("Impossibly got invalid results from digest"), + }) +} diff --git a/datafusion/functions/src/datetime/current_date.rs b/datafusion/functions/src/datetime/current_date.rs index 7edc1a58d9cb..3e3877272097 100644 --- a/datafusion/functions/src/datetime/current_date.rs +++ b/datafusion/functions/src/datetime/current_date.rs @@ -23,7 +23,7 @@ use arrow::datatypes::DataType::Date32; use chrono::{Datelike, NaiveDate, TimeZone}; use datafusion_common::{Result, ScalarValue, internal_err}; -use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr::{ ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, }; @@ -99,23 +99,20 @@ impl ScalarUDFImpl for CurrentDateFunc { fn simplify( &self, - _args: Vec, - info: &dyn SimplifyInfo, + args: Vec, + info: &SimplifyContext, ) -> Result { - let now_ts = info.execution_props().query_execution_start_time; + let Some(now_ts) = info.query_execution_start_time() else { + return Ok(ExprSimplifyResult::Original(args)); + }; // Get timezone from config and convert to local time let days = info - .execution_props() .config_options() - .and_then(|config| { - config - .execution - .time_zone - .as_ref() - .map(|tz| tz.parse::().ok()) - }) - .flatten() + .execution + .time_zone + .as_ref() + .and_then(|tz| tz.parse::().ok()) .map_or_else( || datetime_to_days(&now_ts), |tz| { diff --git a/datafusion/functions/src/datetime/current_time.rs b/datafusion/functions/src/datetime/current_time.rs index 2c9bcdfe49db..855c0c13dc6b 100644 --- a/datafusion/functions/src/datetime/current_time.rs +++ b/datafusion/functions/src/datetime/current_time.rs @@ -22,7 +22,7 @@ use arrow::datatypes::TimeUnit::Nanosecond; use chrono::TimeZone; use chrono::Timelike; use datafusion_common::{Result, ScalarValue, internal_err}; -use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr::{ ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, }; @@ -95,23 +95,20 @@ impl ScalarUDFImpl for CurrentTimeFunc { fn simplify( &self, - _args: Vec, - info: &dyn SimplifyInfo, + args: Vec, + info: &SimplifyContext, ) -> Result { - let now_ts = info.execution_props().query_execution_start_time; + let Some(now_ts) = info.query_execution_start_time() else { + return Ok(ExprSimplifyResult::Original(args)); + }; // Try to get timezone from config and convert to local time let nano = info - .execution_props() .config_options() - .and_then(|config| { - config - .execution - .time_zone - .as_ref() - .map(|tz| tz.parse::().ok()) - }) - .flatten() + .execution + .time_zone + .as_ref() + .and_then(|tz| tz.parse::().ok()) .map_or_else( || datetime_to_time_nanos(&now_ts), |tz| { @@ -143,46 +140,24 @@ fn datetime_to_time_nanos(dt: &chrono::DateTime) -> Option Result { - Ok(false) - } - - fn nullable(&self, _expr: &Expr) -> Result { - Ok(true) - } - - fn execution_props(&self) -> &ExecutionProps { - &self.execution_props - } - - fn get_data_type(&self, _expr: &Expr) -> Result { - Ok(Time64(Nanosecond)) - } - } - - fn set_session_timezone_env(tz: &str, start_time: DateTime) -> MockSimplifyInfo { - let mut config = datafusion_common::config::ConfigOptions::default(); + fn set_session_timezone_env(tz: &str, start_time: DateTime) -> SimplifyContext { + let mut config = ConfigOptions::default(); config.execution.time_zone = if tz.is_empty() { None } else { Some(tz.to_string()) }; - let mut execution_props = - ExecutionProps::new().with_query_execution_start_time(start_time); - execution_props.config_options = Some(Arc::new(config)); - MockSimplifyInfo { execution_props } + let schema = Arc::new(DFSchema::empty()); + SimplifyContext::default() + .with_schema(schema) + .with_config_options(Arc::new(config)) + .with_query_execution_start_time(Some(start_time)) } #[test] diff --git a/datafusion/functions/src/datetime/date_bin.rs b/datafusion/functions/src/datetime/date_bin.rs index 6c67fbad34a1..c0984c1ea64e 100644 --- a/datafusion/functions/src/datetime/date_bin.rs +++ b/datafusion/functions/src/datetime/date_bin.rs @@ -295,7 +295,15 @@ impl ScalarUDFImpl for DateBinFunc { const NANOS_PER_MICRO: i64 = 1_000; const NANOS_PER_MILLI: i64 = 1_000_000; const NANOS_PER_SEC: i64 = NANOSECONDS; - +/// Function type for binning timestamps into intervals +/// +/// Arguments: +/// * `stride` - Interval width (nanoseconds for time-based, months for month-based) +/// * `source` - Timestamp to bin (nanoseconds since epoch) +/// * `origin` - Origin timestamp (nanoseconds since epoch) +/// +/// Returns: Binned timestamp in nanoseconds, or error if out of range +type BinFunction = fn(i64, i64, i64) -> Result; enum Interval { Nanoseconds(i64), Months(i64), @@ -310,7 +318,7 @@ impl Interval { /// `source` is the timestamp being binned /// /// `origin` is the time, in nanoseconds, where windows are measured from - fn bin_fn(&self) -> (i64, fn(i64, i64, i64) -> i64) { + fn bin_fn(&self) -> (i64, BinFunction) { match self { Interval::Nanoseconds(nanos) => (*nanos, date_bin_nanos_interval), Interval::Months(months) => (*months, date_bin_months_interval), @@ -319,13 +327,13 @@ impl Interval { } // return time in nanoseconds that the source timestamp falls into based on the stride and origin -fn date_bin_nanos_interval(stride_nanos: i64, source: i64, origin: i64) -> i64 { +fn date_bin_nanos_interval(stride_nanos: i64, source: i64, origin: i64) -> Result { let time_diff = source - origin; // distance from origin to bin let time_delta = compute_distance(time_diff, stride_nanos); - origin + time_delta + Ok(origin + time_delta) } // distance from origin to bin @@ -341,10 +349,10 @@ fn compute_distance(time_diff: i64, stride: i64) -> i64 { } // return time in nanoseconds that the source timestamp falls into based on the stride and origin -fn date_bin_months_interval(stride_months: i64, source: i64, origin: i64) -> i64 { +fn date_bin_months_interval(stride_months: i64, source: i64, origin: i64) -> Result { // convert source and origin to DateTime - let source_date = to_utc_date_time(source); - let origin_date = to_utc_date_time(origin); + let source_date = to_utc_date_time(source)?; + let origin_date = to_utc_date_time(origin)?; // calculate the number of months between the source and origin let month_diff = (source_date.year() - origin_date.year()) * 12 @@ -355,9 +363,17 @@ fn date_bin_months_interval(stride_months: i64, source: i64, origin: i64) -> i64 let month_delta = compute_distance(month_diff as i64, stride_months); let mut bin_time = if month_delta < 0 { - origin_date - Months::new(month_delta.unsigned_abs() as u32) + match origin_date + .checked_sub_months(Months::new(month_delta.unsigned_abs() as u32)) + { + Some(dt) => dt, + None => return exec_err!("DATE_BIN month subtraction out of range"), + } } else { - origin_date + Months::new(month_delta as u32) + match origin_date.checked_add_months(Months::new(month_delta as u32)) { + Some(dt) => dt, + None => return exec_err!("DATE_BIN month addition out of range"), + } }; // If origin is not midnight of first date of the month, the bin_time may be larger than the source @@ -365,19 +381,32 @@ fn date_bin_months_interval(stride_months: i64, source: i64, origin: i64) -> i64 if bin_time > source_date { let month_delta = month_delta - stride_months; bin_time = if month_delta < 0 { - origin_date - Months::new(month_delta.unsigned_abs() as u32) + match origin_date + .checked_sub_months(Months::new(month_delta.unsigned_abs() as u32)) + { + Some(dt) => dt, + None => return exec_err!("DATE_BIN month subtraction out of range"), + } } else { - origin_date + Months::new(month_delta as u32) + match origin_date.checked_add_months(Months::new(month_delta as u32)) { + Some(dt) => dt, + None => return exec_err!("DATE_BIN month addition out of range"), + } }; } - - bin_time.timestamp_nanos_opt().unwrap() + match bin_time.timestamp_nanos_opt() { + Some(nanos) => Ok(nanos), + None => exec_err!("DATE_BIN result timestamp out of range"), + } } -fn to_utc_date_time(nanos: i64) -> DateTime { +fn to_utc_date_time(nanos: i64) -> Result> { let secs = nanos / NANOS_PER_SEC; let nsec = (nanos % NANOS_PER_SEC) as u32; - DateTime::from_timestamp(secs, nsec).unwrap() + match DateTime::from_timestamp(secs, nsec) { + Some(dt) => Ok(dt), + None => exec_err!("Invalid timestamp value"), + } } // Supported intervals: @@ -392,6 +421,12 @@ fn date_bin_impl( origin: &ColumnarValue, ) -> Result { let stride = match stride { + ColumnarValue::Scalar(s) if s.is_null() => { + // NULL stride -> NULL result (standard SQL NULL propagation) + return Ok(ColumnarValue::Scalar(ScalarValue::try_from( + array.data_type(), + )?)); + } ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(v))) => { let (days, ms) = IntervalDayTimeType::to_parts(*v); let nanos = (TimeDelta::try_days(days as i64).unwrap() @@ -546,15 +581,18 @@ fn date_bin_impl( fn stride_map_fn( origin: i64, stride: i64, - stride_fn: fn(i64, i64, i64) -> i64, - ) -> impl Fn(i64) -> i64 { + stride_fn: BinFunction, + ) -> impl Fn(i64) -> Result { let scale = match T::UNIT { Nanosecond => 1, Microsecond => NANOS_PER_MICRO, Millisecond => NANOS_PER_MILLI, Second => NANOSECONDS, }; - move |x: i64| stride_fn(stride, x * scale, origin) / scale + move |x: i64| match stride_fn(stride, x * scale, origin) { + Ok(result) => Ok(result / scale), + Err(e) => Err(e), + } } Ok(match array { @@ -562,7 +600,7 @@ fn date_bin_impl( let apply_stride_fn = stride_map_fn::(origin, stride, stride_fn); ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( - v.map(apply_stride_fn), + v.and_then(|val| apply_stride_fn(val).ok()), tz_opt.clone(), )) } @@ -570,7 +608,7 @@ fn date_bin_impl( let apply_stride_fn = stride_map_fn::(origin, stride, stride_fn); ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( - v.map(apply_stride_fn), + v.and_then(|val| apply_stride_fn(val).ok()), tz_opt.clone(), )) } @@ -578,7 +616,7 @@ fn date_bin_impl( let apply_stride_fn = stride_map_fn::(origin, stride, stride_fn); ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( - v.map(apply_stride_fn), + v.and_then(|val| apply_stride_fn(val).ok()), tz_opt.clone(), )) } @@ -586,7 +624,7 @@ fn date_bin_impl( let apply_stride_fn = stride_map_fn::(origin, stride, stride_fn); ColumnarValue::Scalar(ScalarValue::TimestampSecond( - v.map(apply_stride_fn), + v.and_then(|val| apply_stride_fn(val).ok()), tz_opt.clone(), )) } @@ -594,50 +632,61 @@ fn date_bin_impl( if !is_time { return exec_err!("DATE_BIN with Time32 source requires Time32 origin"); } - let apply_stride_fn = move |x: i32| { - let binned_nanos = stride_fn(stride, x as i64 * NANOS_PER_MILLI, origin); - let nanos = binned_nanos % (NANOSECONDS_IN_DAY); - (nanos / NANOS_PER_MILLI) as i32 - }; - ColumnarValue::Scalar(ScalarValue::Time32Millisecond(v.map(apply_stride_fn))) + let result = v.and_then(|x| { + match stride_fn(stride, x as i64 * NANOS_PER_MILLI, origin) { + Ok(binned_nanos) => { + let nanos = binned_nanos % (NANOSECONDS_IN_DAY); + Some((nanos / NANOS_PER_MILLI) as i32) + } + Err(_) => None, + } + }); + ColumnarValue::Scalar(ScalarValue::Time32Millisecond(result)) } ColumnarValue::Scalar(ScalarValue::Time32Second(v)) => { if !is_time { return exec_err!("DATE_BIN with Time32 source requires Time32 origin"); } - let apply_stride_fn = move |x: i32| { - let binned_nanos = stride_fn(stride, x as i64 * NANOS_PER_SEC, origin); - let nanos = binned_nanos % (NANOSECONDS_IN_DAY); - (nanos / NANOS_PER_SEC) as i32 - }; - ColumnarValue::Scalar(ScalarValue::Time32Second(v.map(apply_stride_fn))) + let result = v.and_then(|x| { + match stride_fn(stride, x as i64 * NANOS_PER_SEC, origin) { + Ok(binned_nanos) => { + let nanos = binned_nanos % (NANOSECONDS_IN_DAY); + Some((nanos / NANOS_PER_SEC) as i32) + } + Err(_) => None, + } + }); + ColumnarValue::Scalar(ScalarValue::Time32Second(result)) } ColumnarValue::Scalar(ScalarValue::Time64Nanosecond(v)) => { if !is_time { return exec_err!("DATE_BIN with Time64 source requires Time64 origin"); } - let apply_stride_fn = move |x: i64| { - let binned_nanos = stride_fn(stride, x, origin); - binned_nanos % (NANOSECONDS_IN_DAY) - }; - ColumnarValue::Scalar(ScalarValue::Time64Nanosecond(v.map(apply_stride_fn))) + let result = v.and_then(|x| match stride_fn(stride, x, origin) { + Ok(binned_nanos) => Some(binned_nanos % (NANOSECONDS_IN_DAY)), + Err(_) => None, + }); + ColumnarValue::Scalar(ScalarValue::Time64Nanosecond(result)) } ColumnarValue::Scalar(ScalarValue::Time64Microsecond(v)) => { if !is_time { return exec_err!("DATE_BIN with Time64 source requires Time64 origin"); } - let apply_stride_fn = move |x: i64| { - let binned_nanos = stride_fn(stride, x * NANOS_PER_MICRO, origin); - let nanos = binned_nanos % (NANOSECONDS_IN_DAY); - nanos / NANOS_PER_MICRO - }; - ColumnarValue::Scalar(ScalarValue::Time64Microsecond(v.map(apply_stride_fn))) + let result = + v.and_then(|x| match stride_fn(stride, x * NANOS_PER_MICRO, origin) { + Ok(binned_nanos) => { + let nanos = binned_nanos % (NANOSECONDS_IN_DAY); + Some(nanos / NANOS_PER_MICRO) + } + Err(_) => None, + }); + ColumnarValue::Scalar(ScalarValue::Time64Microsecond(result)) } ColumnarValue::Array(array) => { fn transform_array_with_stride( origin: i64, stride: i64, - stride_fn: fn(i64, i64, i64) -> i64, + stride_fn: BinFunction, array: &ArrayRef, tz_opt: &Option>, ) -> Result @@ -645,11 +694,22 @@ fn date_bin_impl( T: ArrowTimestampType, { let array = as_primitive_array::(array)?; - let apply_stride_fn = stride_map_fn::(origin, stride, stride_fn); - let array: PrimitiveArray = array - .unary(apply_stride_fn) - .with_timezone_opt(tz_opt.clone()); - + let scale = match T::UNIT { + Nanosecond => 1, + Microsecond => NANOS_PER_MICRO, + Millisecond => NANOS_PER_MILLI, + Second => NANOSECONDS, + }; + + let result: PrimitiveArray = array.try_unary(|val| { + stride_fn(stride, val * scale, origin) + .map(|binned| binned / scale) + .map_err(|e| { + arrow::error::ArrowError::ComputeError(e.to_string()) + }) + })?; + + let array = result.with_timezone_opt(tz_opt.clone()); Ok(ColumnarValue::Array(Arc::new(array))) } @@ -681,15 +741,18 @@ fn date_bin_impl( ); } let array = array.as_primitive::(); - let apply_stride_fn = move |x: i32| { - let binned_nanos = - stride_fn(stride, x as i64 * NANOS_PER_MILLI, origin); - let nanos = binned_nanos % (NANOSECONDS_IN_DAY); - (nanos / NANOS_PER_MILLI) as i32 - }; - let array: PrimitiveArray = - array.unary(apply_stride_fn); - ColumnarValue::Array(Arc::new(array)) + let result: PrimitiveArray = + array.try_unary(|x| { + stride_fn(stride, x as i64 * NANOS_PER_MILLI, origin) + .map(|binned_nanos| { + let nanos = binned_nanos % (NANOSECONDS_IN_DAY); + (nanos / NANOS_PER_MILLI) as i32 + }) + .map_err(|e| { + arrow::error::ArrowError::ComputeError(e.to_string()) + }) + })?; + ColumnarValue::Array(Arc::new(result)) } Time32(Second) => { if !is_time { @@ -698,15 +761,18 @@ fn date_bin_impl( ); } let array = array.as_primitive::(); - let apply_stride_fn = move |x: i32| { - let binned_nanos = - stride_fn(stride, x as i64 * NANOS_PER_SEC, origin); - let nanos = binned_nanos % (NANOSECONDS_IN_DAY); - (nanos / NANOS_PER_SEC) as i32 - }; - let array: PrimitiveArray = - array.unary(apply_stride_fn); - ColumnarValue::Array(Arc::new(array)) + let result: PrimitiveArray = + array.try_unary(|x| { + stride_fn(stride, x as i64 * NANOS_PER_SEC, origin) + .map(|binned_nanos| { + let nanos = binned_nanos % (NANOSECONDS_IN_DAY); + (nanos / NANOS_PER_SEC) as i32 + }) + .map_err(|e| { + arrow::error::ArrowError::ComputeError(e.to_string()) + }) + })?; + ColumnarValue::Array(Arc::new(result)) } Time64(Microsecond) => { if !is_time { @@ -715,14 +781,18 @@ fn date_bin_impl( ); } let array = array.as_primitive::(); - let apply_stride_fn = move |x: i64| { - let binned_nanos = stride_fn(stride, x * NANOS_PER_MICRO, origin); - let nanos = binned_nanos % (NANOSECONDS_IN_DAY); - nanos / NANOS_PER_MICRO - }; - let array: PrimitiveArray = - array.unary(apply_stride_fn); - ColumnarValue::Array(Arc::new(array)) + let result: PrimitiveArray = + array.try_unary(|x| { + stride_fn(stride, x * NANOS_PER_MICRO, origin) + .map(|binned_nanos| { + let nanos = binned_nanos % (NANOSECONDS_IN_DAY); + nanos / NANOS_PER_MICRO + }) + .map_err(|e| { + arrow::error::ArrowError::ComputeError(e.to_string()) + }) + })?; + ColumnarValue::Array(Arc::new(result)) } Time64(Nanosecond) => { if !is_time { @@ -731,13 +801,15 @@ fn date_bin_impl( ); } let array = array.as_primitive::(); - let apply_stride_fn = move |x: i64| { - let binned_nanos = stride_fn(stride, x, origin); - binned_nanos % (NANOSECONDS_IN_DAY) - }; - let array: PrimitiveArray = - array.unary(apply_stride_fn); - ColumnarValue::Array(Arc::new(array)) + let result: PrimitiveArray = + array.try_unary(|x| { + stride_fn(stride, x, origin) + .map(|binned_nanos| binned_nanos % (NANOSECONDS_IN_DAY)) + .map_err(|e| { + arrow::error::ArrowError::ComputeError(e.to_string()) + }) + })?; + ColumnarValue::Array(Arc::new(result)) } _ => { return exec_err!( @@ -1193,7 +1265,7 @@ mod tests { let origin1 = string_to_timestamp_nanos(origin).unwrap(); let expected1 = string_to_timestamp_nanos(expected).unwrap(); - let result = date_bin_nanos_interval(stride1, source1, origin1); + let result = date_bin_nanos_interval(stride1, source1, origin1).unwrap(); assert_eq!(result, expected1, "{source} = {expected}"); }) } @@ -1221,8 +1293,55 @@ mod tests { let source1 = string_to_timestamp_nanos(source).unwrap(); let expected1 = string_to_timestamp_nanos(expected).unwrap(); - let result = date_bin_nanos_interval(stride1, source1, 0); + let result = date_bin_nanos_interval(stride1, source1, 0).unwrap(); assert_eq!(result, expected1, "{source} = {expected}"); }) } + + #[test] + fn test_date_bin_out_of_range() { + let return_field = &Arc::new(Field::new( + "f", + DataType::Timestamp(TimeUnit::Millisecond, None), + true, + )); + let args = vec![ + ColumnarValue::Scalar(ScalarValue::new_interval_mdn(1637426858, 0, 0)), + ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( + Some(1040292460), + None, + )), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + Some(string_to_timestamp_nanos("1984-01-07 00:00:00").unwrap()), + None, + )), + ]; + + let result = invoke_date_bin_with_args(args, 1, return_field); + assert!(result.is_ok()); + if let ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(val, _)) = + result.unwrap() + { + assert!(val.is_none(), "Expected None for out of range operation"); + } + let args = vec![ + ColumnarValue::Scalar(ScalarValue::new_interval_mdn(1637426858, 0, 0)), + ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( + Some(-1040292460), + None, + )), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + Some(string_to_timestamp_nanos("1984-01-07 00:00:00").unwrap()), + None, + )), + ]; + + let result = invoke_date_bin_with_args(args, 1, return_field); + assert!(result.is_ok()); + if let ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(val, _)) = + result.unwrap() + { + assert!(val.is_none(), "Expected None for out of range operation"); + } + } } diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index 375200d07280..e3080c9d1a00 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -19,6 +19,7 @@ use std::any::Any; use std::str::FromStr; use std::sync::Arc; +use arrow::array::timezone::Tz; use arrow::array::{Array, ArrayRef, Float64Array, Int32Array}; use arrow::compute::kernels::cast_utils::IntervalUnit; use arrow::compute::{DatePart, binary, date_part}; @@ -26,13 +27,18 @@ use arrow::datatypes::DataType::{ Date32, Date64, Duration, Interval, Time32, Time64, Timestamp, }; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; -use arrow::datatypes::{DataType, Field, FieldRef, TimeUnit}; +use arrow::datatypes::{ + DataType, Date32Type, Date64Type, Field, FieldRef, IntervalUnit as ArrowIntervalUnit, + TimeUnit, +}; +use chrono::{Datelike, NaiveDate, TimeZone, Utc}; use datafusion_common::types::{NativeType, logical_date}; use datafusion_common::{ Result, ScalarValue, cast::{ - as_date32_array, as_date64_array, as_int32_array, as_time32_millisecond_array, + as_date32_array, as_date64_array, as_int32_array, as_interval_dt_array, + as_interval_mdn_array, as_interval_ym_array, as_time32_millisecond_array, as_time32_second_array, as_time64_microsecond_array, as_time64_nanosecond_array, as_timestamp_microsecond_array, as_timestamp_millisecond_array, as_timestamp_nanosecond_array, as_timestamp_second_array, @@ -41,9 +47,11 @@ use datafusion_common::{ types::logical_string, utils::take_function_args, }; +use datafusion_expr::preimage::PreimageResult; +use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::{ - ColumnarValue, Documentation, ReturnFieldArgs, ScalarUDFImpl, Signature, - TypeSignature, Volatility, + ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarUDFImpl, Signature, + TypeSignature, Volatility, interval_arithmetic, }; use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; use datafusion_macros::user_doc; @@ -56,8 +64,9 @@ use datafusion_macros::user_doc; argument( name = "part", description = r#"Part of the date to return. The following date parts are supported: - + - year + - isoyear (ISO 8601 week-numbering year) - quarter (emits value in inclusive range [1, 4] based on which quartile of the year the date is in) - month - week (week of the year) @@ -70,7 +79,7 @@ use datafusion_macros::user_doc; - nanosecond - dow (day of the week where Sunday is 0) - doy (day of the year) - - epoch (seconds since Unix epoch) + - epoch (seconds since Unix epoch for timestamps/dates, total seconds for intervals) - isodow (day of the week where Monday is 0) "# ), @@ -148,6 +157,7 @@ impl ScalarUDFImpl for DatePartFunc { fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { let [field, _] = take_function_args(self.name(), args.scalar_arguments)?; + let nullable = args.arg_fields[1].is_nullable(); field .and_then(|sv| { @@ -156,9 +166,9 @@ impl ScalarUDFImpl for DatePartFunc { .filter(|s| !s.is_empty()) .map(|part| { if is_epoch(part) { - Field::new(self.name(), DataType::Float64, true) + Field::new(self.name(), DataType::Float64, nullable) } else { - Field::new(self.name(), DataType::Int32, true) + Field::new(self.name(), DataType::Int32, nullable) } }) }) @@ -215,6 +225,7 @@ impl ScalarUDFImpl for DatePartFunc { } else { // special cases that can be extracted (in postgres) but are not interval units match part_trim.to_lowercase().as_str() { + "isoyear" => date_part(array.as_ref(), DatePart::YearISO)?, "qtr" | "quarter" => date_part(array.as_ref(), DatePart::Quarter)?, "doy" => date_part(array.as_ref(), DatePart::DayOfYear)?, "dow" => date_part(array.as_ref(), DatePart::DayOfWeekSunday0)?, @@ -231,6 +242,71 @@ impl ScalarUDFImpl for DatePartFunc { }) } + // Only casting the year is supported since pruning other IntervalUnit is not possible + // date_part(col, YEAR) = 2024 => col >= '2024-01-01' and col < '2025-01-01' + // But for anything less than YEAR simplifying is not possible without specifying the bigger interval + // date_part(col, MONTH) = 1 => col = '2023-01-01' or col = '2024-01-01' or ... or col = '3000-01-01' + fn preimage( + &self, + args: &[Expr], + lit_expr: &Expr, + info: &SimplifyContext, + ) -> Result { + let [part, col_expr] = take_function_args(self.name(), args)?; + + // Get the interval unit from the part argument + let interval_unit = part + .as_literal() + .and_then(|sv| sv.try_as_str().flatten()) + .map(part_normalization) + .and_then(|s| IntervalUnit::from_str(s).ok()); + + // only support extracting year + match interval_unit { + Some(IntervalUnit::Year) => (), + _ => return Ok(PreimageResult::None), + } + + // Check if the argument is a literal (e.g. date_part(YEAR, col) = 2024) + let Some(argument_literal) = lit_expr.as_literal() else { + return Ok(PreimageResult::None); + }; + + // Extract i32 year from Scalar value + let year = match argument_literal { + ScalarValue::Int32(Some(y)) => *y, + _ => return Ok(PreimageResult::None), + }; + + // Can only extract year from Date32/64 and Timestamp column + let target_type = match info.get_data_type(col_expr)? { + Date32 | Date64 | Timestamp(_, _) => &info.get_data_type(col_expr)?, + _ => return Ok(PreimageResult::None), + }; + + // Compute the Interval bounds + let Some(start_time) = NaiveDate::from_ymd_opt(year, 1, 1) else { + return Ok(PreimageResult::None); + }; + let Some(end_time) = start_time.with_year(year + 1) else { + return Ok(PreimageResult::None); + }; + + // Convert to ScalarValues + let (Some(lower), Some(upper)) = ( + date_to_scalar(start_time, target_type), + date_to_scalar(end_time, target_type), + ) else { + return Ok(PreimageResult::None); + }; + let interval = Box::new(interval_arithmetic::Interval::try_new(lower, upper)?); + + Ok(PreimageResult::Range { + expr: col_expr.clone(), + interval, + }) + } + fn aliases(&self) -> &[String] { &self.aliases } @@ -245,6 +321,52 @@ fn is_epoch(part: &str) -> bool { matches!(part.to_lowercase().as_str(), "epoch") } +fn date_to_scalar(date: NaiveDate, target_type: &DataType) -> Option { + Some(match target_type { + Date32 => ScalarValue::Date32(Some(Date32Type::from_naive_date(date))), + Date64 => ScalarValue::Date64(Some(Date64Type::from_naive_date(date))), + + Timestamp(unit, tz_opt) => { + let naive_midnight = date.and_hms_opt(0, 0, 0)?; + + let utc_dt = if let Some(tz_str) = tz_opt { + let tz: Tz = tz_str.parse().ok()?; + + let local = tz.from_local_datetime(&naive_midnight); + + let local_dt = match local { + chrono::offset::LocalResult::Single(dt) => dt, + chrono::offset::LocalResult::Ambiguous(dt1, _dt2) => dt1, + chrono::offset::LocalResult::None => local.earliest()?, + }; + + local_dt.with_timezone(&Utc) + } else { + Utc.from_utc_datetime(&naive_midnight) + }; + + match unit { + Second => { + ScalarValue::TimestampSecond(Some(utc_dt.timestamp()), tz_opt.clone()) + } + Millisecond => ScalarValue::TimestampMillisecond( + Some(utc_dt.timestamp_millis()), + tz_opt.clone(), + ), + Microsecond => ScalarValue::TimestampMicrosecond( + Some(utc_dt.timestamp_micros()), + tz_opt.clone(), + ), + Nanosecond => ScalarValue::TimestampNanosecond( + Some(utc_dt.timestamp_nanos_opt()?), + tz_opt.clone(), + ), + } + } + _ => return None, + }) +} + // Try to remove quote if exist, if the quote is invalid, return original string and let the downstream function handle the error fn part_normalization(part: &str) -> &str { part.strip_prefix(|c| c == '\'' || c == '\"') @@ -349,6 +471,11 @@ fn seconds(array: &dyn Array, unit: TimeUnit) -> Result { fn epoch(array: &dyn Array) -> Result { const SECONDS_IN_A_DAY: f64 = 86400_f64; + // Note: Month-to-second conversion uses 30 days as an approximation. + // This matches PostgreSQL's behavior for interval epoch extraction, + // but does not represent exact calendar months (which vary 28-31 days). + // See: https://doxygen.postgresql.org/datatype_2timestamp_8h.html + const DAYS_PER_MONTH: f64 = 30_f64; let f: Float64Array = match array.data_type() { Timestamp(Second, _) => as_timestamp_second_array(array)?.unary(|x| x as f64), @@ -373,7 +500,19 @@ fn epoch(array: &dyn Array) -> Result { Time64(Nanosecond) => { as_time64_nanosecond_array(array)?.unary(|x| x as f64 / 1_000_000_000_f64) } - Interval(_) | Duration(_) => return seconds(array, Second), + Interval(ArrowIntervalUnit::YearMonth) => as_interval_ym_array(array)? + .unary(|x| x as f64 * DAYS_PER_MONTH * SECONDS_IN_A_DAY), + Interval(ArrowIntervalUnit::DayTime) => as_interval_dt_array(array)?.unary(|x| { + x.days as f64 * SECONDS_IN_A_DAY + x.milliseconds as f64 / 1_000_f64 + }), + Interval(ArrowIntervalUnit::MonthDayNano) => { + as_interval_mdn_array(array)?.unary(|x| { + x.months as f64 * DAYS_PER_MONTH * SECONDS_IN_A_DAY + + x.days as f64 * SECONDS_IN_A_DAY + + x.nanoseconds as f64 / 1_000_000_000_f64 + }) + } + Duration(_) => return seconds(array, Second), d => return exec_err!("Cannot convert {d:?} to epoch"), }; Ok(Arc::new(f)) diff --git a/datafusion/functions/src/datetime/date_trunc.rs b/datafusion/functions/src/datetime/date_trunc.rs index aca1d24c3116..8497e583ba4b 100644 --- a/datafusion/functions/src/datetime/date_trunc.rs +++ b/datafusion/functions/src/datetime/date_trunc.rs @@ -22,25 +22,30 @@ use std::str::FromStr; use std::sync::Arc; use arrow::array::temporal_conversions::{ - as_datetime_with_timezone, timestamp_ns_to_datetime, + MICROSECONDS, MILLISECONDS, NANOSECONDS, as_datetime_with_timezone, + timestamp_ns_to_datetime, }; use arrow::array::timezone::Tz; use arrow::array::types::{ - ArrowTimestampType, TimestampMicrosecondType, TimestampMillisecondType, + ArrowTimestampType, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; use arrow::array::{Array, ArrayRef, PrimitiveArray}; -use arrow::datatypes::DataType::{self, Null, Timestamp, Utf8, Utf8View}; +use arrow::datatypes::DataType::{self, Time32, Time64, Timestamp}; use arrow::datatypes::TimeUnit::{self, Microsecond, Millisecond, Nanosecond, Second}; +use arrow::datatypes::{Field, FieldRef}; use datafusion_common::cast::as_primitive_array; +use datafusion_common::types::{NativeType, logical_date, logical_string}; use datafusion_common::{ - DataFusionError, Result, ScalarValue, exec_datafusion_err, exec_err, plan_err, + DataFusionError, Result, ScalarValue, exec_datafusion_err, exec_err, internal_err, }; -use datafusion_expr::TypeSignature::Exact; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, TIMEZONE_WILDCARD, Volatility, + ColumnarValue, Documentation, ReturnFieldArgs, ScalarUDFImpl, Signature, + TypeSignature, Volatility, }; +use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; use datafusion_macros::user_doc; use chrono::{ @@ -116,16 +121,30 @@ impl DateTruncGranularity { fn is_fine_granularity_utc(&self) -> bool { self.is_fine_granularity() || matches!(self, Self::Hour | Self::Day) } + + /// Returns true if this granularity is valid for Time types + /// Time types don't have date components, so day/week/month/quarter/year are not valid + fn valid_for_time(&self) -> bool { + matches!( + self, + Self::Hour + | Self::Minute + | Self::Second + | Self::Millisecond + | Self::Microsecond + ) + } } #[user_doc( doc_section(label = "Time and Date Functions"), - description = "Truncates a timestamp value to a specified precision.", + description = "Truncates a timestamp or time value to a specified precision.", syntax_example = "date_trunc(precision, expression)", argument( name = "precision", description = r#"Time precision to truncate to. The following precisions are supported: + For Timestamp types: - year / YEAR - quarter / QUARTER - month / MONTH @@ -136,11 +155,18 @@ impl DateTruncGranularity { - second / SECOND - millisecond / MILLISECOND - microsecond / MICROSECOND + + For Time types (hour, minute, second, millisecond, microsecond only): + - hour / HOUR + - minute / MINUTE + - second / SECOND + - millisecond / MILLISECOND + - microsecond / MICROSECOND "# ), argument( name = "expression", - description = "Time expression to operate on. Can be a constant, column, or function." + description = "Timestamp or time expression to operate on. Can be a constant, column, or function." ) )] #[derive(Debug, PartialEq, Eq, Hash)] @@ -160,45 +186,21 @@ impl DateTruncFunc { Self { signature: Signature::one_of( vec![ - Exact(vec![Utf8, Timestamp(Nanosecond, None)]), - Exact(vec![Utf8View, Timestamp(Nanosecond, None)]), - Exact(vec![ - Utf8, - Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), - ]), - Exact(vec![ - Utf8View, - Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), - ]), - Exact(vec![Utf8, Timestamp(Microsecond, None)]), - Exact(vec![Utf8View, Timestamp(Microsecond, None)]), - Exact(vec![ - Utf8, - Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())), - ]), - Exact(vec![ - Utf8View, - Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())), + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_implicit( + TypeSignatureClass::Timestamp, + // Allow implicit cast from string and date to timestamp for backward compatibility + vec![ + TypeSignatureClass::Native(logical_string()), + TypeSignatureClass::Native(logical_date()), + ], + NativeType::Timestamp(Nanosecond, None), + ), ]), - Exact(vec![Utf8, Timestamp(Millisecond, None)]), - Exact(vec![Utf8View, Timestamp(Millisecond, None)]), - Exact(vec![ - Utf8, - Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())), - ]), - Exact(vec![ - Utf8View, - Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())), - ]), - Exact(vec![Utf8, Timestamp(Second, None)]), - Exact(vec![Utf8View, Timestamp(Second, None)]), - Exact(vec![ - Utf8, - Timestamp(Second, Some(TIMEZONE_WILDCARD.into())), - ]), - Exact(vec![ - Utf8View, - Timestamp(Second, Some(TIMEZONE_WILDCARD.into())), + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Time), ]), ], Volatility::Immutable, @@ -221,19 +223,22 @@ impl ScalarUDFImpl for DateTruncFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - match &arg_types[1] { - Timestamp(Nanosecond, None) | Utf8 | DataType::Date32 | Null => { - Ok(Timestamp(Nanosecond, None)) - } - Timestamp(Nanosecond, tz_opt) => Ok(Timestamp(Nanosecond, tz_opt.clone())), - Timestamp(Microsecond, tz_opt) => Ok(Timestamp(Microsecond, tz_opt.clone())), - Timestamp(Millisecond, tz_opt) => Ok(Timestamp(Millisecond, tz_opt.clone())), - Timestamp(Second, tz_opt) => Ok(Timestamp(Second, tz_opt.clone())), - _ => plan_err!( - "The date_trunc function can only accept timestamp as the second arg." - ), - } + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be called instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let field = &args.arg_fields[1]; + let return_type = if field.data_type().is_null() { + Timestamp(Nanosecond, None) + } else { + field.data_type().clone() + }; + Ok(Arc::new(Field::new( + self.name(), + return_type, + field.is_nullable(), + ))) } fn invoke_with_args( @@ -248,6 +253,9 @@ impl ScalarUDFImpl for DateTruncFunc { { v.to_lowercase() } else if let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(v))) = granularity + { + v.to_lowercase() + } else if let ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(v))) = granularity { v.to_lowercase() } else { @@ -256,6 +264,15 @@ impl ScalarUDFImpl for DateTruncFunc { let granularity = DateTruncGranularity::from_str(&granularity_str)?; + // Check upfront if granularity is valid for Time types + let is_time_type = matches!(array.data_type(), Time64(_) | Time32(_)); + if is_time_type && !granularity.valid_for_time() { + return exec_err!( + "date_trunc does not support '{}' granularity for Time types. Valid values are: hour, minute, second, millisecond, microsecond", + granularity_str + ); + } + fn process_array( array: &dyn Array, granularity: DateTruncGranularity, @@ -303,6 +320,10 @@ impl ScalarUDFImpl for DateTruncFunc { } Ok(match array { + ColumnarValue::Scalar(ScalarValue::Null) => { + // NULL input returns NULL timestamp + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(None, None)) + } ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(v, tz_opt)) => { process_scalar::(v, granularity, tz_opt)? } @@ -315,40 +336,77 @@ impl ScalarUDFImpl for DateTruncFunc { ColumnarValue::Scalar(ScalarValue::TimestampSecond(v, tz_opt)) => { process_scalar::(v, granularity, tz_opt)? } + ColumnarValue::Scalar(ScalarValue::Time64Nanosecond(v)) => { + let truncated = v.map(|val| truncate_time_nanos(val, granularity)); + ColumnarValue::Scalar(ScalarValue::Time64Nanosecond(truncated)) + } + ColumnarValue::Scalar(ScalarValue::Time64Microsecond(v)) => { + let truncated = v.map(|val| truncate_time_micros(val, granularity)); + ColumnarValue::Scalar(ScalarValue::Time64Microsecond(truncated)) + } + ColumnarValue::Scalar(ScalarValue::Time32Millisecond(v)) => { + let truncated = v.map(|val| truncate_time_millis(val, granularity)); + ColumnarValue::Scalar(ScalarValue::Time32Millisecond(truncated)) + } + ColumnarValue::Scalar(ScalarValue::Time32Second(v)) => { + let truncated = v.map(|val| truncate_time_secs(val, granularity)); + ColumnarValue::Scalar(ScalarValue::Time32Second(truncated)) + } ColumnarValue::Array(array) => { let array_type = array.data_type(); - if let Timestamp(unit, tz_opt) = array_type { - match unit { - Second => process_array::( - array, - granularity, - tz_opt, - )?, - Millisecond => process_array::( - array, - granularity, - tz_opt, - )?, - Microsecond => process_array::( - array, - granularity, - tz_opt, - )?, - Nanosecond => process_array::( - array, - granularity, - tz_opt, - )?, + match array_type { + Timestamp(Second, tz_opt) => { + process_array::(array, granularity, tz_opt)? + } + Timestamp(Millisecond, tz_opt) => process_array::< + TimestampMillisecondType, + >( + array, granularity, tz_opt + )?, + Timestamp(Microsecond, tz_opt) => process_array::< + TimestampMicrosecondType, + >( + array, granularity, tz_opt + )?, + Timestamp(Nanosecond, tz_opt) => process_array::< + TimestampNanosecondType, + >( + array, granularity, tz_opt + )?, + Time64(Nanosecond) => { + let arr = as_primitive_array::(array)?; + let result: PrimitiveArray = + arr.unary(|v| truncate_time_nanos(v, granularity)); + ColumnarValue::Array(Arc::new(result)) + } + Time64(Microsecond) => { + let arr = as_primitive_array::(array)?; + let result: PrimitiveArray = + arr.unary(|v| truncate_time_micros(v, granularity)); + ColumnarValue::Array(Arc::new(result)) + } + Time32(Millisecond) => { + let arr = as_primitive_array::(array)?; + let result: PrimitiveArray = + arr.unary(|v| truncate_time_millis(v, granularity)); + ColumnarValue::Array(Arc::new(result)) + } + Time32(Second) => { + let arr = as_primitive_array::(array)?; + let result: PrimitiveArray = + arr.unary(|v| truncate_time_secs(v, granularity)); + ColumnarValue::Array(Arc::new(result)) + } + _ => { + return exec_err!( + "second argument of `date_trunc` is an unsupported array type: {array_type}" + ); } - } else { - return exec_err!( - "second argument of `date_trunc` is an unsupported array type: {array_type}" - ); } } _ => { return exec_err!( - "second argument of `date_trunc` must be timestamp scalar or array" + "second argument of `date_trunc` must be timestamp, time scalar or array" ); } }) @@ -374,6 +432,76 @@ impl ScalarUDFImpl for DateTruncFunc { } } +const NANOS_PER_MICROSECOND: i64 = NANOSECONDS / MICROSECONDS; +const NANOS_PER_MILLISECOND: i64 = NANOSECONDS / MILLISECONDS; +const NANOS_PER_SECOND: i64 = NANOSECONDS; +const NANOS_PER_MINUTE: i64 = 60 * NANOS_PER_SECOND; +const NANOS_PER_HOUR: i64 = 60 * NANOS_PER_MINUTE; + +const MICROS_PER_MILLISECOND: i64 = MICROSECONDS / MILLISECONDS; +const MICROS_PER_SECOND: i64 = MICROSECONDS; +const MICROS_PER_MINUTE: i64 = 60 * MICROS_PER_SECOND; +const MICROS_PER_HOUR: i64 = 60 * MICROS_PER_MINUTE; + +const MILLIS_PER_SECOND: i32 = MILLISECONDS as i32; +const MILLIS_PER_MINUTE: i32 = 60 * MILLIS_PER_SECOND; +const MILLIS_PER_HOUR: i32 = 60 * MILLIS_PER_MINUTE; + +const SECS_PER_MINUTE: i32 = 60; +const SECS_PER_HOUR: i32 = 60 * SECS_PER_MINUTE; + +/// Truncate time in nanoseconds to the specified granularity +fn truncate_time_nanos(value: i64, granularity: DateTruncGranularity) -> i64 { + match granularity { + DateTruncGranularity::Hour => value - (value % NANOS_PER_HOUR), + DateTruncGranularity::Minute => value - (value % NANOS_PER_MINUTE), + DateTruncGranularity::Second => value - (value % NANOS_PER_SECOND), + DateTruncGranularity::Millisecond => value - (value % NANOS_PER_MILLISECOND), + DateTruncGranularity::Microsecond => value - (value % NANOS_PER_MICROSECOND), + // Other granularities are not valid for time - should be caught earlier + _ => value, + } +} + +/// Truncate time in microseconds to the specified granularity +fn truncate_time_micros(value: i64, granularity: DateTruncGranularity) -> i64 { + match granularity { + DateTruncGranularity::Hour => value - (value % MICROS_PER_HOUR), + DateTruncGranularity::Minute => value - (value % MICROS_PER_MINUTE), + DateTruncGranularity::Second => value - (value % MICROS_PER_SECOND), + DateTruncGranularity::Millisecond => value - (value % MICROS_PER_MILLISECOND), + DateTruncGranularity::Microsecond => value, // Already at microsecond precision + // Other granularities are not valid for time + _ => value, + } +} + +/// Truncate time in milliseconds to the specified granularity +fn truncate_time_millis(value: i32, granularity: DateTruncGranularity) -> i32 { + match granularity { + DateTruncGranularity::Hour => value - (value % MILLIS_PER_HOUR), + DateTruncGranularity::Minute => value - (value % MILLIS_PER_MINUTE), + DateTruncGranularity::Second => value - (value % MILLIS_PER_SECOND), + DateTruncGranularity::Millisecond => value, // Already at millisecond precision + DateTruncGranularity::Microsecond => value, // Can't truncate to finer precision + // Other granularities are not valid for time + _ => value, + } +} + +/// Truncate time in seconds to the specified granularity +fn truncate_time_secs(value: i32, granularity: DateTruncGranularity) -> i32 { + match granularity { + DateTruncGranularity::Hour => value - (value % SECS_PER_HOUR), + DateTruncGranularity::Minute => value - (value % SECS_PER_MINUTE), + DateTruncGranularity::Second => value, // Already at second precision + DateTruncGranularity::Millisecond => value, // Can't truncate to finer precision + DateTruncGranularity::Microsecond => value, // Can't truncate to finer precision + // Other granularities are not valid for time + _ => value, + } +} + fn _date_trunc_coarse( granularity: DateTruncGranularity, value: Option, diff --git a/datafusion/functions/src/datetime/now.rs b/datafusion/functions/src/datetime/now.rs index b804efe59106..338a62a118f3 100644 --- a/datafusion/functions/src/datetime/now.rs +++ b/datafusion/functions/src/datetime/now.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use datafusion_common::config::ConfigOptions; use datafusion_common::{Result, ScalarValue, internal_err}; -use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr::{ ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, @@ -121,16 +121,18 @@ impl ScalarUDFImpl for NowFunc { fn simplify( &self, - _args: Vec, - info: &dyn SimplifyInfo, + args: Vec, + info: &SimplifyContext, ) -> Result { - let now_ts = info - .execution_props() - .query_execution_start_time - .timestamp_nanos_opt(); + let Some(now_ts) = info.query_execution_start_time() else { + return Ok(ExprSimplifyResult::Original(args)); + }; Ok(ExprSimplifyResult::Simplified(Expr::Literal( - ScalarValue::TimestampNanosecond(now_ts, self.timezone.clone()), + ScalarValue::TimestampNanosecond( + now_ts.timestamp_nanos_opt(), + self.timezone.clone(), + ), None, ))) } diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index 8d0c47cfe664..4ceaac1cc8af 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -18,15 +18,15 @@ use std::any::Any; use std::sync::Arc; +use arrow::array::builder::StringBuilder; use arrow::array::cast::AsArray; -use arrow::array::{Array, ArrayRef, StringArray, new_null_array}; +use arrow::array::{Array, ArrayRef}; use arrow::compute::cast; use arrow::datatypes::DataType; use arrow::datatypes::DataType::{ Date32, Date64, Duration, Time32, Time64, Timestamp, Utf8, }; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; -use arrow::error::ArrowError; use arrow::util::display::{ArrayFormatter, DurationFormat, FormatOptions}; use datafusion_common::{Result, ScalarValue, exec_err, utils::take_function_args}; use datafusion_expr::TypeSignature::Exact; @@ -143,20 +143,17 @@ impl ScalarUDFImpl for ToCharFunc { let [date_time, format] = take_function_args(self.name(), &args)?; match format { - ColumnarValue::Scalar(ScalarValue::Utf8(None)) - | ColumnarValue::Scalar(ScalarValue::Null) => to_char_scalar(date_time, None), - // constant format - ColumnarValue::Scalar(ScalarValue::Utf8(Some(format))) => { - // invoke to_char_scalar with the known string, without converting to array - to_char_scalar(date_time, Some(format)) + ColumnarValue::Scalar(ScalarValue::Null | ScalarValue::Utf8(None)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) } - ColumnarValue::Array(_) => to_char_array(&args), - _ => { - exec_err!( - "Format for `to_char` must be non-null Utf8, received {:?}", - format.data_type() - ) + ColumnarValue::Scalar(ScalarValue::Utf8(Some(fmt))) => { + to_char_scalar(date_time, fmt) } + ColumnarValue::Array(_) => to_char_array(&args), + _ => exec_err!( + "Format for `to_char` must be non-null Utf8, received {}", + format.data_type() + ), } } @@ -171,11 +168,8 @@ impl ScalarUDFImpl for ToCharFunc { fn build_format_options<'a>( data_type: &DataType, - format: Option<&'a str>, -) -> Result, Result> { - let Some(format) = format else { - return Ok(FormatOptions::new()); - }; + format: &'a str, +) -> Result> { let format_options = match data_type { Date32 => FormatOptions::new() .with_date_format(Some(format)) @@ -194,144 +188,114 @@ fn build_format_options<'a>( }, ), other => { - return Err(exec_err!( + return exec_err!( "to_char only supports date, time, timestamp and duration data types, received {other:?}" - )); + ); } }; Ok(format_options) } -/// Special version when arg\[1] is a scalar -fn to_char_scalar( - expression: &ColumnarValue, - format: Option<&str>, -) -> Result { - // it's possible that the expression is a scalar however because - // of the implementation in arrow-rs we need to convert it to an array +/// Formats `expression` using a constant `format` string. +fn to_char_scalar(expression: &ColumnarValue, format: &str) -> Result { + // ArrayFormatter requires an array, so scalar expressions must be + // converted to a 1-element array first. let data_type = &expression.data_type(); let is_scalar_expression = matches!(&expression, ColumnarValue::Scalar(_)); - let array = expression.clone().into_array(1)?; + let array = expression.to_array(1)?; - if format.is_none() { - return if is_scalar_expression { - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) - } else { - Ok(ColumnarValue::Array(new_null_array(&Utf8, array.len()))) - }; - } + let format_options = build_format_options(data_type, format)?; + let formatter = ArrayFormatter::try_new(array.as_ref(), &format_options)?; - let format_options = match build_format_options(data_type, format) { - Ok(value) => value, - Err(value) => return value, - }; + // Pad the preallocated capacity a bit because format specifiers often + // expand the string (e.g., %Y -> "2026") + let fmt_len = format.len() + 10; + let mut builder = StringBuilder::with_capacity(array.len(), array.len() * fmt_len); - let formatter = ArrayFormatter::try_new(array.as_ref(), &format_options)?; - let formatted: Result>, ArrowError> = (0..array.len()) - .map(|i| { - if array.is_null(i) { - Ok(None) - } else { - formatter.value(i).try_to_string().map(Some) - } - }) - .collect(); - - if let Ok(formatted) = formatted { - if is_scalar_expression { - Ok(ColumnarValue::Scalar(ScalarValue::Utf8( - formatted.first().unwrap().clone(), - ))) + for i in 0..array.len() { + if array.is_null(i) { + builder.append_null(); } else { - Ok(ColumnarValue::Array( - Arc::new(StringArray::from(formatted)) as ArrayRef - )) - } - } else { - // if the data type was a Date32, formatting could have failed because the format string - // contained datetime specifiers, so we'll retry by casting the date array as a timestamp array - if data_type == &Date32 { - return to_char_scalar(&expression.cast_to(&Date64, None)?, format); + // Write directly into the builder's internal buffer, then + // commit the value with append_value(""). + match formatter.value(i).write(&mut builder) { + Ok(()) => builder.append_value(""), + // Arrow's Date32 formatter only handles date specifiers + // (%Y, %m, %d, ...). Format strings with time specifiers + // (%H, %M, %S, ...) cause it to fail. When this happens, + // we retry by casting to Date64, whose datetime formatter + // handles both date and time specifiers (with zero for + // the time components). + Err(_) if data_type == &Date32 => { + return to_char_scalar(&expression.cast_to(&Date64, None)?, format); + } + Err(e) => return Err(e.into()), + } } + } - exec_err!("{}", formatted.unwrap_err()) + let result = builder.finish(); + if is_scalar_expression { + let val = result.is_valid(0).then(|| result.value(0).to_string()); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(val))) + } else { + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) } } fn to_char_array(args: &[ColumnarValue]) -> Result { let arrays = ColumnarValue::values_to_arrays(args)?; - let mut results: Vec> = vec![]; + let data_array = &arrays[0]; let format_array = arrays[1].as_string::(); - let data_type = arrays[0].data_type(); + let data_type = data_array.data_type(); - for idx in 0..arrays[0].len() { - let format = if format_array.is_null(idx) { - None - } else { - Some(format_array.value(idx)) - }; - if format.is_none() { - results.push(None); + // Arbitrary guess for the length of a typical formatted datetime string + let fmt_len = 30; + let mut builder = + StringBuilder::with_capacity(data_array.len(), data_array.len() * fmt_len); + let mut buffer = String::with_capacity(fmt_len); + + for idx in 0..data_array.len() { + if format_array.is_null(idx) || data_array.is_null(idx) { + builder.append_null(); continue; } - let format_options = match build_format_options(data_type, format) { - Ok(value) => value, - Err(value) => return value, - }; - // this isn't ideal but this can't use ValueFormatter as it isn't independent - // from ArrayFormatter - let formatter = ArrayFormatter::try_new(arrays[0].as_ref(), &format_options)?; - let result = formatter.value(idx).try_to_string(); - match result { - Ok(value) => results.push(Some(value)), - Err(e) => { - // if the data type was a Date32, formatting could have failed because the format string - // contained datetime specifiers, so we'll treat this specific date element as a timestamp - if data_type == &Date32 { - let failed_date_value = arrays[0].slice(idx, 1); - - match retry_date_as_timestamp(&failed_date_value, &format_options) { - Ok(value) => { - results.push(Some(value)); - continue; - } - Err(e) => { - return exec_err!("{}", e); - } - } - } - return exec_err!("{}", e); + let format = format_array.value(idx); + let format_options = build_format_options(data_type, format)?; + let formatter = ArrayFormatter::try_new(data_array.as_ref(), &format_options)?; + + buffer.clear(); + + // We'd prefer to write directly to the StringBuilder's internal buffer, + // but the write might fail, and there's no easy way to ensure a partial + // write is removed from the buffer. So instead we write to a temporary + // buffer and `append_value` on success. + match formatter.value(idx).write(&mut buffer) { + Ok(()) => builder.append_value(&buffer), + // Retry with Date64 (see comment in to_char_scalar). + Err(_) if data_type == &Date32 => { + buffer.clear(); + let date64_value = cast(&data_array.slice(idx, 1), &Date64)?; + let retry_fmt = + ArrayFormatter::try_new(date64_value.as_ref(), &format_options)?; + retry_fmt.value(0).write(&mut buffer)?; + builder.append_value(&buffer); } + Err(e) => return Err(e.into()), } } + let result = builder.finish(); match args[0] { - ColumnarValue::Array(_) => Ok(ColumnarValue::Array(Arc::new(StringArray::from( - results, - )) as ArrayRef)), - ColumnarValue::Scalar(_) => match results.first().unwrap() { - Some(value) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( - value.to_string(), - )))), - None => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), - }, + ColumnarValue::Scalar(_) => { + let val = result.is_valid(0).then(|| result.value(0).to_string()); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(val))) + } + ColumnarValue::Array(_) => Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)), } } -fn retry_date_as_timestamp( - array_ref: &ArrayRef, - format_options: &FormatOptions, -) -> Result { - let target_data_type = Date64; - - let date_value = cast(&array_ref, &target_data_type)?; - let formatter = ArrayFormatter::try_new(date_value.as_ref(), format_options)?; - let result = formatter.value(0).try_to_string()?; - - Ok(result) -} - #[cfg(test)] mod tests { use crate::datetime::to_char::ToCharFunc; @@ -814,7 +778,7 @@ mod tests { let result = ToCharFunc::new().invoke_with_args(args); assert_eq!( result.err().unwrap().strip_backtrace(), - "Execution error: Format for `to_char` must be non-null Utf8, received Timestamp(Nanosecond, None)" + "Execution error: Format for `to_char` must be non-null Utf8, received Timestamp(ns)" ); } } diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index 86c949711d01..0500497a15fa 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -324,7 +324,7 @@ fn to_local_time(time_value: &ColumnarValue) -> Result { /// ``` /// /// See `test_adjust_to_local_time()` for example -fn adjust_to_local_time(ts: i64, tz: Tz) -> Result { +pub fn adjust_to_local_time(ts: i64, tz: Tz) -> Result { fn convert_timestamp(ts: i64, converter: F) -> Result> where F: Fn(i64) -> MappedLocalTime>, diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index 58077694b07a..6d40133bd29b 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -19,8 +19,11 @@ use std::any::Any; use std::sync::Arc; use crate::datetime::common::*; -use arrow::array::Float64Array; use arrow::array::timezone::Tz; +use arrow::array::{ + Array, Decimal128Array, Float16Array, Float32Array, Float64Array, + TimestampNanosecondArray, +}; use arrow::datatypes::DataType::*; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; use arrow::datatypes::{ @@ -28,7 +31,6 @@ use arrow::datatypes::{ TimestampNanosecondType, TimestampSecondType, }; use datafusion_common::config::ConfigOptions; -use datafusion_common::format::DEFAULT_CAST_OPTIONS; use datafusion_common::{Result, ScalarType, ScalarValue, exec_err}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDF, ScalarUDFImpl, Signature, Volatility, @@ -40,7 +42,8 @@ use datafusion_macros::user_doc; description = r#" Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000`) in the session time zone. Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') -if no [Chrono formats] are provided. Strings that parse without a time zone are treated as if they are in the +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Strings that parse without a time zone are treated as if they are in the session time zone, or UTC if no session time zone is set. Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). @@ -94,7 +97,8 @@ pub struct ToTimestampFunc { description = r#" Converts a value to a timestamp (`YYYY-MM-DDT00:00:00`) in the session time zone. Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') -if no [Chrono formats] are provided. Strings that parse without a time zone are treated as if they are in the +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Strings that parse without a time zone are treated as if they are in the session time zone, or UTC if no session time zone is set. Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). @@ -143,7 +147,8 @@ pub struct ToTimestampSecondsFunc { description = r#" Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000`) in the session time zone. Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') -if no [Chrono formats] are provided. Strings that parse without a time zone are treated as if they are in the +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Strings that parse without a time zone are treated as if they are in the session time zone, or UTC if no session time zone is set. Integers, unsigned integers, and doubles are interpreted as milliseconds since the unix epoch (`1970-01-01T00:00:00Z`). @@ -192,7 +197,8 @@ pub struct ToTimestampMillisFunc { description = r#" Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000`) in the session time zone. Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') -if no [Chrono formats] are provided. Strings that parse without a time zone are treated as if they are in the +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Strings that parse without a time zone are treated as if they are in the session time zone, or UTC if no session time zone is set. Integers, unsigned integers, and doubles are interpreted as microseconds since the unix epoch (`1970-01-01T00:00:00Z`). @@ -241,7 +247,8 @@ pub struct ToTimestampMicrosFunc { description = r#" Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000000`) in the session time zone. Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') -if no [Chrono formats] are provided. Strings that parse without a time zone are treated as if they are in the +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Strings that parse without a time zone are treated as if they are in the session time zone. Integers, unsigned integers, and doubles are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`). The session time zone can be set using the statement `SET TIMEZONE = 'desired time zone'`. @@ -325,6 +332,45 @@ impl_to_timestamp_constructors!(ToTimestampMillisFunc); impl_to_timestamp_constructors!(ToTimestampMicrosFunc); impl_to_timestamp_constructors!(ToTimestampNanosFunc); +fn decimal_to_nanoseconds(value: i128, scale: i8) -> i64 { + let nanos_exponent = 9_i16 - scale as i16; + let timestamp_nanos = if nanos_exponent >= 0 { + value * 10_i128.pow(nanos_exponent as u32) + } else { + value / 10_i128.pow(nanos_exponent.unsigned_abs() as u32) + }; + timestamp_nanos as i64 +} + +fn decimal128_to_timestamp_nanos( + arg: &ColumnarValue, + tz: Option>, +) -> Result { + match arg { + ColumnarValue::Scalar(ScalarValue::Decimal128(Some(value), _, scale)) => { + let timestamp_nanos = decimal_to_nanoseconds(*value, *scale); + Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + Some(timestamp_nanos), + tz, + ))) + } + ColumnarValue::Scalar(ScalarValue::Decimal128(None, _, _)) => Ok( + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(None, tz)), + ), + ColumnarValue::Array(arr) => { + let decimal_arr = downcast_arg!(arr, Decimal128Array); + let scale = decimal_arr.scale(); + let result: TimestampNanosecondArray = decimal_arr + .iter() + .map(|v| v.map(|val| decimal_to_nanoseconds(val, scale))) + .collect(); + let result = result.with_timezone_opt(tz); + Ok(ColumnarValue::Array(Arc::new(result))) + } + _ => exec_err!("Invalid Decimal128 value for to_timestamp"), + } +} + /// to_timestamp SQL function /// /// Note: `to_timestamp` returns `Timestamp(Nanosecond)` though its arguments are interpreted as **seconds**. @@ -380,48 +426,68 @@ impl ScalarUDFImpl for ToTimestampFunc { let tz = self.timezone.clone(); match args[0].data_type() { - Int32 | Int64 => args[0] + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 => args[0] .cast_to(&Timestamp(Second, None), None)? .cast_to(&Timestamp(Nanosecond, tz), None), Null | Timestamp(_, _) => args[0].cast_to(&Timestamp(Nanosecond, tz), None), - Float64 => { - let rescaled = arrow::compute::kernels::numeric::mul( - &args[0].to_array(1)?, - &arrow::array::Scalar::new(Float64Array::from(vec![ - 1_000_000_000f64, - ])), - )?; - Ok(ColumnarValue::Array(arrow::compute::cast_with_options( - &rescaled, - &Timestamp(Nanosecond, tz), - &DEFAULT_CAST_OPTIONS, - )?)) + Float16 => match &args[0] { + ColumnarValue::Scalar(ScalarValue::Float16(value)) => { + let timestamp_nanos = + value.map(|v| (v.to_f64() * 1_000_000_000.0) as i64); + Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + timestamp_nanos, + tz, + ))) + } + ColumnarValue::Array(arr) => { + let f16_arr = downcast_arg!(arr, Float16Array); + let result: TimestampNanosecondArray = + f16_arr.unary(|x| (x.to_f64() * 1_000_000_000.0) as i64); + Ok(ColumnarValue::Array(Arc::new(result.with_timezone_opt(tz)))) + } + _ => exec_err!("Invalid Float16 value for to_timestamp"), + }, + Float32 => match &args[0] { + ColumnarValue::Scalar(ScalarValue::Float32(value)) => { + let timestamp_nanos = + value.map(|v| (v as f64 * 1_000_000_000.0) as i64); + Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + timestamp_nanos, + tz, + ))) + } + ColumnarValue::Array(arr) => { + let f32_arr = downcast_arg!(arr, Float32Array); + let result: TimestampNanosecondArray = + f32_arr.unary(|x| (x as f64 * 1_000_000_000.0) as i64); + Ok(ColumnarValue::Array(Arc::new(result.with_timezone_opt(tz)))) + } + _ => exec_err!("Invalid Float32 value for to_timestamp"), + }, + Float64 => match &args[0] { + ColumnarValue::Scalar(ScalarValue::Float64(value)) => { + let timestamp_nanos = value.map(|v| (v * 1_000_000_000.0) as i64); + Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + timestamp_nanos, + tz, + ))) + } + ColumnarValue::Array(arr) => { + let f64_arr = downcast_arg!(arr, Float64Array); + let result: TimestampNanosecondArray = + f64_arr.unary(|x| (x * 1_000_000_000.0) as i64); + Ok(ColumnarValue::Array(Arc::new(result.with_timezone_opt(tz)))) + } + _ => exec_err!("Invalid Float64 value for to_timestamp"), + }, + Decimal32(_, _) | Decimal64(_, _) | Decimal256(_, _) => { + let arg = args[0].cast_to(&Decimal128(38, 9), None)?; + decimal128_to_timestamp_nanos(&arg, tz) } + Decimal128(_, _) => decimal128_to_timestamp_nanos(&args[0], tz), Utf8View | LargeUtf8 | Utf8 => { to_timestamp_impl::(&args, "to_timestamp", &tz) } - Decimal128(_, _) => { - match &args[0] { - ColumnarValue::Scalar(ScalarValue::Decimal128( - Some(value), - _, - scale, - )) => { - // Convert decimal to seconds and nanoseconds - let scale_factor = 10_i128.pow(*scale as u32); - let seconds = value / scale_factor; - let fraction = value % scale_factor; - let nanos = (fraction * 1_000_000_000) / scale_factor; - let timestamp_nanos = seconds * 1_000_000_000 + nanos; - - Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( - Some(timestamp_nanos as i64), - tz, - ))) - } - _ => exec_err!("Invalid decimal value"), - } - } other => { exec_err!("Unsupported data type {other} for function to_timestamp") } @@ -473,9 +539,23 @@ impl ScalarUDFImpl for ToTimestampSecondsFunc { let tz = self.timezone.clone(); match args[0].data_type() { - Null | Int32 | Int64 | Timestamp(_, _) | Decimal128(_, _) => { - args[0].cast_to(&Timestamp(Second, tz), None) - } + Null + | Int8 + | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Timestamp(_, _) + | Decimal32(_, _) + | Decimal64(_, _) + | Decimal128(_, _) + | Decimal256(_, _) => args[0].cast_to(&Timestamp(Second, tz), None), + Float16 | Float32 | Float64 => args[0] + .cast_to(&Int64, None)? + .cast_to(&Timestamp(Second, tz), None), Utf8View | LargeUtf8 | Utf8 => to_timestamp_impl::( &args, "to_timestamp_seconds", @@ -533,9 +613,25 @@ impl ScalarUDFImpl for ToTimestampMillisFunc { } match args[0].data_type() { - Null | Int32 | Int64 | Timestamp(_, _) => { + Null + | Int8 + | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Timestamp(_, _) + | Decimal32(_, _) + | Decimal64(_, _) + | Decimal128(_, _) + | Decimal256(_, _) => { args[0].cast_to(&Timestamp(Millisecond, self.timezone.clone()), None) } + Float16 | Float32 | Float64 => args[0] + .cast_to(&Int64, None)? + .cast_to(&Timestamp(Millisecond, self.timezone.clone()), None), Utf8View | LargeUtf8 | Utf8 => to_timestamp_impl::( &args, "to_timestamp_millis", @@ -593,9 +689,25 @@ impl ScalarUDFImpl for ToTimestampMicrosFunc { } match args[0].data_type() { - Null | Int32 | Int64 | Timestamp(_, _) => { + Null + | Int8 + | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Timestamp(_, _) + | Decimal32(_, _) + | Decimal64(_, _) + | Decimal128(_, _) + | Decimal256(_, _) => { args[0].cast_to(&Timestamp(Microsecond, self.timezone.clone()), None) } + Float16 | Float32 | Float64 => args[0] + .cast_to(&Int64, None)? + .cast_to(&Timestamp(Microsecond, self.timezone.clone()), None), Utf8View | LargeUtf8 | Utf8 => to_timestamp_impl::( &args, "to_timestamp_micros", @@ -653,9 +765,25 @@ impl ScalarUDFImpl for ToTimestampNanosFunc { } match args[0].data_type() { - Null | Int32 | Int64 | Timestamp(_, _) => { + Null + | Int8 + | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Timestamp(_, _) + | Decimal32(_, _) + | Decimal64(_, _) + | Decimal128(_, _) + | Decimal256(_, _) => { args[0].cast_to(&Timestamp(Nanosecond, self.timezone.clone()), None) } + Float16 | Float32 | Float64 => args[0] + .cast_to(&Int64, None)? + .cast_to(&Timestamp(Nanosecond, self.timezone.clone()), None), Utf8View | LargeUtf8 | Utf8 => to_timestamp_impl::( &args, "to_timestamp_nanos", @@ -1735,4 +1863,23 @@ mod tests { assert_contains!(actual, expected); } } + + #[test] + fn test_decimal_to_nanoseconds_negative_scale() { + // scale -2: internal value 5 represents 5 * 10^2 = 500 seconds + let nanos = decimal_to_nanoseconds(5, -2); + assert_eq!(nanos, 500_000_000_000); // 500 seconds in nanoseconds + + // scale -1: internal value 10 represents 10 * 10^1 = 100 seconds + let nanos = decimal_to_nanoseconds(10, -1); + assert_eq!(nanos, 100_000_000_000); + + // scale 0: internal value 5 represents 5 seconds + let nanos = decimal_to_nanoseconds(5, 0); + assert_eq!(nanos, 5_000_000_000); + + // scale 3: internal value 1500 represents 1.5 seconds + let nanos = decimal_to_nanoseconds(1500, 3); + assert_eq!(nanos, 1_500_000_000); + } } diff --git a/datafusion/functions/src/datetime/to_unixtime.rs b/datafusion/functions/src/datetime/to_unixtime.rs index 5ebcce0a7cfc..2dd377282725 100644 --- a/datafusion/functions/src/datetime/to_unixtime.rs +++ b/datafusion/functions/src/datetime/to_unixtime.rs @@ -27,7 +27,12 @@ use std::any::Any; #[user_doc( doc_section(label = "Time and Date Functions"), - description = "Converts a value to seconds since the unix epoch (`1970-01-01T00:00:00`). Supports strings, dates, timestamps, integer, unsigned integer, and float types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. Integers, unsigned integers, and floats are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00`).", + description = r#" +Converts a value to seconds since the unix epoch (`1970-01-01T00:00:00`). +Supports strings, dates, timestamps, integer, unsigned integer, and float types as input. +Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Integers, unsigned integers, and floats are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00`)."#, syntax_example = "to_unixtime(expression[, ..., format_n])", sql_example = r#" ```sql diff --git a/datafusion/functions/src/encoding/inner.rs b/datafusion/functions/src/encoding/inner.rs index 7b72c264e555..4ad67b78178f 100644 --- a/datafusion/functions/src/encoding/inner.rs +++ b/datafusion/functions/src/encoding/inner.rs @@ -19,8 +19,8 @@ use arrow::{ array::{ - Array, ArrayRef, AsArray, BinaryArrayType, FixedSizeBinaryArray, - GenericBinaryArray, GenericStringArray, OffsetSizeTrait, + Array, ArrayRef, AsArray, BinaryArrayType, GenericBinaryArray, + GenericStringArray, OffsetSizeTrait, }, datatypes::DataType, }; @@ -52,6 +52,12 @@ const BASE64_ENGINE: GeneralPurpose = GeneralPurpose::new( .with_decode_padding_mode(DecodePaddingMode::Indifferent), ); +// Generate padding characters when encoding +const BASE64_ENGINE_PADDED: GeneralPurpose = GeneralPurpose::new( + &base64::alphabet::STANDARD, + GeneralPurposeConfig::new().with_encode_padding(true), +); + #[user_doc( doc_section(label = "Binary String Functions"), description = "Encode binary data into a textual representation.", @@ -62,7 +68,7 @@ const BASE64_ENGINE: GeneralPurpose = GeneralPurpose::new( ), argument( name = "format", - description = "Supported formats are: `base64`, `hex`" + description = "Supported formats are: `base64`, `base64pad`, `hex`" ), related_udf(name = "decode") )] @@ -239,7 +245,7 @@ fn encode_array(array: &ArrayRef, encoding: Encoding) -> Result { encoding.encode_array::<_, i64>(&array.as_binary::()) } DataType::FixedSizeBinary(_) => { - encoding.encode_fsb_array(array.as_fixed_size_binary()) + encoding.encode_array::<_, i32>(&array.as_fixed_size_binary()) } dt => { internal_err!("Unexpected data type for encode: {dt}") @@ -307,7 +313,7 @@ fn decode_array(array: &ArrayRef, encoding: Encoding) -> Result { let array = array.as_fixed_size_binary(); // TODO: could we be more conservative by accounting for nulls? let estimate = array.len().saturating_mul(*size as usize); - encoding.decode_fsb_array(array, estimate) + encoding.decode_array::<_, i32>(&array, estimate) } dt => { internal_err!("Unexpected data type for decode: {dt}") @@ -319,12 +325,18 @@ fn decode_array(array: &ArrayRef, encoding: Encoding) -> Result { #[derive(Debug, Copy, Clone)] enum Encoding { Base64, + Base64Padded, Hex, } impl fmt::Display for Encoding { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", format!("{self:?}").to_lowercase()) + let name = match self { + Self::Base64 => "base64", + Self::Base64Padded => "base64pad", + Self::Hex => "hex", + }; + write!(f, "{name}") } } @@ -345,9 +357,10 @@ impl TryFrom<&ColumnarValue> for Encoding { }; match encoding { "base64" => Ok(Self::Base64), + "base64pad" => Ok(Self::Base64Padded), "hex" => Ok(Self::Hex), _ => { - let options = [Self::Base64, Self::Hex] + let options = [Self::Base64, Self::Base64Padded, Self::Hex] .iter() .map(|i| i.to_string()) .collect::>() @@ -364,15 +377,18 @@ impl Encoding { fn encode_bytes(self, value: &[u8]) -> String { match self { Self::Base64 => BASE64_ENGINE.encode(value), + Self::Base64Padded => BASE64_ENGINE_PADDED.encode(value), Self::Hex => hex::encode(value), } } fn decode_bytes(self, value: &[u8]) -> Result> { match self { - Self::Base64 => BASE64_ENGINE.decode(value).map_err(|e| { - exec_datafusion_err!("Failed to decode value using base64: {e}") - }), + Self::Base64 | Self::Base64Padded => { + BASE64_ENGINE.decode(value).map_err(|e| { + exec_datafusion_err!("Failed to decode value using {self}: {e}") + }) + } Self::Hex => hex::decode(value).map_err(|e| { exec_datafusion_err!("Failed to decode value using hex: {e}") }), @@ -396,26 +412,15 @@ impl Encoding { .collect(); Ok(Arc::new(array)) } - Self::Hex => { - let array: GenericStringArray = - array.iter().map(|x| x.map(hex::encode)).collect(); - Ok(Arc::new(array)) - } - } - } - - // TODO: refactor this away once https://github.com/apache/arrow-rs/pull/8993 lands - fn encode_fsb_array(self, array: &FixedSizeBinaryArray) -> Result { - match self { - Self::Base64 => { - let array: GenericStringArray = array + Self::Base64Padded => { + let array: GenericStringArray = array .iter() - .map(|x| x.map(|x| BASE64_ENGINE.encode(x))) + .map(|x| x.map(|x| BASE64_ENGINE_PADDED.encode(x))) .collect(); Ok(Arc::new(array)) } Self::Hex => { - let array: GenericStringArray = + let array: GenericStringArray = array.iter().map(|x| x.map(hex::encode)).collect(); Ok(Arc::new(array)) } @@ -448,7 +453,7 @@ impl Encoding { } match self { - Self::Base64 => { + Self::Base64 | Self::Base64Padded => { let upper_bound = base64::decoded_len_estimate(approx_data_size); delegated_decode::<_, _, OutputOffset>(base64_decode, value, upper_bound) } @@ -461,73 +466,6 @@ impl Encoding { } } } - - // TODO: refactor this away once https://github.com/apache/arrow-rs/pull/8993 lands - fn decode_fsb_array( - self, - value: &FixedSizeBinaryArray, - approx_data_size: usize, - ) -> Result { - fn hex_decode(input: &[u8], buf: &mut [u8]) -> Result { - // only write input / 2 bytes to buf - let out_len = input.len() / 2; - let buf = &mut buf[..out_len]; - hex::decode_to_slice(input, buf) - .map_err(|e| exec_datafusion_err!("Failed to decode from hex: {e}"))?; - Ok(out_len) - } - - fn base64_decode(input: &[u8], buf: &mut [u8]) -> Result { - BASE64_ENGINE - .decode_slice(input, buf) - .map_err(|e| exec_datafusion_err!("Failed to decode from base64: {e}")) - } - - fn delegated_decode( - decode: DecodeFunction, - input: &FixedSizeBinaryArray, - conservative_upper_bound_size: usize, - ) -> Result - where - DecodeFunction: Fn(&[u8], &mut [u8]) -> Result, - { - let mut values = vec![0; conservative_upper_bound_size]; - let mut offsets = OffsetBufferBuilder::new(input.len()); - let mut total_bytes_decoded = 0; - for v in input.iter() { - if let Some(v) = v { - let cursor = &mut values[total_bytes_decoded..]; - let decoded = decode(v, cursor)?; - total_bytes_decoded += decoded; - offsets.push_length(decoded); - } else { - offsets.push_length(0); - } - } - // We reserved an upper bound size for the values buffer, but we only use the actual size - values.truncate(total_bytes_decoded); - let binary_array = GenericBinaryArray::::try_new( - offsets.finish(), - Buffer::from_vec(values), - input.nulls().cloned(), - )?; - Ok(Arc::new(binary_array)) - } - - match self { - Self::Base64 => { - let upper_bound = base64::decoded_len_estimate(approx_data_size); - delegated_decode(base64_decode, value, upper_bound) - } - Self::Hex => { - // Calculate the upper bound for decoded byte size - // For hex encoding, each pair of hex characters (2 bytes) represents 1 byte when decoded - // So the upper bound is half the length of the input values. - let upper_bound = approx_data_size / 2; - delegated_decode(hex_decode, value, upper_bound) - } - } - } } fn delegated_decode<'a, DecodeFunction, InputBinaryArray, OutputOffset>( diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs index f88304a6a5f8..b9ce113efa62 100644 --- a/datafusion/functions/src/lib.rs +++ b/datafusion/functions/src/lib.rs @@ -24,8 +24,6 @@ // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -// https://github.com/apache/datafusion/issues/18881 -#![deny(clippy::allow_attributes)] //! Function packages for [DataFusion]. //! diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index 4adc331fef66..380877b59364 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -332,7 +332,8 @@ macro_rules! make_math_binary_udf { use arrow::array::{ArrayRef, AsArray}; use arrow::datatypes::{DataType, Float32Type, Float64Type}; - use datafusion_common::{Result, exec_err}; + use datafusion_common::utils::take_function_args; + use datafusion_common::{Result, ScalarValue, internal_err}; use datafusion_expr::TypeSignature; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ @@ -393,37 +394,76 @@ macro_rules! make_math_binary_udf { &self, args: ScalarFunctionArgs, ) -> Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => { - let y = args[0].as_primitive::(); - let x = args[1].as_primitive::(); - let result = arrow::compute::binary::<_, _, _, Float64Type>( - y, - x, - |y, x| f64::$BINARY_FUNC(y, x), - )?; - Arc::new(result) as _ - } - DataType::Float32 => { - let y = args[0].as_primitive::(); - let x = args[1].as_primitive::(); - let result = arrow::compute::binary::<_, _, _, Float32Type>( - y, - x, - |y, x| f32::$BINARY_FUNC(y, x), - )?; - Arc::new(result) as _ - } - other => { - return exec_err!( - "Unsupported data type {other:?} for function {}", - self.name() - ); + let ScalarFunctionArgs { + args, return_field, .. + } = args; + let return_type = return_field.data_type(); + let [y, x] = take_function_args(self.name(), args)?; + + match (y, x) { + ( + ColumnarValue::Scalar(y_scalar), + ColumnarValue::Scalar(x_scalar), + ) => match (&y_scalar, &x_scalar) { + (y, x) if y.is_null() || x.is_null() => { + ColumnarValue::Scalar(ScalarValue::Null) + .cast_to(return_type, None) + } + ( + ScalarValue::Float64(Some(yv)), + ScalarValue::Float64(Some(xv)), + ) => Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some( + f64::$BINARY_FUNC(*yv, *xv), + )))), + ( + ScalarValue::Float32(Some(yv)), + ScalarValue::Float32(Some(xv)), + ) => Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some( + f32::$BINARY_FUNC(*yv, *xv), + )))), + _ => internal_err!( + "Unexpected scalar types for function {}: {:?}, {:?}", + self.name(), + y_scalar.data_type(), + x_scalar.data_type() + ), + }, + (y, x) => { + let args = ColumnarValue::values_to_arrays(&[y, x])?; + let arr: ArrayRef = match args[0].data_type() { + DataType::Float64 => { + let y = args[0].as_primitive::(); + let x = args[1].as_primitive::(); + let result = + arrow::compute::binary::<_, _, _, Float64Type>( + y, + x, + |y, x| f64::$BINARY_FUNC(y, x), + )?; + Arc::new(result) as _ + } + DataType::Float32 => { + let y = args[0].as_primitive::(); + let x = args[1].as_primitive::(); + let result = + arrow::compute::binary::<_, _, _, Float32Type>( + y, + x, + |y, x| f32::$BINARY_FUNC(y, x), + )?; + Arc::new(result) as _ + } + other => { + return internal_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ); + } + }; + + Ok(ColumnarValue::Array(arr)) } - }; - - Ok(ColumnarValue::Array(arr)) + } } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/math/abs.rs b/datafusion/functions/src/math/abs.rs index 081668f7669f..1b5aaf7745a8 100644 --- a/datafusion/functions/src/math/abs.rs +++ b/datafusion/functions/src/math/abs.rs @@ -50,6 +50,7 @@ macro_rules! make_abs_function { }}; } +#[macro_export] macro_rules! make_try_abs_function { ($ARRAY_TYPE:ident) => {{ |input: &ArrayRef| { @@ -62,7 +63,8 @@ macro_rules! make_try_abs_function { x )) }) - })?; + }) + .and_then(|v| Ok(v.with_data_type(input.data_type().clone())))?; // maintain decimal's precision and scale Ok(Arc::new(res) as ArrayRef) } }}; diff --git a/datafusion/functions/src/math/ceil.rs b/datafusion/functions/src/math/ceil.rs index 501741002f96..5961b3cb27fe 100644 --- a/datafusion/functions/src/math/ceil.rs +++ b/datafusion/functions/src/math/ceil.rs @@ -95,8 +95,35 @@ impl ScalarUDFImpl for CeilFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let value = &args[0]; + let arg = &args.args[0]; + + // Scalar fast path for float types - avoid array conversion overhead entirely + if let ColumnarValue::Scalar(scalar) = arg { + match scalar { + ScalarValue::Float64(v) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Float64( + v.map(f64::ceil), + ))); + } + ScalarValue::Float32(v) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Float32( + v.map(f32::ceil), + ))); + } + ScalarValue::Null => { + return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))); + } + // For decimals: convert to array of size 1, process, then extract scalar + // This ensures we don't expand the array while reusing overflow validation + _ => {} + } + } + + // Track if input was a scalar to convert back at the end + let is_scalar = matches!(arg, ColumnarValue::Scalar(_)); + + // Array path (also handles decimal scalars converted to size-1 arrays) + let value = arg.to_array(args.number_rows)?; let result: ArrayRef = match value.data_type() { DataType::Float64 => Arc::new( @@ -114,7 +141,7 @@ impl ScalarUDFImpl for CeilFunc { } DataType::Decimal32(precision, scale) => { apply_decimal_op::( - value, + &value, *precision, *scale, self.name(), @@ -123,7 +150,7 @@ impl ScalarUDFImpl for CeilFunc { } DataType::Decimal64(precision, scale) => { apply_decimal_op::( - value, + &value, *precision, *scale, self.name(), @@ -132,7 +159,7 @@ impl ScalarUDFImpl for CeilFunc { } DataType::Decimal128(precision, scale) => { apply_decimal_op::( - value, + &value, *precision, *scale, self.name(), @@ -141,7 +168,7 @@ impl ScalarUDFImpl for CeilFunc { } DataType::Decimal256(precision, scale) => { apply_decimal_op::( - value, + &value, *precision, *scale, self.name(), @@ -156,7 +183,12 @@ impl ScalarUDFImpl for CeilFunc { } }; - Ok(ColumnarValue::Array(result)) + // If input was a scalar, convert result back to scalar + if is_scalar { + ScalarValue::try_from_array(&result, 0).map(ColumnarValue::Scalar) + } else { + Ok(ColumnarValue::Array(result)) + } } fn output_ordering(&self, input: &[ExprProperties]) -> Result { diff --git a/datafusion/functions/src/math/cot.rs b/datafusion/functions/src/math/cot.rs index a0d7b02b68e5..1f67ef713833 100644 --- a/datafusion/functions/src/math/cot.rs +++ b/datafusion/functions/src/math/cot.rs @@ -18,12 +18,12 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, AsArray}; +use arrow::array::AsArray; use arrow::datatypes::DataType::{Float32, Float64}; use arrow::datatypes::{DataType, Float32Type, Float64Type}; -use crate::utils::make_scalar_function; -use datafusion_common::{Result, exec_err}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, internal_err}; use datafusion_expr::{ColumnarValue, Documentation, ScalarFunctionArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -96,24 +96,47 @@ impl ScalarUDFImpl for CotFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(cot, vec![])(&args.args) - } -} + let return_field = args.return_field; + let [arg] = take_function_args(self.name(), args.args)?; + + match arg { + ColumnarValue::Scalar(scalar) => { + if scalar.is_null() { + return ColumnarValue::Scalar(ScalarValue::Null) + .cast_to(return_field.data_type(), None); + } -///cot SQL function -fn cot(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - Float64 => Ok(Arc::new( - args[0] - .as_primitive::() - .unary::<_, Float64Type>(|x: f64| compute_cot64(x)), - ) as ArrayRef), - Float32 => Ok(Arc::new( - args[0] - .as_primitive::() - .unary::<_, Float32Type>(|x: f32| compute_cot32(x)), - ) as ArrayRef), - other => exec_err!("Unsupported data type {other:?} for function cot"), + match scalar { + ScalarValue::Float64(Some(v)) => Ok(ColumnarValue::Scalar( + ScalarValue::Float64(Some(compute_cot64(v))), + )), + ScalarValue::Float32(Some(v)) => Ok(ColumnarValue::Scalar( + ScalarValue::Float32(Some(compute_cot32(v))), + )), + _ => { + internal_err!( + "Unexpected scalar type for cot: {:?}", + scalar.data_type() + ) + } + } + } + ColumnarValue::Array(array) => match array.data_type() { + Float64 => Ok(ColumnarValue::Array(Arc::new( + array + .as_primitive::() + .unary::<_, Float64Type>(compute_cot64), + ))), + Float32 => Ok(ColumnarValue::Array(Arc::new( + array + .as_primitive::() + .unary::<_, Float32Type>(compute_cot32), + ))), + other => { + internal_err!("Unexpected data type {other:?} for function cot") + } + }, + } } } @@ -129,54 +152,212 @@ fn compute_cot64(x: f64) -> f64 { #[cfg(test)] mod test { - use crate::math::cot::cot; + use std::sync::Arc; + use arrow::array::{ArrayRef, Float32Array, Float64Array}; + use arrow::datatypes::{DataType, Field}; + use datafusion_common::ScalarValue; use datafusion_common::cast::{as_float32_array, as_float64_array}; - use std::sync::Arc; + use datafusion_common::config::ConfigOptions; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; + + use crate::math::cot::CotFunc; #[test] fn test_cot_f32() { - let args: Vec = - vec![Arc::new(Float32Array::from(vec![12.1, 30.0, 90.0, -30.0]))]; - let result = cot(&args).expect("failed to initialize function cot"); - let floats = - as_float32_array(&result).expect("failed to initialize function cot"); - - let expected = Float32Array::from(vec![ - -1.986_460_4, - -0.156_119_96, - -0.501_202_8, - 0.156_119_96, - ]); - - let eps = 1e-6; - assert_eq!(floats.len(), 4); - assert!((floats.value(0) - expected.value(0)).abs() < eps); - assert!((floats.value(1) - expected.value(1)).abs() < eps); - assert!((floats.value(2) - expected.value(2)).abs() < eps); - assert!((floats.value(3) - expected.value(3)).abs() < eps); + let array = Arc::new(Float32Array::from(vec![12.1, 30.0, 90.0, -30.0])); + let arg_fields = vec![Field::new("a", DataType::Float32, false).into()]; + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)], + arg_fields, + number_rows: array.len(), + return_field: Field::new("f", DataType::Float32, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = CotFunc::new() + .invoke_with_args(args) + .expect("failed to initialize function cot"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float32_array(&arr) + .expect("failed to convert result to a Float32Array"); + + let expected = Float32Array::from(vec![ + -1.986_460_4, + -0.156_119_96, + -0.501_202_8, + 0.156_119_96, + ]); + + let eps = 1e-6; + assert_eq!(floats.len(), 4); + assert!((floats.value(0) - expected.value(0)).abs() < eps); + assert!((floats.value(1) - expected.value(1)).abs() < eps); + assert!((floats.value(2) - expected.value(2)).abs() < eps); + assert!((floats.value(3) - expected.value(3)).abs() < eps); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } } #[test] fn test_cot_f64() { - let args: Vec = - vec![Arc::new(Float64Array::from(vec![12.1, 30.0, 90.0, -30.0]))]; - let result = cot(&args).expect("failed to initialize function cot"); - let floats = - as_float64_array(&result).expect("failed to initialize function cot"); - - let expected = Float64Array::from(vec![ - -1.986_458_685_881_4, - -0.156_119_952_161_6, - -0.501_202_783_380_1, - 0.156_119_952_161_6, - ]); - - let eps = 1e-12; - assert_eq!(floats.len(), 4); - assert!((floats.value(0) - expected.value(0)).abs() < eps); - assert!((floats.value(1) - expected.value(1)).abs() < eps); - assert!((floats.value(2) - expected.value(2)).abs() < eps); - assert!((floats.value(3) - expected.value(3)).abs() < eps); + let array = Arc::new(Float64Array::from(vec![12.1, 30.0, 90.0, -30.0])); + let arg_fields = vec![Field::new("a", DataType::Float64, false).into()]; + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)], + arg_fields, + number_rows: array.len(), + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = CotFunc::new() + .invoke_with_args(args) + .expect("failed to initialize function cot"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + + let expected = Float64Array::from(vec![ + -1.986_458_685_881_4, + -0.156_119_952_161_6, + -0.501_202_783_380_1, + 0.156_119_952_161_6, + ]); + + let eps = 1e-12; + assert_eq!(floats.len(), 4); + assert!((floats.value(0) - expected.value(0)).abs() < eps); + assert!((floats.value(1) - expected.value(1)).abs() < eps); + assert!((floats.value(2) - expected.value(2)).abs() < eps); + assert!((floats.value(3) - expected.value(3)).abs() < eps); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_cot_scalar_f64() { + let arg_fields = vec![Field::new("a", DataType::Float64, false).into()]; + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(1.0)))], + arg_fields, + number_rows: 1, + return_field: Field::new("f", DataType::Float64, false).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = CotFunc::new() + .invoke_with_args(args) + .expect("cot scalar should succeed"); + + match result { + ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) => { + // cot(1.0) = 1/tan(1.0) ≈ 0.6420926159343306 + let expected = 1.0_f64 / 1.0_f64.tan(); + assert!((v - expected).abs() < 1e-12); + } + _ => panic!("Expected Float64 scalar"), + } + } + + #[test] + fn test_cot_scalar_f32() { + let arg_fields = vec![Field::new("a", DataType::Float32, false).into()]; + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(ScalarValue::Float32(Some(1.0)))], + arg_fields, + number_rows: 1, + return_field: Field::new("f", DataType::Float32, false).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = CotFunc::new() + .invoke_with_args(args) + .expect("cot scalar should succeed"); + + match result { + ColumnarValue::Scalar(ScalarValue::Float32(Some(v))) => { + let expected = 1.0_f32 / 1.0_f32.tan(); + assert!((v - expected).abs() < 1e-6); + } + _ => panic!("Expected Float32 scalar"), + } + } + + #[test] + fn test_cot_scalar_null() { + let arg_fields = vec![Field::new("a", DataType::Float64, true).into()]; + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(ScalarValue::Float64(None))], + arg_fields, + number_rows: 1, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = CotFunc::new() + .invoke_with_args(args) + .expect("cot null should succeed"); + + match result { + ColumnarValue::Scalar(scalar) => { + assert!(scalar.is_null()); + } + _ => panic!("Expected scalar result"), + } + } + + #[test] + fn test_cot_scalar_zero() { + let arg_fields = vec![Field::new("a", DataType::Float64, false).into()]; + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(0.0)))], + arg_fields, + number_rows: 1, + return_field: Field::new("f", DataType::Float64, false).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = CotFunc::new() + .invoke_with_args(args) + .expect("cot zero should succeed"); + + match result { + ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) => { + // cot(0) = 1/tan(0) = infinity + assert!(v.is_infinite()); + } + _ => panic!("Expected Float64 scalar"), + } + } + + #[test] + fn test_cot_scalar_pi() { + let arg_fields = vec![Field::new("a", DataType::Float64, false).into()]; + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(ScalarValue::Float64(Some( + std::f64::consts::PI, + )))], + arg_fields, + number_rows: 1, + return_field: Field::new("f", DataType::Float64, false).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = CotFunc::new() + .invoke_with_args(args) + .expect("cot pi should succeed"); + + match result { + ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) => { + // cot(PI) = 1/tan(PI) - very large negative number due to floating point + let expected = 1.0_f64 / std::f64::consts::PI.tan(); + assert!((v - expected).abs() < 1e-6); + } + _ => panic!("Expected Float64 scalar"), + } } } diff --git a/datafusion/functions/src/math/factorial.rs b/datafusion/functions/src/math/factorial.rs index ffe12466dc17..c1dd802140c0 100644 --- a/datafusion/functions/src/math/factorial.rs +++ b/datafusion/functions/src/math/factorial.rs @@ -22,8 +22,9 @@ use std::sync::Arc; use arrow::datatypes::DataType::Int64; use arrow::datatypes::{DataType, Int64Type}; -use crate::utils::make_scalar_function; -use datafusion_common::{Result, exec_err}; +use datafusion_common::{ + Result, ScalarValue, exec_err, internal_err, utils::take_function_args, +}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -81,7 +82,39 @@ impl ScalarUDFImpl for FactorialFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(factorial, vec![])(&args.args) + let [arg] = take_function_args(self.name(), args.args)?; + + match arg { + ColumnarValue::Scalar(scalar) => { + if scalar.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))); + } + + match scalar { + ScalarValue::Int64(Some(v)) => { + let result = compute_factorial(v)?; + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(result)))) + } + _ => { + internal_err!( + "Unexpected data type {:?} for function factorial", + scalar.data_type() + ) + } + } + } + ColumnarValue::Array(array) => match array.data_type() { + Int64 => { + let result: Int64Array = array + .as_primitive::() + .try_unary(compute_factorial)?; + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } + other => { + internal_err!("Unexpected data type {other:?} for function factorial") + } + }, + } } fn documentation(&self) -> Option<&Documentation> { @@ -113,53 +146,12 @@ const FACTORIALS: [i64; 21] = [ 2432902008176640000, ]; // if return type changes, this constant needs to be updated accordingly -/// Factorial SQL function -fn factorial(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - Int64 => { - let result: Int64Array = - args[0].as_primitive::().try_unary(|a| { - if a < 0 { - Ok(1) - } else if a < FACTORIALS.len() as i64 { - Ok(FACTORIALS[a as usize]) - } else { - exec_err!("Overflow happened on FACTORIAL({a})") - } - })?; - Ok(Arc::new(result) as ArrayRef) - } - other => exec_err!("Unsupported data type {other:?} for function factorial."), - } -} - -#[cfg(test)] -mod test { - use super::*; - use datafusion_common::cast::as_int64_array; - - #[test] - fn test_factorial_i64() { - let args: Vec = vec![ - Arc::new(Int64Array::from(vec![0, 1, 2, 4, 20, -1])), // input - ]; - - let result = factorial(&args).expect("failed to initialize function factorial"); - let ints = - as_int64_array(&result).expect("failed to initialize function factorial"); - - let expected = Int64Array::from(vec![1, 1, 2, 24, 2432902008176640000, 1]); - - assert_eq!(ints, &expected); - } - - #[test] - fn test_overflow() { - let args: Vec = vec![ - Arc::new(Int64Array::from(vec![21])), // input - ]; - - let result = factorial(&args); - assert!(result.is_err()); +fn compute_factorial(n: i64) -> Result { + if n < 0 { + Ok(1) + } else if n < FACTORIALS.len() as i64 { + Ok(FACTORIALS[n as usize]) + } else { + exec_err!("Overflow happened on FACTORIAL({n})") } } diff --git a/datafusion/functions/src/math/floor.rs b/datafusion/functions/src/math/floor.rs index 221e58e1e7a7..d4f25716ff7e 100644 --- a/datafusion/functions/src/math/floor.rs +++ b/datafusion/functions/src/math/floor.rs @@ -19,18 +19,22 @@ use std::any::Any; use std::sync::Arc; use arrow::array::{ArrayRef, AsArray}; +use arrow::compute::{DecimalCast, rescale_decimal}; use arrow::datatypes::{ - DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float32Type, - Float64Type, + ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type, + Decimal256Type, DecimalType, Float32Type, Float64Type, }; use datafusion_common::{Result, ScalarValue, exec_err}; use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::preimage::PreimageResult; +use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ - Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, - TypeSignature, TypeSignatureClass, Volatility, + Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDFImpl, + Signature, TypeSignature, TypeSignatureClass, Volatility, }; use datafusion_macros::user_doc; +use num_traits::{CheckedAdd, Float, One}; use super::decimal::{apply_decimal_op, floor_decimal_value}; @@ -74,6 +78,42 @@ impl FloorFunc { } } +// ============ Macro for preimage bounds ============ +/// Generates the code to call the appropriate bounds function and wrap results. +macro_rules! preimage_bounds { + // Float types: call float_preimage_bounds and wrap in ScalarValue + (float: $variant:ident, $value:expr) => { + float_preimage_bounds($value).map(|(lo, hi)| { + ( + ScalarValue::$variant(Some(lo)), + ScalarValue::$variant(Some(hi)), + ) + }) + }; + + // Integer types: call int_preimage_bounds and wrap in ScalarValue + (int: $variant:ident, $value:expr) => { + int_preimage_bounds($value).map(|(lo, hi)| { + ( + ScalarValue::$variant(Some(lo)), + ScalarValue::$variant(Some(hi)), + ) + }) + }; + + // Decimal types: call decimal_preimage_bounds with precision/scale and wrap in ScalarValue + (decimal: $variant:ident, $decimal_type:ty, $value:expr, $precision:expr, $scale:expr) => { + decimal_preimage_bounds::<$decimal_type>($value, $precision, $scale).map( + |(lo, hi)| { + ( + ScalarValue::$variant(Some(lo), $precision, $scale), + ScalarValue::$variant(Some(hi), $precision, $scale), + ) + }, + ) + }; +} + impl ScalarUDFImpl for FloorFunc { fn as_any(&self) -> &dyn Any { self @@ -95,8 +135,35 @@ impl ScalarUDFImpl for FloorFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let value = &args[0]; + let arg = &args.args[0]; + + // Scalar fast path for float types - avoid array conversion overhead entirely + if let ColumnarValue::Scalar(scalar) = arg { + match scalar { + ScalarValue::Float64(v) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Float64( + v.map(f64::floor), + ))); + } + ScalarValue::Float32(v) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Float32( + v.map(f32::floor), + ))); + } + ScalarValue::Null => { + return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))); + } + // For decimals: convert to array of size 1, process, then extract scalar + // This ensures we don't expand the array while reusing overflow validation + _ => {} + } + } + + // Track if input was a scalar to convert back at the end + let is_scalar = matches!(arg, ColumnarValue::Scalar(_)); + + // Array path (also handles decimal scalars converted to size-1 arrays) + let value = arg.to_array(args.number_rows)?; let result: ArrayRef = match value.data_type() { DataType::Float64 => Arc::new( @@ -114,7 +181,7 @@ impl ScalarUDFImpl for FloorFunc { } DataType::Decimal32(precision, scale) => { apply_decimal_op::( - value, + &value, *precision, *scale, self.name(), @@ -123,7 +190,7 @@ impl ScalarUDFImpl for FloorFunc { } DataType::Decimal64(precision, scale) => { apply_decimal_op::( - value, + &value, *precision, *scale, self.name(), @@ -132,7 +199,7 @@ impl ScalarUDFImpl for FloorFunc { } DataType::Decimal128(precision, scale) => { apply_decimal_op::( - value, + &value, *precision, *scale, self.name(), @@ -141,7 +208,7 @@ impl ScalarUDFImpl for FloorFunc { } DataType::Decimal256(precision, scale) => { apply_decimal_op::( - value, + &value, *precision, *scale, self.name(), @@ -156,7 +223,12 @@ impl ScalarUDFImpl for FloorFunc { } }; - Ok(ColumnarValue::Array(result)) + // If input was a scalar, convert result back to scalar + if is_scalar { + ScalarValue::try_from_array(&result, 0).map(ColumnarValue::Scalar) + } else { + Ok(ColumnarValue::Array(result)) + } } fn output_ordering(&self, input: &[ExprProperties]) -> Result { @@ -168,7 +240,450 @@ impl ScalarUDFImpl for FloorFunc { Interval::make_unbounded(&data_type) } + /// Compute the preimage for floor function. + /// + /// For `floor(x) = N`, the preimage is `x >= N AND x < N + 1` + /// because floor(x) = N for all x in [N, N+1). + /// + /// This enables predicate pushdown optimizations, transforming: + /// `floor(col) = 100` into `col >= 100 AND col < 101` + fn preimage( + &self, + args: &[Expr], + lit_expr: &Expr, + _info: &SimplifyContext, + ) -> Result { + // floor takes exactly one argument and we do not expect to reach here with multiple arguments. + debug_assert!(args.len() == 1, "floor() takes exactly one argument"); + + let arg = args[0].clone(); + + // Extract the literal value being compared to + let Expr::Literal(lit_value, _) = lit_expr else { + return Ok(PreimageResult::None); + }; + + // Compute lower bound (N) and upper bound (N + 1) using helper functions + let Some((lower, upper)) = (match lit_value { + // Floating-point types + ScalarValue::Float64(Some(n)) => preimage_bounds!(float: Float64, *n), + ScalarValue::Float32(Some(n)) => preimage_bounds!(float: Float32, *n), + + // Integer types (not reachable from SQL/SLT: floor() only accepts Float64/Float32/Decimal, + // so the RHS literal is always coerced to one of those before preimage runs; kept for + // programmatic use and unit tests) + ScalarValue::Int8(Some(n)) => preimage_bounds!(int: Int8, *n), + ScalarValue::Int16(Some(n)) => preimage_bounds!(int: Int16, *n), + ScalarValue::Int32(Some(n)) => preimage_bounds!(int: Int32, *n), + ScalarValue::Int64(Some(n)) => preimage_bounds!(int: Int64, *n), + + // Decimal types + // DECIMAL(precision, scale) where precision ≤ 38 -> Decimal128(precision, scale) + // DECIMAL(precision, scale) where precision > 38 -> Decimal256(precision, scale) + // Decimal32 and Decimal64 are unreachable from SQL/SLT. + ScalarValue::Decimal32(Some(n), precision, scale) => { + preimage_bounds!(decimal: Decimal32, Decimal32Type, *n, *precision, *scale) + } + ScalarValue::Decimal64(Some(n), precision, scale) => { + preimage_bounds!(decimal: Decimal64, Decimal64Type, *n, *precision, *scale) + } + ScalarValue::Decimal128(Some(n), precision, scale) => { + preimage_bounds!(decimal: Decimal128, Decimal128Type, *n, *precision, *scale) + } + ScalarValue::Decimal256(Some(n), precision, scale) => { + preimage_bounds!(decimal: Decimal256, Decimal256Type, *n, *precision, *scale) + } + + // Unsupported types + _ => None, + }) else { + return Ok(PreimageResult::None); + }; + + Ok(PreimageResult::Range { + expr: arg, + interval: Box::new(Interval::try_new(lower, upper)?), + }) + } + fn documentation(&self) -> Option<&Documentation> { self.doc() } } + +// ============ Helper functions for preimage bounds ============ + +/// Compute preimage bounds for floor function on floating-point types. +/// For floor(x) = n, the preimage is [n, n+1). +/// Returns None if: +/// - The value is non-finite (infinity, NaN) +/// - The value is not an integer (floor always returns integers, so floor(x) = 1.3 has no solution) +/// - Adding 1 would lose precision at extreme values +fn float_preimage_bounds(n: F) -> Option<(F, F)> { + let one = F::one(); + // Check for non-finite values (infinity, NaN) + if !n.is_finite() { + return None; + } + // floor always returns an integer, so if n has a fractional part, there's no solution + if n.fract() != F::zero() { + return None; + } + // Check for precision loss at extreme values + if n + one <= n { + return None; + } + Some((n, n + one)) +} + +/// Compute preimage bounds for floor function on integer types. +/// For floor(x) = n, the preimage is [n, n+1). +/// Returns None if adding 1 would overflow. +fn int_preimage_bounds(n: I) -> Option<(I, I)> { + let upper = n.checked_add(&I::one())?; + Some((n, upper)) +} + +/// Compute preimage bounds for floor function on decimal types. +/// For floor(x) = n, the preimage is [n, n+1). +/// Returns None if: +/// - The value has a fractional part (floor always returns integers) +/// - Adding 1 would overflow +fn decimal_preimage_bounds( + value: D::Native, + precision: u8, + scale: i8, +) -> Option<(D::Native, D::Native)> +where + D::Native: DecimalCast + ArrowNativeTypeOp + std::ops::Rem, +{ + // Use rescale_decimal to compute "1" at target scale (avoids manual pow) + // Convert integer 1 (scale=0) to the target scale + let one_scaled: D::Native = rescale_decimal::( + D::Native::ONE, // value = 1 + 1, // input_precision = 1 + 0, // input_scale = 0 (integer) + precision, // output_precision + scale, // output_scale + )?; + + // floor always returns an integer, so if value has a fractional part, there's no solution + // Check: value % one_scaled != 0 means fractional part exists + if scale > 0 && value % one_scaled != D::Native::ZERO { + return None; + } + + // Compute upper bound using checked addition + // Before preimage stage, the internal i128/i256(value) is validated based on the precision and scale. + // MAX_DECIMAL128_FOR_EACH_PRECISION and MAX_DECIMAL256_FOR_EACH_PRECISION are used to validate the internal i128/i256. + // Any invalid i128/i256 will not reach here. + // Therefore, the add_checked will always succeed if tested via SQL/SLT path. + let upper = value.add_checked(one_scaled).ok()?; + + Some((value, upper)) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_buffer::i256; + use datafusion_expr::col; + + /// Helper to test valid preimage cases that should return a Range + fn assert_preimage_range( + input: ScalarValue, + expected_lower: ScalarValue, + expected_upper: ScalarValue, + ) { + let floor_func = FloorFunc::new(); + let args = vec![col("x")]; + let lit_expr = Expr::Literal(input.clone(), None); + let info = SimplifyContext::default(); + + let result = floor_func.preimage(&args, &lit_expr, &info).unwrap(); + + match result { + PreimageResult::Range { expr, interval } => { + assert_eq!(expr, col("x")); + assert_eq!(interval.lower().clone(), expected_lower); + assert_eq!(interval.upper().clone(), expected_upper); + } + PreimageResult::None => { + panic!("Expected Range, got None for input {input:?}") + } + } + } + + /// Helper to test cases that should return None + fn assert_preimage_none(input: ScalarValue) { + let floor_func = FloorFunc::new(); + let args = vec![col("x")]; + let lit_expr = Expr::Literal(input.clone(), None); + let info = SimplifyContext::default(); + + let result = floor_func.preimage(&args, &lit_expr, &info).unwrap(); + assert!( + matches!(result, PreimageResult::None), + "Expected None for input {input:?}" + ); + } + + #[test] + fn test_floor_preimage_valid_cases() { + // Float64 + assert_preimage_range( + ScalarValue::Float64(Some(100.0)), + ScalarValue::Float64(Some(100.0)), + ScalarValue::Float64(Some(101.0)), + ); + // Float32 + assert_preimage_range( + ScalarValue::Float32(Some(50.0)), + ScalarValue::Float32(Some(50.0)), + ScalarValue::Float32(Some(51.0)), + ); + // Int64 + assert_preimage_range( + ScalarValue::Int64(Some(42)), + ScalarValue::Int64(Some(42)), + ScalarValue::Int64(Some(43)), + ); + // Int32 + assert_preimage_range( + ScalarValue::Int32(Some(100)), + ScalarValue::Int32(Some(100)), + ScalarValue::Int32(Some(101)), + ); + // Negative values + assert_preimage_range( + ScalarValue::Float64(Some(-5.0)), + ScalarValue::Float64(Some(-5.0)), + ScalarValue::Float64(Some(-4.0)), + ); + // Zero + assert_preimage_range( + ScalarValue::Float64(Some(0.0)), + ScalarValue::Float64(Some(0.0)), + ScalarValue::Float64(Some(1.0)), + ); + } + + #[test] + fn test_floor_preimage_non_integer_float() { + // floor(x) = 1.3 has NO SOLUTION because floor always returns an integer + // Therefore preimage should return None for non-integer literals + assert_preimage_none(ScalarValue::Float64(Some(1.3))); + assert_preimage_none(ScalarValue::Float64(Some(-2.5))); + assert_preimage_none(ScalarValue::Float32(Some(3.7))); + } + + #[test] + fn test_floor_preimage_integer_overflow() { + // All integer types at MAX value should return None + assert_preimage_none(ScalarValue::Int64(Some(i64::MAX))); + assert_preimage_none(ScalarValue::Int32(Some(i32::MAX))); + assert_preimage_none(ScalarValue::Int16(Some(i16::MAX))); + assert_preimage_none(ScalarValue::Int8(Some(i8::MAX))); + } + + #[test] + fn test_floor_preimage_float_edge_cases() { + // Float64 edge cases + assert_preimage_none(ScalarValue::Float64(Some(f64::INFINITY))); + assert_preimage_none(ScalarValue::Float64(Some(f64::NEG_INFINITY))); + assert_preimage_none(ScalarValue::Float64(Some(f64::NAN))); + assert_preimage_none(ScalarValue::Float64(Some(f64::MAX))); // precision loss + + // Float32 edge cases + assert_preimage_none(ScalarValue::Float32(Some(f32::INFINITY))); + assert_preimage_none(ScalarValue::Float32(Some(f32::NEG_INFINITY))); + assert_preimage_none(ScalarValue::Float32(Some(f32::NAN))); + assert_preimage_none(ScalarValue::Float32(Some(f32::MAX))); // precision loss + } + + #[test] + fn test_floor_preimage_null_values() { + assert_preimage_none(ScalarValue::Float64(None)); + assert_preimage_none(ScalarValue::Float32(None)); + assert_preimage_none(ScalarValue::Int64(None)); + } + + // ============ Decimal32 Tests (mirrors float/int tests) ============ + + #[test] + fn test_floor_preimage_decimal_valid_cases() { + // ===== Decimal32 ===== + // Positive integer decimal: 100.00 (scale=2, so raw=10000) + // floor(x) = 100.00 -> x in [100.00, 101.00) + assert_preimage_range( + ScalarValue::Decimal32(Some(10000), 9, 2), + ScalarValue::Decimal32(Some(10000), 9, 2), // 100.00 + ScalarValue::Decimal32(Some(10100), 9, 2), // 101.00 + ); + + // Smaller positive: 50.00 + assert_preimage_range( + ScalarValue::Decimal32(Some(5000), 9, 2), + ScalarValue::Decimal32(Some(5000), 9, 2), // 50.00 + ScalarValue::Decimal32(Some(5100), 9, 2), // 51.00 + ); + + // Negative integer decimal: -5.00 + assert_preimage_range( + ScalarValue::Decimal32(Some(-500), 9, 2), + ScalarValue::Decimal32(Some(-500), 9, 2), // -5.00 + ScalarValue::Decimal32(Some(-400), 9, 2), // -4.00 + ); + + // Zero: 0.00 + assert_preimage_range( + ScalarValue::Decimal32(Some(0), 9, 2), + ScalarValue::Decimal32(Some(0), 9, 2), // 0.00 + ScalarValue::Decimal32(Some(100), 9, 2), // 1.00 + ); + + // Scale 0 (pure integer): 42 + assert_preimage_range( + ScalarValue::Decimal32(Some(42), 9, 0), + ScalarValue::Decimal32(Some(42), 9, 0), + ScalarValue::Decimal32(Some(43), 9, 0), + ); + + // ===== Decimal64 ===== + assert_preimage_range( + ScalarValue::Decimal64(Some(10000), 18, 2), + ScalarValue::Decimal64(Some(10000), 18, 2), // 100.00 + ScalarValue::Decimal64(Some(10100), 18, 2), // 101.00 + ); + + // Negative + assert_preimage_range( + ScalarValue::Decimal64(Some(-500), 18, 2), + ScalarValue::Decimal64(Some(-500), 18, 2), // -5.00 + ScalarValue::Decimal64(Some(-400), 18, 2), // -4.00 + ); + + // Zero + assert_preimage_range( + ScalarValue::Decimal64(Some(0), 18, 2), + ScalarValue::Decimal64(Some(0), 18, 2), + ScalarValue::Decimal64(Some(100), 18, 2), + ); + + // ===== Decimal128 ===== + assert_preimage_range( + ScalarValue::Decimal128(Some(10000), 38, 2), + ScalarValue::Decimal128(Some(10000), 38, 2), // 100.00 + ScalarValue::Decimal128(Some(10100), 38, 2), // 101.00 + ); + + // Negative + assert_preimage_range( + ScalarValue::Decimal128(Some(-500), 38, 2), + ScalarValue::Decimal128(Some(-500), 38, 2), // -5.00 + ScalarValue::Decimal128(Some(-400), 38, 2), // -4.00 + ); + + // Zero + assert_preimage_range( + ScalarValue::Decimal128(Some(0), 38, 2), + ScalarValue::Decimal128(Some(0), 38, 2), + ScalarValue::Decimal128(Some(100), 38, 2), + ); + + // ===== Decimal256 ===== + assert_preimage_range( + ScalarValue::Decimal256(Some(i256::from(10000)), 76, 2), + ScalarValue::Decimal256(Some(i256::from(10000)), 76, 2), // 100.00 + ScalarValue::Decimal256(Some(i256::from(10100)), 76, 2), // 101.00 + ); + + // Negative + assert_preimage_range( + ScalarValue::Decimal256(Some(i256::from(-500)), 76, 2), + ScalarValue::Decimal256(Some(i256::from(-500)), 76, 2), // -5.00 + ScalarValue::Decimal256(Some(i256::from(-400)), 76, 2), // -4.00 + ); + + // Zero + assert_preimage_range( + ScalarValue::Decimal256(Some(i256::ZERO), 76, 2), + ScalarValue::Decimal256(Some(i256::ZERO), 76, 2), + ScalarValue::Decimal256(Some(i256::from(100)), 76, 2), + ); + } + + #[test] + fn test_floor_preimage_decimal_non_integer() { + // floor(x) = 1.30 has NO SOLUTION because floor always returns an integer + // Therefore preimage should return None for non-integer decimals + + // Decimal32 + assert_preimage_none(ScalarValue::Decimal32(Some(130), 9, 2)); // 1.30 + assert_preimage_none(ScalarValue::Decimal32(Some(-250), 9, 2)); // -2.50 + assert_preimage_none(ScalarValue::Decimal32(Some(370), 9, 2)); // 3.70 + assert_preimage_none(ScalarValue::Decimal32(Some(1), 9, 2)); // 0.01 + + // Decimal64 + assert_preimage_none(ScalarValue::Decimal64(Some(130), 18, 2)); // 1.30 + assert_preimage_none(ScalarValue::Decimal64(Some(-250), 18, 2)); // -2.50 + + // Decimal128 + assert_preimage_none(ScalarValue::Decimal128(Some(130), 38, 2)); // 1.30 + assert_preimage_none(ScalarValue::Decimal128(Some(-250), 38, 2)); // -2.50 + + // Decimal256 + assert_preimage_none(ScalarValue::Decimal256(Some(i256::from(130)), 76, 2)); // 1.30 + assert_preimage_none(ScalarValue::Decimal256(Some(i256::from(-250)), 76, 2)); // -2.50 + + // Decimal32: i32::MAX - 50 + // This return None because the value is not an integer, not because it is out of range. + assert_preimage_none(ScalarValue::Decimal32(Some(i32::MAX - 50), 10, 2)); + + // Decimal64: i64::MAX - 50 + // This return None because the value is not an integer, not because it is out of range. + assert_preimage_none(ScalarValue::Decimal64(Some(i64::MAX - 50), 19, 2)); + } + + #[test] + fn test_floor_preimage_decimal_overflow() { + // Test near MAX where adding scale_factor would overflow + + // Decimal32: i32::MAX + assert_preimage_none(ScalarValue::Decimal32(Some(i32::MAX), 10, 0)); + + // Decimal64: i64::MAX + assert_preimage_none(ScalarValue::Decimal64(Some(i64::MAX), 19, 0)); + } + + #[test] + fn test_floor_preimage_decimal_edge_cases() { + // ===== Decimal32 ===== + // Large value that doesn't overflow + // Decimal(9,2) max value is 9,999,999.99 (stored as 999,999,999) + // Use a large value that fits Decimal(9,2) and is divisible by 100 + let safe_max_aligned_32 = 999_999_900; // 9,999,999.00 + assert_preimage_range( + ScalarValue::Decimal32(Some(safe_max_aligned_32), 9, 2), + ScalarValue::Decimal32(Some(safe_max_aligned_32), 9, 2), + ScalarValue::Decimal32(Some(safe_max_aligned_32 + 100), 9, 2), + ); + + // Negative edge: use a large negative value that fits Decimal(9,2) + // Decimal(9,2) min value is -9,999,999.99 (stored as -999,999,999) + let min_aligned_32 = -999_999_900; // -9,999,999.00 + assert_preimage_range( + ScalarValue::Decimal32(Some(min_aligned_32), 9, 2), + ScalarValue::Decimal32(Some(min_aligned_32), 9, 2), + ScalarValue::Decimal32(Some(min_aligned_32 + 100), 9, 2), + ); + } + + #[test] + fn test_floor_preimage_decimal_null() { + assert_preimage_none(ScalarValue::Decimal32(None, 9, 2)); + assert_preimage_none(ScalarValue::Decimal64(None, 18, 2)); + assert_preimage_none(ScalarValue::Decimal128(None, 38, 2)); + assert_preimage_none(ScalarValue::Decimal256(None, 76, 2)); + } +} diff --git a/datafusion/functions/src/math/gcd.rs b/datafusion/functions/src/math/gcd.rs index baf52d780683..1f6a353a85ee 100644 --- a/datafusion/functions/src/math/gcd.rs +++ b/datafusion/functions/src/math/gcd.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{ArrayRef, AsArray, Int64Array, PrimitiveArray, new_null_array}; +use arrow::array::{ArrayRef, AsArray, Int64Array, PrimitiveArray}; use arrow::compute::try_binary; use arrow::datatypes::{DataType, Int64Type}; use arrow::error::ArrowError; @@ -144,10 +144,7 @@ fn compute_gcd_with_scalar(arr: &ArrayRef, scalar: Option) -> Result Ok(ColumnarValue::Array(new_null_array( - &DataType::Int64, - arr.len(), - ))), + None => Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))), } } diff --git a/datafusion/functions/src/math/iszero.rs b/datafusion/functions/src/math/iszero.rs index 6349551ca0a4..aa93d797eb7b 100644 --- a/datafusion/functions/src/math/iszero.rs +++ b/datafusion/functions/src/math/iszero.rs @@ -18,20 +18,26 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, AsArray, BooleanArray}; -use arrow::datatypes::DataType::{Boolean, Float32, Float64}; -use arrow::datatypes::{DataType, Float32Type, Float64Type}; +use arrow::array::{ArrowNativeTypeOp, AsArray, BooleanArray}; +use arrow::datatypes::DataType::{ + Boolean, Decimal32, Decimal64, Decimal128, Decimal256, Float16, Float32, Float64, + Int8, Int16, Int32, Int64, Null, UInt8, UInt16, UInt32, UInt64, +}; +use arrow::datatypes::{ + DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float16Type, + Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type, + UInt16Type, UInt32Type, UInt64Type, +}; -use datafusion_common::{Result, exec_err}; -use datafusion_expr::TypeSignature::Exact; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, internal_err}; +use datafusion_expr::{Coercion, TypeSignatureClass}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; -use crate::utils::make_scalar_function; - #[user_doc( doc_section(label = "Math Functions"), description = "Returns true if a given number is +0.0 or -0.0 otherwise returns false.", @@ -59,12 +65,10 @@ impl Default for IsZeroFunc { impl IsZeroFunc { pub fn new() -> Self { - use DataType::*; + // Accept any numeric type (ints, uints, floats, decimals) without implicit casts. + let numeric = Coercion::new_exact(TypeSignatureClass::Numeric); Self { - signature: Signature::one_of( - vec![Exact(vec![Float32]), Exact(vec![Float64])], - Volatility::Immutable, - ), + signature: Signature::coercible(vec![numeric], Volatility::Immutable), } } } @@ -87,70 +91,155 @@ impl ScalarUDFImpl for IsZeroFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(iszero, vec![])(&args.args) + let [arg] = take_function_args(self.name(), args.args)?; + + match arg { + ColumnarValue::Scalar(scalar) => { + if scalar.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))); + } + + match scalar { + ScalarValue::Float64(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0.0)))) + } + ScalarValue::Float32(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0.0)))) + } + ScalarValue::Float16(Some(v)) => Ok(ColumnarValue::Scalar( + ScalarValue::Boolean(Some(v.is_zero())), + )), + + ScalarValue::Int8(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::Int16(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::Int32(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::Int64(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::UInt8(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::UInt16(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::UInt32(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::UInt64(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + + ScalarValue::Decimal32(Some(v), ..) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::Decimal64(Some(v), ..) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::Decimal128(Some(v), ..) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::Decimal256(Some(v), ..) => Ok(ColumnarValue::Scalar( + ScalarValue::Boolean(Some(v.is_zero())), + )), + + _ => { + internal_err!( + "Unexpected scalar type for iszero: {:?}", + scalar.data_type() + ) + } + } + } + ColumnarValue::Array(array) => match array.data_type() { + Null => Ok(ColumnarValue::Array(Arc::new(BooleanArray::new_null( + array.len(), + )))), + + Float64 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0.0, + )))), + Float32 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0.0, + )))), + Float16 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x.is_zero(), + )))), + + Int8 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + Int16 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + Int32 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + Int64 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + UInt8 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + UInt16 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + UInt32 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + UInt64 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + + Decimal32(_, _) => { + Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))) + } + Decimal64(_, _) => { + Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))) + } + Decimal128(_, _) => { + Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))) + } + Decimal256(_, _) => { + Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x.is_zero(), + )))) + } + + other => { + internal_err!("Unexpected data type {other:?} for function iszero") + } + }, + } } fn documentation(&self) -> Option<&Documentation> { self.doc() } } - -/// Iszero SQL function -fn iszero(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - Float64 => Ok(Arc::new(BooleanArray::from_unary( - args[0].as_primitive::(), - |x| x == 0.0, - )) as ArrayRef), - - Float32 => Ok(Arc::new(BooleanArray::from_unary( - args[0].as_primitive::(), - |x| x == 0.0, - )) as ArrayRef), - - other => exec_err!("Unsupported data type {other:?} for function iszero"), - } -} - -#[cfg(test)] -mod test { - use std::sync::Arc; - - use arrow::array::{ArrayRef, Float32Array, Float64Array}; - - use datafusion_common::cast::as_boolean_array; - - use crate::math::iszero::iszero; - - #[test] - fn test_iszero_f64() { - let args: Vec = - vec![Arc::new(Float64Array::from(vec![1.0, 0.0, 3.0, -0.0]))]; - - let result = iszero(&args).expect("failed to initialize function iszero"); - let booleans = - as_boolean_array(&result).expect("failed to initialize function iszero"); - - assert_eq!(booleans.len(), 4); - assert!(!booleans.value(0)); - assert!(booleans.value(1)); - assert!(!booleans.value(2)); - assert!(booleans.value(3)); - } - - #[test] - fn test_iszero_f32() { - let args: Vec = - vec![Arc::new(Float32Array::from(vec![1.0, 0.0, 3.0, -0.0]))]; - - let result = iszero(&args).expect("failed to initialize function iszero"); - let booleans = - as_boolean_array(&result).expect("failed to initialize function iszero"); - - assert_eq!(booleans.len(), 4); - assert!(!booleans.value(0)); - assert!(booleans.value(1)); - assert!(!booleans.value(2)); - assert!(booleans.value(3)); - } -} diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 0c50afa2dffd..d1906a4bf0e0 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -21,9 +21,7 @@ use std::any::Any; use super::power::PowerFunc; -use crate::utils::{ - calculate_binary_math, decimal32_to_i32, decimal64_to_i64, decimal128_to_i128, -}; +use crate::utils::calculate_binary_math; use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{ DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float16Type, @@ -36,7 +34,7 @@ use datafusion_common::{ Result, ScalarValue, exec_err, internal_err, plan_datafusion_err, plan_err, }; use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF, @@ -44,7 +42,7 @@ use datafusion_expr::{ }; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; -use num_traits::Float; +use num_traits::{Float, ToPrimitive}; #[user_doc( doc_section(label = "Math Functions"), @@ -104,109 +102,109 @@ impl LogFunc { } } -/// Binary function to calculate logarithm of Decimal32 `value` using `base` base -/// Returns error if base is invalid -fn log_decimal32(value: i32, scale: i8, base: f64) -> Result { - if !base.is_finite() || base.trunc() != base { - return Err(ArrowError::ComputeError(format!( - "Log cannot use non-integer base: {base}" - ))); - } - if (base as u32) < 2 { - return Err(ArrowError::ComputeError(format!( - "Log base must be greater than 1: {base}" - ))); - } - - // Match f64::log behaviour - if value <= 0 { - return Ok(f64::NAN); - } +/// Checks if the base is valid for the efficient integer logarithm algorithm. +#[inline] +fn is_valid_integer_base(base: f64) -> bool { + base.trunc() == base && base >= 2.0 && base <= u32::MAX as f64 +} - if scale < 0 { - let actual_value = (value as f64) * 10.0_f64.powi(-(scale as i32)); - Ok(actual_value.log(base)) - } else { - let unscaled_value = decimal32_to_i32(value, scale)?; - if unscaled_value <= 0 { - return Ok(f64::NAN); - } - let log_value: u32 = unscaled_value.ilog(base as i32); - Ok(log_value as f64) +/// Calculate logarithm for Decimal32 values. +/// For integer bases >= 2 with non-negative scale, uses the efficient u32 ilog algorithm. +/// Otherwise falls back to f64 computation. +fn log_decimal32(value: i32, scale: i8, base: f64) -> Result { + if is_valid_integer_base(base) + && scale >= 0 + && let Some(unscaled) = unscale_to_u32(value, scale) + { + return if unscaled > 0 { + Ok(unscaled.ilog(base as u32) as f64) + } else { + Ok(f64::NAN) + }; } + decimal_to_f64(value, scale).map(|v| v.log(base)) } -/// Binary function to calculate logarithm of Decimal64 `value` using `base` base -/// Returns error if base is invalid +/// Calculate logarithm for Decimal64 values. +/// For integer bases >= 2 with non-negative scale, uses the efficient u64 ilog algorithm. +/// Otherwise falls back to f64 computation. fn log_decimal64(value: i64, scale: i8, base: f64) -> Result { - if !base.is_finite() || base.trunc() != base { - return Err(ArrowError::ComputeError(format!( - "Log cannot use non-integer base: {base}" - ))); - } - if (base as u32) < 2 { - return Err(ArrowError::ComputeError(format!( - "Log base must be greater than 1: {base}" - ))); + if is_valid_integer_base(base) + && scale >= 0 + && let Some(unscaled) = unscale_to_u64(value, scale) + { + return if unscaled > 0 { + Ok(unscaled.ilog(base as u64) as f64) + } else { + Ok(f64::NAN) + }; } + decimal_to_f64(value, scale).map(|v| v.log(base)) +} - if value <= 0 { - return Ok(f64::NAN); +/// Calculate logarithm for Decimal128 values. +/// For integer bases >= 2 with non-negative scale, uses the efficient u128 ilog algorithm. +/// Otherwise falls back to f64 computation. +fn log_decimal128(value: i128, scale: i8, base: f64) -> Result { + if is_valid_integer_base(base) + && scale >= 0 + && let Some(unscaled) = unscale_to_u128(value, scale) + { + return if unscaled > 0 { + Ok(unscaled.ilog(base as u128) as f64) + } else { + Ok(f64::NAN) + }; } + decimal_to_f64(value, scale).map(|v| v.log(base)) +} - if scale < 0 { - let actual_value = (value as f64) * 10.0_f64.powi(-(scale as i32)); - Ok(actual_value.log(base)) - } else { - let unscaled_value = decimal64_to_i64(value, scale)?; - if unscaled_value <= 0 { - return Ok(f64::NAN); - } - let log_value: u32 = unscaled_value.ilog(base as i64); - Ok(log_value as f64) - } +/// Unscale a Decimal32 value to u32. +#[inline] +fn unscale_to_u32(value: i32, scale: i8) -> Option { + let value_u32 = u32::try_from(value).ok()?; + let divisor = 10u32.checked_pow(scale as u32)?; + Some(value_u32 / divisor) } -/// Binary function to calculate an integer logarithm of Decimal128 `value` using `base` base -/// Returns error if base is invalid -fn log_decimal128(value: i128, scale: i8, base: f64) -> Result { - if !base.is_finite() || base.trunc() != base { - return Err(ArrowError::ComputeError(format!( - "Log cannot use non-integer base: {base}" - ))); - } - if (base as u32) < 2 { - return Err(ArrowError::ComputeError(format!( - "Log base must be greater than 1: {base}" - ))); - } +/// Unscale a Decimal64 value to u64. +#[inline] +fn unscale_to_u64(value: i64, scale: i8) -> Option { + let value_u64 = u64::try_from(value).ok()?; + let divisor = 10u64.checked_pow(scale as u32)?; + Some(value_u64 / divisor) +} - if value <= 0 { - // Reflect f64::log behaviour - return Ok(f64::NAN); - } +/// Unscale a Decimal128 value to u128. +#[inline] +fn unscale_to_u128(value: i128, scale: i8) -> Option { + let value_u128 = u128::try_from(value).ok()?; + let divisor = 10u128.checked_pow(scale as u32)?; + Some(value_u128 / divisor) +} - if scale < 0 { - let actual_value = (value as f64) * 10.0_f64.powi(-(scale as i32)); - Ok(actual_value.log(base)) - } else { - let unscaled_value = decimal128_to_i128(value, scale)?; - if unscaled_value <= 0 { - return Ok(f64::NAN); - } - let log_value: u32 = unscaled_value.ilog(base as i128); - Ok(log_value as f64) - } +/// Convert a scaled decimal value to f64. +#[inline] +fn decimal_to_f64(value: T, scale: i8) -> Result { + let value_f64 = value.to_f64().ok_or_else(|| { + ArrowError::ComputeError("Cannot convert value to f64".to_string()) + })?; + let scale_factor = 10f64.powi(scale as i32); + Ok(value_f64 / scale_factor) } -/// Binary function to calculate an integer logarithm of Decimal128 `value` using `base` base -/// Returns error if base is invalid or if value is out of bounds of Decimal128 fn log_decimal256(value: i256, scale: i8, base: f64) -> Result { + // Try to convert to i128 for the optimized path match value.to_i128() { - Some(value) => log_decimal128(value, scale, base), - None => Err(ArrowError::NotYetImplemented(format!( - "Log of Decimal256 larger than Decimal128 is not yet supported: {value}" - ))), + Some(v) => log_decimal128(v, scale, base), + None => { + // For very large Decimal256 values, use f64 computation + let value_f64 = value.to_f64().ok_or_else(|| { + ArrowError::ComputeError(format!("Cannot convert {value} to f64")) + })?; + let scale_factor = 10f64.powi(scale as i32); + Ok((value_f64 / scale_factor).log(base)) + } } } @@ -343,7 +341,7 @@ impl ScalarUDFImpl for LogFunc { fn simplify( &self, mut args: Vec, - info: &dyn SimplifyInfo, + info: &SimplifyContext, ) -> Result { let mut arg_types = args .iter() @@ -430,7 +428,6 @@ fn is_pow(func: &ScalarUDF) -> bool { #[cfg(test)] mod tests { - use std::collections::HashMap; use std::sync::Arc; use super::*; @@ -440,10 +437,8 @@ mod tests { }; use arrow::compute::SortOptions; use arrow::datatypes::{DECIMAL256_MAX_PRECISION, Field}; - use datafusion_common::DFSchema; use datafusion_common::cast::{as_float32_array, as_float64_array}; use datafusion_common::config::ConfigOptions; - use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::simplify::SimplifyContext; #[test] @@ -784,10 +779,7 @@ mod tests { #[test] // Test log() simplification errors fn test_log_simplify_errors() { - let props = ExecutionProps::new(); - let schema = - Arc::new(DFSchema::new_with_metadata(vec![], HashMap::new()).unwrap()); - let context = SimplifyContext::new(&props).with_schema(schema); + let context = SimplifyContext::default(); // Expect 0 args to error let _ = LogFunc::new().simplify(vec![], &context).unwrap_err(); // Expect 3 args to error @@ -799,10 +791,7 @@ mod tests { #[test] // Test that non-simplifiable log() expressions are unchanged after simplification fn test_log_simplify_original() { - let props = ExecutionProps::new(); - let schema = - Arc::new(DFSchema::new_with_metadata(vec![], HashMap::new()).unwrap()); - let context = SimplifyContext::new(&props).with_schema(schema); + let context = SimplifyContext::default(); // One argument with no simplifications let result = LogFunc::new().simplify(vec![lit(2)], &context).unwrap(); let ExprSimplifyResult::Original(args) = result else { @@ -1169,7 +1158,8 @@ mod tests { } #[test] - fn test_log_decimal128_wrong_base() { + fn test_log_decimal128_invalid_base() { + // Invalid base (-2.0) should return NaN, matching f64::log behavior let arg_fields = vec![ Field::new("b", DataType::Float64, false).into(), Field::new("x", DataType::Decimal128(38, 0), false).into(), @@ -1184,16 +1174,26 @@ mod tests { return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), }; - let result = LogFunc::new().invoke_with_args(args); - assert!(result.is_err()); - assert_eq!( - "Arrow error: Compute error: Log base must be greater than 1: -2", - result.unwrap_err().to_string().lines().next().unwrap() - ); + let result = LogFunc::new() + .invoke_with_args(args) + .expect("should not error on invalid base"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + assert_eq!(floats.len(), 1); + assert!(floats.value(0).is_nan()); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } } #[test] - fn test_log_decimal256_error() { + fn test_log_decimal256_large() { + // Large Decimal256 values that don't fit in i128 now use f64 fallback let arg_field = Field::new("a", DataType::Decimal256(38, 0), false).into(); let args = ScalarFunctionArgs { args: vec![ @@ -1207,11 +1207,26 @@ mod tests { return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), }; - let result = LogFunc::new().invoke_with_args(args); - assert!(result.is_err()); - assert_eq!( - result.unwrap_err().to_string().lines().next().unwrap(), - "Arrow error: Not yet implemented: Log of Decimal256 larger than Decimal128 is not yet supported: 170141183460469231731687303715884106727" - ); + let result = LogFunc::new() + .invoke_with_args(args) + .expect("should handle large Decimal256 via f64 fallback"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + assert_eq!(floats.len(), 1); + // The f64 fallback may lose some precision for very large numbers, + // but we verify we get a reasonable positive result (not NaN/infinity) + let log_result = floats.value(0); + assert!( + log_result.is_finite() && log_result > 0.0, + "Expected positive finite log result, got {log_result}" + ); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } } } diff --git a/datafusion/functions/src/math/nans.rs b/datafusion/functions/src/math/nans.rs index be21cfde0aa6..632eafe1e009 100644 --- a/datafusion/functions/src/math/nans.rs +++ b/datafusion/functions/src/math/nans.rs @@ -17,12 +17,21 @@ //! Math function: `isnan()`. -use arrow::datatypes::{DataType, Float32Type, Float64Type}; -use datafusion_common::{Result, exec_err}; -use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, TypeSignature}; - use arrow::array::{ArrayRef, AsArray, BooleanArray}; -use datafusion_expr::{Documentation, ScalarUDFImpl, Signature, Volatility}; +use arrow::datatypes::DataType::{ + Decimal32, Decimal64, Decimal128, Decimal256, Float16, Float32, Float64, Int8, Int16, + Int32, Int64, Null, UInt8, UInt16, UInt32, UInt64, +}; +use arrow::datatypes::{ + DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float16Type, + Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type, + UInt16Type, UInt32Type, UInt64Type, +}; +use datafusion_common::{Result, ScalarValue, exec_err, utils::take_function_args}; +use datafusion_expr::{ + Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignatureClass, Volatility, +}; use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; @@ -54,15 +63,10 @@ impl Default for IsNanFunc { impl IsNanFunc { pub fn new() -> Self { - use DataType::*; + // Accept any numeric type (ints, uints, floats, decimals) without implicit casts. + let numeric = Coercion::new_exact(TypeSignatureClass::Numeric); Self { - signature: Signature::one_of( - vec![ - TypeSignature::Exact(vec![Float32]), - TypeSignature::Exact(vec![Float64]), - ], - Volatility::Immutable, - ), + signature: Signature::coercible(vec![numeric], Volatility::Immutable), } } } @@ -84,26 +88,123 @@ impl ScalarUDFImpl for IsNanFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - - let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => Arc::new(BooleanArray::from_unary( - args[0].as_primitive::(), - f64::is_nan, - )) as ArrayRef, - - DataType::Float32 => Arc::new(BooleanArray::from_unary( - args[0].as_primitive::(), - f32::is_nan, - )) as ArrayRef, - other => { - return exec_err!( - "Unsupported data type {other:?} for function {}", - self.name() - ); + let [arg] = take_function_args(self.name(), args.args)?; + + match arg { + ColumnarValue::Scalar(scalar) => { + if scalar.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))); + } + + let result = match scalar { + ScalarValue::Float64(Some(v)) => Some(v.is_nan()), + ScalarValue::Float32(Some(v)) => Some(v.is_nan()), + ScalarValue::Float16(Some(v)) => Some(v.is_nan()), + + // Non-float numeric inputs are never NaN + ScalarValue::Int8(_) + | ScalarValue::Int16(_) + | ScalarValue::Int32(_) + | ScalarValue::Int64(_) + | ScalarValue::UInt8(_) + | ScalarValue::UInt16(_) + | ScalarValue::UInt32(_) + | ScalarValue::UInt64(_) + | ScalarValue::Decimal32(_, _, _) + | ScalarValue::Decimal64(_, _, _) + | ScalarValue::Decimal128(_, _, _) + | ScalarValue::Decimal256(_, _, _) => Some(false), + + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ); + } + }; + + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(result))) } - }; - Ok(ColumnarValue::Array(arr)) + ColumnarValue::Array(array) => { + // NOTE: BooleanArray::from_unary preserves nulls. + let arr: ArrayRef = match array.data_type() { + Null => Arc::new(BooleanArray::new_null(array.len())) as ArrayRef, + + Float64 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + f64::is_nan, + )) as ArrayRef, + Float32 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + f32::is_nan, + )) as ArrayRef, + Float16 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x.is_nan(), + )) as ArrayRef, + + // Non-float numeric arrays are never NaN + Decimal32(_, _) => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + Decimal64(_, _) => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + Decimal128(_, _) => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + Decimal256(_, _) => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + + Int8 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + Int16 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + Int32 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + Int64 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + UInt8 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + UInt16 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + UInt32 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + UInt64 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ); + } + }; + + Ok(ColumnarValue::Array(arr)) + } + } } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/math/nanvl.rs b/datafusion/functions/src/math/nanvl.rs index 345b1a5b71ae..2bdc3fbbc64a 100644 --- a/datafusion/functions/src/math/nanvl.rs +++ b/datafusion/functions/src/math/nanvl.rs @@ -18,12 +18,10 @@ use std::any::Any; use std::sync::Arc; -use crate::utils::make_scalar_function; - -use arrow::array::{ArrayRef, AsArray, Float32Array, Float64Array}; -use arrow::datatypes::DataType::{Float32, Float64}; -use arrow::datatypes::{DataType, Float32Type, Float64Type}; -use datafusion_common::{DataFusionError, Result, exec_err}; +use arrow::array::{ArrayRef, AsArray, Float16Array, Float32Array, Float64Array}; +use arrow::datatypes::DataType::{Float16, Float32, Float64}; +use arrow::datatypes::{DataType, Float16Type, Float32Type, Float64Type}; +use datafusion_common::{Result, ScalarValue, exec_err, utils::take_function_args}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, @@ -66,10 +64,13 @@ impl Default for NanvlFunc { impl NanvlFunc { pub fn new() -> Self { - use DataType::*; Self { signature: Signature::one_of( - vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])], + vec![ + Exact(vec![Float16, Float16]), + Exact(vec![Float32, Float32]), + Exact(vec![Float64, Float64]), + ], Volatility::Immutable, ), } @@ -91,13 +92,31 @@ impl ScalarUDFImpl for NanvlFunc { fn return_type(&self, arg_types: &[DataType]) -> Result { match &arg_types[0] { + Float16 => Ok(Float16), Float32 => Ok(Float32), _ => Ok(Float64), } } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(nanvl, vec![])(&args.args) + let [x, y] = take_function_args(self.name(), args.args)?; + + match (x, y) { + (ColumnarValue::Scalar(ScalarValue::Float16(Some(v))), y) if v.is_nan() => { + Ok(y) + } + (ColumnarValue::Scalar(ScalarValue::Float32(Some(v))), y) if v.is_nan() => { + Ok(y) + } + (ColumnarValue::Scalar(ScalarValue::Float64(Some(v))), y) if v.is_nan() => { + Ok(y) + } + (x @ ColumnarValue::Scalar(_), _) => Ok(x), + (x, y) => { + let args = ColumnarValue::values_to_arrays(&[x, y])?; + Ok(ColumnarValue::Array(nanvl(&args)?)) + } + } } fn documentation(&self) -> Option<&Documentation> { @@ -106,29 +125,49 @@ impl ScalarUDFImpl for NanvlFunc { } /// Nanvl SQL function +/// +/// - x is NaN -> output is y (which may itself be NULL) +/// - otherwise -> output is x (which may itself be NULL) fn nanvl(args: &[ArrayRef]) -> Result { match args[0].data_type() { Float64 => { - let compute_nanvl = |x: f64, y: f64| { - if x.is_nan() { y } else { x } - }; - - let x = args[0].as_primitive() as &Float64Array; - let y = args[1].as_primitive() as &Float64Array; - arrow::compute::binary::<_, _, _, Float64Type>(x, y, compute_nanvl) - .map(|res| Arc::new(res) as _) - .map_err(DataFusionError::from) + let x = args[0].as_primitive::(); + let y = args[1].as_primitive::(); + let result: Float64Array = x + .iter() + .zip(y.iter()) + .map(|(x_value, y_value)| match x_value { + Some(x_value) if x_value.is_nan() => y_value, + _ => x_value, + }) + .collect(); + Ok(Arc::new(result) as ArrayRef) } Float32 => { - let compute_nanvl = |x: f32, y: f32| { - if x.is_nan() { y } else { x } - }; - - let x = args[0].as_primitive() as &Float32Array; - let y = args[1].as_primitive() as &Float32Array; - arrow::compute::binary::<_, _, _, Float32Type>(x, y, compute_nanvl) - .map(|res| Arc::new(res) as _) - .map_err(DataFusionError::from) + let x = args[0].as_primitive::(); + let y = args[1].as_primitive::(); + let result: Float32Array = x + .iter() + .zip(y.iter()) + .map(|(x_value, y_value)| match x_value { + Some(x_value) if x_value.is_nan() => y_value, + _ => x_value, + }) + .collect(); + Ok(Arc::new(result) as ArrayRef) + } + Float16 => { + let x = args[0].as_primitive::(); + let y = args[1].as_primitive::(); + let result: Float16Array = x + .iter() + .zip(y.iter()) + .map(|(x_value, y_value)| match x_value { + Some(x_value) if x_value.is_nan() => y_value, + _ => x_value, + }) + .collect(); + Ok(Arc::new(result) as ArrayRef) } other => exec_err!("Unsupported data type {other:?} for function nanvl"), } @@ -146,8 +185,8 @@ mod test { #[test] fn test_nanvl_f64() { let args: Vec = vec![ - Arc::new(Float64Array::from(vec![1.0, f64::NAN, 3.0, f64::NAN])), // y - Arc::new(Float64Array::from(vec![5.0, 6.0, f64::NAN, f64::NAN])), // x + Arc::new(Float64Array::from(vec![1.0, f64::NAN, 3.0, f64::NAN])), // x + Arc::new(Float64Array::from(vec![5.0, 6.0, f64::NAN, f64::NAN])), // y ]; let result = nanvl(&args).expect("failed to initialize function nanvl"); @@ -164,8 +203,8 @@ mod test { #[test] fn test_nanvl_f32() { let args: Vec = vec![ - Arc::new(Float32Array::from(vec![1.0, f32::NAN, 3.0, f32::NAN])), // y - Arc::new(Float32Array::from(vec![5.0, 6.0, f32::NAN, f32::NAN])), // x + Arc::new(Float32Array::from(vec![1.0, f32::NAN, 3.0, f32::NAN])), // x + Arc::new(Float32Array::from(vec![5.0, 6.0, f32::NAN, f32::NAN])), // y ]; let result = nanvl(&args).expect("failed to initialize function nanvl"); diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 33166f6444f2..489c59aa3d6f 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -22,21 +22,23 @@ use super::log::LogFunc; use crate::utils::{calculate_binary_decimal_math, calculate_binary_math}; use arrow::array::{Array, ArrayRef}; +use arrow::datatypes::i256; use arrow::datatypes::{ - ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type, - Decimal256Type, Float64Type, Int64Type, + ArrowNativeType, ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, + Decimal128Type, Decimal256Type, Float64Type, Int64Type, }; use arrow::error::ArrowError; use datafusion_common::types::{NativeType, logical_float64, logical_int64}; use datafusion_common::utils::take_function_args; use datafusion_common::{Result, ScalarValue, internal_err}; use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr::{ Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility, lit, }; use datafusion_macros::user_doc; +use num_traits::{NumCast, ToPrimitive}; #[user_doc( doc_section(label = "Math Functions"), @@ -112,12 +114,15 @@ impl PowerFunc { /// 2.5 is represented as 25 with scale 1 /// The unscaled result is 25^4 = 390625 /// Scale it back to 1: 390625 / 10^4 = 39 -/// -/// Returns error if base is invalid fn pow_decimal_int(base: T, scale: i8, exp: i64) -> Result where - T: From + ArrowNativeTypeOp, + T: ArrowNativeType + ArrowNativeTypeOp + ToPrimitive + NumCast + Copy, { + // Negative exponent: fall back to float computation + if exp < 0 { + return pow_decimal_float(base, scale, exp as f64); + } + let exp: u32 = exp.try_into().map_err(|_| { ArrowError::ArithmeticOverflow(format!("Unsupported exp value: {exp}")) })?; @@ -125,13 +130,13 @@ where // If scale < 0, 10^scale (e.g., 10^-2 = 0.01) becomes 0 in integer arithmetic. if exp == 0 { return if scale >= 0 { - T::from(10).pow_checked(scale as u32).map_err(|_| { + T::usize_as(10).pow_checked(scale as u32).map_err(|_| { ArrowError::ArithmeticOverflow(format!( "Cannot make unscale factor for {scale} and {exp}" )) }) } else { - Ok(T::from(0)) + Ok(T::ZERO) }; } let powered: T = base.pow_checked(exp).map_err(|_| { @@ -149,11 +154,12 @@ where // If mul_exp is positive, we divide (standard case). // If mul_exp is negative, we multiply (negative scale case). if mul_exp > 0 { - let div_factor: T = T::from(10).pow_checked(mul_exp as u32).map_err(|_| { - ArrowError::ArithmeticOverflow(format!( - "Cannot make div factor for {scale} and {exp}" - )) - })?; + let div_factor: T = + T::usize_as(10).pow_checked(mul_exp as u32).map_err(|_| { + ArrowError::ArithmeticOverflow(format!( + "Cannot make div factor for {scale} and {exp}" + )) + })?; powered.div_checked(div_factor) } else { // mul_exp is negative, so we multiply by 10^(-mul_exp) @@ -162,33 +168,227 @@ where "Overflow while negating scale exponent".to_string(), ) })?; - let mul_factor: T = T::from(10).pow_checked(abs_exp as u32).map_err(|_| { - ArrowError::ArithmeticOverflow(format!( - "Cannot make mul factor for {scale} and {exp}" - )) - })?; + let mul_factor: T = + T::usize_as(10).pow_checked(abs_exp as u32).map_err(|_| { + ArrowError::ArithmeticOverflow(format!( + "Cannot make mul factor for {scale} and {exp}" + )) + })?; powered.mul_checked(mul_factor) } } /// Binary function to calculate a math power to float exponent /// for scaled integer types. -/// Returns error if exponent is negative or non-integer, or base invalid fn pow_decimal_float(base: T, scale: i8, exp: f64) -> Result where - T: From + ArrowNativeTypeOp, + T: ArrowNativeType + ArrowNativeTypeOp + ToPrimitive + NumCast + Copy, { - if !exp.is_finite() || exp.trunc() != exp { + if exp.is_finite() && exp.trunc() == exp && exp >= 0f64 && exp < u32::MAX as f64 { + return pow_decimal_int(base, scale, exp as i64); + } + + if !exp.is_finite() { return Err(ArrowError::ComputeError(format!( - "Cannot use non-integer exp: {exp}" + "Cannot use non-finite exp: {exp}" ))); } - if exp < 0f64 || exp >= u32::MAX as f64 { + + pow_decimal_float_fallback(base, scale, exp) +} + +/// Compute the f64 power result and scale it back. +/// Returns the rounded i128 result for conversion to target type. +#[inline] +fn compute_pow_f64_result( + base_f64: f64, + scale: i8, + exp: f64, +) -> Result { + let result_f64 = base_f64.powf(exp); + + if !result_f64.is_finite() { return Err(ArrowError::ArithmeticOverflow(format!( - "Unsupported exp value: {exp}" + "Result of {base_f64}^{exp} is not finite" + ))); + } + + let scale_factor = 10f64.powi(scale as i32); + let result_scaled = result_f64 * scale_factor; + let result_rounded = result_scaled.round(); + + if result_rounded.abs() > i128::MAX as f64 { + return Err(ArrowError::ArithmeticOverflow(format!( + "Result {result_rounded} is too large for the target decimal type" + ))); + } + + Ok(result_rounded as i128) +} + +/// Convert i128 result to target decimal native type using NumCast. +/// Returns error if value overflows the target type. +#[inline] +fn decimal_from_i128(value: i128) -> Result +where + T: NumCast, +{ + NumCast::from(value).ok_or_else(|| { + ArrowError::ArithmeticOverflow(format!( + "Value {value} is too large for the target decimal type" + )) + }) +} + +/// Fallback implementation using f64 for negative or non-integer exponents. +/// This handles cases that cannot be computed using integer arithmetic. +fn pow_decimal_float_fallback(base: T, scale: i8, exp: f64) -> Result +where + T: ToPrimitive + NumCast + Copy, +{ + if scale < 0 { + return Err(ArrowError::NotYetImplemented(format!( + "Negative scale is not yet supported: {scale}" ))); } - pow_decimal_int(base, scale, exp as i64) + + let scale_factor = 10f64.powi(scale as i32); + let base_f64 = base.to_f64().ok_or_else(|| { + ArrowError::ComputeError("Cannot convert base to f64".to_string()) + })? / scale_factor; + + let result_i128 = compute_pow_f64_result(base_f64, scale, exp)?; + + decimal_from_i128(result_i128) +} + +/// Decimal256 specialized float exponent version. +fn pow_decimal256_float(base: i256, scale: i8, exp: f64) -> Result { + if exp.is_finite() && exp.trunc() == exp && exp >= 0f64 && exp < u32::MAX as f64 { + return pow_decimal256_int(base, scale, exp as i64); + } + + if !exp.is_finite() { + return Err(ArrowError::ComputeError(format!( + "Cannot use non-finite exp: {exp}" + ))); + } + + pow_decimal256_float_fallback(base, scale, exp) +} + +/// Decimal256 specialized integer exponent version. +fn pow_decimal256_int(base: i256, scale: i8, exp: i64) -> Result { + if exp < 0 { + return pow_decimal256_float(base, scale, exp as f64); + } + + let exp: u32 = exp.try_into().map_err(|_| { + ArrowError::ArithmeticOverflow(format!("Unsupported exp value: {exp}")) + })?; + + if exp == 0 { + return if scale >= 0 { + i256::from_i128(10).pow_checked(scale as u32).map_err(|_| { + ArrowError::ArithmeticOverflow(format!( + "Cannot make unscale factor for {scale} and {exp}" + )) + }) + } else { + Ok(i256::from_i128(0)) + }; + } + + let powered: i256 = base.pow_checked(exp).map_err(|_| { + ArrowError::ArithmeticOverflow(format!("Cannot raise base {base:?} to exp {exp}")) + })?; + + let mul_exp = (scale as i64).wrapping_mul(exp as i64 - 1); + + if mul_exp == 0 { + return Ok(powered); + } + + if mul_exp > 0 { + let div_factor: i256 = + i256::from_i128(10) + .pow_checked(mul_exp as u32) + .map_err(|_| { + ArrowError::ArithmeticOverflow(format!( + "Cannot make div factor for {scale} and {exp}" + )) + })?; + powered.div_checked(div_factor) + } else { + let abs_exp = mul_exp.checked_neg().ok_or_else(|| { + ArrowError::ArithmeticOverflow( + "Overflow while negating scale exponent".to_string(), + ) + })?; + let mul_factor: i256 = + i256::from_i128(10) + .pow_checked(abs_exp as u32) + .map_err(|_| { + ArrowError::ArithmeticOverflow(format!( + "Cannot make mul factor for {scale} and {exp}" + )) + })?; + powered.mul_checked(mul_factor) + } +} + +/// Fallback implementation for Decimal256. +fn pow_decimal256_float_fallback( + base: i256, + scale: i8, + exp: f64, +) -> Result { + if scale < 0 { + return Err(ArrowError::NotYetImplemented(format!( + "Negative scale is not yet supported: {scale}" + ))); + } + + let scale_factor = 10f64.powi(scale as i32); + let base_f64 = base.to_f64().ok_or_else(|| { + ArrowError::ComputeError("Cannot convert base to f64".to_string()) + })? / scale_factor; + + let result_i128 = compute_pow_f64_result(base_f64, scale, exp)?; + + // i256 can be constructed from i128 directly + Ok(i256::from_i128(result_i128)) +} + +/// Fallback implementation for decimal power when exponent is an array. +/// Casts decimal to float64, computes power, and casts back to original decimal type. +/// This is used for performance when exponent varies per-row. +fn pow_decimal_with_float_fallback( + base: &ArrayRef, + exponent: &ColumnarValue, + num_rows: usize, +) -> Result { + use arrow::compute::cast; + + let original_type = base.data_type().clone(); + let base_f64 = cast(base.as_ref(), &DataType::Float64)?; + + let exp_f64 = match exponent { + ColumnarValue::Array(arr) => cast(arr.as_ref(), &DataType::Float64)?, + ColumnarValue::Scalar(scalar) => { + let scalar_f64 = scalar.cast_to(&DataType::Float64)?; + scalar_f64.to_array_of_size(num_rows)? + } + }; + + let result_f64 = calculate_binary_math::( + &base_f64, + &ColumnarValue::Array(exp_f64), + |b, e| Ok(f64::powf(b, e)), + )?; + + let result = cast(result_f64.as_ref(), &original_type)?; + Ok(ColumnarValue::Array(result)) } impl ScalarUDFImpl for PowerFunc { @@ -218,8 +418,25 @@ impl ScalarUDFImpl for PowerFunc { fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let [base, exponent] = take_function_args(self.name(), &args.args)?; + + // For decimal types, only use native decimal + // operations when we have a scalar exponent. When the exponent is an array, + // fall back to float computation for better performance. + let use_float_fallback = matches!( + base.data_type(), + DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + ) && matches!(exponent, ColumnarValue::Array(_)); + let base = base.to_array(args.number_rows)?; + // If decimal with array exponent, cast to float and compute + if use_float_fallback { + return pow_decimal_with_float_fallback(&base, exponent, args.number_rows); + } + let arr: ArrayRef = match (base.data_type(), exponent.data_type()) { (DataType::Float64, DataType::Float64) => { calculate_binary_math::( @@ -311,7 +528,7 @@ impl ScalarUDFImpl for PowerFunc { >( &base, exponent, - |b, e| pow_decimal_int(b, *scale, e), + |b, e| pow_decimal256_int(b, *scale, e), *precision, *scale, )? @@ -325,7 +542,7 @@ impl ScalarUDFImpl for PowerFunc { >( &base, exponent, - |b, e| pow_decimal_float(b, *scale, e), + |b, e| pow_decimal256_float(b, *scale, e), *precision, *scale, )? @@ -346,7 +563,7 @@ impl ScalarUDFImpl for PowerFunc { fn simplify( &self, args: Vec, - info: &dyn SimplifyInfo, + info: &SimplifyContext, ) -> Result { let [base, exponent] = take_function_args("power", args)?; let base_type = info.get_data_type(&base)?; @@ -398,19 +615,53 @@ mod tests { #[test] fn test_pow_decimal128_helper() { // Expression: 2.5 ^ 4 = 39.0625 - assert_eq!(pow_decimal_int(25, 1, 4).unwrap(), i128::from(390)); - assert_eq!(pow_decimal_int(2500, 3, 4).unwrap(), i128::from(39062)); - assert_eq!(pow_decimal_int(25000, 4, 4).unwrap(), i128::from(390625)); + assert_eq!(pow_decimal_int(25i128, 1, 4).unwrap(), 390i128); + assert_eq!(pow_decimal_int(2500i128, 3, 4).unwrap(), 39062i128); + assert_eq!(pow_decimal_int(25000i128, 4, 4).unwrap(), 390625i128); // Expression: 25 ^ 4 = 390625 - assert_eq!(pow_decimal_int(25, 0, 4).unwrap(), i128::from(390625)); + assert_eq!(pow_decimal_int(25i128, 0, 4).unwrap(), 390625i128); // Expressions for edge cases - assert_eq!(pow_decimal_int(25, 1, 1).unwrap(), i128::from(25)); - assert_eq!(pow_decimal_int(25, 0, 1).unwrap(), i128::from(25)); - assert_eq!(pow_decimal_int(25, 0, 0).unwrap(), i128::from(1)); - assert_eq!(pow_decimal_int(25, 1, 0).unwrap(), i128::from(10)); + assert_eq!(pow_decimal_int(25i128, 1, 1).unwrap(), 25i128); + assert_eq!(pow_decimal_int(25i128, 0, 1).unwrap(), 25i128); + assert_eq!(pow_decimal_int(25i128, 0, 0).unwrap(), 1i128); + assert_eq!(pow_decimal_int(25i128, 1, 0).unwrap(), 10i128); + + assert_eq!(pow_decimal_int(25i128, -1, 4).unwrap(), 390625000i128); + } + + #[test] + fn test_pow_decimal_float_fallback() { + // Test negative exponent: 4^(-1) = 0.25 + // 4 with scale 2 = 400, result should be 25 (0.25 with scale 2) + let result: i128 = pow_decimal_float(400i128, 2, -1.0).unwrap(); + assert_eq!(result, 25); + + // Test non-integer exponent: 4^0.5 = 2 + // 4 with scale 2 = 400, result should be 200 (2.0 with scale 2) + let result: i128 = pow_decimal_float(400i128, 2, 0.5).unwrap(); + assert_eq!(result, 200); + + // Test 8^(1/3) = 2 (cube root) + // 8 with scale 1 = 80, result should be 20 (2.0 with scale 1) + let result: i128 = pow_decimal_float(80i128, 1, 1.0 / 3.0).unwrap(); + assert_eq!(result, 20); + + // Test negative base with integer exponent still works + // (-2)^3 = -8 + // -2 with scale 1 = -20, result should be -80 (-8.0 with scale 1) + let result: i128 = pow_decimal_float(-20i128, 1, 3.0).unwrap(); + assert_eq!(result, -80); + + // Test positive integer exponent goes through fast path + // 2.5^4 = 39.0625 + // 25 with scale 1, result should be 390 (39.0 with scale 1) - truncated + let result: i128 = pow_decimal_float(25i128, 1, 4.0).unwrap(); + assert_eq!(result, 390); // Uses integer path - assert_eq!(pow_decimal_int(25, -1, 4).unwrap(), i128::from(390625000)); + // Test non-finite exponent returns error + assert!(pow_decimal_float(100i128, 2, f64::NAN).is_err()); + assert!(pow_decimal_float(100i128, 2, f64::INFINITY).is_err()); } } diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index de70788128b8..07cddf9341f2 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -25,19 +25,130 @@ use arrow::datatypes::DataType::{ }; use arrow::datatypes::{ ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type, - Decimal256Type, Float32Type, Float64Type, Int32Type, + Decimal256Type, DecimalType, Float32Type, Float64Type, Int32Type, }; +use arrow::datatypes::{Field, FieldRef}; use arrow::error::ArrowError; use datafusion_common::types::{ NativeType, logical_float32, logical_float64, logical_int32, }; -use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ - Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, - TypeSignature, TypeSignatureClass, Volatility, + Coercion, ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility, }; use datafusion_macros::user_doc; +use std::sync::Arc; + +fn output_scale_for_decimal(precision: u8, input_scale: i8, decimal_places: i32) -> i8 { + // `decimal_places` controls the maximum output scale, but scale cannot exceed the input scale. + // + // For negative-scale decimals, allow further scale reduction to match negative `decimal_places` + // (e.g. scale -2 rounded to -3 becomes scale -3). This preserves fixed precision by + // representing the rounded result at a coarser scale. + if input_scale < 0 { + // Decimal scales must be within [-precision, precision] and fit in i8. For negative-scale + // decimals, allow rounding to move the output scale further negative, but cap it at + // `-precision` (beyond that, the rounded result is always 0). + let min_scale = -i32::from(precision); + let new_scale = i32::from(input_scale).min(decimal_places).max(min_scale); + return new_scale as i8; + } + + // The `min` ensures the result is always within i8 range because `input_scale` is i8. + let decimal_places = decimal_places.max(0); + i32::from(input_scale).min(decimal_places) as i8 +} + +fn normalize_decimal_places_for_decimal( + decimal_places: i32, + precision: u8, + scale: i8, +) -> Option { + if decimal_places >= 0 { + return Some(decimal_places); + } + + // For fixed precision decimals, the absolute value is strictly less than 10^(precision - scale). + // If the rounding position is beyond that (abs(decimal_places) > precision - scale), the + // rounded result is always 0, and we can avoid overflow in intermediate 10^n computations. + let max_rounding_pow10 = i64::from(precision) - i64::from(scale); + if max_rounding_pow10 <= 0 { + return None; + } + + let abs_decimal_places = i64::from(decimal_places.unsigned_abs()); + (abs_decimal_places <= max_rounding_pow10).then_some(decimal_places) +} + +fn validate_decimal_precision( + value: T::Native, + precision: u8, + scale: i8, +) -> Result { + T::validate_decimal_precision(value, precision, scale).map_err(|e| { + ArrowError::ComputeError(format!( + "Decimal overflow: rounded value exceeds precision {precision}: {e}" + )) + })?; + Ok(value) +} + +fn calculate_new_precision_scale( + precision: u8, + scale: i8, + decimal_places: Option, +) -> Result { + if let Some(decimal_places) = decimal_places { + let new_scale = output_scale_for_decimal(precision, scale, decimal_places); + + // When rounding an integer decimal (scale == 0) to a negative `decimal_places`, a carry can + // add an extra digit to the integer part (e.g. 99 -> 100 when rounding to -1). This can + // only happen when the rounding position is within the existing precision. + let abs_decimal_places = decimal_places.unsigned_abs(); + let new_precision = if scale == 0 + && decimal_places < 0 + && abs_decimal_places <= u32::from(precision) + { + precision.saturating_add(1).min(T::MAX_PRECISION) + } else { + precision + }; + Ok(T::TYPE_CONSTRUCTOR(new_precision, new_scale)) + } else { + let new_precision = precision.saturating_add(1).min(T::MAX_PRECISION); + Ok(T::TYPE_CONSTRUCTOR(new_precision, scale)) + } +} + +fn decimal_places_from_scalar(scalar: &ScalarValue) -> Result { + let out_of_range = |value: String| { + datafusion_common::DataFusionError::Execution(format!( + "round decimal_places {value} is out of supported i32 range" + )) + }; + match scalar { + ScalarValue::Int8(Some(v)) => Ok(i32::from(*v)), + ScalarValue::Int16(Some(v)) => Ok(i32::from(*v)), + ScalarValue::Int32(Some(v)) => Ok(*v), + ScalarValue::Int64(Some(v)) => { + i32::try_from(*v).map_err(|_| out_of_range(v.to_string())) + } + ScalarValue::UInt8(Some(v)) => Ok(i32::from(*v)), + ScalarValue::UInt16(Some(v)) => Ok(i32::from(*v)), + ScalarValue::UInt32(Some(v)) => { + i32::try_from(*v).map_err(|_| out_of_range(v.to_string())) + } + ScalarValue::UInt64(Some(v)) => { + i32::try_from(*v).map_err(|_| out_of_range(v.to_string())) + } + other => exec_err!( + "Unexpected datatype for decimal_places: {}", + other.data_type() + ), + } +} #[user_doc( doc_section(label = "Math Functions"), @@ -117,15 +228,59 @@ impl ScalarUDFImpl for RoundFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(match arg_types[0].clone() { - Float32 => Float32, - dt @ Decimal128(_, _) - | dt @ Decimal256(_, _) - | dt @ Decimal32(_, _) - | dt @ Decimal64(_, _) => dt, - _ => Float64, - }) + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let input_field = &args.arg_fields[0]; + let input_type = input_field.data_type(); + + // If decimal_places is a scalar literal, we can incorporate it into the output type + // (scale reduction). Otherwise, keep the input scale as we can't pick a per-row scale. + // + // Note: `scalar_arguments` contains the original literal values (pre-coercion), so + // integer literals may appear as Int64 even though the signature coerces them to Int32. + let decimal_places: Option = match args.scalar_arguments.get(1) { + None => Some(0), // No dp argument means default to 0 + Some(None) => None, // dp is not a literal (e.g. column) + Some(Some(scalar)) if scalar.is_null() => Some(0), // null dp => default to 0 + Some(Some(scalar)) => Some(decimal_places_from_scalar(scalar)?), + }; + + // Calculate return type based on input type + // For decimals: reduce scale to decimal_places (reclaims precision for integer part) + // This matches Spark/DuckDB behavior where ROUND adjusts the scale + // BUT only if dp is a scalar literal - otherwise keep original scale and add + // extra precision to accommodate potential carry-over. + let return_type = + match input_type { + Float32 => Float32, + Decimal32(precision, scale) => calculate_new_precision_scale::< + Decimal32Type, + >( + *precision, *scale, decimal_places + )?, + Decimal64(precision, scale) => calculate_new_precision_scale::< + Decimal64Type, + >( + *precision, *scale, decimal_places + )?, + Decimal128(precision, scale) => calculate_new_precision_scale::< + Decimal128Type, + >( + *precision, *scale, decimal_places + )?, + Decimal256(precision, scale) => calculate_new_precision_scale::< + Decimal256Type, + >( + *precision, *scale, decimal_places + )?, + _ => Float64, + }; + + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + Ok(Arc::new(Field::new(self.name(), return_type, nullable))) + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("use return_field_from_args instead") } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -141,7 +296,150 @@ impl ScalarUDFImpl for RoundFunc { &default_decimal_places }; - round_columnar(&args.args[0], decimal_places, args.number_rows) + if let (ColumnarValue::Scalar(value_scalar), ColumnarValue::Scalar(dp_scalar)) = + (&args.args[0], decimal_places) + { + if value_scalar.is_null() || dp_scalar.is_null() { + return ColumnarValue::Scalar(ScalarValue::Null) + .cast_to(args.return_type(), None); + } + + let dp = if let ScalarValue::Int32(Some(dp)) = dp_scalar { + *dp + } else { + return internal_err!( + "Unexpected datatype for decimal_places: {}", + dp_scalar.data_type() + ); + }; + + match (value_scalar, args.return_type()) { + (ScalarValue::Float32(Some(v)), _) => { + let rounded = round_float(*v, dp)?; + Ok(ColumnarValue::Scalar(ScalarValue::from(rounded))) + } + (ScalarValue::Float64(Some(v)), _) => { + let rounded = round_float(*v, dp)?; + Ok(ColumnarValue::Scalar(ScalarValue::from(rounded))) + } + ( + ScalarValue::Decimal32(Some(v), in_precision, scale), + Decimal32(out_precision, out_scale), + ) => { + let rounded = + round_decimal_or_zero(*v, *in_precision, *scale, *out_scale, dp)?; + let rounded = if *out_precision == Decimal32Type::MAX_PRECISION + && *scale == 0 + && dp < 0 + { + // With scale == 0 and negative dp, rounding can carry into an additional + // digit (e.g. 99 -> 100). If we're already at max precision we can't widen + // the type, so validate and error rather than producing an invalid decimal. + validate_decimal_precision::( + rounded, + *out_precision, + *out_scale, + ) + } else { + Ok(rounded) + }?; + let scalar = + ScalarValue::Decimal32(Some(rounded), *out_precision, *out_scale); + Ok(ColumnarValue::Scalar(scalar)) + } + ( + ScalarValue::Decimal64(Some(v), in_precision, scale), + Decimal64(out_precision, out_scale), + ) => { + let rounded = + round_decimal_or_zero(*v, *in_precision, *scale, *out_scale, dp)?; + let rounded = if *out_precision == Decimal64Type::MAX_PRECISION + && *scale == 0 + && dp < 0 + { + // See Decimal32 branch for details. + validate_decimal_precision::( + rounded, + *out_precision, + *out_scale, + ) + } else { + Ok(rounded) + }?; + let scalar = + ScalarValue::Decimal64(Some(rounded), *out_precision, *out_scale); + Ok(ColumnarValue::Scalar(scalar)) + } + ( + ScalarValue::Decimal128(Some(v), in_precision, scale), + Decimal128(out_precision, out_scale), + ) => { + let rounded = + round_decimal_or_zero(*v, *in_precision, *scale, *out_scale, dp)?; + let rounded = if *out_precision == Decimal128Type::MAX_PRECISION + && *scale == 0 + && dp < 0 + { + // See Decimal32 branch for details. + validate_decimal_precision::( + rounded, + *out_precision, + *out_scale, + ) + } else { + Ok(rounded) + }?; + let scalar = ScalarValue::Decimal128( + Some(rounded), + *out_precision, + *out_scale, + ); + Ok(ColumnarValue::Scalar(scalar)) + } + ( + ScalarValue::Decimal256(Some(v), in_precision, scale), + Decimal256(out_precision, out_scale), + ) => { + let rounded = + round_decimal_or_zero(*v, *in_precision, *scale, *out_scale, dp)?; + let rounded = if *out_precision == Decimal256Type::MAX_PRECISION + && *scale == 0 + && dp < 0 + { + // See Decimal32 branch for details. + validate_decimal_precision::( + rounded, + *out_precision, + *out_scale, + ) + } else { + Ok(rounded) + }?; + let scalar = ScalarValue::Decimal256( + Some(rounded), + *out_precision, + *out_scale, + ); + Ok(ColumnarValue::Scalar(scalar)) + } + (ScalarValue::Null, _) => ColumnarValue::Scalar(ScalarValue::Null) + .cast_to(args.return_type(), None), + (value_scalar, return_type) => { + internal_err!( + "Unexpected datatype for round(value, decimal_places): value {}, return type {}", + value_scalar.data_type(), + return_type + ) + } + } + } else { + round_columnar( + &args.args[0], + decimal_places, + args.number_rows, + args.return_type(), + ) + } } fn output_ordering(&self, input: &[ExprProperties]) -> Result { @@ -168,13 +466,15 @@ fn round_columnar( value: &ColumnarValue, decimal_places: &ColumnarValue, number_rows: usize, + return_type: &DataType, ) -> Result { let value_array = value.to_array(number_rows)?; let both_scalars = matches!(value, ColumnarValue::Scalar(_)) && matches!(decimal_places, ColumnarValue::Scalar(_)); + let decimal_places_is_array = matches!(decimal_places, ColumnarValue::Array(_)); - let arr: ArrayRef = match value_array.data_type() { - Float64 => { + let arr: ArrayRef = match (value_array.data_type(), return_type) { + (Float64, _) => { let result = calculate_binary_math::( value_array.as_ref(), decimal_places, @@ -182,7 +482,7 @@ fn round_columnar( )?; result as _ } - Float32 => { + (Float32, _) => { let result = calculate_binary_math::( value_array.as_ref(), decimal_places, @@ -190,7 +490,8 @@ fn round_columnar( )?; result as _ } - Decimal32(precision, scale) => { + (Decimal32(input_precision, scale), Decimal32(precision, new_scale)) => { + // reduce scale to reclaim integer precision let result = calculate_binary_decimal_math::< Decimal32Type, Int32Type, @@ -199,13 +500,34 @@ fn round_columnar( >( value_array.as_ref(), decimal_places, - |v, dp| round_decimal(v, *scale, dp), + |v, dp| { + let rounded = round_decimal_or_zero( + v, + *input_precision, + *scale, + *new_scale, + dp, + )?; + if *precision == Decimal32Type::MAX_PRECISION + && (decimal_places_is_array || (*scale == 0 && dp < 0)) + { + // If we're already at max precision, we can't widen the result type. For + // dp arrays, or for scale == 0 with negative dp, rounding can overflow the + // fixed-precision type. Validate per-row and return an error instead of + // producing an invalid decimal that Arrow may display incorrectly. + validate_decimal_precision::( + rounded, *precision, *new_scale, + ) + } else { + Ok(rounded) + } + }, *precision, - *scale, + *new_scale, )?; result as _ } - Decimal64(precision, scale) => { + (Decimal64(input_precision, scale), Decimal64(precision, new_scale)) => { let result = calculate_binary_decimal_math::< Decimal64Type, Int32Type, @@ -214,13 +536,31 @@ fn round_columnar( >( value_array.as_ref(), decimal_places, - |v, dp| round_decimal(v, *scale, dp), + |v, dp| { + let rounded = round_decimal_or_zero( + v, + *input_precision, + *scale, + *new_scale, + dp, + )?; + if *precision == Decimal64Type::MAX_PRECISION + && (decimal_places_is_array || (*scale == 0 && dp < 0)) + { + // See Decimal32 branch for details. + validate_decimal_precision::( + rounded, *precision, *new_scale, + ) + } else { + Ok(rounded) + } + }, *precision, - *scale, + *new_scale, )?; result as _ } - Decimal128(precision, scale) => { + (Decimal128(input_precision, scale), Decimal128(precision, new_scale)) => { let result = calculate_binary_decimal_math::< Decimal128Type, Int32Type, @@ -229,13 +569,31 @@ fn round_columnar( >( value_array.as_ref(), decimal_places, - |v, dp| round_decimal(v, *scale, dp), + |v, dp| { + let rounded = round_decimal_or_zero( + v, + *input_precision, + *scale, + *new_scale, + dp, + )?; + if *precision == Decimal128Type::MAX_PRECISION + && (decimal_places_is_array || (*scale == 0 && dp < 0)) + { + // See Decimal32 branch for details. + validate_decimal_precision::( + rounded, *precision, *new_scale, + ) + } else { + Ok(rounded) + } + }, *precision, - *scale, + *new_scale, )?; result as _ } - Decimal256(precision, scale) => { + (Decimal256(input_precision, scale), Decimal256(precision, new_scale)) => { let result = calculate_binary_decimal_math::< Decimal256Type, Int32Type, @@ -244,13 +602,31 @@ fn round_columnar( >( value_array.as_ref(), decimal_places, - |v, dp| round_decimal(v, *scale, dp), + |v, dp| { + let rounded = round_decimal_or_zero( + v, + *input_precision, + *scale, + *new_scale, + dp, + )?; + if *precision == Decimal256Type::MAX_PRECISION + && (decimal_places_is_array || (*scale == 0 && dp < 0)) + { + // See Decimal32 branch for details. + validate_decimal_precision::( + rounded, *precision, *new_scale, + ) + } else { + Ok(rounded) + } + }, *precision, - *scale, + *new_scale, )?; result as _ } - other => exec_err!("Unsupported data type {other:?} for function round")?, + (other, _) => exec_err!("Unsupported data type {other:?} for function round")?, }; if both_scalars { @@ -274,19 +650,17 @@ where fn round_decimal( value: V, - scale: i8, + input_scale: i8, + output_scale: i8, decimal_places: i32, ) -> Result { - let diff = i64::from(scale) - i64::from(decimal_places); + let diff = i64::from(input_scale) - i64::from(decimal_places); if diff <= 0 { return Ok(value); } - let diff: u32 = diff.try_into().map_err(|e| { - ArrowError::ComputeError(format!( - "Invalid value for decimal places: {decimal_places}: {e}" - )) - })?; + debug_assert!(diff <= i64::from(u32::MAX)); + let diff = diff as u32; let one = V::ONE; let two = V::from_usize(2).ok_or_else(|| { @@ -298,7 +672,7 @@ fn round_decimal( let factor = ten.pow_checked(diff).map_err(|_| { ArrowError::ComputeError(format!( - "Overflow while rounding decimal with scale {scale} and decimal places {decimal_places}" + "Overflow while rounding decimal with scale {input_scale} and decimal places {decimal_places}" )) })?; @@ -317,11 +691,44 @@ fn round_decimal( })?; } + // `quotient` is the rounded value at scale `decimal_places`. Rescale to the desired + // `output_scale` (which is always >= `decimal_places` in cases where diff > 0). + let scale_shift = i64::from(output_scale) - i64::from(decimal_places); + if scale_shift == 0 { + return Ok(quotient); + } + + debug_assert!(scale_shift > 0); + debug_assert!(scale_shift <= i64::from(u32::MAX)); + let scale_shift = scale_shift as u32; + let shift_factor = ten.pow_checked(scale_shift).map_err(|_| { + ArrowError::ComputeError(format!( + "Overflow while rounding decimal with scale {input_scale} and decimal places {decimal_places}" + )) + })?; quotient - .mul_checked(factor) + .mul_checked(shift_factor) .map_err(|_| ArrowError::ComputeError("Overflow while rounding decimal".into())) } +fn round_decimal_or_zero( + value: V, + precision: u8, + input_scale: i8, + output_scale: i8, + decimal_places: i32, +) -> Result { + if let Some(dp) = + normalize_decimal_places_for_decimal(decimal_places, precision, input_scale) + { + round_decimal(value, input_scale, output_scale, dp) + } else { + V::from_usize(0).ok_or_else(|| { + ArrowError::ComputeError("Internal error: could not create constant 0".into()) + }) + } +} + #[cfg(test)] mod test { use std::sync::Arc; @@ -337,12 +744,17 @@ mod test { decimal_places: Option, ) -> Result { let number_rows = value.len(); + // NOTE: For decimal inputs, the actual ROUND return type can differ from the + // input type (scale reduction for literal `decimal_places`). These unit tests + // only exercise Float32/Float64 behavior. + let return_type = value.data_type().clone(); let value = ColumnarValue::Array(value); let decimal_places = decimal_places .map(ColumnarValue::Array) .unwrap_or_else(|| ColumnarValue::Scalar(ScalarValue::Int32(Some(0)))); - let result = super::round_columnar(&value, &decimal_places, number_rows)?; + let result = + super::round_columnar(&value, &decimal_places, number_rows, &return_type)?; match result { ColumnarValue::Array(array) => Ok(array), ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(1), diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs index e217088c64c2..8a3769a12f29 100644 --- a/datafusion/functions/src/math/signum.rs +++ b/datafusion/functions/src/math/signum.rs @@ -18,11 +18,12 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, AsArray}; +use arrow::array::AsArray; use arrow::datatypes::DataType::{Float32, Float64}; use arrow::datatypes::{DataType, Float32Type, Float64Type}; -use datafusion_common::{Result, exec_err}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, internal_err}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, @@ -30,8 +31,6 @@ use datafusion_expr::{ }; use datafusion_macros::user_doc; -use crate::utils::make_scalar_function; - #[user_doc( doc_section(label = "Math Functions"), description = r#"Returns the sign of a number. @@ -98,7 +97,53 @@ impl ScalarUDFImpl for SignumFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(signum, vec![])(&args.args) + let return_type = args.return_type().clone(); + let [arg] = take_function_args(self.name(), args.args)?; + + match arg { + ColumnarValue::Scalar(scalar) => { + if scalar.is_null() { + return ColumnarValue::Scalar(ScalarValue::Null) + .cast_to(&return_type, None); + } + + match scalar { + ScalarValue::Float64(Some(v)) => { + let result = if v == 0.0 { 0.0 } else { v.signum() }; + Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(result)))) + } + ScalarValue::Float32(Some(v)) => { + let result = if v == 0.0 { 0.0 } else { v.signum() }; + Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(result)))) + } + _ => { + internal_err!( + "Unexpected scalar type for signum: {:?}", + scalar.data_type() + ) + } + } + } + ColumnarValue::Array(array) => match array.data_type() { + Float64 => Ok(ColumnarValue::Array(Arc::new( + array.as_primitive::().unary::<_, Float64Type>( + |x: f64| { + if x == 0.0 { 0.0 } else { x.signum() } + }, + ), + ))), + Float32 => Ok(ColumnarValue::Array(Arc::new( + array.as_primitive::().unary::<_, Float32Type>( + |x: f32| { + if x == 0.0 { 0.0 } else { x.signum() } + }, + ), + ))), + other => { + internal_err!("Unsupported data type {other:?} for function signum") + } + }, + } } fn documentation(&self) -> Option<&Documentation> { @@ -106,33 +151,6 @@ impl ScalarUDFImpl for SignumFunc { } } -/// signum SQL function -fn signum(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - Float64 => Ok(Arc::new( - args[0] - .as_primitive::() - .unary::<_, Float64Type>( - |x: f64| { - if x == 0_f64 { 0_f64 } else { x.signum() } - }, - ), - ) as ArrayRef), - - Float32 => Ok(Arc::new( - args[0] - .as_primitive::() - .unary::<_, Float32Type>( - |x: f32| { - if x == 0_f32 { 0_f32 } else { x.signum() } - }, - ), - ) as ArrayRef), - - other => exec_err!("Unsupported data type {other:?} for function signum"), - } -} - #[cfg(test)] mod test { use std::sync::Arc; diff --git a/datafusion/functions/src/math/trunc.rs b/datafusion/functions/src/math/trunc.rs index 6727ba8fbdf0..ecdad22e8af1 100644 --- a/datafusion/functions/src/math/trunc.rs +++ b/datafusion/functions/src/math/trunc.rs @@ -24,7 +24,7 @@ use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; use arrow::datatypes::DataType::{Float32, Float64}; use arrow::datatypes::{DataType, Float32Type, Float64Type, Int64Type}; use datafusion_common::ScalarValue::Int64; -use datafusion_common::{Result, exec_err}; +use datafusion_common::{Result, ScalarValue, exec_err}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ @@ -110,7 +110,50 @@ impl ScalarUDFImpl for TruncFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(trunc, vec![])(&args.args) + // Extract precision from second argument (default 0) + let precision = match args.args.get(1) { + Some(ColumnarValue::Scalar(Int64(Some(p)))) => Some(*p), + Some(ColumnarValue::Scalar(Int64(None))) => None, // null precision + Some(ColumnarValue::Array(_)) => { + // Precision is an array - use array path + return make_scalar_function(trunc, vec![])(&args.args); + } + None => Some(0), // default precision + Some(cv) => { + return exec_err!( + "trunc function requires precision to be Int64, got {:?}", + cv.data_type() + ); + } + }; + + // Scalar fast path using tuple matching for (value, precision) + match (&args.args[0], precision) { + // Null cases + (ColumnarValue::Scalar(sv), _) if sv.is_null() => { + ColumnarValue::Scalar(ScalarValue::Null).cast_to(args.return_type(), None) + } + (_, None) => { + ColumnarValue::Scalar(ScalarValue::Null).cast_to(args.return_type(), None) + } + // Scalar cases + (ColumnarValue::Scalar(ScalarValue::Float64(Some(v))), Some(p)) => Ok( + ColumnarValue::Scalar(ScalarValue::Float64(Some(if p == 0 { + v.trunc() + } else { + compute_truncate64(*v, p) + }))), + ), + (ColumnarValue::Scalar(ScalarValue::Float32(Some(v))), Some(p)) => Ok( + ColumnarValue::Scalar(ScalarValue::Float32(Some(if p == 0 { + v.trunc() + } else { + compute_truncate32(*v, p) + }))), + ), + // Array path for everything else + _ => make_scalar_function(trunc, vec![])(&args.args), + } } fn output_ordering(&self, input: &[ExprProperties]) -> Result { @@ -202,12 +245,12 @@ fn trunc(args: &[ArrayRef]) -> Result { fn compute_truncate32(x: f32, y: i64) -> f32 { let factor = 10.0_f32.powi(y as i32); - (x * factor).round() / factor + (x * factor).trunc() / factor } fn compute_truncate64(x: f64, y: i64) -> f64 { let factor = 10.0_f64.powi(y as i32); - (x * factor).round() / factor + (x * factor).trunc() / factor } #[cfg(test)] @@ -238,9 +281,9 @@ mod test { assert_eq!(floats.len(), 5); assert_eq!(floats.value(0), 15.0); - assert_eq!(floats.value(1), 1_234.268); + assert_eq!(floats.value(1), 1_234.267); assert_eq!(floats.value(2), 1_233.12); - assert_eq!(floats.value(3), 3.312_98); + assert_eq!(floats.value(3), 3.312_97); assert_eq!(floats.value(4), -21.123_4); } @@ -263,9 +306,9 @@ mod test { assert_eq!(floats.len(), 5); assert_eq!(floats.value(0), 5.0); - assert_eq!(floats.value(1), 234.268); + assert_eq!(floats.value(1), 234.267); assert_eq!(floats.value(2), 123.12); - assert_eq!(floats.value(3), 123.312_98); + assert_eq!(floats.value(3), 123.312_97); assert_eq!(floats.value(4), -321.123_1); } diff --git a/datafusion/functions/src/regex/regexplike.rs b/datafusion/functions/src/regex/regexplike.rs index f707c8e0d8c7..b2df38a679ae 100644 --- a/datafusion/functions/src/regex/regexplike.rs +++ b/datafusion/functions/src/regex/regexplike.rs @@ -17,7 +17,7 @@ //! Regex expressions -use arrow::array::{Array, ArrayRef, AsArray, GenericStringArray}; +use arrow::array::{Array, ArrayRef, AsArray, BooleanArray, GenericStringArray}; use arrow::compute::kernels::regexp; use arrow::datatypes::DataType; use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View}; @@ -31,9 +31,10 @@ use datafusion_expr::{ }; use datafusion_macros::user_doc; -use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr_common::operator::Operator; use datafusion_expr_common::type_coercion::binary::BinaryTypeCoercer; +use regex::Regex; use std::any::Any; use std::sync::Arc; @@ -130,35 +131,52 @@ impl ScalarUDFImpl for RegexpLikeFunc { args: datafusion_expr::ScalarFunctionArgs, ) -> Result { let args = &args.args; - - let len = args - .iter() - .fold(Option::::None, |acc, arg| match arg { - ColumnarValue::Scalar(_) => acc, - ColumnarValue::Array(a) => Some(a.len()), - }); - - let is_scalar = len.is_none(); - let inferred_length = len.unwrap_or(1); - let args = args - .iter() - .map(|arg| arg.to_array(inferred_length)) - .collect::>>()?; - - let result = regexp_like(&args); - if is_scalar { - // If all inputs are scalar, keeps output as scalar - let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); - result.map(ColumnarValue::Scalar) - } else { - result.map(ColumnarValue::Array) + match args.as_slice() { + [ColumnarValue::Scalar(value), ColumnarValue::Scalar(pattern)] => { + let value = scalar_string(value)?; + let pattern = scalar_string(pattern)?; + regexp_like_scalar(value, pattern, None) + } + [ + ColumnarValue::Scalar(value), + ColumnarValue::Scalar(pattern), + ColumnarValue::Scalar(flags), + ] => { + let value = scalar_string(value)?; + let pattern = scalar_string(pattern)?; + let flags = scalar_string(flags)?; + regexp_like_scalar(value, pattern, flags) + } + [ColumnarValue::Array(values), ColumnarValue::Scalar(pattern)] => { + let pattern = scalar_string(pattern)?; + let array = regexp_like_array_scalar(values, pattern, None)?; + Ok(ColumnarValue::Array(array)) + } + [ + ColumnarValue::Array(values), + ColumnarValue::Scalar(pattern), + ColumnarValue::Scalar(flags), + ] => { + let flags = scalar_string(flags)?; + if flags.is_some_and(|flagz| flagz.contains('g')) { + plan_err!("regexp_like() does not support the \"global\" option") + } else { + let pattern = scalar_string(pattern)?; + let array = regexp_like_array_scalar(values, pattern, flags)?; + Ok(ColumnarValue::Array(array)) + } + } + _ => { + let args = ColumnarValue::values_to_arrays(args)?; + regexp_like(&args).map(ColumnarValue::Array) + } } } fn simplify( &self, mut args: Vec, - info: &dyn SimplifyInfo, + info: &SimplifyContext, ) -> Result { // Try to simplify regexp_like usage to one of the builtin operators since those have // optimized code paths for the case where the regular expression pattern is a scalar. @@ -302,7 +320,10 @@ pub fn regexp_like(args: &[ArrayRef]) -> Result { } }; - if flags.iter().any(|s| s == Some("g")) { + if flags + .iter() + .any(|s| s.is_some_and(|flagz| flagz.contains('g'))) + { return plan_err!("regexp_like() does not support the \"global\" option"); } @@ -314,6 +335,83 @@ pub fn regexp_like(args: &[ArrayRef]) -> Result { } } +fn scalar_string(value: &ScalarValue) -> Result> { + match value.try_as_str() { + Some(v) => Ok(v), + None => internal_err!( + "Unsupported data type {:?} for function `regexp_like`", + value.data_type() + ), + } +} + +fn regexp_like_array_scalar( + values: &ArrayRef, + pattern: Option<&str>, + flags: Option<&str>, +) -> Result { + use DataType::*; + + let Some(pattern) = pattern else { + return Ok(Arc::new(BooleanArray::new_null(values.len()))); + }; + let array = match values.data_type() { + Utf8 => { + let array = values.as_string::(); + regexp::regexp_is_match_scalar(array, pattern, flags)? + } + Utf8View => { + let array = values.as_string_view(); + regexp::regexp_is_match_scalar(array, pattern, flags)? + } + LargeUtf8 => { + let array = values.as_string::(); + regexp::regexp_is_match_scalar(array, pattern, flags)? + } + other => { + return internal_err!( + "Unsupported data type {other:?} for function `regexp_like`" + ); + } + }; + + Ok(Arc::new(array)) +} + +fn regexp_like_scalar( + value: Option<&str>, + pattern: Option<&str>, + flags: Option<&str>, +) -> Result { + if flags.is_some_and(|flagz| flagz.contains('g')) { + return plan_err!("regexp_like() does not support the \"global\" option"); + } + + if value.is_none() || pattern.is_none() { + return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))); + } + + let value = value.unwrap(); + let pattern = pattern.unwrap(); + let pattern = match flags { + Some(flagz) => format!("(?{flagz}){pattern}"), + None => pattern.to_string(), + }; + + let result = if pattern.is_empty() { + true + } else { + let re = Regex::new(pattern.as_str()).map_err(|e| { + datafusion_common::DataFusionError::Execution(format!( + "Regular expression did not compile: {e:?}" + )) + })?; + re.is_match(value) + }; + + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(result)))) +} + fn handle_regexp_like( values: &ArrayRef, patterns: &ArrayRef, @@ -356,7 +454,7 @@ fn handle_regexp_like( .map_err(|e| arrow_datafusion_err!(e))? } (Utf8, LargeUtf8) => { - let value = values.as_string_view(); + let value = values.as_string::(); let pattern = patterns.as_string::(); regexp::regexp_is_match(value, pattern, flags) @@ -399,8 +497,37 @@ mod tests { use arrow::array::StringArray; use arrow::array::{BooleanBuilder, StringViewArray}; + use arrow::datatypes::{DataType, Field}; + use datafusion_common::config::ConfigOptions; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; - use crate::regex::regexplike::regexp_like; + use crate::regex::regexplike::{RegexpLikeFunc, regexp_like}; + + fn invoke_regexp_like(args: Vec) -> Result { + let number_rows = args + .iter() + .find_map(|arg| match arg { + ColumnarValue::Array(array) => Some(array.len()), + _ => None, + }) + .unwrap_or(1); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Arc::new(Field::new(format!("arg_{idx}"), arg.data_type(), true)) + }) + .collect::>(); + + RegexpLikeFunc::new().invoke_with_args(ScalarFunctionArgs { + args, + arg_fields, + number_rows, + return_field: Arc::new(Field::new("f", DataType::Boolean, true)), + config_options: Arc::new(ConfigOptions::default()), + }) + } #[test] fn test_case_sensitive_regexp_like_utf8() { @@ -499,4 +626,66 @@ mod tests { "Error during planning: regexp_like() does not support the \"global\" option" ); } + + #[test] + fn test_regexp_like_scalar_invoke() { + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("foobarbequebaz".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("(bar)(beque)".to_string()))), + ]; + let result = invoke_regexp_like(args).unwrap(); + match result { + ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))) => {} + other => panic!("Unexpected result {other:?}"), + } + } + + #[test] + fn test_regexp_like_array_scalar_invoke() { + let values = Arc::new(StringArray::from(vec!["abc", "xyz"])); + let args = vec![ + ColumnarValue::Array(values), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("^(a)".to_string()))), + ]; + let result = invoke_regexp_like(args).unwrap(); + let mut expected_builder = BooleanBuilder::new(); + expected_builder.append_value(true); + expected_builder.append_value(false); + let expected = expected_builder.finish(); + match result { + ColumnarValue::Array(array) => { + assert_eq!(array.as_ref(), &expected); + } + other => panic!("Unexpected result {other:?}"), + } + } + + #[test] + fn test_regexp_like_scalar_flags_with_global() { + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("abc".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("^(a)".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("ig".to_string()))), + ]; + let err = invoke_regexp_like(args).expect_err("global flag should be rejected"); + assert_eq!( + err.strip_backtrace(), + "Error during planning: regexp_like() does not support the \"global\" option" + ); + } + + #[test] + fn test_regexp_like_array_scalar_flags_with_global() { + let values = Arc::new(StringArray::from(vec!["abc", "xyz"])); + let args = vec![ + ColumnarValue::Array(values), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("^(a)".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("ig".to_string()))), + ]; + let err = invoke_regexp_like(args).expect_err("global flag should be rejected"); + assert_eq!( + err.strip_backtrace(), + "Error during planning: regexp_like() does not support the \"global\" option" + ); + } } diff --git a/datafusion/functions/src/regex/regexpreplace.rs b/datafusion/functions/src/regex/regexpreplace.rs index b5ab46f0ec71..68e324e21c89 100644 --- a/datafusion/functions/src/regex/regexpreplace.rs +++ b/datafusion/functions/src/regex/regexpreplace.rs @@ -189,13 +189,19 @@ fn regexp_replace_func(args: &[ColumnarValue]) -> Result { } } -/// replace POSIX capture groups (like \1) with Rust Regex group (like ${1}) +/// replace POSIX capture groups (like \1 or \\1) with Rust Regex group (like ${1}) /// used by regexp_replace +/// Handles both single backslash (\1) and double backslash (\\1) which can occur +/// when SQL strings with escaped backslashes are passed through +/// +/// Note: \0 is converted to ${0}, which in Rust's regex replacement syntax +/// substitutes the entire match. This is consistent with POSIX behavior where +/// \0 (or &) refers to the entire matched string. fn regex_replace_posix_groups(replacement: &str) -> String { static CAPTURE_GROUPS_RE_LOCK: LazyLock = - LazyLock::new(|| Regex::new(r"(\\)(\d*)").unwrap()); + LazyLock::new(|| Regex::new(r"\\{1,2}(\d+)").unwrap()); CAPTURE_GROUPS_RE_LOCK - .replace_all(replacement, "$${$2}") + .replace_all(replacement, "$${$1}") .into_owned() } @@ -659,6 +665,42 @@ mod tests { use super::*; + #[test] + fn test_regex_replace_posix_groups() { + // Test that \1, \2, etc. are replaced with ${1}, ${2}, etc. + assert_eq!(regex_replace_posix_groups(r"\1"), "${1}"); + assert_eq!(regex_replace_posix_groups(r"\12"), "${12}"); + assert_eq!(regex_replace_posix_groups(r"X\1Y"), "X${1}Y"); + assert_eq!(regex_replace_posix_groups(r"\1\2"), "${1}${2}"); + + // Test double backslash (from SQL escaped strings like '\\1') + assert_eq!(regex_replace_posix_groups(r"\\1"), "${1}"); + assert_eq!(regex_replace_posix_groups(r"X\\1Y"), "X${1}Y"); + assert_eq!(regex_replace_posix_groups(r"\\1\\2"), "${1}${2}"); + + // Test 3 or 4 backslashes before digits to document expected behavior + assert_eq!(regex_replace_posix_groups(r"\\\1"), r"\${1}"); + assert_eq!(regex_replace_posix_groups(r"\\\\1"), r"\\${1}"); + assert_eq!(regex_replace_posix_groups(r"\\\1\\\\2"), r"\${1}\\${2}"); + + // Test that a lone backslash is NOT replaced (requires at least one digit) + assert_eq!(regex_replace_posix_groups(r"\"), r"\"); + assert_eq!(regex_replace_posix_groups(r"foo\bar"), r"foo\bar"); + + // Test that backslash followed by non-digit is preserved + assert_eq!(regex_replace_posix_groups(r"\n"), r"\n"); + assert_eq!(regex_replace_posix_groups(r"\t"), r"\t"); + + // Test \0 behavior: \0 is converted to ${0}, which in Rust's regex + // replacement syntax substitutes the entire match. This is consistent + // with POSIX behavior where \0 (or &) refers to the entire matched string. + assert_eq!(regex_replace_posix_groups(r"\0"), "${0}"); + assert_eq!( + regex_replace_posix_groups(r"prefix\0suffix"), + "prefix${0}suffix" + ); + } + macro_rules! static_pattern_regexp_replace { ($name:ident, $T:ty, $O:ty) => { #[test] diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index fe3c508edea0..bfd035ed3c0d 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -15,12 +15,12 @@ // specific language governing permissions and limitations // under the License. -use crate::utils::make_scalar_function; use arrow::array::{ArrayRef, AsArray, Int32Array, StringArrayType}; use arrow::datatypes::DataType; use arrow::error::ArrowError; use datafusion_common::types::logical_string; -use datafusion_common::{Result, internal_err}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, internal_err}; use datafusion_expr::{ColumnarValue, Documentation, TypeSignatureClass}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use datafusion_expr_common::signature::Coercion; @@ -91,7 +91,31 @@ impl ScalarUDFImpl for AsciiFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(ascii, vec![])(&args.args) + let [arg] = take_function_args(self.name(), args.args)?; + + match arg { + ColumnarValue::Scalar(scalar) => { + if scalar.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Int32(None))); + } + + match scalar { + ScalarValue::Utf8(Some(s)) + | ScalarValue::LargeUtf8(Some(s)) + | ScalarValue::Utf8View(Some(s)) => { + let result = s.chars().next().map_or(0, |c| c as i32); + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(result)))) + } + _ => { + internal_err!( + "Unexpected data type {:?} for function ascii", + scalar.data_type() + ) + } + } + } + ColumnarValue::Array(array) => Ok(ColumnarValue::Array(ascii(&[array])?)), + } } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/string/btrim.rs b/datafusion/functions/src/string/btrim.rs index 3ca5db3c49a8..beea527f6d0b 100644 --- a/datafusion/functions/src/string/btrim.rs +++ b/datafusion/functions/src/string/btrim.rs @@ -30,7 +30,7 @@ use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; -/// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed. +/// Returns the longest string with leading and trailing characters removed. If the characters are not specified, spaces are removed. /// btrim('xyxtrimyyx', 'xyz') = 'trim' fn btrim(args: &[ArrayRef]) -> Result { let use_string_view = args[0].data_type() == &DataType::Utf8View; @@ -45,7 +45,7 @@ fn btrim(args: &[ArrayRef]) -> Result { #[user_doc( doc_section(label = "String Functions"), - description = "Trims the specified trim string from the start and end of a string. If no trim string is provided, all whitespace is removed from the start and end of the input string.", + description = "Trims the specified trim string from the start and end of a string. If no trim string is provided, all spaces are removed from the start and end of the input string.", syntax_example = "btrim(str[, trim_str])", sql_example = r#"```sql > select btrim('__datafusion____', '_'); @@ -58,7 +58,7 @@ fn btrim(args: &[ArrayRef]) -> Result { standard_argument(name = "str", prefix = "String"), argument( name = "trim_str", - description = r"String expression to operate on. Can be a constant, column, or function, and any combination of operators. _Default is whitespace characters._" + description = r"String expression to operate on. Can be a constant, column, or function, and any combination of operators. _Default is a space._" ), alternative_syntax = "trim(BOTH trim_str FROM str)", alternative_syntax = "trim(trim_str FROM str)", diff --git a/datafusion/functions/src/string/chr.rs b/datafusion/functions/src/string/chr.rs index ba011b94367e..2f432c838e01 100644 --- a/datafusion/functions/src/string/chr.rs +++ b/datafusion/functions/src/string/chr.rs @@ -18,24 +18,21 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::ArrayRef; -use arrow::array::GenericStringBuilder; +use arrow::array::{ArrayRef, GenericStringBuilder, Int64Array}; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Int64; use arrow::datatypes::DataType::Utf8; -use crate::utils::make_scalar_function; use datafusion_common::cast::as_int64_array; -use datafusion_common::{Result, exec_err}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; /// Returns the character with the given code. /// chr(65) = 'A' -fn chr(args: &[ArrayRef]) -> Result { - let integer_array = as_int64_array(&args[0])?; - +fn chr_array(integer_array: &Int64Array) -> Result { let mut builder = GenericStringBuilder::::with_capacity( integer_array.len(), // 1 byte per character, assuming that is the common case @@ -56,15 +53,11 @@ fn chr(args: &[ArrayRef]) -> Result { return exec_err!("invalid Unicode scalar value: {integer}"); } - None => { - builder.append_null(); - } + None => builder.append_null(), } } - let result = builder.finish(); - - Ok(Arc::new(result) as ArrayRef) + Ok(Arc::new(builder.finish()) as ArrayRef) } #[user_doc( @@ -119,7 +112,32 @@ impl ScalarUDFImpl for ChrFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(chr, vec![])(&args.args) + let [arg] = take_function_args(self.name(), args.args)?; + + match arg { + ColumnarValue::Scalar(ScalarValue::Int64(Some(code_point))) => { + if let Ok(u) = u32::try_from(code_point) + && let Some(c) = core::char::from_u32(u) + { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + c.to_string(), + )))) + } else { + exec_err!("invalid Unicode scalar value: {code_point}") + } + } + ColumnarValue::Scalar(ScalarValue::Int64(None)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) + } + ColumnarValue::Array(array) => { + let integer_array = as_int64_array(&array)?; + Ok(ColumnarValue::Array(chr_array(integer_array)?)) + } + other => internal_err!( + "Unexpected data type {:?} for function chr", + other.data_type() + ), + } } fn documentation(&self) -> Option<&Documentation> { @@ -130,13 +148,27 @@ impl ScalarUDFImpl for ChrFunc { #[cfg(test)] mod tests { use super::*; + use arrow::array::{Array, Int64Array, StringArray}; + use arrow::datatypes::Field; use datafusion_common::assert_contains; + use datafusion_common::config::ConfigOptions; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; + + fn invoke_chr(arg: ColumnarValue, number_rows: usize) -> Result { + ChrFunc::new().invoke_with_args(ScalarFunctionArgs { + args: vec![arg], + arg_fields: vec![Field::new("a", Int64, true).into()], + number_rows, + return_field: Field::new("f", Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }) + } #[test] fn test_chr_normal() { let input = Arc::new(Int64Array::from(vec![ - Some(0), // null + Some(0), // \u{0000} Some(65), // A Some(66), // B Some(67), // C @@ -149,8 +181,13 @@ mod tests { Some(9), // tab Some(0x10FFFF), // 0x10FFFF, the largest Unicode code point ])); - let result = chr(&[input]).unwrap(); - let string_array = result.as_any().downcast_ref::().unwrap(); + + let result = invoke_chr(ColumnarValue::Array(input), 12).unwrap(); + let ColumnarValue::Array(arr) = result else { + panic!("Expected array"); + }; + let string_array = arr.as_any().downcast_ref::().unwrap(); + let expected = [ "\u{0000}", "A", @@ -174,55 +211,48 @@ mod tests { #[test] fn test_chr_error() { - // invalid Unicode code points (too large) let input = Arc::new(Int64Array::from(vec![i64::MAX])); - let result = chr(&[input]); + let result = invoke_chr(ColumnarValue::Array(input), 1); assert!(result.is_err()); assert_contains!( result.err().unwrap().to_string(), "invalid Unicode scalar value: 9223372036854775807" ); - // invalid Unicode code points (too large) case 2 let input = Arc::new(Int64Array::from(vec![0x10FFFF + 1])); - let result = chr(&[input]); + let result = invoke_chr(ColumnarValue::Array(input), 1); assert!(result.is_err()); assert_contains!( result.err().unwrap().to_string(), "invalid Unicode scalar value: 1114112" ); - // invalid Unicode code points (surrogate code point) - // link: let input = Arc::new(Int64Array::from(vec![0xD800 + 1])); - let result = chr(&[input]); + let result = invoke_chr(ColumnarValue::Array(input), 1); assert!(result.is_err()); assert_contains!( result.err().unwrap().to_string(), "invalid Unicode scalar value: 55297" ); - // negative input - let input = Arc::new(Int64Array::from(vec![i64::MIN + 2i64])); // will be 2 if cast to u32 - let result = chr(&[input]); + let input = Arc::new(Int64Array::from(vec![i64::MIN + 2i64])); + let result = invoke_chr(ColumnarValue::Array(input), 1); assert!(result.is_err()); assert_contains!( result.err().unwrap().to_string(), "invalid Unicode scalar value: -9223372036854775806" ); - // negative input case 2 let input = Arc::new(Int64Array::from(vec![-1])); - let result = chr(&[input]); + let result = invoke_chr(ColumnarValue::Array(input), 1); assert!(result.is_err()); assert_contains!( result.err().unwrap().to_string(), "invalid Unicode scalar value: -1" ); - // one error with valid values after - let input = Arc::new(Int64Array::from(vec![65, -1, 66])); // A, -1, B - let result = chr(&[input]); + let input = Arc::new(Int64Array::from(vec![65, -1, 66])); + let result = invoke_chr(ColumnarValue::Array(input), 3); assert!(result.is_err()); assert_contains!( result.err().unwrap().to_string(), @@ -232,10 +262,36 @@ mod tests { #[test] fn test_chr_empty() { - // empty input array let input = Arc::new(Int64Array::from(Vec::::new())); - let result = chr(&[input]).unwrap(); - let string_array = result.as_any().downcast_ref::().unwrap(); + let result = invoke_chr(ColumnarValue::Array(input), 0).unwrap(); + let ColumnarValue::Array(arr) = result else { + panic!("Expected array"); + }; + let string_array = arr.as_any().downcast_ref::().unwrap(); assert_eq!(string_array.len(), 0); } + + #[test] + fn test_chr_scalar() { + let result = + invoke_chr(ColumnarValue::Scalar(ScalarValue::Int64(Some(65))), 1).unwrap(); + + match result { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { + assert_eq!(s, "A"); + } + other => panic!("Unexpected result: {other:?}"), + } + } + + #[test] + fn test_chr_scalar_null() { + let result = + invoke_chr(ColumnarValue::Scalar(ScalarValue::Int64(None)), 1).unwrap(); + + match result { + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {} + other => panic!("Unexpected result: {other:?}"), + } + } } diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index 4a775c2744ea..77af82e25c48 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -38,6 +38,22 @@ use datafusion_expr::ColumnarValue; /// from the beginning of the input string where the trimmed result starts. pub(crate) trait Trimmer { fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32); + + /// Optimized trim for a single ASCII byte. + /// Uses byte-level scanning instead of char-level iteration. + fn trim_ascii_char(input: &str, byte: u8) -> (&str, u32); +} + +/// Returns the number of leading bytes matching `byte` +#[inline] +fn leading_bytes(bytes: &[u8], byte: u8) -> usize { + bytes.iter().take_while(|&&b| b == byte).count() +} + +/// Returns the number of trailing bytes matching `byte` +#[inline] +fn trailing_bytes(bytes: &[u8], byte: u8) -> usize { + bytes.iter().rev().take_while(|&&b| b == byte).count() } /// Left trim - removes leading characters @@ -46,10 +62,19 @@ pub(crate) struct TrimLeft; impl Trimmer for TrimLeft { #[inline] fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32) { + if pattern.len() == 1 && pattern[0].is_ascii() { + return Self::trim_ascii_char(input, pattern[0] as u8); + } let trimmed = input.trim_start_matches(pattern); let offset = (input.len() - trimmed.len()) as u32; (trimmed, offset) } + + #[inline] + fn trim_ascii_char(input: &str, byte: u8) -> (&str, u32) { + let start = leading_bytes(input.as_bytes(), byte); + (&input[start..], start as u32) + } } /// Right trim - removes trailing characters @@ -58,9 +83,19 @@ pub(crate) struct TrimRight; impl Trimmer for TrimRight { #[inline] fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32) { + if pattern.len() == 1 && pattern[0].is_ascii() { + return Self::trim_ascii_char(input, pattern[0] as u8); + } let trimmed = input.trim_end_matches(pattern); (trimmed, 0) } + + #[inline] + fn trim_ascii_char(input: &str, byte: u8) -> (&str, u32) { + let bytes = input.as_bytes(); + let end = bytes.len() - trailing_bytes(bytes, byte); + (&input[..end], 0) + } } /// Both trim - removes both leading and trailing characters @@ -69,11 +104,22 @@ pub(crate) struct TrimBoth; impl Trimmer for TrimBoth { #[inline] fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32) { + if pattern.len() == 1 && pattern[0].is_ascii() { + return Self::trim_ascii_char(input, pattern[0] as u8); + } let left_trimmed = input.trim_start_matches(pattern); let offset = (input.len() - left_trimmed.len()) as u32; let trimmed = left_trimmed.trim_end_matches(pattern); (trimmed, offset) } + + #[inline] + fn trim_ascii_char(input: &str, byte: u8) -> (&str, u32) { + let bytes = input.as_bytes(); + let start = leading_bytes(bytes, byte); + let end = bytes.len() - trailing_bytes(&bytes[start..], byte); + (&input[start..end], start as u32) + } } pub(crate) fn general_trim( @@ -99,19 +145,24 @@ fn string_view_trim(args: &[ArrayRef]) -> Result { match args.len() { 1 => { - // Default whitespace trim - pattern is just space - let pattern = [' ']; + // Trim spaces by default for (src_str_opt, raw_view) in string_view_array .iter() .zip(string_view_array.views().iter()) { - trim_and_append_view::( - src_str_opt, - &pattern, - &mut views_buf, - &mut null_builder, - raw_view, - ); + if let Some(src_str) = src_str_opt { + let (trimmed, offset) = Tr::trim_ascii_char(src_str, b' '); + make_and_append_view( + &mut views_buf, + &mut null_builder, + raw_view, + trimmed, + offset, + ); + } else { + null_builder.append_null(); + views_buf.push(0); + } } } 2 => { @@ -141,6 +192,7 @@ fn string_view_trim(args: &[ArrayRef]) -> Result { } } else { // Per-row pattern - must compute pattern chars for each row + let mut pattern: Vec = Vec::new(); for ((src_str_opt, raw_view), characters_opt) in string_view_array .iter() .zip(string_view_array.views().iter()) @@ -149,7 +201,8 @@ fn string_view_trim(args: &[ArrayRef]) -> Result { if let (Some(src_str), Some(characters)) = (src_str_opt, characters_opt) { - let pattern: Vec = characters.chars().collect(); + pattern.clear(); + pattern.extend(characters.chars()); let (trimmed, offset) = Tr::trim(src_str, &pattern); make_and_append_view( &mut views_buf, @@ -225,11 +278,10 @@ fn string_trim(args: &[ArrayRef]) -> Result { - // Default whitespace trim - pattern is just space - let pattern = [' ']; + // Trim spaces by default let result = string_array .iter() - .map(|string| string.map(|s| Tr::trim(s, &pattern).0)) + .map(|string| string.map(|s| Tr::trim_ascii_char(s, b' ').0)) .collect::>(); Ok(Arc::new(result) as ArrayRef) @@ -255,12 +307,14 @@ fn string_trim(args: &[ArrayRef]) -> Result = Vec::new(); let result = string_array .iter() .zip(characters_array.iter()) .map(|(string, characters)| match (string, characters) { (Some(s), Some(c)) => { - let pattern: Vec = c.chars().collect(); + pattern.clear(); + pattern.extend(c.chars()); Some(Tr::trim(s, &pattern).0) } _ => None, diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 42d455a05760..e67454125328 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -28,7 +28,7 @@ use crate::strings::{ use datafusion_common::cast::{as_string_array, as_string_view_array}; use datafusion_common::{Result, ScalarValue, internal_err, plan_err}; use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr::{ColumnarValue, Documentation, Expr, Volatility, lit}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; @@ -120,24 +120,21 @@ impl ScalarUDFImpl for ConcatFunc { } }); - let array_len = args - .iter() - .filter_map(|x| match x { - ColumnarValue::Array(array) => Some(array.len()), - _ => None, - }) - .next(); + let array_len = args.iter().find_map(|x| match x { + ColumnarValue::Array(array) => Some(array.len()), + _ => None, + }); // Scalar if array_len.is_none() { - let mut result = String::new(); - for arg in args { + let mut values = Vec::with_capacity(args.len()); + for arg in &args { let ColumnarValue::Scalar(scalar) = arg else { return internal_err!("concat expected scalar value, got {arg:?}"); }; match scalar.try_as_str() { - Some(Some(v)) => result.push_str(v), + Some(Some(v)) => values.push(v), Some(None) => {} // null literal None => plan_err!( "Concat function does not support scalar type {}", @@ -145,6 +142,7 @@ impl ScalarUDFImpl for ConcatFunc { )?, } } + let result = values.concat(); return match return_datatype { DataType::Utf8View => { @@ -206,7 +204,9 @@ impl ScalarUDFImpl for ConcatFunc { DataType::Utf8View => { let string_array = as_string_view_array(array)?; - data_size += string_array.len(); + // This is an estimate; in particular, it will + // undercount arrays of short strings (<= 12 bytes). + data_size += string_array.total_buffer_bytes_used(); let column = if array.is_nullable() { ColumnarValueRef::NullableStringViewArray(string_array) } else { @@ -277,7 +277,7 @@ impl ScalarUDFImpl for ConcatFunc { fn simplify( &self, args: Vec, - _info: &dyn SimplifyInfo, + _info: &SimplifyContext, ) -> Result { simplify_concat(args) } diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 8fe095c5ce2b..9d3b32eedf8f 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, StringArray, as_largestring_array}; +use arrow::array::Array; use std::any::Any; use std::sync::Arc; @@ -25,10 +25,12 @@ use crate::string::concat; use crate::string::concat::simplify_concat; use crate::string::concat_ws; use crate::strings::{ColumnarValueRef, StringArrayBuilder}; -use datafusion_common::cast::{as_string_array, as_string_view_array}; +use datafusion_common::cast::{ + as_large_string_array, as_string_array, as_string_view_array, +}; use datafusion_common::{Result, ScalarValue, exec_err, internal_err, plan_err}; use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr::{ColumnarValue, Documentation, Expr, Volatility, lit}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; @@ -105,7 +107,6 @@ impl ScalarUDFImpl for ConcatWsFunc { fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let ScalarFunctionArgs { args, .. } = args; - // do not accept 0 arguments. if args.len() < 2 { return exec_err!( "concat_ws was called with {} arguments. It requires at least 2.", @@ -113,18 +114,14 @@ impl ScalarUDFImpl for ConcatWsFunc { ); } - let array_len = args - .iter() - .filter_map(|x| match x { - ColumnarValue::Array(array) => Some(array.len()), - _ => None, - }) - .next(); + let array_len = args.iter().find_map(|x| match x { + ColumnarValue::Array(array) => Some(array.len()), + _ => None, + }); // Scalar if array_len.is_none() { let ColumnarValue::Scalar(scalar) = &args[0] else { - // loop above checks for all args being scalar unreachable!() }; let sep = match scalar.try_as_str() { @@ -136,43 +133,21 @@ impl ScalarUDFImpl for ConcatWsFunc { None => return internal_err!("Expected string literal, got {scalar:?}"), }; - let mut result = String::new(); - // iterator over Option - let iter = &mut args[1..].iter().map(|arg| { + let mut values = Vec::with_capacity(args.len() - 1); + for arg in &args[1..] { let ColumnarValue::Scalar(scalar) = arg else { - // loop above checks for all args being scalar unreachable!() }; - scalar.try_as_str() - }); - - // append first non null arg - for scalar in iter.by_ref() { - match scalar { - Some(Some(s)) => { - result.push_str(s); - break; - } - Some(None) => {} // null literal string - None => { - return internal_err!("Expected string literal, got {scalar:?}"); - } - } - } - // handle subsequent non null args - for scalar in iter.by_ref() { - match scalar { - Some(Some(s)) => { - result.push_str(sep); - result.push_str(s); - } + match scalar.try_as_str() { + Some(Some(v)) => values.push(v), Some(None) => {} // null literal string None => { return internal_err!("Expected string literal, got {scalar:?}"); } } } + let result = values.join(sep); return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))); } @@ -183,23 +158,53 @@ impl ScalarUDFImpl for ConcatWsFunc { // parse sep let sep = match &args[0] { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { - data_size += s.len() * len * (args.len() - 2); // estimate - ColumnarValueRef::Scalar(s.as_bytes()) - } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { - return Ok(ColumnarValue::Array(Arc::new(StringArray::new_null(len)))); - } - ColumnarValue::Array(array) => { - let string_array = as_string_array(array)?; - data_size += string_array.values().len() * (args.len() - 2); // estimate - if array.is_nullable() { - ColumnarValueRef::NullableArray(string_array) - } else { - ColumnarValueRef::NonNullableArray(string_array) + ColumnarValue::Scalar(scalar) => match scalar.try_as_str() { + Some(Some(s)) => { + data_size += s.len() * len * (args.len() - 2); // estimate + ColumnarValueRef::Scalar(s.as_bytes()) } - } - _ => unreachable!("concat ws"), + Some(None) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } + None => { + return internal_err!("Expected string separator, got {scalar:?}"); + } + }, + ColumnarValue::Array(array) => match array.data_type() { + DataType::Utf8 => { + let string_array = as_string_array(array)?; + data_size += string_array.values().len() * (args.len() - 2); + if array.is_nullable() { + ColumnarValueRef::NullableArray(string_array) + } else { + ColumnarValueRef::NonNullableArray(string_array) + } + } + DataType::LargeUtf8 => { + let string_array = as_large_string_array(array)?; + data_size += string_array.values().len() * (args.len() - 2); + if array.is_nullable() { + ColumnarValueRef::NullableLargeStringArray(string_array) + } else { + ColumnarValueRef::NonNullableLargeStringArray(string_array) + } + } + DataType::Utf8View => { + let string_array = as_string_view_array(array)?; + data_size += + string_array.total_buffer_bytes_used() * (args.len() - 2); + if array.is_nullable() { + ColumnarValueRef::NullableStringViewArray(string_array) + } else { + ColumnarValueRef::NonNullableStringViewArray(string_array) + } + } + other => { + return plan_err!( + "Input was {other} which is not a supported datatype for concat_ws separator" + ); + } + }, }; let mut columns = Vec::with_capacity(args.len() - 1); @@ -227,7 +232,7 @@ impl ScalarUDFImpl for ConcatWsFunc { columns.push(column); } DataType::LargeUtf8 => { - let string_array = as_largestring_array(array); + let string_array = as_large_string_array(array)?; data_size += string_array.values().len(); let column = if array.is_nullable() { @@ -242,11 +247,9 @@ impl ScalarUDFImpl for ConcatWsFunc { DataType::Utf8View => { let string_array = as_string_view_array(array)?; - data_size += string_array - .data_buffers() - .iter() - .map(|buf| buf.len()) - .sum::(); + // This is an estimate; in particular, it will + // undercount arrays of short strings (<= 12 bytes). + data_size += string_array.total_buffer_bytes_used(); let column = if array.is_nullable() { ColumnarValueRef::NullableStringViewArray(string_array) } else { @@ -272,18 +275,14 @@ impl ScalarUDFImpl for ConcatWsFunc { continue; } - let mut iter = columns.iter(); - for column in iter.by_ref() { + let mut first = true; + for column in &columns { if column.is_valid(i) { + if !first { + builder.write::(&sep, i); + } builder.write::(column, i); - break; - } - } - - for column in iter { - if column.is_valid(i) { - builder.write::(&sep, i); - builder.write::(column, i); + first = false; } } @@ -301,7 +300,7 @@ impl ScalarUDFImpl for ConcatWsFunc { fn simplify( &self, args: Vec, - _info: &dyn SimplifyInfo, + _info: &SimplifyContext, ) -> Result { match &args[..] { [delimiter, vals @ ..] => simplify_concat_ws(delimiter, vals), @@ -567,4 +566,78 @@ mod tests { Ok(()) } + + #[test] + fn concat_ws_utf8view_scalar_separator() -> Result<()> { + let c0 = ColumnarValue::Scalar(ScalarValue::Utf8View(Some(",".to_string()))); + let c1 = + ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); + let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("x"), + None, + Some("z"), + ]))); + + let arg_fields = vec![ + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![c0, c1, c2], + arg_fields, + number_rows: 3, + return_field: Field::new("f", Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = ConcatWsFunc::new().invoke_with_args(args)?; + let expected = + Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef; + match &result { + ColumnarValue::Array(array) => { + assert_eq!(&expected, array); + } + _ => panic!("Expected array result"), + } + + Ok(()) + } + + #[test] + fn concat_ws_largeutf8_scalar_separator() -> Result<()> { + let c0 = ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(",".to_string()))); + let c1 = + ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); + let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("x"), + None, + Some("z"), + ]))); + + let arg_fields = vec![ + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![c0, c1, c2], + arg_fields, + number_rows: 3, + return_field: Field::new("f", Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = ConcatWsFunc::new().invoke_with_args(args)?; + let expected = + Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef; + match &result { + ColumnarValue::Array(array) => { + assert_eq!(&expected, array); + } + _ => panic!("Expected array result"), + } + + Ok(()) + } } diff --git a/datafusion/functions/src/string/ltrim.rs b/datafusion/functions/src/string/ltrim.rs index abdf83e2d781..f84b273b8d6b 100644 --- a/datafusion/functions/src/string/ltrim.rs +++ b/datafusion/functions/src/string/ltrim.rs @@ -31,7 +31,7 @@ use datafusion_expr::{ }; use datafusion_macros::user_doc; -/// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed. +/// Returns the longest string with leading characters removed. If the characters are not specified, spaces are removed. /// ltrim('zzzytest', 'xyz') = 'test' fn ltrim(args: &[ArrayRef]) -> Result { let use_string_view = args[0].data_type() == &DataType::Utf8View; @@ -46,7 +46,7 @@ fn ltrim(args: &[ArrayRef]) -> Result { #[user_doc( doc_section(label = "String Functions"), - description = "Trims the specified trim string from the beginning of a string. If no trim string is provided, all whitespace is removed from the start of the input string.", + description = "Trims the specified trim string from the beginning of a string. If no trim string is provided, spaces are removed from the start of the input string.", syntax_example = "ltrim(str[, trim_str])", sql_example = r#"```sql > select ltrim(' datafusion '); @@ -65,7 +65,7 @@ fn ltrim(args: &[ArrayRef]) -> Result { standard_argument(name = "str", prefix = "String"), argument( name = "trim_str", - description = r"String expression to trim from the beginning of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._" + description = r"String expression to trim from the beginning of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is a space._" ), alternative_syntax = "trim(LEADING trim_str FROM str)", related_udf(name = "btrim"), diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index 2ca5e190c6e0..65f320c4f9f1 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -18,16 +18,17 @@ use std::any::Any; use std::sync::Arc; -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use crate::utils::utf8_to_str_type; use arrow::array::{ - ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array, + Array, ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array, OffsetSizeTrait, StringArrayType, StringViewArray, }; use arrow::datatypes::DataType; use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View}; use datafusion_common::cast::as_int64_array; use datafusion_common::types::{NativeType, logical_int64, logical_string}; -use datafusion_common::{DataFusionError, Result, exec_err}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{DataFusionError, Result, ScalarValue, exec_err, internal_err}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; @@ -99,7 +100,63 @@ impl ScalarUDFImpl for RepeatFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(repeat, vec![])(&args.args) + let return_type = args.return_field.data_type().clone(); + let [string_arg, count_arg] = take_function_args(self.name(), args.args)?; + + // Early return if either argument is a scalar null + if let ColumnarValue::Scalar(s) = &string_arg + && s.is_null() + { + return Ok(ColumnarValue::Scalar(ScalarValue::try_from(&return_type)?)); + } + if let ColumnarValue::Scalar(c) = &count_arg + && c.is_null() + { + return Ok(ColumnarValue::Scalar(ScalarValue::try_from(&return_type)?)); + } + + match (&string_arg, &count_arg) { + ( + ColumnarValue::Scalar(string_scalar), + ColumnarValue::Scalar(count_scalar), + ) => { + let count = match count_scalar { + ScalarValue::Int64(Some(n)) => *n, + _ => { + return internal_err!( + "Unexpected data type {:?} for repeat count", + count_scalar.data_type() + ); + } + }; + + let result = match string_scalar { + ScalarValue::Utf8(Some(s)) | ScalarValue::Utf8View(Some(s)) => { + ScalarValue::Utf8(Some(compute_repeat( + s, + count, + i32::MAX as usize, + )?)) + } + ScalarValue::LargeUtf8(Some(s)) => ScalarValue::LargeUtf8(Some( + compute_repeat(s, count, i64::MAX as usize)?, + )), + _ => { + return internal_err!( + "Unexpected data type {:?} for function repeat", + string_scalar.data_type() + ); + } + }; + + Ok(ColumnarValue::Scalar(result)) + } + _ => { + let string_array = string_arg.to_array(args.number_rows)?; + let count_array = count_arg.to_array(args.number_rows)?; + Ok(ColumnarValue::Array(repeat(&string_array, &count_array)?)) + } + } } fn documentation(&self) -> Option<&Documentation> { @@ -107,13 +164,30 @@ impl ScalarUDFImpl for RepeatFunc { } } +/// Computes repeat for a single string value with max size check +#[inline] +fn compute_repeat(s: &str, count: i64, max_size: usize) -> Result { + if count <= 0 { + return Ok(String::new()); + } + let result_len = s.len().saturating_mul(count as usize); + if result_len > max_size { + return exec_err!( + "string size overflow on repeat, max size is {}, but got {}", + max_size, + result_len + ); + } + Ok(s.repeat(count as usize)) +} + /// Repeats string the specified number of times. /// repeat('Pg', 4) = 'PgPgPgPg' -fn repeat(args: &[ArrayRef]) -> Result { - let number_array = as_int64_array(&args[1])?; - match args[0].data_type() { +fn repeat(string_array: &ArrayRef, count_array: &ArrayRef) -> Result { + let number_array = as_int64_array(count_array)?; + match string_array.data_type() { Utf8View => { - let string_view_array = args[0].as_string_view(); + let string_view_array = string_array.as_string_view(); repeat_impl::( &string_view_array, number_array, @@ -121,17 +195,17 @@ fn repeat(args: &[ArrayRef]) -> Result { ) } Utf8 => { - let string_array = args[0].as_string::(); + let string_arr = string_array.as_string::(); repeat_impl::>( - &string_array, + &string_arr, number_array, i32::MAX as usize, ) } LargeUtf8 => { - let string_array = args[0].as_string::(); + let string_arr = string_array.as_string::(); repeat_impl::>( - &string_array, + &string_arr, number_array, i64::MAX as usize, ) @@ -150,7 +224,7 @@ fn repeat_impl<'a, T, S>( ) -> Result where T: OffsetSizeTrait, - S: StringArrayType<'a>, + S: StringArrayType<'a> + 'a, { let mut total_capacity = 0; let mut max_item_capacity = 0; @@ -181,37 +255,55 @@ where // Reusable buffer to avoid allocations in string.repeat() let mut buffer = Vec::::with_capacity(max_item_capacity); - string_array - .iter() - .zip(number_array.iter()) - .for_each(|(string, number)| { + // Helper function to repeat a string into a buffer using doubling strategy + // count must be > 0 + #[inline] + fn repeat_to_buffer(buffer: &mut Vec, string: &str, count: usize) { + buffer.clear(); + if !string.is_empty() { + let src = string.as_bytes(); + // Initial copy + buffer.extend_from_slice(src); + // Doubling strategy: copy what we have so far until we reach the target + while buffer.len() < src.len() * count { + let copy_len = buffer.len().min(src.len() * count - buffer.len()); + // SAFETY: we're copying valid UTF-8 bytes that we already verified + buffer.extend_from_within(..copy_len); + } + } + } + + // Fast path: no nulls in either array + if string_array.null_count() == 0 && number_array.null_count() == 0 { + for i in 0..string_array.len() { + // SAFETY: i is within bounds (0..len) and null_count() == 0 guarantees valid value + let string = unsafe { string_array.value_unchecked(i) }; + let count = number_array.value(i); + if count > 0 { + repeat_to_buffer(&mut buffer, string, count as usize); + // SAFETY: buffer contains valid UTF-8 since we only copy from a valid &str + builder.append_value(unsafe { std::str::from_utf8_unchecked(&buffer) }); + } else { + builder.append_value(""); + } + } + } else { + // Slow path: handle nulls + for (string, number) in string_array.iter().zip(number_array.iter()) { match (string, number) { - (Some(string), Some(number)) if number >= 0 => { - buffer.clear(); - let count = number as usize; - if count > 0 && !string.is_empty() { - let src = string.as_bytes(); - // Initial copy - buffer.extend_from_slice(src); - // Doubling strategy: copy what we have so far until we reach the target - while buffer.len() < src.len() * count { - let copy_len = - buffer.len().min(src.len() * count - buffer.len()); - // SAFETY: we're copying valid UTF-8 bytes that we already verified - buffer.extend_from_within(..copy_len); - } - } - // SAFETY: buffer contains valid UTF-8 since we only ever copy from a valid &str + (Some(string), Some(count)) if count > 0 => { + repeat_to_buffer(&mut buffer, string, count as usize); + // SAFETY: buffer contains valid UTF-8 since we only copy from a valid &str builder .append_value(unsafe { std::str::from_utf8_unchecked(&buffer) }); } (Some(_), Some(_)) => builder.append_value(""), _ => builder.append_null(), } - }); - let array = builder.finish(); + } + } - Ok(Arc::new(array) as ArrayRef) + Ok(Arc::new(builder.finish()) as ArrayRef) } #[cfg(test)] diff --git a/datafusion/functions/src/string/replace.rs b/datafusion/functions/src/string/replace.rs index 165e0634a6b8..458b86d0c6fb 100644 --- a/datafusion/functions/src/string/replace.rs +++ b/datafusion/functions/src/string/replace.rs @@ -228,19 +228,21 @@ fn replace_into_string(buffer: &mut String, string: &str, from: &str, to: &str) return; } - // Fast path for replacing a single ASCII character with another single ASCII character - // This matches Rust's str::replace() optimization and enables vectorization + // Fast path for replacing a single ASCII character with another single ASCII character. + // Extends the buffer's underlying Vec directly, for performance. if let ([from_byte], [to_byte]) = (from.as_bytes(), to.as_bytes()) && from_byte.is_ascii() && to_byte.is_ascii() { - // SAFETY: We're replacing ASCII with ASCII, which preserves UTF-8 validity - let replaced: Vec = string - .as_bytes() - .iter() - .map(|b| if *b == *from_byte { *to_byte } else { *b }) - .collect(); - buffer.push_str(unsafe { std::str::from_utf8_unchecked(&replaced) }); + // SAFETY: Replacing an ASCII byte with another ASCII byte preserves UTF-8 validity. + unsafe { + buffer.as_mut_vec().extend( + string + .as_bytes() + .iter() + .map(|&b| if b == *from_byte { *to_byte } else { b }), + ); + } return; } diff --git a/datafusion/functions/src/string/rtrim.rs b/datafusion/functions/src/string/rtrim.rs index 0916c514798c..5659d0acfd97 100644 --- a/datafusion/functions/src/string/rtrim.rs +++ b/datafusion/functions/src/string/rtrim.rs @@ -31,7 +31,7 @@ use datafusion_expr::{ }; use datafusion_macros::user_doc; -/// Returns the longest string with trailing characters removed. If the characters are not specified, whitespace is removed. +/// Returns the longest string with trailing characters removed. If the characters are not specified, spaces are removed. /// rtrim('testxxzx', 'xyz') = 'test' fn rtrim(args: &[ArrayRef]) -> Result { let use_string_view = args[0].data_type() == &DataType::Utf8View; @@ -46,7 +46,7 @@ fn rtrim(args: &[ArrayRef]) -> Result { #[user_doc( doc_section(label = "String Functions"), - description = "Trims the specified trim string from the end of a string. If no trim string is provided, all whitespace is removed from the end of the input string.", + description = "Trims the specified trim string from the end of a string. If no trim string is provided, all spaces are removed from the end of the input string.", syntax_example = "rtrim(str[, trim_str])", alternative_syntax = "trim(TRAILING trim_str FROM str)", sql_example = r#"```sql @@ -66,7 +66,7 @@ fn rtrim(args: &[ArrayRef]) -> Result { standard_argument(name = "str", prefix = "String"), argument( name = "trim_str", - description = "String expression to trim from the end of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._" + description = "String expression to trim from the end of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is a space._" ), related_udf(name = "btrim"), related_udf(name = "ltrim") diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs index d29d33a154d7..0bd197818e4e 100644 --- a/datafusion/functions/src/string/split_part.rs +++ b/datafusion/functions/src/string/split_part.rs @@ -25,7 +25,7 @@ use arrow::datatypes::DataType; use datafusion_common::ScalarValue; use datafusion_common::cast::as_int64_array; use datafusion_common::types::{NativeType, logical_int64, logical_string}; -use datafusion_common::{DataFusionError, Result, exec_err}; +use datafusion_common::{DataFusionError, Result, exec_datafusion_err, exec_err}; use datafusion_expr::{ Coercion, ColumnarValue, Documentation, TypeSignatureClass, Volatility, }; @@ -48,7 +48,10 @@ use std::sync::Arc; ```"#, standard_argument(name = "str", prefix = "String"), argument(name = "delimiter", description = "String or character to split on."), - argument(name = "pos", description = "Position of the part to return.") + argument( + name = "pos", + description = "Position of the part to return (counting from 1). Negative values count backward from the end of the string." + ) )] #[derive(Debug, PartialEq, Eq, Hash)] pub struct SplitPartFunc { @@ -219,22 +222,47 @@ where .try_for_each(|((string, delimiter), n)| -> Result<(), DataFusionError> { match (string, delimiter, n) { (Some(string), Some(delimiter), Some(n)) => { - let split_string: Vec<&str> = string.split(delimiter).collect(); - let len = split_string.len(); + let result = match n.cmp(&0) { + std::cmp::Ordering::Greater => { + // Positive index: use nth() to avoid collecting all parts + // This stops iteration as soon as we find the nth element + let idx: usize = (n - 1).try_into().map_err(|_| { + exec_datafusion_err!( + "split_part index {n} exceeds maximum supported value" + ) + })?; - let index = match n.cmp(&0) { - std::cmp::Ordering::Less => len as i64 + n, + if delimiter.is_empty() { + // Match PostgreSQL split_part behavior for empty delimiter: + // treat the input as a single field ("ab" -> ["ab"]), + // rather than Rust's split("") result (["", "a", "b", ""]). + (n == 1).then_some(string) + } else { + string.split(delimiter).nth(idx) + } + } + std::cmp::Ordering::Less => { + // Negative index: use rsplit().nth() to efficiently get from the end + // rsplit iterates in reverse, so -1 means first from rsplit (index 0) + let idx: usize = (n.unsigned_abs() - 1).try_into().map_err(|_| { + exec_datafusion_err!( + "split_part index {n} exceeds minimum supported value" + ) + })?; + if delimiter.is_empty() { + // Match PostgreSQL split_part behavior for empty delimiter: + // treat the input as a single field ("ab" -> ["ab"]), + // rather than Rust's split("") result (["", "a", "b", ""]). + (n == -1).then_some(string) + } else { + string.rsplit(delimiter).nth(idx) + } + } std::cmp::Ordering::Equal => { return exec_err!("field position must not be zero"); } - std::cmp::Ordering::Greater => n - 1, - } as usize; - - if index < len { - builder.append_value(split_string[index]); - } else { - builder.append_value(""); - } + }; + builder.append_value(result.unwrap_or("")); } _ => builder.append_null(), } @@ -314,6 +342,131 @@ mod tests { Utf8, StringArray ); + test_function!( + SplitPartFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( + "abc~@~def~@~ghi" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(i64::MIN))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + // Edge cases with delimiters + test_function!( + SplitPartFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(",")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ], + Ok(Some("a")), + &str, + Utf8, + StringArray + ); + test_function!( + SplitPartFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(",")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + SplitPartFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ], + Ok(Some("a,b")), + &str, + Utf8, + StringArray + ); + test_function!( + SplitPartFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + SplitPartFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(" ")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ], + Ok(Some("a,b")), + &str, + Utf8, + StringArray + ); + test_function!( + SplitPartFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(" ")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + + // Edge cases with delimiters with negative n + test_function!( + SplitPartFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))), + ], + Ok(Some("a,b")), + &str, + Utf8, + StringArray + ); + test_function!( + SplitPartFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(" ")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))), + ], + Ok(Some("a,b")), + &str, + Utf8, + StringArray + ); + test_function!( + SplitPartFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-2))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); Ok(()) } diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index c38a5bffcb2b..e50bd9f65766 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -21,14 +21,13 @@ use std::sync::Arc; use arrow::array::{ArrayRef, Scalar}; use arrow::compute::kernels::comparison::starts_with as arrow_starts_with; use arrow::datatypes::DataType; +use datafusion_common::types::logical_string; use datafusion_common::utils::take_function_args; -use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr::type_coercion::binary::{ binary_to_string_coercion, string_coercion, }; - -use datafusion_common::types::logical_string; -use datafusion_common::{Result, ScalarValue, exec_err}; use datafusion_expr::{ Coercion, ColumnarValue, Documentation, Expr, Like, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, cast, @@ -164,7 +163,7 @@ impl ScalarUDFImpl for StartsWithFunc { fn simplify( &self, args: Vec, - info: &dyn SimplifyInfo, + info: &SimplifyContext, ) -> Result { if let Expr::Literal(scalar_value, _) = &args[1] { // Convert starts_with(col, 'prefix') to col LIKE 'prefix%' with proper escaping diff --git a/datafusion/functions/src/string/to_hex.rs b/datafusion/functions/src/string/to_hex.rs index 891cbe254957..ed8ce07b876d 100644 --- a/datafusion/functions/src/string/to_hex.rs +++ b/datafusion/functions/src/string/to_hex.rs @@ -18,7 +18,6 @@ use std::any::Any; use std::sync::Arc; -use crate::utils::make_scalar_function; use arrow::array::{Array, ArrayRef, StringArray}; use arrow::buffer::{Buffer, OffsetBuffer}; use arrow::datatypes::{ @@ -26,7 +25,7 @@ use arrow::datatypes::{ Int64Type, UInt8Type, UInt16Type, UInt32Type, UInt64Type, }; use datafusion_common::cast::as_primitive_array; -use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; use datafusion_expr::{ Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, @@ -38,11 +37,11 @@ const HEX_CHARS: &[u8; 16] = b"0123456789abcdef"; /// Converts the number to its equivalent hexadecimal representation. /// to_hex(2147483647) = '7fffffff' -fn to_hex(args: &[ArrayRef]) -> Result +fn to_hex_array(array: &ArrayRef) -> Result where T::Native: ToHex, { - let integer_array = as_primitive_array::(&args[0])?; + let integer_array = as_primitive_array::(array)?; let len = integer_array.len(); // Max hex string length: 16 chars for u64/i64 @@ -78,6 +77,14 @@ where Ok(Arc::new(result) as ArrayRef) } +#[inline] +fn to_hex_scalar(value: T) -> String { + let mut hex_buffer = [0u8; 16]; + let hex_len = value.write_hex_to_buffer(&mut hex_buffer); + // SAFETY: hex_buffer is ASCII hex digits + unsafe { std::str::from_utf8_unchecked(&hex_buffer[16 - hex_len..]).to_string() } +} + /// Trait for converting integer types to hexadecimal in a buffer trait ToHex: ArrowNativeType { /// Write hex representation to buffer and return the number of hex digits written. @@ -223,33 +230,71 @@ impl ScalarUDFImpl for ToHexFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - match args.args[0].data_type() { - DataType::Null => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), - DataType::Int64 => { - make_scalar_function(to_hex::, vec![])(&args.args) - } - DataType::UInt64 => { - make_scalar_function(to_hex::, vec![])(&args.args) - } - DataType::Int32 => { - make_scalar_function(to_hex::, vec![])(&args.args) - } - DataType::UInt32 => { - make_scalar_function(to_hex::, vec![])(&args.args) - } - DataType::Int16 => { - make_scalar_function(to_hex::, vec![])(&args.args) - } - DataType::UInt16 => { - make_scalar_function(to_hex::, vec![])(&args.args) - } - DataType::Int8 => { - make_scalar_function(to_hex::, vec![])(&args.args) - } - DataType::UInt8 => { - make_scalar_function(to_hex::, vec![])(&args.args) + let arg = &args.args[0]; + + match arg { + ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) => Ok( + ColumnarValue::Scalar(ScalarValue::Utf8(Some(to_hex_scalar(*v)))), + ), + ColumnarValue::Scalar(ScalarValue::UInt64(Some(v))) => Ok( + ColumnarValue::Scalar(ScalarValue::Utf8(Some(to_hex_scalar(*v)))), + ), + ColumnarValue::Scalar(ScalarValue::Int32(Some(v))) => Ok( + ColumnarValue::Scalar(ScalarValue::Utf8(Some(to_hex_scalar(*v)))), + ), + ColumnarValue::Scalar(ScalarValue::UInt32(Some(v))) => Ok( + ColumnarValue::Scalar(ScalarValue::Utf8(Some(to_hex_scalar(*v)))), + ), + ColumnarValue::Scalar(ScalarValue::Int16(Some(v))) => Ok( + ColumnarValue::Scalar(ScalarValue::Utf8(Some(to_hex_scalar(*v)))), + ), + ColumnarValue::Scalar(ScalarValue::UInt16(Some(v))) => Ok( + ColumnarValue::Scalar(ScalarValue::Utf8(Some(to_hex_scalar(*v)))), + ), + ColumnarValue::Scalar(ScalarValue::Int8(Some(v))) => Ok( + ColumnarValue::Scalar(ScalarValue::Utf8(Some(to_hex_scalar(*v)))), + ), + ColumnarValue::Scalar(ScalarValue::UInt8(Some(v))) => Ok( + ColumnarValue::Scalar(ScalarValue::Utf8(Some(to_hex_scalar(*v)))), + ), + + // NULL scalars + ColumnarValue::Scalar(s) if s.is_null() => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) } - other => exec_err!("Unsupported data type {other:?} for function to_hex"), + + ColumnarValue::Array(array) => match array.data_type() { + DataType::Int64 => { + Ok(ColumnarValue::Array(to_hex_array::(array)?)) + } + DataType::UInt64 => { + Ok(ColumnarValue::Array(to_hex_array::(array)?)) + } + DataType::Int32 => { + Ok(ColumnarValue::Array(to_hex_array::(array)?)) + } + DataType::UInt32 => { + Ok(ColumnarValue::Array(to_hex_array::(array)?)) + } + DataType::Int16 => { + Ok(ColumnarValue::Array(to_hex_array::(array)?)) + } + DataType::UInt16 => { + Ok(ColumnarValue::Array(to_hex_array::(array)?)) + } + DataType::Int8 => { + Ok(ColumnarValue::Array(to_hex_array::(array)?)) + } + DataType::UInt8 => { + Ok(ColumnarValue::Array(to_hex_array::(array)?)) + } + other => exec_err!("Unsupported data type {other:?} for function to_hex"), + }, + + other => internal_err!( + "Unexpected argument type {:?} for function to_hex", + other.data_type() + ), } } @@ -288,8 +333,8 @@ mod tests { let expected = $expected; let array = <$array_type>::from(input); - let array_ref = Arc::new(array); - let hex_result = to_hex::<$arrow_type>(&[array_ref])?; + let array_ref: ArrayRef = Arc::new(array); + let hex_result = to_hex_array::<$arrow_type>(&array_ref)?; let hex_array = as_string_array(&hex_result)?; let expected_array = StringArray::from(expected); diff --git a/datafusion/functions/src/strings.rs b/datafusion/functions/src/strings.rs index a7be3ef79299..cfddf57b094b 100644 --- a/datafusion/functions/src/strings.rs +++ b/datafusion/functions/src/strings.rs @@ -152,43 +152,34 @@ impl StringViewArrayBuilder { } ColumnarValueRef::NullableArray(array) => { if !CHECK_VALID || array.is_valid(i) { - self.block.push_str( - std::str::from_utf8(array.value(i).as_bytes()).unwrap(), - ); + self.block.push_str(array.value(i)); } } ColumnarValueRef::NullableLargeStringArray(array) => { if !CHECK_VALID || array.is_valid(i) { - self.block.push_str( - std::str::from_utf8(array.value(i).as_bytes()).unwrap(), - ); + self.block.push_str(array.value(i)); } } ColumnarValueRef::NullableStringViewArray(array) => { if !CHECK_VALID || array.is_valid(i) { - self.block.push_str( - std::str::from_utf8(array.value(i).as_bytes()).unwrap(), - ); + self.block.push_str(array.value(i)); } } ColumnarValueRef::NonNullableArray(array) => { - self.block - .push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap()); + self.block.push_str(array.value(i)); } ColumnarValueRef::NonNullableLargeStringArray(array) => { - self.block - .push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap()); + self.block.push_str(array.value(i)); } ColumnarValueRef::NonNullableStringViewArray(array) => { - self.block - .push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap()); + self.block.push_str(array.value(i)); } } } pub fn append_offset(&mut self) { self.builder.append_value(&self.block); - self.block = String::new(); + self.block.clear(); } pub fn finish(mut self) -> StringViewArray { diff --git a/datafusion/functions/src/unicode/common.rs b/datafusion/functions/src/unicode/common.rs new file mode 100644 index 000000000000..93f0c7900961 --- /dev/null +++ b/datafusion/functions/src/unicode/common.rs @@ -0,0 +1,183 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Common utilities for implementing unicode functions + +use arrow::array::{ + Array, ArrayAccessor, ArrayIter, ArrayRef, ByteView, GenericStringArray, Int64Array, + OffsetSizeTrait, StringViewArray, make_view, +}; +use arrow::datatypes::DataType; +use arrow_buffer::{NullBuffer, ScalarBuffer}; +use datafusion_common::cast::{ + as_generic_string_array, as_int64_array, as_string_view_array, +}; +use datafusion_common::exec_err; +use std::cmp::Ordering; +use std::ops::Range; +use std::sync::Arc; + +/// A trait for `left` and `right` byte slicing operations +pub(crate) trait LeftRightSlicer { + fn slice(string: &str, n: i64) -> Range; +} + +pub(crate) struct LeftSlicer {} + +impl LeftRightSlicer for LeftSlicer { + fn slice(string: &str, n: i64) -> Range { + 0..left_right_byte_length(string, n) + } +} + +pub(crate) struct RightSlicer {} + +impl LeftRightSlicer for RightSlicer { + fn slice(string: &str, n: i64) -> Range { + if n == 0 { + // Return nothing for `n=0` + 0..0 + } else if n == i64::MIN { + // Special case for i64::MIN overflow + 0..0 + } else { + left_right_byte_length(string, -n)..string.len() + } + } +} + +/// Calculate the byte length of the substring of `n` chars from string `string` +#[inline] +fn left_right_byte_length(string: &str, n: i64) -> usize { + match n.cmp(&0) { + Ordering::Less => string + .char_indices() + .nth_back((n.unsigned_abs().min(usize::MAX as u64) - 1) as usize) + .map(|(index, _)| index) + .unwrap_or(0), + Ordering::Equal => 0, + Ordering::Greater => string + .char_indices() + .nth(n.unsigned_abs().min(usize::MAX as u64) as usize) + .map(|(index, _)| index) + .unwrap_or(string.len()), + } +} + +/// General implementation for `left` and `right` functions +pub(crate) fn general_left_right( + args: &[ArrayRef], +) -> datafusion_common::Result { + let n_array = as_int64_array(&args[1])?; + + match args[0].data_type() { + DataType::Utf8 => { + let string_array = as_generic_string_array::(&args[0])?; + general_left_right_array::(string_array, n_array) + } + DataType::LargeUtf8 => { + let string_array = as_generic_string_array::(&args[0])?; + general_left_right_array::(string_array, n_array) + } + DataType::Utf8View => { + let string_view_array = as_string_view_array(&args[0])?; + general_left_right_view::(string_view_array, n_array) + } + _ => exec_err!("Not supported"), + } +} + +/// `general_left_right` implementation for strings +fn general_left_right_array< + 'a, + T: OffsetSizeTrait, + V: ArrayAccessor, + F: LeftRightSlicer, +>( + string_array: V, + n_array: &Int64Array, +) -> datafusion_common::Result { + let iter = ArrayIter::new(string_array); + let result = iter + .zip(n_array.iter()) + .map(|(string, n)| match (string, n) { + (Some(string), Some(n)) => { + let range = F::slice(string, n); + // Extract a given range from a byte-indexed slice + Some(&string[range]) + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// `general_left_right` implementation for StringViewArray +fn general_left_right_view( + string_view_array: &StringViewArray, + n_array: &Int64Array, +) -> datafusion_common::Result { + let len = n_array.len(); + + let views = string_view_array.views(); + // Every string in StringViewArray has one corresponding view in `views` + debug_assert!(views.len() == string_view_array.len()); + + // Compose null buffer at once + let string_nulls = string_view_array.nulls(); + let n_nulls = n_array.nulls(); + let new_nulls = NullBuffer::union(string_nulls, n_nulls); + + let new_views = (0..len) + .map(|idx| { + let view = views[idx]; + + let is_valid = match &new_nulls { + Some(nulls_buf) => nulls_buf.is_valid(idx), + None => true, + }; + + if is_valid { + let string: &str = string_view_array.value(idx); + let n = n_array.value(idx); + + // Input string comes from StringViewArray, so it should fit in 32-bit length + let range = F::slice(string, n); + let result_bytes = &string.as_bytes()[range.clone()]; + + let byte_view = ByteView::from(view); + // New offset starts at 0 for left, and at `range.start` for right, + // which is encoded in the given range + let new_offset = byte_view.offset + (range.start as u32); + // Reuse buffer + make_view(result_bytes, byte_view.buffer_index, new_offset) + } else { + // For nulls, keep the original view + view + } + }) + .collect::>(); + + // Buffers are unchanged + let result = StringViewArray::try_new( + ScalarBuffer::from(new_views), + Vec::from(string_view_array.data_buffers()), + new_nulls, + )?; + Ok(Arc::new(result) as ArrayRef) +} diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs index a25c37266c2c..0cf20584a6bc 100644 --- a/datafusion/functions/src/unicode/find_in_set.rs +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use arrow::array::{ ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, - PrimitiveArray, new_null_array, + PrimitiveArray, }; use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; @@ -98,9 +98,8 @@ impl ScalarUDFImpl for FindInSetFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let ScalarFunctionArgs { args, .. } = args; - - let [string, str_list] = take_function_args(self.name(), args)?; + let return_field = args.return_field; + let [string, str_list] = take_function_args(self.name(), args.args)?; match (string, str_list) { // both inputs are scalars @@ -139,9 +138,11 @@ impl ScalarUDFImpl for FindInSetFunc { | ScalarValue::LargeUtf8(str_list_literal), ), ) => { - let result_array = match str_list_literal { + match str_list_literal { // find_in_set(column_a, null) = null - None => new_null_array(str_array.data_type(), str_array.len()), + None => Ok(ColumnarValue::Scalar(ScalarValue::try_new_null( + return_field.data_type(), + )?)), Some(str_list_literal) => { let str_list = str_list_literal.split(',').collect::>(); let result = match str_array.data_type() { @@ -172,10 +173,9 @@ impl ScalarUDFImpl for FindInSetFunc { ) } }; - Arc::new(result?) + Ok(ColumnarValue::Array(Arc::new(result?))) } - }; - Ok(ColumnarValue::Array(result_array)) + } } // `string` is scalar, `str_list` is an array @@ -187,11 +187,11 @@ impl ScalarUDFImpl for FindInSetFunc { ), ColumnarValue::Array(str_list_array), ) => { - let res = match string_literal { + match string_literal { // find_in_set(null, column_b) = null - None => { - new_null_array(str_list_array.data_type(), str_list_array.len()) - } + None => Ok(ColumnarValue::Scalar(ScalarValue::try_new_null( + return_field.data_type(), + )?)), Some(string) => { let result = match str_list_array.data_type() { DataType::Utf8 => { @@ -218,10 +218,9 @@ impl ScalarUDFImpl for FindInSetFunc { ) } }; - Arc::new(result?) + Ok(ColumnarValue::Array(Arc::new(result?))) } - }; - Ok(ColumnarValue::Array(res)) + } } // both inputs are arrays diff --git a/datafusion/functions/src/unicode/initcap.rs b/datafusion/functions/src/unicode/initcap.rs index 929b0c316951..a0cae69c5201 100644 --- a/datafusion/functions/src/unicode/initcap.rs +++ b/datafusion/functions/src/unicode/initcap.rs @@ -19,14 +19,16 @@ use std::any::Any; use std::sync::Arc; use arrow::array::{ - Array, ArrayRef, GenericStringBuilder, OffsetSizeTrait, StringViewBuilder, + Array, ArrayRef, GenericStringArray, GenericStringBuilder, OffsetSizeTrait, + StringViewBuilder, }; +use arrow::buffer::{Buffer, OffsetBuffer}; use arrow::datatypes::DataType; use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::types::logical_string; -use datafusion_common::{Result, exec_err}; +use datafusion_common::{Result, ScalarValue, exec_err}; use datafusion_expr::{ Coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, @@ -99,6 +101,39 @@ impl ScalarUDFImpl for InitcapFunc { &self, args: datafusion_expr::ScalarFunctionArgs, ) -> Result { + let arg = &args.args[0]; + + // Scalar fast path - handle directly without array conversion + if let ColumnarValue::Scalar(scalar) = arg { + return match scalar { + ScalarValue::Utf8(None) + | ScalarValue::LargeUtf8(None) + | ScalarValue::Utf8View(None) => Ok(arg.clone()), + ScalarValue::Utf8(Some(s)) => { + let mut result = String::new(); + initcap_string(s, &mut result); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))) + } + ScalarValue::LargeUtf8(Some(s)) => { + let mut result = String::new(); + initcap_string(s, &mut result); + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result)))) + } + ScalarValue::Utf8View(Some(s)) => { + let mut result = String::new(); + initcap_string(s, &mut result); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result)))) + } + other => { + exec_err!( + "Unsupported data type {:?} for function `initcap`", + other.data_type() + ) + } + }; + } + + // Array path let args = &args.args; match args[0].data_type() { DataType::Utf8 => make_scalar_function(initcap::, vec![])(args), @@ -115,8 +150,8 @@ impl ScalarUDFImpl for InitcapFunc { } } -/// Converts the first letter of each word to upper case and the rest to lower -/// case. Words are sequences of alphanumeric characters separated by +/// Converts the first letter of each word to uppercase and the rest to +/// lowercase. Words are sequences of alphanumeric characters separated by /// non-alphanumeric characters. /// /// Example: @@ -126,6 +161,10 @@ impl ScalarUDFImpl for InitcapFunc { fn initcap(args: &[ArrayRef]) -> Result { let string_array = as_generic_string_array::(&args[0])?; + if string_array.is_ascii() { + return Ok(initcap_ascii_array(string_array)); + } + let mut builder = GenericStringBuilder::::with_capacity( string_array.len(), string_array.value_data().len(), @@ -143,12 +182,67 @@ fn initcap(args: &[ArrayRef]) -> Result { Ok(Arc::new(builder.finish()) as ArrayRef) } +/// Fast path for `Utf8` or `LargeUtf8` arrays that are ASCII-only. We can use a +/// single pass over the buffer and operate directly on bytes. +fn initcap_ascii_array( + string_array: &GenericStringArray, +) -> ArrayRef { + let offsets = string_array.offsets(); + let src = string_array.value_data(); + let first_offset = offsets.first().unwrap().as_usize(); + let last_offset = offsets.last().unwrap().as_usize(); + + // For sliced arrays, only convert the visible bytes, not the entire input + // buffer. + let mut out = Vec::with_capacity(last_offset - first_offset); + + for window in offsets.windows(2) { + let start = window[0].as_usize(); + let end = window[1].as_usize(); + + let mut prev_is_alnum = false; + for &b in &src[start..end] { + let converted = if prev_is_alnum { + b.to_ascii_lowercase() + } else { + b.to_ascii_uppercase() + }; + out.push(converted); + prev_is_alnum = b.is_ascii_alphanumeric(); + } + } + + let values = Buffer::from_vec(out); + let out_offsets = if first_offset == 0 { + offsets.clone() + } else { + // For sliced arrays, we need to rebase the offsets to reflect that the + // output only contains the bytes in the visible slice. + let rebased_offsets = offsets + .iter() + .map(|offset| T::usize_as(offset.as_usize() - first_offset)) + .collect::>(); + OffsetBuffer::::new(rebased_offsets.into()) + }; + + // SAFETY: ASCII case conversion preserves byte length, so the original + // string boundaries are preserved. `out_offsets` is either identical to + // the input offsets or a rebased version relative to the compacted values + // buffer. + Arc::new(unsafe { + GenericStringArray::::new_unchecked( + out_offsets, + values, + string_array.nulls().cloned(), + ) + }) +} + fn initcap_utf8view(args: &[ArrayRef]) -> Result { let string_view_array = as_string_view_array(&args[0])?; - let mut builder = StringViewBuilder::with_capacity(string_view_array.len()); - let mut container = String::new(); + string_view_array.iter().for_each(|str| match str { Some(s) => { initcap_string(s, &mut container); @@ -165,13 +259,16 @@ fn initcap_string(input: &str, container: &mut String) { let mut prev_is_alphanumeric = false; if input.is_ascii() { - for c in input.chars() { + container.reserve(input.len()); + // SAFETY: each byte is ASCII, so the result is valid UTF-8. + let out = unsafe { container.as_mut_vec() }; + for &b in input.as_bytes() { if prev_is_alphanumeric { - container.push(c.to_ascii_lowercase()); + out.push(b.to_ascii_lowercase()); } else { - container.push(c.to_ascii_uppercase()); - }; - prev_is_alphanumeric = c.is_ascii_alphanumeric(); + out.push(b.to_ascii_uppercase()); + } + prev_is_alphanumeric = b.is_ascii_alphanumeric(); } } else { for c in input.chars() { @@ -189,10 +286,11 @@ fn initcap_string(input: &str, container: &mut String) { mod tests { use crate::unicode::initcap::InitcapFunc; use crate::utils::test::test_function; - use arrow::array::{Array, StringArray, StringViewArray}; + use arrow::array::{Array, ArrayRef, LargeStringArray, StringArray, StringViewArray}; use arrow::datatypes::DataType::{Utf8, Utf8View}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use std::sync::Arc; #[test] fn test_functions() -> Result<()> { @@ -296,4 +394,114 @@ mod tests { Ok(()) } + + #[test] + fn test_initcap_ascii_array() -> Result<()> { + let array = StringArray::from(vec![ + Some("hello world"), + None, + Some("foo-bar_baz/baX"), + Some(""), + Some("123 abc 456DEF"), + Some("ALL CAPS"), + Some("already correct"), + ]); + let args: Vec = vec![Arc::new(array)]; + let result = super::initcap::(&args)?; + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.len(), 7); + assert_eq!(result.value(0), "Hello World"); + assert!(result.is_null(1)); + assert_eq!(result.value(2), "Foo-Bar_Baz/Bax"); + assert_eq!(result.value(3), ""); + assert_eq!(result.value(4), "123 Abc 456def"); + assert_eq!(result.value(5), "All Caps"); + assert_eq!(result.value(6), "Already Correct"); + Ok(()) + } + + #[test] + fn test_initcap_ascii_large_array() -> Result<()> { + let array = LargeStringArray::from(vec![ + Some("hello world"), + None, + Some("foo-bar_baz/baX"), + Some(""), + Some("123 abc 456DEF"), + Some("ALL CAPS"), + Some("already correct"), + ]); + let args: Vec = vec![Arc::new(array)]; + let result = super::initcap::(&args)?; + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.len(), 7); + assert_eq!(result.value(0), "Hello World"); + assert!(result.is_null(1)); + assert_eq!(result.value(2), "Foo-Bar_Baz/Bax"); + assert_eq!(result.value(3), ""); + assert_eq!(result.value(4), "123 Abc 456def"); + assert_eq!(result.value(5), "All Caps"); + assert_eq!(result.value(6), "Already Correct"); + Ok(()) + } + + /// Test that initcap works correctly on a sliced ASCII StringArray. + #[test] + fn test_initcap_sliced_ascii_array() -> Result<()> { + let array = StringArray::from(vec![ + Some("hello world"), + Some("foo bar"), + Some("baz qux"), + ]); + // Slice to get only the last two elements. The resulting array's + // offsets are [11, 18, 25] (non-zero start), but value_data still + // contains the full original buffer. + let sliced = array.slice(1, 2); + let args: Vec = vec![Arc::new(sliced)]; + let result = super::initcap::(&args)?; + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.len(), 2); + assert_eq!(result.value(0), "Foo Bar"); + assert_eq!(result.value(1), "Baz Qux"); + + // The output values buffer should be compact + assert_eq!(*result.offsets().first().unwrap(), 0); + assert_eq!( + result.value_data().len(), + *result.offsets().last().unwrap() as usize + ); + Ok(()) + } + + /// Test that initcap works correctly on a sliced ASCII LargeStringArray. + #[test] + fn test_initcap_sliced_ascii_large_array() -> Result<()> { + let array = LargeStringArray::from(vec![ + Some("hello world"), + Some("foo bar"), + Some("baz qux"), + ]); + // Slice to get only the last two elements. The resulting array's + // offsets are [11, 18, 25] (non-zero start), but value_data still + // contains the full original buffer. + let sliced = array.slice(1, 2); + let args: Vec = vec![Arc::new(sliced)]; + let result = super::initcap::(&args)?; + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.len(), 2); + assert_eq!(result.value(0), "Foo Bar"); + assert_eq!(result.value(1), "Baz Qux"); + + // The output values buffer should be compact + assert_eq!(*result.offsets().first().unwrap(), 0); + assert_eq!( + result.value_data().len(), + *result.offsets().last().unwrap() as usize + ); + Ok(()) + } } diff --git a/datafusion/functions/src/unicode/left.rs b/datafusion/functions/src/unicode/left.rs index ecff8f869950..76873e7f5d3e 100644 --- a/datafusion/functions/src/unicode/left.rs +++ b/datafusion/functions/src/unicode/left.rs @@ -16,20 +16,11 @@ // under the License. use std::any::Any; -use std::cmp::Ordering; -use std::sync::Arc; -use arrow::array::{ - Array, ArrayAccessor, ArrayIter, ArrayRef, GenericStringArray, Int64Array, - OffsetSizeTrait, -}; +use crate::unicode::common::{LeftSlicer, general_left_right}; +use crate::utils::make_scalar_function; use arrow::datatypes::DataType; - -use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::Result; -use datafusion_common::cast::{ - as_generic_string_array, as_int64_array, as_string_view_array, -}; use datafusion_common::exec_err; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ @@ -94,22 +85,26 @@ impl ScalarUDFImpl for LeftFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_str_type(&arg_types[0], "left") + Ok(arg_types[0].clone()) } + /// Returns first n characters in the string, or when n is negative, returns all but last |n| characters. + /// left('abcde', 2) = 'ab' + /// left('abcde', -2) = 'abc' + /// The implementation uses UTF-8 code points as characters fn invoke_with_args( &self, args: datafusion_expr::ScalarFunctionArgs, ) -> Result { let args = &args.args; match args[0].data_type() { - DataType::Utf8 | DataType::Utf8View => { - make_scalar_function(left::, vec![])(args) + DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 => { + make_scalar_function(general_left_right::, vec![])(args) } - DataType::LargeUtf8 => make_scalar_function(left::, vec![])(args), other => exec_err!( - "Unsupported data type {other:?} for function left,\ - expected Utf8View, Utf8 or LargeUtf8." + "Unsupported data type {other:?} for function {},\ + expected Utf8View, Utf8 or LargeUtf8.", + self.name() ), } } @@ -119,54 +114,10 @@ impl ScalarUDFImpl for LeftFunc { } } -/// Returns first n characters in the string, or when n is negative, returns all but last |n| characters. -/// left('abcde', 2) = 'ab' -/// The implementation uses UTF-8 code points as characters -fn left(args: &[ArrayRef]) -> Result { - let n_array = as_int64_array(&args[1])?; - - if args[0].data_type() == &DataType::Utf8View { - let string_array = as_string_view_array(&args[0])?; - left_impl::(string_array, n_array) - } else { - let string_array = as_generic_string_array::(&args[0])?; - left_impl::(string_array, n_array) - } -} - -fn left_impl<'a, T: OffsetSizeTrait, V: ArrayAccessor>( - string_array: V, - n_array: &Int64Array, -) -> Result { - let iter = ArrayIter::new(string_array); - let result = iter - .zip(n_array.iter()) - .map(|(string, n)| match (string, n) { - (Some(string), Some(n)) => match n.cmp(&0) { - Ordering::Less => { - let len = string.chars().count() as i64; - Some(if n.abs() < len { - string.chars().take((len + n) as usize).collect::() - } else { - "".to_string() - }) - } - Ordering::Equal => Some("".to_string()), - Ordering::Greater => { - Some(string.chars().take(n as usize).collect::()) - } - }, - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - #[cfg(test)] mod tests { - use arrow::array::{Array, StringArray}; - use arrow::datatypes::DataType::Utf8; + use arrow::array::{Array, StringArray, StringViewArray}; + use arrow::datatypes::DataType::{Utf8, Utf8View}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -209,6 +160,17 @@ mod tests { Utf8, StringArray ); + test_function!( + LeftFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(i64::MIN)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); test_function!( LeftFunc::new(), vec![ @@ -290,6 +252,74 @@ mod tests { StringArray ); + // StringView cases + test_function!( + LeftFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("abcde".to_string()))), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(Some("ab")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + LeftFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("abcde".to_string()))), + ColumnarValue::Scalar(ScalarValue::from(200i64)), + ], + Ok(Some("abcde")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + LeftFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("".to_string()))), + ColumnarValue::Scalar(ScalarValue::from(200i64)), + ], + Ok(Some("")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + LeftFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + "joséésoj".to_string() + ))), + ColumnarValue::Scalar(ScalarValue::from(-3i64)), + ], + Ok(Some("joséé")), + &str, + Utf8View, + StringViewArray + ); + + // Unicode indexing case + let input = "joé楽s𐀀so↓j"; + for n in 1..=input.chars().count() { + let expected = input + .chars() + .take(input.chars().count() - n) + .collect::(); + test_function!( + LeftFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::from(input)), + ColumnarValue::Scalar(ScalarValue::from(-(n as i64))), + ], + Ok(Some(expected.as_str())), + &str, + Utf8, + StringArray + ); + } + Ok(()) } } diff --git a/datafusion/functions/src/unicode/lpad.rs b/datafusion/functions/src/unicode/lpad.rs index a892c0adf58d..50d15c7d62a6 100644 --- a/datafusion/functions/src/unicode/lpad.rs +++ b/datafusion/functions/src/unicode/lpad.rs @@ -49,7 +49,10 @@ use datafusion_macros::user_doc; +---------------------------------------------+ ```"#, standard_argument(name = "str", prefix = "String"), - argument(name = "n", description = "String length to pad to."), + argument( + name = "n", + description = "String length to pad to. If the input string is longer than this length, it is truncated (on the right)." + ), argument( name = "padding_str", description = "Optional string expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._" @@ -225,24 +228,47 @@ where continue; } - // Reuse buffers by clearing and refilling - graphemes_buf.clear(); - graphemes_buf.extend(string.graphemes(true)); - - fill_chars_buf.clear(); - fill_chars_buf.extend(fill.chars()); - - if length < graphemes_buf.len() { - builder.append_value(graphemes_buf[..length].concat()); - } else if fill_chars_buf.is_empty() { - builder.append_value(string); + if string.is_ascii() && fill.is_ascii() { + // ASCII fast path: byte length == character length, + // so we skip expensive grapheme segmentation. + let str_len = string.len(); + if length < str_len { + builder.append_value(&string[..length]); + } else if fill.is_empty() { + builder.append_value(string); + } else { + let pad_len = length - str_len; + let fill_len = fill.len(); + let full_reps = pad_len / fill_len; + let remainder = pad_len % fill_len; + for _ in 0..full_reps { + builder.write_str(fill)?; + } + if remainder > 0 { + builder.write_str(&fill[..remainder])?; + } + builder.append_value(string); + } } else { - for l in 0..length - graphemes_buf.len() { - let c = *fill_chars_buf.get(l % fill_chars_buf.len()).unwrap(); - builder.write_char(c)?; + // Reuse buffers by clearing and refilling + graphemes_buf.clear(); + graphemes_buf.extend(string.graphemes(true)); + + fill_chars_buf.clear(); + fill_chars_buf.extend(fill.chars()); + + if length < graphemes_buf.len() { + builder.append_value(graphemes_buf[..length].concat()); + } else if fill_chars_buf.is_empty() { + builder.append_value(string); + } else { + for l in 0..length - graphemes_buf.len() { + let c = + *fill_chars_buf.get(l % fill_chars_buf.len()).unwrap(); + builder.write_char(c)?; + } + builder.append_value(string); } - builder.write_str(string)?; - builder.append_value(""); } } else { builder.append_null(); @@ -266,17 +292,30 @@ where continue; } - // Reuse buffer by clearing and refilling - graphemes_buf.clear(); - graphemes_buf.extend(string.graphemes(true)); - - if length < graphemes_buf.len() { - builder.append_value(graphemes_buf[..length].concat()); + if string.is_ascii() { + // ASCII fast path: byte length == character length + let str_len = string.len(); + if length < str_len { + builder.append_value(&string[..length]); + } else { + for _ in 0..(length - str_len) { + builder.write_str(" ")?; + } + builder.append_value(string); + } } else { - builder - .write_str(" ".repeat(length - graphemes_buf.len()).as_str())?; - builder.write_str(string)?; - builder.append_value(""); + // Reuse buffer by clearing and refilling + graphemes_buf.clear(); + graphemes_buf.extend(string.graphemes(true)); + + if length < graphemes_buf.len() { + builder.append_value(graphemes_buf[..length].concat()); + } else { + for _ in 0..(length - graphemes_buf.len()) { + builder.write_str(" ")?; + } + builder.append_value(string); + } } } else { builder.append_null(); @@ -523,6 +562,17 @@ mod tests { None, Ok(None) ); + test_lpad!( + Some("hello".into()), + ScalarValue::Int64(Some(2i64)), + Ok(Some("he")) + ); + test_lpad!( + Some("hi".into()), + ScalarValue::Int64(Some(6i64)), + Some("xy".into()), + Ok(Some("xyxyhi")) + ); test_lpad!( Some("josé".into()), ScalarValue::Int64(Some(10i64)), diff --git a/datafusion/functions/src/unicode/mod.rs b/datafusion/functions/src/unicode/mod.rs index 4a0dd21d749a..7250b3915fb5 100644 --- a/datafusion/functions/src/unicode/mod.rs +++ b/datafusion/functions/src/unicode/mod.rs @@ -22,6 +22,7 @@ use std::sync::Arc; use datafusion_expr::ScalarUDF; pub mod character_length; +pub mod common; pub mod find_in_set; pub mod initcap; pub mod left; diff --git a/datafusion/functions/src/unicode/right.rs b/datafusion/functions/src/unicode/right.rs index ac98a3f202a5..a97e242b73f9 100644 --- a/datafusion/functions/src/unicode/right.rs +++ b/datafusion/functions/src/unicode/right.rs @@ -16,20 +16,11 @@ // under the License. use std::any::Any; -use std::cmp::{Ordering, max}; -use std::sync::Arc; -use arrow::array::{ - Array, ArrayAccessor, ArrayIter, ArrayRef, GenericStringArray, Int64Array, - OffsetSizeTrait, -}; +use crate::unicode::common::{RightSlicer, general_left_right}; +use crate::utils::make_scalar_function; use arrow::datatypes::DataType; - -use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::Result; -use datafusion_common::cast::{ - as_generic_string_array, as_int64_array, as_string_view_array, -}; use datafusion_common::exec_err; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ @@ -94,22 +85,26 @@ impl ScalarUDFImpl for RightFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_str_type(&arg_types[0], "right") + Ok(arg_types[0].clone()) } + /// Returns right n characters in the string, or when n is negative, returns all but first |n| characters. + /// right('abcde', 2) = 'de' + /// right('abcde', -2) = 'cde' + /// The implementation uses UTF-8 code points as characters fn invoke_with_args( &self, args: datafusion_expr::ScalarFunctionArgs, ) -> Result { let args = &args.args; match args[0].data_type() { - DataType::Utf8 | DataType::Utf8View => { - make_scalar_function(right::, vec![])(args) + DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 => { + make_scalar_function(general_left_right::, vec![])(args) } - DataType::LargeUtf8 => make_scalar_function(right::, vec![])(args), other => exec_err!( - "Unsupported data type {other:?} for function right,\ - expected Utf8View, Utf8 or LargeUtf8." + "Unsupported data type {other:?} for function {},\ + expected Utf8View, Utf8 or LargeUtf8.", + self.name() ), } } @@ -119,58 +114,10 @@ impl ScalarUDFImpl for RightFunc { } } -/// Returns last n characters in the string, or when n is negative, returns all but first |n| characters. -/// right('abcde', 2) = 'de' -/// The implementation uses UTF-8 code points as characters -fn right(args: &[ArrayRef]) -> Result { - let n_array = as_int64_array(&args[1])?; - if args[0].data_type() == &DataType::Utf8View { - // string_view_right(args) - let string_array = as_string_view_array(&args[0])?; - right_impl::(&mut string_array.iter(), n_array) - } else { - // string_right::(args) - let string_array = &as_generic_string_array::(&args[0])?; - right_impl::(&mut string_array.iter(), n_array) - } -} - -// Currently the return type can only be Utf8 or LargeUtf8, to reach fully support, we need -// to edit the `get_optimal_return_type` in utils.rs to make the udfs be able to return Utf8View -// See https://github.com/apache/datafusion/issues/11790#issuecomment-2283777166 -fn right_impl<'a, T: OffsetSizeTrait, V: ArrayAccessor>( - string_array_iter: &mut ArrayIter, - n_array: &Int64Array, -) -> Result { - let result = string_array_iter - .zip(n_array.iter()) - .map(|(string, n)| match (string, n) { - (Some(string), Some(n)) => match n.cmp(&0) { - Ordering::Less => Some( - string - .chars() - .skip(n.unsigned_abs() as usize) - .collect::(), - ), - Ordering::Equal => Some("".to_string()), - Ordering::Greater => Some( - string - .chars() - .skip(max(string.chars().count() as i64 - n, 0) as usize) - .collect::(), - ), - }, - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - #[cfg(test)] mod tests { - use arrow::array::{Array, StringArray}; - use arrow::datatypes::DataType::Utf8; + use arrow::array::{Array, StringArray, StringViewArray}; + use arrow::datatypes::DataType::{Utf8, Utf8View}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -213,6 +160,17 @@ mod tests { Utf8, StringArray ); + test_function!( + RightFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(i64::MIN)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); test_function!( RightFunc::new(), vec![ @@ -260,10 +218,10 @@ mod tests { test_function!( RightFunc::new(), vec![ - ColumnarValue::Scalar(ScalarValue::from("joséésoj")), + ColumnarValue::Scalar(ScalarValue::from("joséérend")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ], - Ok(Some("éésoj")), + Ok(Some("érend")), &str, Utf8, StringArray @@ -271,10 +229,10 @@ mod tests { test_function!( RightFunc::new(), vec![ - ColumnarValue::Scalar(ScalarValue::from("joséésoj")), + ColumnarValue::Scalar(ScalarValue::from("joséérend")), ColumnarValue::Scalar(ScalarValue::from(-3i64)), ], - Ok(Some("éésoj")), + Ok(Some("éérend")), &str, Utf8, StringArray @@ -294,6 +252,71 @@ mod tests { StringArray ); + // StringView cases + test_function!( + RightFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("abcde".to_string()))), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(Some("de")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + RightFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("abcde".to_string()))), + ColumnarValue::Scalar(ScalarValue::from(200i64)), + ], + Ok(Some("abcde")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + RightFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("".to_string()))), + ColumnarValue::Scalar(ScalarValue::from(200i64)), + ], + Ok(Some("")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + RightFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + "joséérend".to_string() + ))), + ColumnarValue::Scalar(ScalarValue::from(-3i64)), + ], + Ok(Some("éérend")), + &str, + Utf8View, + StringViewArray + ); + + // Unicode indexing case + let input = "joé楽s𐀀so↓j"; + for n in 1..=input.chars().count() { + let expected = input.chars().skip(n).collect::(); + test_function!( + RightFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::from(input)), + ColumnarValue::Scalar(ScalarValue::from(-(n as i64))), + ], + Ok(Some(expected.as_str())), + &str, + Utf8, + StringArray + ); + } + Ok(()) } } diff --git a/datafusion/functions/src/unicode/rpad.rs b/datafusion/functions/src/unicode/rpad.rs index 14f517faf8cf..95d9492bc246 100644 --- a/datafusion/functions/src/unicode/rpad.rs +++ b/datafusion/functions/src/unicode/rpad.rs @@ -48,7 +48,10 @@ use unicode_segmentation::UnicodeSegmentation; +-----------------------------------------------+ ```"#, standard_argument(name = "str", prefix = "String"), - argument(name = "n", description = "String length to pad to."), + argument( + name = "n", + description = "String length to pad to. If the input string is longer than this length, it is truncated." + ), argument( name = "padding_str", description = "String expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._" @@ -203,7 +206,8 @@ fn rpad( } } -/// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated. +/// Extends the string to length 'length' by appending the characters fill (a space by default). +/// If the string is already longer than length then it is truncated (on the right). /// rpad('hi', 5, 'xy') = 'hixyx' fn rpad_impl<'a, StringArrType, FillArrType, StringArrayLen>( string_array: &StringArrType, @@ -234,6 +238,18 @@ where let length = if length < 0 { 0 } else { length as usize }; if length == 0 { builder.append_value(""); + } else if string.is_ascii() { + // ASCII fast path: byte length == character length + let str_len = string.len(); + if length < str_len { + builder.append_value(&string[..length]); + } else { + builder.write_str(string)?; + for _ in 0..(length - str_len) { + builder.write_str(" ")?; + } + builder.append_value(""); + } } else { // Reuse buffer by clearing and refilling graphemes_buf.clear(); @@ -244,9 +260,9 @@ where .append_value(graphemes_buf[..length].concat()); } else { builder.write_str(string)?; - builder.write_str( - &" ".repeat(length - graphemes_buf.len()), - )?; + for _ in 0..(length - graphemes_buf.len()) { + builder.write_str(" ")?; + } builder.append_value(""); } } @@ -273,27 +289,52 @@ where ); } let length = if length < 0 { 0 } else { length as usize }; - // Reuse buffer by clearing and refilling - graphemes_buf.clear(); - graphemes_buf.extend(string.graphemes(true)); - - if length < graphemes_buf.len() { - builder - .append_value(graphemes_buf[..length].concat()); - } else if fill.is_empty() { - builder.append_value(string); + if string.is_ascii() && fill.is_ascii() { + // ASCII fast path: byte length == character length, + // so we skip expensive grapheme segmentation. + let str_len = string.len(); + if length < str_len { + builder.append_value(&string[..length]); + } else if fill.is_empty() { + builder.append_value(string); + } else { + let pad_len = length - str_len; + let fill_len = fill.len(); + let full_reps = pad_len / fill_len; + let remainder = pad_len % fill_len; + builder.write_str(string)?; + for _ in 0..full_reps { + builder.write_str(fill)?; + } + if remainder > 0 { + builder.write_str(&fill[..remainder])?; + } + builder.append_value(""); + } } else { - builder.write_str(string)?; - // Reuse fill_chars_buf by clearing and refilling - fill_chars_buf.clear(); - fill_chars_buf.extend(fill.chars()); - for l in 0..length - graphemes_buf.len() { - let c = *fill_chars_buf - .get(l % fill_chars_buf.len()) - .unwrap(); - builder.write_char(c)?; + // Reuse buffer by clearing and refilling + graphemes_buf.clear(); + graphemes_buf.extend(string.graphemes(true)); + + if length < graphemes_buf.len() { + builder.append_value( + graphemes_buf[..length].concat(), + ); + } else if fill.is_empty() { + builder.append_value(string); + } else { + builder.write_str(string)?; + // Reuse fill_chars_buf by clearing and refilling + fill_chars_buf.clear(); + fill_chars_buf.extend(fill.chars()); + for l in 0..length - graphemes_buf.len() { + let c = *fill_chars_buf + .get(l % fill_chars_buf.len()) + .unwrap(); + builder.write_char(c)?; + } + builder.append_value(""); } - builder.append_value(""); } } _ => builder.append_null(), @@ -459,6 +500,29 @@ mod tests { Utf8, StringArray ); + test_function!( + RPadFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::from("hello")), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(Some("he")), + &str, + Utf8, + StringArray + ); + test_function!( + RPadFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(6i64)), + ColumnarValue::Scalar(ScalarValue::from("xy")), + ], + Ok(Some("hixyxy")), + &str, + Utf8, + StringArray + ); test_function!( RPadFunc::new(), vec![ diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index 9be086c4cf5f..c1d6ecffe551 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -32,6 +32,7 @@ use datafusion_expr::{ Volatility, }; use datafusion_macros::user_doc; +use memchr::memchr; #[user_doc( doc_section(label = "String Functions"), @@ -179,6 +180,31 @@ fn strpos(args: &[ArrayRef]) -> Result { } } +/// Find `needle` in `haystack` using `memchr` to quickly skip to positions +/// where the first byte matches, then verify the remaining bytes. Using +/// string::find is slower because it has significant per-call overhead that +/// `memchr` does not, and strpos is often invoked many times on short inputs. +/// Returns a 1-based position, or 0 if not found. +/// Both inputs must be ASCII-only. +fn find_ascii_substring(haystack: &[u8], needle: &[u8]) -> usize { + let needle_len = needle.len(); + let first_byte = needle[0]; + let mut offset = 0; + + while let Some(pos) = memchr(first_byte, &haystack[offset..]) { + let start = offset + pos; + if start + needle_len > haystack.len() { + return 0; + } + if haystack[start..start + needle_len] == *needle { + return start + 1; + } + offset = start + 1; + } + + 0 +} + /// Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.) /// strpos('high', 'ig') = 2 /// The implementation uses UTF-8 code points as characters @@ -198,37 +224,25 @@ where .zip(substring_iter) .map(|(string, substring)| match (string, substring) { (Some(string), Some(substring)) => { - // If only ASCII characters are present, we can use the slide window method to find - // the sub vector in the main vector. This is faster than string.find() method. + if substring.is_empty() { + return T::Native::from_usize(1); + } + + let substring_bytes = substring.as_bytes(); + let string_bytes = string.as_bytes(); + + if substring_bytes.len() > string_bytes.len() { + return T::Native::from_usize(0); + } + if ascii_only { - // If the substring is empty, the result is 1. - if substring.is_empty() { - T::Native::from_usize(1) - } else { - T::Native::from_usize( - string - .as_bytes() - .windows(substring.len()) - .position(|w| w == substring.as_bytes()) - .map(|x| x + 1) - .unwrap_or(0), - ) - } + T::Native::from_usize(find_ascii_substring( + string_bytes, + substring_bytes, + )) } else { // For non-ASCII, use a single-pass search that tracks both // byte position and character position simultaneously - if substring.is_empty() { - return T::Native::from_usize(1); - } - - let substring_bytes = substring.as_bytes(); - let string_bytes = string.as_bytes(); - - if substring_bytes.len() > string_bytes.len() { - return T::Native::from_usize(0); - } - - // Single pass: find substring while counting characters let mut char_pos = 0; for (byte_idx, _) in string.char_indices() { char_pos += 1; diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index cc1d53b3aad6..505388089f19 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -176,7 +176,7 @@ fn substr(args: &[ArrayRef]) -> Result { // `get_true_start_end('Hi🌏', 1, None) -> (0, 6)` // `get_true_start_end('Hi🌏', 1, 1) -> (0, 1)` // `get_true_start_end('Hi🌏', -10, 2) -> (0, 0)` -fn get_true_start_end( +pub fn get_true_start_end( input: &str, start: i64, count: Option, @@ -185,7 +185,10 @@ fn get_true_start_end( let start = start.checked_sub(1).unwrap_or(start); let end = match count { - Some(count) => start + count as i64, + Some(count) => { + let count_i64 = i64::try_from(count).unwrap_or(i64::MAX); + start.saturating_add(count_i64) + } None => input.len() as i64, }; let count_to_end = count.is_some(); @@ -235,7 +238,7 @@ fn get_true_start_end( // string, such as `substr(long_str_with_1k_chars, 1, 32)`. // In such case the overhead of ASCII-validation may not be worth it, so // skip the validation for short prefix for now. -fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>( +pub fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>( string_array: &V, start: &Int64Array, count: Option<&Int64Array>, @@ -247,7 +250,7 @@ fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>( // HACK: can be simplified if function has specialized // implementation for `ScalarValue` (implement without `make_scalar_function()`) - let avg_prefix_len = start + let total_prefix_len = start .iter() .zip(count.iter()) .take(n_sample) @@ -255,11 +258,11 @@ fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>( let start = start.unwrap_or(0); let count = count.unwrap_or(0); // To get substring, need to decode from 0 to start+count instead of start to start+count - start + count + start.saturating_add(count) }) - .sum::(); + .fold(0i64, |acc, val| acc.saturating_add(val)); - avg_prefix_len as f64 / n_sample as f64 <= short_prefix_threshold + (total_prefix_len as f64 / n_sample as f64) <= short_prefix_threshold } None => false, }; @@ -810,7 +813,7 @@ mod tests { SubstrFunc::new(), vec![ ColumnarValue::Scalar(ScalarValue::from("abc")), - ColumnarValue::Scalar(ScalarValue::from(-9223372036854775808i64)), + ColumnarValue::Scalar(ScalarValue::from(i64::MIN)), ], Ok(Some("abc")), &str, @@ -821,7 +824,7 @@ mod tests { SubstrFunc::new(), vec![ ColumnarValue::Scalar(ScalarValue::from("overflow")), - ColumnarValue::Scalar(ScalarValue::from(-9223372036854775808i64)), + ColumnarValue::Scalar(ScalarValue::from(i64::MIN)), ColumnarValue::Scalar(ScalarValue::from(1i64)), ], exec_err!("negative overflow when calculating skip value"), @@ -829,6 +832,18 @@ mod tests { Utf8View, StringViewArray ); + test_function!( + SubstrFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::from("large count")), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ColumnarValue::Scalar(ScalarValue::from(i64::MAX)), + ], + Ok(Some("arge count")), + &str, + Utf8View, + StringViewArray + ); Ok(()) } diff --git a/datafusion/functions/src/unicode/substrindex.rs b/datafusion/functions/src/unicode/substrindex.rs index cd9d0702b497..6389dc92c238 100644 --- a/datafusion/functions/src/unicode/substrindex.rs +++ b/datafusion/functions/src/unicode/substrindex.rs @@ -19,8 +19,8 @@ use std::any::Any; use std::sync::Arc; use arrow::array::{ - ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, - PrimitiveArray, StringBuilder, + ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, + GenericStringBuilder, OffsetSizeTrait, PrimitiveArray, }; use arrow::datatypes::{DataType, Int32Type, Int64Type}; @@ -182,7 +182,8 @@ fn substr_index_general< where T::Native: OffsetSizeTrait, { - let mut builder = StringBuilder::new(); + let num_rows = string_array.len(); + let mut builder = GenericStringBuilder::::with_capacity(num_rows, 0); let string_iter = ArrayIter::new(string_array); let delimiter_array_iter = ArrayIter::new(delimiter_array); let count_array_iter = ArrayIter::new(count_array); @@ -198,31 +199,49 @@ where } let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX); - let length = if n > 0 { - let split = string.split(delimiter); - split - .take(occurrences) - .map(|s| s.len() + delimiter.len()) - .sum::() - - delimiter.len() - } else { - let split = string.rsplit(delimiter); - split - .take(occurrences) - .map(|s| s.len() + delimiter.len()) - .sum::() - - delimiter.len() - }; - if n > 0 { - match string.get(..length) { - Some(substring) => builder.append_value(substring), - None => builder.append_null(), + let result_idx = if delimiter.len() == 1 { + // Fast path: use byte-level search for single-character delimiters + let d_byte = delimiter.as_bytes()[0]; + let bytes = string.as_bytes(); + + if n > 0 { + bytes + .iter() + .enumerate() + .filter(|&(_, &b)| b == d_byte) + .nth(occurrences - 1) + .map(|(idx, _)| idx) + } else { + bytes + .iter() + .enumerate() + .rev() + .filter(|&(_, &b)| b == d_byte) + .nth(occurrences - 1) + .map(|(idx, _)| idx + 1) } + } else if n > 0 { + // Multi-byte path: forward search for n-th occurrence + string + .match_indices(delimiter) + .nth(occurrences - 1) + .map(|(idx, _)| idx) } else { - match string.get(string.len().saturating_sub(length)..) { - Some(substring) => builder.append_value(substring), - None => builder.append_null(), + // Multi-byte path: backward search for n-th occurrence from the right + string + .rmatch_indices(delimiter) + .nth(occurrences - 1) + .map(|(idx, _)| idx + delimiter.len()) + }; + match result_idx { + Some(idx) => { + if n > 0 { + builder.append_value(&string[..idx]); + } else { + builder.append_value(&string[idx..]); + } } + None => builder.append_value(string), } } _ => builder.append_null(), @@ -328,7 +347,6 @@ mod tests { Utf8, StringArray ); - Ok(()) } } diff --git a/datafusion/functions/src/unicode/translate.rs b/datafusion/functions/src/unicode/translate.rs index f97c0ed5c299..e86eaf8111b1 100644 --- a/datafusion/functions/src/unicode/translate.rs +++ b/datafusion/functions/src/unicode/translate.rs @@ -35,8 +35,8 @@ use datafusion_macros::user_doc; #[user_doc( doc_section(label = "String Functions"), - description = "Translates characters in a string to specified translation characters.", - syntax_example = "translate(str, chars, translation)", + description = "Performs character-wise substitution based on a mapping.", + syntax_example = "translate(str, from, to)", sql_example = r#"```sql > select translate('twice', 'wic', 'her'); +--------------------------------------------------+ @@ -46,10 +46,10 @@ use datafusion_macros::user_doc; +--------------------------------------------------+ ```"#, standard_argument(name = "str", prefix = "String"), - argument(name = "chars", description = "Characters to translate."), + argument(name = "from", description = "The characters to be replaced."), argument( - name = "translation", - description = "Translation characters. Translation characters replace only characters at the same position in the **chars** string." + name = "to", + description = "The characters to replace them with. Each character in **from** that is found in **str** is replaced by the character at the same index in **to**. Any characters in **from** that don't have a corresponding character in **to** are removed. If a character appears more than once in **from**, the first occurrence determines the mapping." ) )] #[derive(Debug, PartialEq, Eq, Hash)] @@ -71,6 +71,7 @@ impl TranslateFunc { vec![ Exact(vec![Utf8View, Utf8, Utf8]), Exact(vec![Utf8, Utf8, Utf8]), + Exact(vec![LargeUtf8, Utf8, Utf8]), ], Volatility::Immutable, ), @@ -99,6 +100,61 @@ impl ScalarUDFImpl for TranslateFunc { &self, args: datafusion_expr::ScalarFunctionArgs, ) -> Result { + // When from and to are scalars, pre-build the translation map once + if let (Some(from_str), Some(to_str)) = ( + try_as_scalar_str(&args.args[1]), + try_as_scalar_str(&args.args[2]), + ) { + let to_graphemes: Vec<&str> = to_str.graphemes(true).collect(); + + let mut from_map: HashMap<&str, usize> = HashMap::new(); + for (index, c) in from_str.graphemes(true).enumerate() { + // Ignore characters that already exist in from_map + from_map.entry(c).or_insert(index); + } + + let ascii_table = build_ascii_translate_table(from_str, to_str); + + let string_array = args.args[0].to_array_of_size(args.number_rows)?; + + let result = match string_array.data_type() { + DataType::Utf8View => { + let arr = string_array.as_string_view(); + translate_with_map::( + arr, + &from_map, + &to_graphemes, + ascii_table.as_ref(), + ) + } + DataType::Utf8 => { + let arr = string_array.as_string::(); + translate_with_map::( + arr, + &from_map, + &to_graphemes, + ascii_table.as_ref(), + ) + } + DataType::LargeUtf8 => { + let arr = string_array.as_string::(); + translate_with_map::( + arr, + &from_map, + &to_graphemes, + ascii_table.as_ref(), + ) + } + other => { + return exec_err!( + "Unsupported data type {other:?} for function translate" + ); + } + }?; + + return Ok(ColumnarValue::Array(result)); + } + make_scalar_function(invoke_translate, vec![])(&args.args) } @@ -107,6 +163,14 @@ impl ScalarUDFImpl for TranslateFunc { } } +/// If `cv` is a non-null scalar string, return its value. +fn try_as_scalar_str(cv: &ColumnarValue) -> Option<&str> { + match cv { + ColumnarValue::Scalar(s) => s.try_as_str().flatten(), + _ => None, + } +} + fn invoke_translate(args: &[ArrayRef]) -> Result { match args[0].data_type() { DataType::Utf8View => { @@ -123,8 +187,8 @@ fn invoke_translate(args: &[ArrayRef]) -> Result { } DataType::LargeUtf8 => { let string_array = args[0].as_string::(); - let from_array = args[1].as_string::(); - let to_array = args[2].as_string::(); + let from_array = args[1].as_string::(); + let to_array = args[2].as_string::(); translate::(string_array, from_array, to_array) } other => { @@ -170,7 +234,7 @@ where // Build from_map using reusable buffer from_graphemes.extend(from.graphemes(true)); for (index, c) in from_graphemes.iter().enumerate() { - // Ignore characters that already exist in from_map, else insert + // Ignore characters that already exist in from_map from_map.entry(*c).or_insert(index); } @@ -199,6 +263,97 @@ where Ok(Arc::new(result) as ArrayRef) } +/// Sentinel value in the ASCII translate table indicating the character should +/// be deleted (the `from` character has no corresponding `to` character). Any +/// value > 127 works since valid ASCII is 0–127. +const ASCII_DELETE: u8 = 0xFF; + +/// If `from` and `to` are both ASCII, build a fixed-size lookup table for +/// translation. Each entry maps an input byte to its replacement byte, or to +/// [`ASCII_DELETE`] if the character should be removed. Returns `None` if +/// either string contains non-ASCII characters. +fn build_ascii_translate_table(from: &str, to: &str) -> Option<[u8; 128]> { + if !from.is_ascii() || !to.is_ascii() { + return None; + } + let mut table = [0u8; 128]; + for i in 0..128u8 { + table[i as usize] = i; + } + let to_bytes = to.as_bytes(); + let mut seen = [false; 128]; + for (i, from_byte) in from.bytes().enumerate() { + let idx = from_byte as usize; + if !seen[idx] { + seen[idx] = true; + if i < to_bytes.len() { + table[idx] = to_bytes[i]; + } else { + table[idx] = ASCII_DELETE; + } + } + } + Some(table) +} + +/// Optimized translate for constant `from` and `to` arguments: uses a pre-built +/// translation map instead of rebuilding it for every row. When an ASCII byte +/// lookup table is provided, ASCII input rows use the lookup table; non-ASCII +/// inputs fallback to using the map. +fn translate_with_map<'a, T: OffsetSizeTrait, V>( + string_array: V, + from_map: &HashMap<&str, usize>, + to_graphemes: &[&str], + ascii_table: Option<&[u8; 128]>, +) -> Result +where + V: ArrayAccessor, +{ + let mut result_graphemes: Vec<&str> = Vec::new(); + let mut ascii_buf: Vec = Vec::new(); + + let result = ArrayIter::new(string_array) + .map(|string| { + string.map(|s| { + // Fast path: byte-level table lookup for ASCII strings + if let Some(table) = ascii_table + && s.is_ascii() + { + ascii_buf.clear(); + for &b in s.as_bytes() { + let mapped = table[b as usize]; + if mapped != ASCII_DELETE { + ascii_buf.push(mapped); + } + } + // SAFETY: all bytes are ASCII, hence valid UTF-8. + return unsafe { + std::str::from_utf8_unchecked(&ascii_buf).to_owned() + }; + } + + // Slow path: grapheme-based translation + result_graphemes.clear(); + + for c in s.graphemes(true) { + match from_map.get(c) { + Some(n) => { + if let Some(replacement) = to_graphemes.get(*n) { + result_graphemes.push(*replacement); + } + } + None => result_graphemes.push(c), + } + } + + result_graphemes.concat() + }) + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + #[cfg(test)] mod tests { use arrow::array::{Array, StringArray}; @@ -284,6 +439,21 @@ mod tests { Utf8, StringArray ); + // Non-ASCII input with ASCII scalar from/to: exercises the + // grapheme fallback within translate_with_map. + test_function!( + TranslateFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::from("café")), + ColumnarValue::Scalar(ScalarValue::from("ae")), + ColumnarValue::Scalar(ScalarValue::from("AE")) + ], + Ok(Some("cAfé")), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] test_function!( TranslateFunc::new(), diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index e4980728b18a..b9bde1454994 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -147,7 +147,7 @@ where if scalar.is_null() { // Null scalar is castable to any numeric, creating a non-null expression. // Provide null array explicitly to make result null - PrimitiveArray::::new_null(1) + PrimitiveArray::::new_null(left.len()) } else { let right = R::Native::try_from(scalar.clone()).map_err(|_| { DataFusionError::NotImplemented(format!( @@ -363,12 +363,30 @@ pub mod test { }; } - use arrow::datatypes::DataType; + use arrow::{ + array::Int32Array, + datatypes::{DataType, Int32Type}, + }; use itertools::Either; pub(crate) use test_function; use super::*; + #[test] + fn test_calculate_binary_math_scalar_null() { + let left = Int32Array::from(vec![1, 2]); + let right = ColumnarValue::Scalar(ScalarValue::Int32(None)); + let result = calculate_binary_math::( + &left, + &right, + |x, y| Ok(x + y), + ) + .unwrap(); + + assert_eq!(result.len(), 2); + assert_eq!(result.null_count(), 2); + } + #[test] fn string_to_int_type() { let v = utf8_to_int_type(&DataType::Utf8, "test").unwrap(); diff --git a/datafusion/macros/Cargo.toml b/datafusion/macros/Cargo.toml index 85833bf11649..91f1dde62aaa 100644 --- a/datafusion/macros/Cargo.toml +++ b/datafusion/macros/Cargo.toml @@ -45,5 +45,5 @@ proc-macro = true [dependencies] datafusion-doc = { workspace = true } -quote = "1.0.41" -syn = { version = "2.0.113", features = ["full"] } +quote = "1.0.44" +syn = { version = "2.0.117", features = ["full"] } diff --git a/datafusion/macros/src/user_doc.rs b/datafusion/macros/src/user_doc.rs index 27f73fd95538..ce9e7d55ef10 100644 --- a/datafusion/macros/src/user_doc.rs +++ b/datafusion/macros/src/user_doc.rs @@ -20,7 +20,6 @@ html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] #![cfg_attr(docsrs, feature(doc_cfg))] -#![deny(clippy::allow_attributes)] extern crate proc_macro; use datafusion_doc::scalar_doc_sections::doc_sections_const; diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 15d3261ca513..76d3f73f6876 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -55,7 +55,7 @@ itertools = { workspace = true } log = { workspace = true } recursive = { workspace = true, optional = true } regex = { workspace = true } -regex-syntax = "0.8.6" +regex-syntax = "0.8.9" [dev-dependencies] async-trait = { workspace = true } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 02395c76bdd9..ed04aa4285d1 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -36,22 +36,22 @@ use datafusion_common::{ }; use datafusion_expr::expr::{ self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, InList, - InSubquery, Like, ScalarFunction, Sort, WindowFunction, + InSubquery, Like, ScalarFunction, SetComparison, Sort, WindowFunction, }; use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; use datafusion_expr::expr_schema::cast_subquery; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::type_coercion::binary::{comparison_coercion, like_coercion}; -use datafusion_expr::type_coercion::functions::fields_with_udf; +use datafusion_expr::type_coercion::functions::{UDFCoercionExt, fields_with_udf}; +use datafusion_expr::type_coercion::is_datetime; use datafusion_expr::type_coercion::other::{ get_coerce_type_for_case_expression, get_coerce_type_for_list, }; -use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_utf8view_or_large_utf8}; use datafusion_expr::utils::merge_schema; use datafusion_expr::{ - AggregateUDF, Cast, Expr, ExprSchemable, Join, Limit, LogicalPlan, Operator, - Projection, ScalarUDF, Union, WindowFrame, WindowFrameBound, WindowFrameUnits, - is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, lit, not, + Cast, Expr, ExprSchemable, Join, Limit, LogicalPlan, Operator, Projection, Union, + WindowFrame, WindowFrameBound, WindowFrameUnits, is_false, is_not_false, is_not_true, + is_not_unknown, is_true, is_unknown, lit, not, }; /// Performs type coercion by determining the schema @@ -500,6 +500,43 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { negated, )))) } + Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + }) => { + let new_plan = analyze_internal( + self.schema, + Arc::unwrap_or_clone(subquery.subquery), + )? + .data; + let expr_type = expr.get_type(self.schema)?; + let subquery_type = new_plan.schema().field(0).data_type(); + if (expr_type.is_numeric() && subquery_type.is_string()) + || (subquery_type.is_numeric() && expr_type.is_string()) + { + return plan_err!( + "expr type {expr_type} can't cast to {subquery_type} in SetComparison" + ); + } + let common_type = comparison_coercion(&expr_type, subquery_type).ok_or( + plan_datafusion_err!( + "expr type {expr_type} can't cast to {subquery_type} in SetComparison" + ), + )?; + let new_subquery = Subquery { + subquery: Arc::new(new_plan), + outer_ref_columns: subquery.outer_ref_columns, + spans: subquery.spans, + }; + Ok(Transformed::yes(Expr::SetComparison(SetComparison::new( + Box::new(expr.cast_to(&common_type, self.schema)?), + cast_subquery(new_subquery, &common_type)?, + op, + quantifier, + )))) + } Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op( *expr, self.schema, @@ -637,11 +674,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { Ok(Transformed::yes(Expr::Case(case))) } Expr::ScalarFunction(ScalarFunction { func, args }) => { - let new_expr = coerce_arguments_for_signature_with_scalar_udf( - args, - self.schema, - &func, - )?; + let new_expr = + coerce_arguments_for_signature(args, self.schema, func.as_ref())?; Ok(Transformed::yes(Expr::ScalarFunction( ScalarFunction::new_udf(func, new_expr), ))) @@ -657,11 +691,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { null_treatment, }, }) => { - let new_expr = coerce_arguments_for_signature_with_aggregate_udf( - args, - self.schema, - &func, - )?; + let new_expr = + coerce_arguments_for_signature(args, self.schema, func.as_ref())?; Ok(Transformed::yes(Expr::AggregateFunction( expr::AggregateFunction::new_udf( func, @@ -692,13 +723,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { let args = match &fun { expr::WindowFunctionDefinition::AggregateUDF(udf) => { - coerce_arguments_for_signature_with_aggregate_udf( - args, - self.schema, - udf, - )? + coerce_arguments_for_signature(args, self.schema, udf.as_ref())? + } + expr::WindowFunctionDefinition::WindowUDF(udf) => { + coerce_arguments_for_signature(args, self.schema, udf.as_ref())? } - _ => args, }; let new_expr = Expr::from(WindowFunction { @@ -859,12 +888,15 @@ fn coerce_frame_bound( fn extract_window_frame_target_type(col_type: &DataType) -> Result { if col_type.is_numeric() - || is_utf8_or_utf8view_or_large_utf8(col_type) - || matches!(col_type, DataType::List(_)) - || matches!(col_type, DataType::LargeList(_)) - || matches!(col_type, DataType::FixedSizeList(_, _)) - || matches!(col_type, DataType::Null) - || matches!(col_type, DataType::Boolean) + || col_type.is_string() + || col_type.is_null() + || matches!( + col_type, + DataType::List(_) + | DataType::LargeList(_) + | DataType::FixedSizeList(_, _) + | DataType::Boolean + ) { Ok(col_type.clone()) } else if is_datetime(col_type) { @@ -917,45 +949,11 @@ fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result { /// `signature`, if possible. /// /// See the module level documentation for more detail on coercion. -fn coerce_arguments_for_signature_with_scalar_udf( - expressions: Vec, - schema: &DFSchema, - func: &ScalarUDF, -) -> Result> { - if expressions.is_empty() { - return Ok(expressions); - } - - let current_fields = expressions - .iter() - .map(|e| e.to_field(schema).map(|(_, f)| f)) - .collect::>>()?; - - let coerced_types = fields_with_udf(¤t_fields, func)? - .into_iter() - .map(|f| f.data_type().clone()) - .collect::>(); - - expressions - .into_iter() - .enumerate() - .map(|(i, expr)| expr.cast_to(&coerced_types[i], schema)) - .collect() -} - -/// Returns `expressions` coerced to types compatible with -/// `signature`, if possible. -/// -/// See the module level documentation for more detail on coercion. -fn coerce_arguments_for_signature_with_aggregate_udf( +fn coerce_arguments_for_signature( expressions: Vec, schema: &DFSchema, - func: &AggregateUDF, + func: &F, ) -> Result> { - if expressions.is_empty() { - return Ok(expressions); - } - let current_fields = expressions .iter() .map(|e| e.to_field(schema).map(|(_, f)| f)) @@ -1890,7 +1888,7 @@ mod test { .err() .unwrap() .strip_backtrace(); - assert!(err.starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'avg' function: coercion from Utf8 to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed")); + assert!(err.starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'avg' function: coercion from Utf8 to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float16, Float32, Float64]) failed")); Ok(()) } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index d9273a8f60fb..2096c4277031 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -34,7 +34,9 @@ use datafusion_expr::expr::{Alias, ScalarFunction}; use datafusion_expr::logical_plan::{ Aggregate, Filter, LogicalPlan, Projection, Sort, Window, }; -use datafusion_expr::{BinaryExpr, Case, Expr, Operator, SortExpr, col}; +use datafusion_expr::{ + BinaryExpr, Case, Expr, ExpressionPlacement, Operator, SortExpr, col, +}; const CSE_PREFIX: &str = "__common_expr"; @@ -698,10 +700,27 @@ impl CSEController for ExprCSEController<'_> { } fn is_ignored(&self, node: &Expr) -> bool { + // MoveTowardsLeafNodes expressions (e.g. get_field) are cheap struct + // field accesses that the ExtractLeafExpressions / PushDownLeafProjections + // rules deliberately duplicate when needed (one copy for a filter + // predicate, another for an output column). CSE deduplicating them + // creates intermediate projections that fight with those rules, + // causing optimizer instability — ExtractLeafExpressions will undo + // the dedup, creating an infinite loop that runs until the iteration + // limit is hit. Skip them. + if node.placement() == ExpressionPlacement::MoveTowardsLeafNodes { + return true; + } + // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] let is_normal_minus_aggregates = matches!( node, + // TODO: there's an argument for removing `Literal` from here, + // maybe using `Expr::placemement().should_push_to_leaves()` instead + // so that we extract common literals and don't broadcast them to num_batch_rows multiple times. + // However that currently breaks things like `percentile_cont()` which expect literal arguments + // (and would instead be getting `col(__common_expr_n)`). Expr::Literal(..) | Expr::Column(..) | Expr::ScalarVariable(..) @@ -825,6 +844,7 @@ mod test { use super::*; use crate::assert_optimized_plan_eq_snapshot; use crate::optimizer::OptimizerContext; + use crate::test::udfs::leaf_udf_expr; use crate::test::*; use datafusion_expr::test::function_stub::{avg, sum}; @@ -1826,4 +1846,56 @@ mod test { panic!("dummy - not implemented") } } + + /// Identical MoveTowardsLeafNodes expressions should NOT be deduplicated + /// by CSE — they are cheap (e.g. struct field access) and the extraction + /// rules deliberately duplicate them. Deduplicating causes optimizer + /// instability where one optimizer rule will undo the work of another, + /// resulting in an infinite optimization loop until the + /// we hit the max iteration limit and then give up. + #[test] + fn test_leaf_expression_not_extracted() -> Result<()> { + let table_scan = test_table_scan()?; + + let leaf = leaf_udf_expr(col("a")); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![leaf.clone().alias("c1"), leaf.alias("c2")])? + .build()?; + + // Plan should be unchanged — no __common_expr introduced + assert_optimized_plan_equal!( + plan, + @r" + Projection: leaf_udf(test.a) AS c1, leaf_udf(test.a) AS c2 + TableScan: test + " + ) + } + + /// When a MoveTowardsLeafNodes expression appears as a sub-expression of + /// a larger expression that IS duplicated, only the outer expression gets + /// deduplicated; the leaf sub-expression stays inline. + #[test] + fn test_leaf_subexpression_not_extracted() -> Result<()> { + let table_scan = test_table_scan()?; + + // leaf_udf(a) + b appears twice — the outer `+` is a common + // sub-expression, but leaf_udf(a) by itself is MoveTowardsLeafNodes + // and should not be extracted separately. + let common = leaf_udf_expr(col("a")) + col("b"); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![common.clone().alias("c1"), common.alias("c2")])? + .build()?; + + // The whole `leaf_udf(a) + b` gets deduplicated as __common_expr_1, + // but leaf_udf(a) alone is NOT pulled out. + assert_optimized_plan_equal!( + plan, + @r" + Projection: __common_expr_1 AS c1, __common_expr_1 AS c2 + Projection: leaf_udf(test.a) + test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) + } } diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index e8a9c8c83ae9..52d777f874fa 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -36,7 +36,6 @@ use datafusion_expr::{ BinaryExpr, Cast, EmptyRelation, Expr, FetchType, LogicalPlan, LogicalPlanBuilder, Operator, expr, lit, }; -use datafusion_physical_expr::execution_props::ExecutionProps; /// This struct rewrite the sub query plan by pull up the correlated /// expressions(contains outer reference columns) from the inner subquery's @@ -509,8 +508,7 @@ fn agg_exprs_evaluation_result_on_empty_batch( .data()?; let result_expr = result_expr.unalias(); - let props = ExecutionProps::new(); - let info = SimplifyContext::new(&props).with_schema(Arc::clone(schema)); + let info = SimplifyContext::default().with_schema(Arc::clone(schema)); let simplifier = ExprSimplifier::new(info); let result_expr = simplifier.simplify(result_expr)?; expr_result_map_for_count_bug.insert(e.schema_name().to_string(), result_expr); @@ -543,8 +541,7 @@ fn proj_exprs_evaluation_result_on_empty_batch( .data()?; if result_expr.ne(expr) { - let props = ExecutionProps::new(); - let info = SimplifyContext::new(&props).with_schema(Arc::clone(schema)); + let info = SimplifyContext::default().with_schema(Arc::clone(schema)); let simplifier = ExprSimplifier::new(info); let result_expr = simplifier.simplify(result_expr)?; let expr_name = match expr { @@ -584,8 +581,7 @@ fn filter_exprs_evaluation_result_on_empty_batch( .data()?; let pull_up_expr = if result_expr.ne(filter_expr) { - let props = ExecutionProps::new(); - let info = SimplifyContext::new(&props).with_schema(schema); + let info = SimplifyContext::default().with_schema(schema); let simplifier = ExprSimplifier::new(info); let result_expr = simplifier.simplify(result_expr)?; match &result_expr { diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index c8acb044876c..281d2d73481d 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -27,7 +27,10 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{Column, Result, assert_or_internal_err, plan_err}; +use datafusion_common::{ + Column, DFSchemaRef, ExprSchema, NullEquality, Result, assert_or_internal_err, + plan_err, +}; use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; @@ -310,6 +313,39 @@ fn mark_join( ) } +/// Check if join keys in the join filter may contain NULL values +/// +/// Returns true if any join key column is nullable on either side. +/// This is used to optimize null-aware anti joins: if all join keys are non-nullable, +/// we can use a regular anti join instead of the more expensive null-aware variant. +fn join_keys_may_be_null( + join_filter: &Expr, + left_schema: &DFSchemaRef, + right_schema: &DFSchemaRef, +) -> Result { + // Extract columns from the join filter + let mut columns = std::collections::HashSet::new(); + expr_to_columns(join_filter, &mut columns)?; + + // Check if any column is nullable + for col in columns { + // Check in left schema + if let Ok(field) = left_schema.field_from_column(&col) + && field.as_ref().is_nullable() + { + return Ok(true); + } + // Check in right schema + if let Ok(field) = right_schema.field_from_column(&col) + && field.as_ref().is_nullable() + { + return Ok(true); + } + } + + Ok(false) +} + fn build_join( left: &LogicalPlan, subquery: &LogicalPlan, @@ -403,6 +439,8 @@ fn build_join( // Degenerate case: no right columns referenced by the predicate(s) sub_query_alias.clone() }; + + // Mark joins don't use null-aware semantics (they use three-valued logic with mark column) let new_plan = LogicalPlanBuilder::from(left.clone()) .join_on(right_projected, join_type, Some(join_filter))? .build()?; @@ -415,10 +453,36 @@ fn build_join( return Ok(Some(new_plan)); } + // Determine if this should be a null-aware anti join + // Null-aware semantics are only needed for NOT IN subqueries, not NOT EXISTS: + // - NOT IN: Uses three-valued logic, requires null-aware handling + // - NOT EXISTS: Uses two-valued logic, regular anti join is correct + // We can distinguish them: NOT IN has in_predicate_opt, NOT EXISTS does not + // + // Additionally, if the join keys are non-nullable on both sides, we don't need + // null-aware semantics because NULLs cannot exist in the data. + let null_aware = join_type == JoinType::LeftAnti + && in_predicate_opt.is_some() + && join_keys_may_be_null(&join_filter, left.schema(), sub_query_alias.schema())?; + // join our sub query into the main plan - let new_plan = LogicalPlanBuilder::from(left.clone()) - .join_on(sub_query_alias, join_type, Some(join_filter))? - .build()?; + let new_plan = if null_aware { + // Use join_detailed_with_options to set null_aware flag + LogicalPlanBuilder::from(left.clone()) + .join_detailed_with_options( + sub_query_alias, + join_type, + (Vec::::new(), Vec::::new()), // No equijoin keys, filter-based join + Some(join_filter), + NullEquality::NullEqualsNothing, + true, // null_aware + )? + .build()? + } else { + LogicalPlanBuilder::from(left.clone()) + .join_on(sub_query_alias, join_type, Some(join_filter))? + .build()? + }; debug!( "predicate subquery optimized:\n{}", new_plan.display_indent() @@ -1977,7 +2041,7 @@ mod tests { TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [arr:Int32;N] Unnest: lists[sq.arr|depth=1] structs[] [arr:Int32;N] - TableScan: sq [arr:List(Field { data_type: Int32, nullable: true });N] + TableScan: sq [arr:List(Int32);N] " ) } @@ -2012,7 +2076,7 @@ mod tests { TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [a:UInt32;N] Unnest: lists[sq.a|depth=1] structs[] [a:UInt32;N] - TableScan: sq [a:List(Field { data_type: UInt32, nullable: true });N] + TableScan: sq [a:List(UInt32);N] " ) } diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 770291566346..3cb0516a6d29 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -341,6 +341,7 @@ fn find_inner_join( filter: None, schema: join_schema, null_equality, + null_aware: false, })); } } @@ -363,6 +364,7 @@ fn find_inner_join( join_type: JoinType::Inner, join_constraint: JoinConstraint::On, null_equality, + null_aware: false, })) } @@ -522,7 +524,7 @@ mod tests { plan, @ r" Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] " @@ -608,7 +610,7 @@ mod tests { plan, @ r" Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] " @@ -634,7 +636,7 @@ mod tests { plan, @ r" Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] " @@ -856,7 +858,7 @@ mod tests { plan, @ r" Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] @@ -936,7 +938,7 @@ mod tests { TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] " @@ -1010,7 +1012,7 @@ mod tests { Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] @@ -1246,7 +1248,7 @@ mod tests { plan, @ r" Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] " @@ -1367,6 +1369,7 @@ mod tests { filter: None, schema: join_schema, null_equality: NullEquality::NullEqualsNull, // Test preservation + null_aware: false, }); // Apply filter that can create join conditions diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index 2c78051c1413..58abe38d04bc 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -119,6 +119,7 @@ impl OptimizerRule for EliminateOuterJoin { filter: join.filter.clone(), schema: Arc::clone(&join.schema), null_equality: join.null_equality, + null_aware: join.null_aware, })); Filter::try_new(filter.predicate, new_join) .map(|f| Transformed::yes(LogicalPlan::Filter(f))) diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index a623faf8a2ff..0a50761e8a9f 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -76,6 +76,7 @@ impl OptimizerRule for ExtractEquijoinPredicate { join_constraint, schema, null_equality, + null_aware, }) => { let left_schema = left.schema(); let right_schema = right.schema(); @@ -117,6 +118,7 @@ impl OptimizerRule for ExtractEquijoinPredicate { // According to `is not distinct from`'s semantics, it's // safe to override it null_equality: NullEquality::NullEqualsNull, + null_aware, }))); } } @@ -132,6 +134,7 @@ impl OptimizerRule for ExtractEquijoinPredicate { join_constraint, schema, null_equality, + null_aware, }))) } else { Ok(Transformed::no(LogicalPlan::Join(Join { @@ -143,6 +146,7 @@ impl OptimizerRule for ExtractEquijoinPredicate { join_constraint, schema, null_equality, + null_aware, }))) } } diff --git a/datafusion/optimizer/src/extract_leaf_expressions.rs b/datafusion/optimizer/src/extract_leaf_expressions.rs new file mode 100644 index 000000000000..922ea7933781 --- /dev/null +++ b/datafusion/optimizer/src/extract_leaf_expressions.rs @@ -0,0 +1,3053 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Two-pass optimizer pipeline that pushes cheap expressions (like struct field +//! access `user['status']`) closer to data sources, enabling early data reduction +//! and source-level optimizations (e.g., Parquet column pruning). See +//! [`ExtractLeafExpressions`] (pass 1) and [`PushDownLeafProjections`] (pass 2). + +use indexmap::{IndexMap, IndexSet}; +use std::collections::HashMap; +use std::sync::Arc; + +use datafusion_common::alias::AliasGenerator; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::{Column, DFSchema, Result, qualified_name}; +use datafusion_expr::logical_plan::LogicalPlan; +use datafusion_expr::{Expr, ExpressionPlacement, Projection}; + +use crate::optimizer::ApplyOrder; +use crate::push_down_filter::replace_cols_by_name; +use crate::utils::has_all_column_refs; +use crate::{OptimizerConfig, OptimizerRule}; + +/// Prefix for aliases generated by the extraction optimizer passes. +/// +/// This prefix is **reserved for internal optimizer use**. User-defined aliases +/// starting with this prefix may be misidentified as optimizer-generated +/// extraction aliases, leading to unexpected behavior. Do not use this prefix +/// in user queries. +const EXTRACTED_EXPR_PREFIX: &str = "__datafusion_extracted"; + +/// Returns `true` if any sub-expression in `exprs` has +/// [`ExpressionPlacement::MoveTowardsLeafNodes`] placement. +/// +/// This is a lightweight pre-check that short-circuits as soon as one +/// extractable expression is found, avoiding the expensive allocations +/// (column HashSets, extractors, expression rewrites) that the full +/// extraction pipeline requires. +fn has_extractable_expr(exprs: &[Expr]) -> bool { + exprs.iter().any(|expr| { + expr.exists(|e| Ok(e.placement() == ExpressionPlacement::MoveTowardsLeafNodes)) + .unwrap_or(false) + }) +} + +/// Extracts `MoveTowardsLeafNodes` sub-expressions from non-projection nodes +/// into **extraction projections** (pass 1 of 2). +/// +/// This handles Filter, Sort, Limit, Aggregate, and Join nodes. For Projection +/// nodes, extraction and pushdown are handled by [`PushDownLeafProjections`]. +/// +/// # Key Concepts +/// +/// **Extraction projection**: a projection inserted *below* a node that +/// pre-computes a cheap expression and exposes it under an alias +/// (`__datafusion_extracted_N`). The parent node then references the alias +/// instead of the original expression. +/// +/// **Recovery projection**: a projection inserted *above* a node to restore +/// the original output schema when extraction changes it. +/// Schema-preserving nodes (Filter, Sort, Limit) gain extra columns from +/// the extraction projection that bubble up; the recovery projection selects +/// only the original columns to hide the extras. +/// +/// # Example +/// +/// Given a filter with a struct field access: +/// +/// ```text +/// Filter: user['status'] = 'active' +/// TableScan: t [id, user] +/// ``` +/// +/// This rule: +/// 1. Inserts an **extraction projection** below the filter: +/// 2. Adds a **recovery projection** above to hide the extra column: +/// +/// ```text +/// Projection: id, user <-- recovery projection +/// Filter: __datafusion_extracted_1 = 'active' +/// Projection: user['status'] AS __datafusion_extracted_1, id, user <-- extraction projection +/// TableScan: t [id, user] +/// ``` +/// +/// **Important:** The `PushDownFilter` rule is aware of projections created by this rule +/// and will not push filters through them. It uses `ExpressionPlacement` to detect +/// `MoveTowardsLeafNodes` expressions and skip filter pushdown past them. +#[derive(Default, Debug)] +pub struct ExtractLeafExpressions {} + +impl ExtractLeafExpressions { + /// Create a new [`ExtractLeafExpressions`] + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for ExtractLeafExpressions { + fn name(&self) -> &str { + "extract_leaf_expressions" + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + if !config.options().optimizer.enable_leaf_expression_pushdown { + return Ok(Transformed::no(plan)); + } + let alias_generator = config.alias_generator(); + + // Advance the alias generator past any user-provided __datafusion_extracted_N + // aliases to prevent collisions when generating new extraction aliases. + advance_generator_past_existing(&plan, alias_generator)?; + + plan.transform_down_with_subqueries(|plan| { + extract_from_plan(plan, alias_generator) + }) + } +} + +/// Scans the current plan node's expressions for pre-existing +/// `__datafusion_extracted_N` aliases and advances the generator +/// counter past them to avoid collisions with user-provided aliases. +fn advance_generator_past_existing( + plan: &LogicalPlan, + alias_generator: &AliasGenerator, +) -> Result<()> { + plan.apply(|plan| { + plan.expressions().iter().try_for_each(|expr| { + expr.apply(|e| { + if let Expr::Alias(alias) = e + && let Some(id) = alias + .name + .strip_prefix(EXTRACTED_EXPR_PREFIX) + .and_then(|s| s.strip_prefix('_')) + .and_then(|s| s.parse().ok()) + { + alias_generator.update_min_id(id); + } + Ok(TreeNodeRecursion::Continue) + })?; + Ok::<(), datafusion_common::error::DataFusionError>(()) + })?; + Ok(TreeNodeRecursion::Continue) + }) + .map(|_| ()) +} + +/// Extracts `MoveTowardsLeafNodes` sub-expressions from a plan node. +/// +/// Works for any number of inputs (0, 1, 2, …N). For multi-input nodes +/// like Join, each extracted sub-expression is routed to the correct input +/// by checking which input's schema contains all of the expression's column +/// references. +fn extract_from_plan( + plan: LogicalPlan, + alias_generator: &Arc, +) -> Result> { + // Only extract from plan types whose output schema is predictable after + // expression rewriting. Nodes like Window derive column names from + // their expressions, so rewriting `get_field` inside a window function + // changes the output schema and breaks the recovery projection. + if !matches!( + &plan, + LogicalPlan::Aggregate(_) + | LogicalPlan::Filter(_) + | LogicalPlan::Sort(_) + | LogicalPlan::Limit(_) + | LogicalPlan::Join(_) + ) { + return Ok(Transformed::no(plan)); + } + + let inputs = plan.inputs(); + if inputs.is_empty() { + return Ok(Transformed::no(plan)); + } + + // Fast pre-check: skip all allocations if no extractable expressions exist + if !has_extractable_expr(&plan.expressions()) { + return Ok(Transformed::no(plan)); + } + + // Save original output schema before any transformation + let original_schema = Arc::clone(plan.schema()); + + // Build per-input schemas from borrowed inputs (before plan is consumed + // by map_expressions). We only need schemas and column sets for routing; + // the actual inputs are cloned later only if extraction succeeds. + let input_schemas: Vec> = + inputs.iter().map(|i| Arc::clone(i.schema())).collect(); + + // Build per-input extractors + let mut extractors: Vec = input_schemas + .iter() + .map(|schema| LeafExpressionExtractor::new(schema.as_ref(), alias_generator)) + .collect(); + + // Build per-input column sets for routing expressions to the correct input + let input_column_sets: Vec> = input_schemas + .iter() + .map(|schema| schema_columns(schema.as_ref())) + .collect(); + + // Transform expressions via map_expressions with routing + let transformed = plan.map_expressions(|expr| { + routing_extract(expr, &mut extractors, &input_column_sets) + })?; + + // If no expressions were rewritten, nothing was extracted + if !transformed.transformed { + return Ok(transformed); + } + + // Clone inputs now that we know extraction succeeded. Wrap in Arc + // upfront since build_extraction_projection expects &Arc. + let owned_inputs: Vec> = transformed + .data + .inputs() + .into_iter() + .map(|i| Arc::new(i.clone())) + .collect(); + + // Build per-input extraction projections (None means no extractions for that input) + let new_inputs: Vec = owned_inputs + .into_iter() + .zip(extractors.iter()) + .map(|(input_arc, extractor)| { + match extractor.build_extraction_projection(&input_arc)? { + Some(plan) => Ok(plan), + // No extractions for this input — recover the LogicalPlan + // without cloning (refcount is 1 since build returned None). + None => { + Ok(Arc::try_unwrap(input_arc).unwrap_or_else(|arc| (*arc).clone())) + } + } + }) + .collect::>>()?; + + // Rebuild the plan keeping its rewritten expressions but replacing + // inputs with the new extraction projections. + let new_plan = transformed + .data + .with_new_exprs(transformed.data.expressions(), new_inputs)?; + + // Add recovery projection if the output schema changed + let recovered = build_recovery_projection(original_schema.as_ref(), new_plan)?; + + Ok(Transformed::yes(recovered)) +} + +/// Given an expression, returns the index of the input whose columns fully +/// cover the expression's column references. +/// Returns `None` if the expression references columns from multiple inputs +/// or if multiple inputs match (ambiguous, e.g. unqualified columns present +/// in both sides of a join). +fn find_owning_input( + expr: &Expr, + input_column_sets: &[std::collections::HashSet], +) -> Option { + let mut found = None; + for (idx, cols) in input_column_sets.iter().enumerate() { + if has_all_column_refs(expr, cols) { + if found.is_some() { + // Ambiguous — multiple inputs match + return None; + } + found = Some(idx); + } + } + found +} + +/// Walks an expression tree top-down, extracting `MoveTowardsLeafNodes` +/// sub-expressions and routing each to the correct per-input extractor. +fn routing_extract( + expr: Expr, + extractors: &mut [LeafExpressionExtractor], + input_column_sets: &[std::collections::HashSet], +) -> Result> { + expr.transform_down(|e| { + // Skip expressions already aliased with extracted expression pattern + if let Expr::Alias(alias) = &e + && alias.name.starts_with(EXTRACTED_EXPR_PREFIX) + { + return Ok(Transformed { + data: e, + transformed: false, + tnr: TreeNodeRecursion::Jump, + }); + } + + // Don't extract Alias nodes directly — preserve the alias and let + // transform_down recurse into the inner expression + if matches!(&e, Expr::Alias(_)) { + return Ok(Transformed::no(e)); + } + + match e.placement() { + ExpressionPlacement::MoveTowardsLeafNodes => { + if let Some(idx) = find_owning_input(&e, input_column_sets) { + let col_ref = extractors[idx].add_extracted(e)?; + Ok(Transformed::yes(col_ref)) + } else { + // References columns from multiple inputs — cannot extract + Ok(Transformed::no(e)) + } + } + ExpressionPlacement::Column => { + // Track columns that the parent node references so the + // extraction projection includes them as pass-through. + // Without this, the extraction projection would only + // contain __datafusion_extracted_N aliases, and the parent couldn't + // resolve its other column references. + if let Expr::Column(col) = &e + && let Some(idx) = find_owning_input(&e, input_column_sets) + { + extractors[idx].columns_needed.insert(col.clone()); + } + Ok(Transformed::no(e)) + } + _ => Ok(Transformed::no(e)), + } + }) +} + +/// Returns all columns in the schema (both qualified and unqualified forms) +fn schema_columns(schema: &DFSchema) -> std::collections::HashSet { + schema + .iter() + .flat_map(|(qualifier, field)| { + [ + Column::new(qualifier.cloned(), field.name()), + Column::new_unqualified(field.name()), + ] + }) + .collect() +} + +/// Rewrites extraction pairs and column references from one qualifier +/// space to another. +/// +/// Builds a replacement map by zipping `from_schema` (whose qualifiers +/// currently appear in `pairs` / `columns`) with `to_schema` (the +/// qualifiers we want), then applies `replace_cols_by_name`. +/// +/// Used for SubqueryAlias (alias-space -> input-space) and Union +/// (union output-space -> per-branch input-space). +fn remap_pairs_and_columns( + pairs: &[(Expr, String)], + columns: &IndexSet, + from_schema: &DFSchema, + to_schema: &DFSchema, +) -> Result { + let mut replace_map = HashMap::new(); + for ((from_q, from_f), (to_q, to_f)) in from_schema.iter().zip(to_schema.iter()) { + replace_map.insert( + qualified_name(from_q, from_f.name()), + Expr::Column(Column::new(to_q.cloned(), to_f.name())), + ); + } + let remapped_pairs: Vec<(Expr, String)> = pairs + .iter() + .map(|(expr, alias)| { + Ok(( + replace_cols_by_name(expr.clone(), &replace_map)?, + alias.clone(), + )) + }) + .collect::>()?; + let remapped_columns: IndexSet = columns + .iter() + .filter_map(|col| { + let rewritten = + replace_cols_by_name(Expr::Column(col.clone()), &replace_map).ok()?; + if let Expr::Column(c) = rewritten { + Some(c) + } else { + Some(col.clone()) + } + }) + .collect(); + Ok(ExtractionTarget { + pairs: remapped_pairs, + columns: remapped_columns, + }) +} + +// ============================================================================= +// Helper Types & Functions for Extraction Targeting +// ============================================================================= + +/// A bundle of extraction pairs (expression + alias) and standalone columns +/// that need to be pushed through a plan node. +struct ExtractionTarget { + /// Extracted expressions paired with their generated aliases. + pairs: Vec<(Expr, String)>, + /// Standalone column references needed by the parent node. + columns: IndexSet, +} + +/// Build a replacement map from a projection: output_column_name -> underlying_expr. +/// +/// This is used to resolve column references through a renaming projection. +/// For example, if a projection has `user AS x`, this maps `x` -> `col("user")`. +fn build_projection_replace_map(projection: &Projection) -> HashMap { + projection + .schema + .iter() + .zip(projection.expr.iter()) + .map(|((qualifier, field), expr)| { + let key = Column::from((qualifier, field)).flat_name(); + (key, expr.clone().unalias()) + }) + .collect() +} + +/// Build a recovery projection to restore the original output schema. +/// +/// After extraction, a node's output schema may differ from the original: +/// +/// - **Schema-preserving nodes** (Filter/Sort/Limit): the extraction projection +/// below adds extra `__datafusion_extracted_N` columns that bubble up through +/// the node. Recovery selects only the original columns to hide the extras. +/// ```text +/// Original schema: [id, user] +/// After extraction: [__datafusion_extracted_1, id, user] ← extra column leaked through +/// Recovery: SELECT id, user FROM ... ← hides __datafusion_extracted_1 +/// ``` +/// +/// - **Schema-defining nodes** (Aggregate): same number of columns but names +/// may differ because extracted aliases replaced the original expressions. +/// Recovery maps positionally, aliasing where names changed. +/// ```text +/// Original: [SUM(user['balance'])] +/// After: [SUM(__datafusion_extracted_1)] ← name changed +/// Recovery: SUM(__datafusion_extracted_1) AS "SUM(user['balance'])" +/// ``` +/// +/// - **Schemas identical** → no recovery projection needed. +fn build_recovery_projection( + original_schema: &DFSchema, + input: LogicalPlan, +) -> Result { + let new_schema = input.schema(); + let orig_len = original_schema.fields().len(); + let new_len = new_schema.fields().len(); + + if orig_len == new_len { + // Same number of fields — check if schemas are identical + let schemas_match = original_schema.iter().zip(new_schema.iter()).all( + |((orig_q, orig_f), (new_q, new_f))| { + orig_f.name() == new_f.name() && orig_q == new_q + }, + ); + if schemas_match { + return Ok(input); + } + + // Schema-defining nodes (Aggregate, Join): names may differ at some + // positions because extracted aliases replaced the original expressions. + // Map positionally, aliasing where the name changed. + // + // Invariant: `with_new_exprs` on all supported node types (Aggregate, + // Filter, Sort, Limit, Join) preserves column order, so positional + // mapping is safe here. + debug_assert!( + orig_len == new_len, + "build_recovery_projection: positional mapping requires same field count, \ + got original={orig_len} vs new={new_len}" + ); + let mut proj_exprs = Vec::with_capacity(orig_len); + for (i, (orig_qualifier, orig_field)) in original_schema.iter().enumerate() { + let (new_qualifier, new_field) = new_schema.qualified_field(i); + if orig_field.name() == new_field.name() && orig_qualifier == new_qualifier { + proj_exprs.push(Expr::from((orig_qualifier, orig_field))); + } else { + let new_col = Expr::Column(Column::from((new_qualifier, new_field))); + proj_exprs.push( + new_col.alias_qualified(orig_qualifier.cloned(), orig_field.name()), + ); + } + } + let projection = Projection::try_new(proj_exprs, Arc::new(input))?; + Ok(LogicalPlan::Projection(projection)) + } else { + // Schema-preserving nodes: new schema has extra extraction columns. + // Original columns still exist by name; select them to hide extras. + let col_exprs: Vec = original_schema.iter().map(Expr::from).collect(); + let projection = Projection::try_new(col_exprs, Arc::new(input))?; + Ok(LogicalPlan::Projection(projection)) + } +} + +/// Collects `MoveTowardsLeafNodes` sub-expressions found during expression +/// tree traversal and can build an extraction projection from them. +/// +/// # Example +/// +/// Given `Filter: user['status'] = 'active' AND user['name'] IS NOT NULL`: +/// - `add_extracted(user['status'])` → stores it, returns `col("__datafusion_extracted_1")` +/// - `add_extracted(user['name'])` → stores it, returns `col("__datafusion_extracted_2")` +/// - `build_extraction_projection()` produces: +/// `Projection: user['status'] AS __datafusion_extracted_1, user['name'] AS __datafusion_extracted_2, ` +struct LeafExpressionExtractor<'a> { + /// Extracted expressions: maps expression -> alias + extracted: IndexMap, + /// Columns referenced by extracted expressions or the parent node, + /// included as pass-through in the extraction projection. + columns_needed: IndexSet, + /// Input schema + input_schema: &'a DFSchema, + /// Alias generator + alias_generator: &'a Arc, +} + +impl<'a> LeafExpressionExtractor<'a> { + fn new(input_schema: &'a DFSchema, alias_generator: &'a Arc) -> Self { + Self { + extracted: IndexMap::new(), + columns_needed: IndexSet::new(), + input_schema, + alias_generator, + } + } + + /// Adds an expression to extracted set, returns column reference. + fn add_extracted(&mut self, expr: Expr) -> Result { + // Deduplication: reuse existing alias if same expression + if let Some(alias) = self.extracted.get(&expr) { + return Ok(Expr::Column(Column::new_unqualified(alias))); + } + + // Track columns referenced by this expression + for col in expr.column_refs() { + self.columns_needed.insert(col.clone()); + } + + // Generate unique alias + let alias = self.alias_generator.next(EXTRACTED_EXPR_PREFIX); + self.extracted.insert(expr, alias.clone()); + + Ok(Expr::Column(Column::new_unqualified(&alias))) + } + + /// Builds an extraction projection above the given input, or merges into + /// it if the input is already a projection. Delegates to + /// [`build_extraction_projection_impl`]. + /// + /// Returns `None` if there are no extractions. + fn build_extraction_projection( + &self, + input: &Arc, + ) -> Result> { + if self.extracted.is_empty() { + return Ok(None); + } + let pairs: Vec<(Expr, String)> = self + .extracted + .iter() + .map(|(e, a)| (e.clone(), a.clone())) + .collect(); + let proj = build_extraction_projection_impl( + &pairs, + &self.columns_needed, + input, + self.input_schema, + )?; + Ok(Some(LogicalPlan::Projection(proj))) + } +} + +/// Build an extraction projection above the target node (shared by both passes). +/// +/// If the target is an existing projection, merges into it. This requires +/// resolving column references through the projection's rename mapping: +/// if the projection has `user AS u`, and an extracted expression references +/// `u['name']`, we must rewrite it to `user['name']` since the merged +/// projection reads from the same input as the original. +/// +/// Deduplicates by resolved expression equality and adds pass-through +/// columns as needed. Otherwise builds a fresh projection with extracted +/// expressions + ALL input schema columns. +fn build_extraction_projection_impl( + extracted_exprs: &[(Expr, String)], + columns_needed: &IndexSet, + target: &Arc, + target_schema: &DFSchema, +) -> Result { + if let LogicalPlan::Projection(existing) = target.as_ref() { + // Merge into existing projection + let mut proj_exprs = existing.expr.clone(); + + // Build a map of existing expressions (by Expr equality) to their aliases + let existing_extractions: IndexMap = existing + .expr + .iter() + .filter_map(|e| { + if let Expr::Alias(alias) = e + && alias.name.starts_with(EXTRACTED_EXPR_PREFIX) + { + return Some((*alias.expr.clone(), alias.name.clone())); + } + None + }) + .collect(); + + // Resolve column references through the projection's rename mapping + let replace_map = build_projection_replace_map(existing); + + // Add new extracted expressions, resolving column refs through the projection + for (expr, alias) in extracted_exprs { + let resolved = replace_cols_by_name(expr.clone().alias(alias), &replace_map)?; + let resolved_inner = if let Expr::Alias(a) = &resolved { + a.expr.as_ref() + } else { + &resolved + }; + if let Some(existing_alias) = existing_extractions.get(resolved_inner) { + // Same expression already extracted under a different alias — + // add the expression with the new alias so both names are + // available in the output. We can't reference the existing alias + // as a column within the same projection, so we duplicate the + // computation. + if existing_alias != alias { + proj_exprs.push(resolved); + } + } else { + proj_exprs.push(resolved); + } + } + + // Add any new pass-through columns that aren't already in the projection. + // We check against existing.input.schema() (the projection's source) rather + // than target_schema (the projection's output) because columns produced + // by alias expressions (e.g., CSE's __common_expr_N) exist in the output but + // not the input, and cannot be added as pass-through Column references. + let existing_cols: IndexSet = existing + .expr + .iter() + .filter_map(|e| { + if let Expr::Column(c) = e { + Some(c.clone()) + } else { + None + } + }) + .collect(); + + let input_schema = existing.input.schema(); + for col in columns_needed { + let col_expr = Expr::Column(col.clone()); + let resolved = replace_cols_by_name(col_expr, &replace_map)?; + if let Expr::Column(resolved_col) = &resolved + && !existing_cols.contains(resolved_col) + && input_schema.has_column(resolved_col) + { + proj_exprs.push(Expr::Column(resolved_col.clone())); + } + // If resolved to non-column expr, it's already computed by existing projection + } + + Projection::try_new(proj_exprs, Arc::clone(&existing.input)) + } else { + // Build new projection with extracted expressions + all input columns + let mut proj_exprs = Vec::new(); + for (expr, alias) in extracted_exprs { + proj_exprs.push(expr.clone().alias(alias)); + } + for (qualifier, field) in target_schema.iter() { + proj_exprs.push(Expr::from((qualifier, field))); + } + Projection::try_new(proj_exprs, Arc::clone(target)) + } +} + +// ============================================================================= +// Pass 2: PushDownLeafProjections +// ============================================================================= + +/// Pushes extraction projections down through schema-preserving nodes towards +/// leaf nodes (pass 2 of 2, after [`ExtractLeafExpressions`]). +/// +/// Handles two types of projections: +/// - **Pure extraction projections** (all `__datafusion_extracted` aliases + columns): +/// pushes through Filter/Sort/Limit, merges into existing projections, or routes +/// into multi-input node inputs (Join, SubqueryAlias, etc.) +/// - **Mixed projections** (user projections containing `MoveTowardsLeafNodes` +/// sub-expressions): splits into a recovery projection + extraction projection, +/// then pushes the extraction projection down. +/// +/// # Example: Pushing through a Filter +/// +/// After pass 1, the extraction projection sits directly below the filter: +/// ```text +/// Projection: id, user <-- recovery +/// Filter: __datafusion_extracted_1 = 'active' +/// Projection: user['status'] AS __datafusion_extracted_1, id, user <-- extraction +/// TableScan: t [id, user] +/// ``` +/// +/// Pass 2 pushes the extraction projection through the recovery and filter, +/// and a subsequent `OptimizeProjections` pass removes the (now-redundant) +/// recovery projection: +/// ```text +/// Filter: __datafusion_extracted_1 = 'active' +/// Projection: user['status'] AS __datafusion_extracted_1, id, user <-- extraction (pushed down) +/// TableScan: t [id, user] +/// ``` +#[derive(Default, Debug)] +pub struct PushDownLeafProjections {} + +impl PushDownLeafProjections { + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for PushDownLeafProjections { + fn name(&self) -> &str { + "push_down_leaf_projections" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + if !config.options().optimizer.enable_leaf_expression_pushdown { + return Ok(Transformed::no(plan)); + } + let alias_generator = config.alias_generator(); + match try_push_input(&plan, alias_generator)? { + Some(new_plan) => Ok(Transformed::yes(new_plan)), + None => Ok(Transformed::no(plan)), + } + } +} + +/// Attempts to push a projection's extractable expressions further down. +/// +/// Returns `Some(new_subtree)` if the projection was pushed down or merged, +/// `None` if there is nothing to push or the projection sits above a barrier. +fn try_push_input( + input: &LogicalPlan, + alias_generator: &Arc, +) -> Result> { + let LogicalPlan::Projection(proj) = input else { + return Ok(None); + }; + split_and_push_projection(proj, alias_generator) +} + +/// Splits a projection into extractable pieces, pushes them towards leaf +/// nodes, and adds a recovery projection if needed. +/// +/// Handles both: +/// - **Pure extraction projections** (all `__datafusion_extracted` aliases + columns) +/// - **Mixed projections** (containing `MoveTowardsLeafNodes` sub-expressions) +/// +/// Returns `Some(new_subtree)` if extractions were pushed down, +/// `None` if there is nothing to extract or push. +/// +/// # Example: Mixed Projection +/// +/// ```text +/// Input plan: +/// Projection: user['name'] IS NOT NULL AS has_name, id +/// Filter: ... +/// TableScan +/// +/// Phase 1 (Split): +/// extraction_pairs: [(user['name'], "__datafusion_extracted_1")] +/// recovery_exprs: [__datafusion_extracted_1 IS NOT NULL AS has_name, id] +/// +/// Phase 2 (Push): +/// Push extraction projection through Filter toward TableScan +/// +/// Phase 3 (Recovery): +/// Projection: __datafusion_extracted_1 IS NOT NULL AS has_name, id <-- recovery +/// Filter: ... +/// Projection: user['name'] AS __datafusion_extracted_1, id <-- extraction (pushed) +/// TableScan +/// ``` +fn split_and_push_projection( + proj: &Projection, + alias_generator: &Arc, +) -> Result> { + // Fast pre-check: skip if there are no pre-existing extracted aliases + // and no new extractable expressions. + let has_existing_extracted = proj.expr.iter().any(|e| { + matches!(e, Expr::Alias(alias) if alias.name.starts_with(EXTRACTED_EXPR_PREFIX)) + }); + if !has_existing_extracted && !has_extractable_expr(&proj.expr) { + return Ok(None); + } + + let input = &proj.input; + let input_schema = input.schema(); + + // ── Phase 1: Split ────────────────────────────────────────────────── + // For each projection expression, collect extraction pairs and build + // recovery expressions. + // + // Pre-existing `__datafusion_extracted` aliases are inserted into the + // extractor's `IndexMap` with the **full** `Expr::Alias(…)` as the key, + // so the alias name participates in equality. This prevents collisions + // when CSE rewrites produce the same inner expression under different + // alias names (e.g. `__common_expr_4 AS __datafusion_extracted_1` and + // `__common_expr_4 AS __datafusion_extracted_3`). New extractions from + // `routing_extract` use bare (non-Alias) keys and get normal dedup. + // + // When building the final `extraction_pairs`, the Alias wrapper is + // stripped so consumers see the usual `(inner_expr, alias_name)` tuples. + + let mut extractors = vec![LeafExpressionExtractor::new( + input_schema.as_ref(), + alias_generator, + )]; + let input_column_sets = vec![schema_columns(input_schema.as_ref())]; + + let original_schema = proj.schema.as_ref(); + let mut recovery_exprs: Vec = Vec::with_capacity(proj.expr.len()); + let mut needs_recovery = false; + let mut has_new_extractions = false; + let mut proj_exprs_captured: usize = 0; + // Track standalone column expressions (Case B) to detect column refs + // from extracted aliases (Case A) that aren't also standalone expressions. + let mut standalone_columns: IndexSet = IndexSet::new(); + + for (expr, (qualifier, field)) in proj.expr.iter().zip(original_schema.iter()) { + if let Expr::Alias(alias) = expr + && alias.name.starts_with(EXTRACTED_EXPR_PREFIX) + { + // Insert the full Alias expression as the key so that + // distinct alias names don't collide in the IndexMap. + let alias_name = alias.name.clone(); + + for col_ref in alias.expr.column_refs() { + extractors[0].columns_needed.insert(col_ref.clone()); + } + + extractors[0] + .extracted + .insert(expr.clone(), alias_name.clone()); + recovery_exprs.push(Expr::Column(Column::new_unqualified(&alias_name))); + proj_exprs_captured += 1; + } else if let Expr::Column(col) = expr { + // Plain column pass-through — track it in the extractor + extractors[0].columns_needed.insert(col.clone()); + standalone_columns.insert(col.clone()); + recovery_exprs.push(expr.clone()); + proj_exprs_captured += 1; + } else { + // Everything else: run through routing_extract + let transformed = + routing_extract(expr.clone(), &mut extractors, &input_column_sets)?; + if transformed.transformed { + has_new_extractions = true; + } + let transformed_expr = transformed.data; + + // Build recovery expression, aliasing back to original name if needed + let original_name = field.name(); + let needs_alias = if let Expr::Column(col) = &transformed_expr { + col.name.as_str() != original_name + } else { + let expr_name = transformed_expr.schema_name().to_string(); + original_name != &expr_name + }; + let recovery_expr = if needs_alias { + needs_recovery = true; + transformed_expr + .clone() + .alias_qualified(qualifier.cloned(), original_name) + } else { + transformed_expr.clone() + }; + + // If the expression was transformed (i.e., has extracted sub-parts), + // it differs from what the pushed projection outputs → needs recovery. + // Also, any non-column, non-__datafusion_extracted expression needs recovery + // because the pushed extraction projection won't output it directly. + if transformed.transformed || !matches!(expr, Expr::Column(_)) { + needs_recovery = true; + } + + recovery_exprs.push(recovery_expr); + } + } + + // Build extraction_pairs, stripping the Alias wrapper from pre-existing + // entries (they used the full Alias as the map key to avoid dedup). + let extractor = &extractors[0]; + let extraction_pairs: Vec<(Expr, String)> = extractor + .extracted + .iter() + .map(|(e, a)| match e { + Expr::Alias(alias) => (*alias.expr.clone(), a.clone()), + _ => (e.clone(), a.clone()), + }) + .collect(); + let columns_needed = &extractor.columns_needed; + + // If no extractions found, nothing to do + if extraction_pairs.is_empty() { + return Ok(None); + } + + // If columns_needed has entries that aren't standalone projection columns + // (i.e., they came from column refs inside extracted aliases), a merge + // into an inner projection will widen the schema with those extra columns, + // requiring a recovery projection to restore the original schema. + if columns_needed + .iter() + .any(|c| !standalone_columns.contains(c)) + { + needs_recovery = true; + } + + // ── Phase 2: Push down ────────────────────────────────────────────── + let proj_input = Arc::clone(&proj.input); + let pushed = push_extraction_pairs( + &extraction_pairs, + columns_needed, + proj, + &proj_input, + alias_generator, + proj_exprs_captured, + )?; + + // ── Phase 3: Recovery ─────────────────────────────────────────────── + // Determine the base plan: either the pushed result or an in-place extraction. + let base_plan = match pushed { + Some(plan) => plan, + None => { + if !has_new_extractions { + // Only pre-existing __datafusion_extracted aliases and columns, no new + // extractions from routing_extract. The original projection is + // already an extraction projection that couldn't be pushed + // further. Return None. + return Ok(None); + } + // Build extraction projection in-place (couldn't push down) + let input_arc = Arc::clone(input); + let extraction = build_extraction_projection_impl( + &extraction_pairs, + columns_needed, + &input_arc, + input_schema.as_ref(), + )?; + LogicalPlan::Projection(extraction) + } + }; + + // Wrap with recovery projection if the output schema changed + if needs_recovery { + let recovery = LogicalPlan::Projection(Projection::try_new( + recovery_exprs, + Arc::new(base_plan), + )?); + Ok(Some(recovery)) + } else { + Ok(Some(base_plan)) + } +} + +/// Returns true if the plan is a Projection where ALL expressions are either +/// `Alias(EXTRACTED_EXPR_PREFIX, ...)` or `Column`, with at least one extraction. +/// Such projections can safely be pushed further without re-extraction. +fn is_pure_extraction_projection(plan: &LogicalPlan) -> bool { + let LogicalPlan::Projection(proj) = plan else { + return false; + }; + let mut has_extraction = false; + for expr in &proj.expr { + match expr { + Expr::Alias(alias) if alias.name.starts_with(EXTRACTED_EXPR_PREFIX) => { + has_extraction = true; + } + Expr::Column(_) => {} + _ => return false, + } + } + has_extraction +} + +/// Pushes extraction pairs down through the projection's input node, +/// dispatching to the appropriate handler based on the input node type. +fn push_extraction_pairs( + pairs: &[(Expr, String)], + columns_needed: &IndexSet, + proj: &Projection, + proj_input: &Arc, + alias_generator: &Arc, + proj_exprs_captured: usize, +) -> Result> { + match proj_input.as_ref() { + // Merge into existing projection, then try to push the result further down. + // Only merge when every expression in the outer projection is fully + // captured as either an extraction pair (Case A: __datafusion_extracted + // alias) or a plain column (Case B). Uncaptured expressions (e.g. + // `col AS __common_expr_1` from CSE, or complex expressions with + // extracted sub-parts) would be lost during the merge. + LogicalPlan::Projection(_) if proj_exprs_captured == proj.expr.len() => { + let target_schema = Arc::clone(proj_input.schema()); + let merged = build_extraction_projection_impl( + pairs, + columns_needed, + proj_input, + target_schema.as_ref(), + )?; + let merged_plan = LogicalPlan::Projection(merged); + + // After merging, try to push the result further down, but ONLY + // if the merged result is still a pure extraction projection + // (all __datafusion_extracted aliases + columns). If the merge inherited + // bare MoveTowardsLeafNodes expressions from the inner projection, + // pushing would re-extract them into new aliases and fail when + // the (None, true) fallback can't find the original aliases. + // This handles: Extraction → Recovery(cols) → Filter → ... → TableScan + // by pushing through the recovery projection AND the filter in one pass. + if is_pure_extraction_projection(&merged_plan) + && let Some(pushed) = try_push_input(&merged_plan, alias_generator)? + { + return Ok(Some(pushed)); + } + Ok(Some(merged_plan)) + } + // Generic: handles Filter/Sort/Limit (via recursion), + // SubqueryAlias (with qualifier remap in try_push_into_inputs), + // Join, and anything else. + // Safely bails out for nodes that don't pass through extracted + // columns (Aggregate, Window) via the output schema check. + _ => try_push_into_inputs( + pairs, + columns_needed, + proj_input.as_ref(), + alias_generator, + ), + } +} + +/// Routes extraction pairs and columns to the appropriate inputs. +/// +/// - **Union**: broadcasts to every input via [`remap_pairs_and_columns`]. +/// - **Other nodes**: routes each expression to the one input that owns +/// all of its column references (via [`find_owning_input`]). +/// +/// Returns `None` if any expression can't be routed or no input has pairs. +fn route_to_inputs( + pairs: &[(Expr, String)], + columns: &IndexSet, + node: &LogicalPlan, + input_column_sets: &[std::collections::HashSet], + input_schemas: &[Arc], +) -> Result>> { + let num_inputs = input_schemas.len(); + let mut per_input: Vec = (0..num_inputs) + .map(|_| ExtractionTarget { + pairs: vec![], + columns: IndexSet::new(), + }) + .collect(); + + if matches!(node, LogicalPlan::Union(_)) { + // Union output schema and each input schema have the same fields by + // index but may differ in qualifiers (e.g. output `s` vs input + // `simple_struct.s`). Remap pairs/columns to each input's space. + let union_schema = node.schema(); + for (idx, input_schema) in input_schemas.iter().enumerate() { + per_input[idx] = + remap_pairs_and_columns(pairs, columns, union_schema, input_schema)?; + } + } else { + for (expr, alias) in pairs { + match find_owning_input(expr, input_column_sets) { + Some(idx) => per_input[idx].pairs.push((expr.clone(), alias.clone())), + None => return Ok(None), // Cross-input expression — bail out + } + } + for col in columns { + let col_expr = Expr::Column(col.clone()); + match find_owning_input(&col_expr, input_column_sets) { + Some(idx) => { + per_input[idx].columns.insert(col.clone()); + } + None => return Ok(None), // Ambiguous column — bail out + } + } + } + + // Check at least one input has extractions to push + if per_input.iter().all(|t| t.pairs.is_empty()) { + return Ok(None); + } + + Ok(Some(per_input)) +} + +/// Pushes extraction expressions into a node's inputs by routing each +/// expression to the input that owns all of its column references. +/// +/// Works for any number of inputs (1, 2, …N). For single-input nodes, +/// all expressions trivially route to that input. For multi-input nodes +/// (Join, etc.), each expression is routed to the side that owns its columns. +/// +/// Returns `Some(new_node)` if all expressions could be routed AND the +/// rebuilt node's output schema contains all extracted aliases. +/// Returns `None` if any expression references columns from multiple inputs +/// or the node doesn't pass through the extracted columns. +/// +/// # Example: Join with expressions from both sides +/// +/// ```text +/// Extraction projection above a Join: +/// Projection: left.user['name'] AS __datafusion_extracted_1, right.order['total'] AS __datafusion_extracted_2, ... +/// Join: left.id = right.user_id +/// TableScan: left [id, user] +/// TableScan: right [user_id, order] +/// +/// After routing each expression to its owning input: +/// Join: left.id = right.user_id +/// Projection: user['name'] AS __datafusion_extracted_1, id, user <-- left-side extraction +/// TableScan: left [id, user] +/// Projection: order['total'] AS __datafusion_extracted_2, user_id, order <-- right-side extraction +/// TableScan: right [user_id, order] +/// ``` +fn try_push_into_inputs( + pairs: &[(Expr, String)], + columns_needed: &IndexSet, + node: &LogicalPlan, + alias_generator: &Arc, +) -> Result> { + let inputs = node.inputs(); + if inputs.is_empty() { + return Ok(None); + } + + // SubqueryAlias remaps qualifiers between input and output. + // Rewrite pairs/columns from alias-space to input-space before routing. + let remapped = if let LogicalPlan::SubqueryAlias(sa) = node { + remap_pairs_and_columns(pairs, columns_needed, &sa.schema, sa.input.schema())? + } else { + ExtractionTarget { + pairs: pairs.to_vec(), + columns: columns_needed.clone(), + } + }; + let pairs = &remapped.pairs[..]; + let columns_needed = &remapped.columns; + + // Build per-input schemas and column sets for routing + let input_schemas: Vec> = + inputs.iter().map(|i| Arc::clone(i.schema())).collect(); + let input_column_sets: Vec> = + input_schemas.iter().map(|s| schema_columns(s)).collect(); + + // Route pairs and columns to the appropriate inputs + let per_input = match route_to_inputs( + pairs, + columns_needed, + node, + &input_column_sets, + &input_schemas, + )? { + Some(routed) => routed, + None => return Ok(None), + }; + + let num_inputs = inputs.len(); + + // Build per-input extraction projections and push them as far as possible + // immediately. This is critical because map_children preserves cached schemas, + // so if the TopDown pass later pushes a child further (changing its output + // schema), the parent node's schema becomes stale. + let mut new_inputs: Vec = Vec::with_capacity(num_inputs); + for (idx, input) in inputs.into_iter().enumerate() { + if per_input[idx].pairs.is_empty() { + new_inputs.push(input.clone()); + } else { + let input_arc = Arc::new(input.clone()); + let target_schema = Arc::clone(input.schema()); + let proj = build_extraction_projection_impl( + &per_input[idx].pairs, + &per_input[idx].columns, + &input_arc, + target_schema.as_ref(), + )?; + // Verify all requested aliases appear in the projection's output. + // A merge may deduplicate if the same expression already exists + // under a different alias, leaving the requested alias missing. + let proj_schema = proj.schema.as_ref(); + for (_expr, alias) in &per_input[idx].pairs { + if !proj_schema.fields().iter().any(|f| f.name() == alias) { + return Ok(None); + } + } + let proj_plan = LogicalPlan::Projection(proj); + // Try to push the extraction projection further down within + // this input (e.g., through Filter → existing extraction projection). + // This ensures the input's output schema is stable and won't change + // when the TopDown pass later visits children. + match try_push_input(&proj_plan, alias_generator)? { + Some(pushed) => new_inputs.push(pushed), + None => new_inputs.push(proj_plan), + } + } + } + + // Rebuild the node with new inputs + let new_node = node.with_new_exprs(node.expressions(), new_inputs)?; + + // Safety check: verify all extracted aliases appear in the rebuilt + // node's output schema. Nodes like Aggregate define their own output + // and won't pass through extracted columns — bail out for those. + let output_schema = new_node.schema(); + for (_expr, alias) in pairs { + if !output_schema.fields().iter().any(|f| f.name() == alias) { + return Ok(None); + } + } + + Ok(Some(new_node)) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; + use crate::optimize_projections::OptimizeProjections; + use crate::test::udfs::PlacementTestUDF; + use crate::test::*; + use crate::{Optimizer, OptimizerContext}; + use datafusion_common::Result; + use datafusion_expr::expr::ScalarFunction; + use datafusion_expr::{Expr, ExpressionPlacement}; + use datafusion_expr::{ + ScalarUDF, col, lit, logical_plan::builder::LogicalPlanBuilder, + }; + + fn leaf_udf(expr: Expr, name: &str) -> Expr { + Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::new(ScalarUDF::new_from_impl( + PlacementTestUDF::new() + .with_placement(ExpressionPlacement::MoveTowardsLeafNodes), + )), + vec![expr, lit(name)], + )) + } + + // ========================================================================= + // Combined optimization stage formatter + // ========================================================================= + + /// Runs all 4 optimization stages and returns a single formatted string. + /// Stages that produce the same plan as the previous stage show + /// "(same as )" to reduce noise. + /// + /// Stages: + /// 1. **Original** - OptimizeProjections only (baseline) + /// 2. **After Extraction** - + ExtractLeafExpressions + /// 3. **After Pushdown** - + PushDownLeafProjections + /// 4. **Optimized** - + final OptimizeProjections + fn format_optimization_stages(plan: &LogicalPlan) -> Result { + let run = |rules: Vec>| -> Result { + let ctx = OptimizerContext::new().with_max_passes(1); + let optimizer = Optimizer::with_rules(rules); + let optimized = optimizer.optimize(plan.clone(), &ctx, |_, _| {})?; + Ok(format!("{optimized}")) + }; + + let original = run(vec![Arc::new(OptimizeProjections::new())])?; + + let after_extract = run(vec![ + Arc::new(OptimizeProjections::new()), + Arc::new(ExtractLeafExpressions::new()), + ])?; + + let after_pushdown = run(vec![ + Arc::new(OptimizeProjections::new()), + Arc::new(ExtractLeafExpressions::new()), + Arc::new(PushDownLeafProjections::new()), + ])?; + + let optimized = run(vec![ + Arc::new(OptimizeProjections::new()), + Arc::new(ExtractLeafExpressions::new()), + Arc::new(PushDownLeafProjections::new()), + Arc::new(OptimizeProjections::new()), + ])?; + + let mut out = format!("## Original Plan\n{original}"); + + out.push_str("\n\n## After Extraction\n"); + if after_extract == original { + out.push_str("(same as original)"); + } else { + out.push_str(&after_extract); + } + + out.push_str("\n\n## After Pushdown\n"); + if after_pushdown == after_extract { + out.push_str("(same as after extraction)"); + } else { + out.push_str(&after_pushdown); + } + + out.push_str("\n\n## Optimized\n"); + if optimized == after_pushdown { + out.push_str("(same as after pushdown)"); + } else { + out.push_str(&optimized); + } + + Ok(out) + } + + /// Assert all optimization stages for a plan in a single insta snapshot. + macro_rules! assert_stages { + ($plan:expr, @ $expected:literal $(,)?) => {{ + let result = format_optimization_stages(&$plan)?; + insta::assert_snapshot!(result, @ $expected); + Ok::<(), datafusion_common::DataFusionError>(()) + }}; + } + + #[test] + fn test_extract_from_filter() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan.clone()) + .filter(leaf_udf(col("user"), "status").eq(lit("active")))? + .select(vec![ + table_scan + .schema() + .index_of_column_by_name(None, "id") + .unwrap(), + ])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: test.id + Filter: leaf_udf(test.user, Utf8("status")) = Utf8("active") + TableScan: test projection=[id, user] + + ## After Extraction + Projection: test.id + Projection: test.id, test.user + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + Projection: test.id + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id + TableScan: test projection=[id, user] + "#) + } + + #[test] + fn test_no_extraction_for_column() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(col("a").eq(lit(1)))? + .build()?; + + assert_stages!(plan, @" + ## Original Plan + Filter: test.a = Int32(1) + TableScan: test projection=[a, b, c] + + ## After Extraction + (same as original) + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + ") + } + + #[test] + fn test_extract_from_projection() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![leaf_udf(col("user"), "name")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")) + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("name")) + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: leaf_udf(test.user, Utf8("name")) + TableScan: test projection=[user] + "#) + } + + #[test] + fn test_extract_from_projection_with_subexpression() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![ + leaf_udf(col("user"), "name") + .is_not_null() + .alias("has_name"), + ])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")) IS NOT NULL AS has_name + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 IS NOT NULL AS has_name + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: leaf_udf(test.user, Utf8("name")) IS NOT NULL AS has_name + TableScan: test projection=[user] + "#) + } + + #[test] + fn test_projection_no_extraction_for_column() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + + assert_stages!(plan, @" + ## Original Plan + TableScan: test projection=[a, b] + + ## After Extraction + (same as original) + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + ") + } + + #[test] + fn test_filter_with_deduplication() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let field_access = leaf_udf(col("user"), "name"); + // Filter with the same expression used twice + let plan = LogicalPlanBuilder::from(table_scan) + .filter( + field_access + .clone() + .is_not_null() + .and(field_access.is_null()), + )? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Filter: leaf_udf(test.user, Utf8("name")) IS NOT NULL AND leaf_udf(test.user, Utf8("name")) IS NULL + TableScan: test projection=[id, user] + + ## After Extraction + Projection: test.id, test.user + Filter: __datafusion_extracted_1 IS NOT NULL AND __datafusion_extracted_1 IS NULL + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + "#) + } + + #[test] + fn test_already_leaf_expression_in_filter() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(leaf_udf(col("user"), "name").eq(lit("test")))? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Filter: leaf_udf(test.user, Utf8("name")) = Utf8("test") + TableScan: test projection=[id, user] + + ## After Extraction + Projection: test.id, test.user + Filter: __datafusion_extracted_1 = Utf8("test") + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + "#) + } + + #[test] + fn test_extract_from_aggregate_group_by() -> Result<()> { + use datafusion_expr::test::function_stub::count; + + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![leaf_udf(col("user"), "status")], vec![count(lit(1))])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Aggregate: groupBy=[[leaf_udf(test.user, Utf8("status"))]], aggr=[[COUNT(Int32(1))]] + TableScan: test projection=[user] + + ## After Extraction + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("status")), COUNT(Int32(1)) + Aggregate: groupBy=[[__datafusion_extracted_1]], aggr=[[COUNT(Int32(1))]] + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("status")), COUNT(Int32(1)) + Aggregate: groupBy=[[__datafusion_extracted_1]], aggr=[[COUNT(Int32(1))]] + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1 + TableScan: test projection=[user] + "#) + } + + #[test] + fn test_extract_from_aggregate_args() -> Result<()> { + use datafusion_expr::test::function_stub::count; + + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("user")], + vec![count(leaf_udf(col("user"), "value"))], + )? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Aggregate: groupBy=[[test.user]], aggr=[[COUNT(leaf_udf(test.user, Utf8("value")))]] + TableScan: test projection=[user] + + ## After Extraction + Projection: test.user, COUNT(__datafusion_extracted_1) AS COUNT(leaf_udf(test.user,Utf8("value"))) + Aggregate: groupBy=[[test.user]], aggr=[[COUNT(__datafusion_extracted_1)]] + Projection: leaf_udf(test.user, Utf8("value")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + "#) + } + + #[test] + fn test_projection_with_filter_combined() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(leaf_udf(col("user"), "status").eq(lit("active")))? + .project(vec![leaf_udf(col("user"), "name")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")) + Filter: leaf_udf(test.user, Utf8("status")) = Utf8("active") + TableScan: test projection=[user] + + ## After Extraction + Projection: leaf_udf(test.user, Utf8("name")) + Projection: test.user + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## After Pushdown + Projection: __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("name")) + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.user, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2 + TableScan: test projection=[user] + + ## Optimized + Projection: __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("name")) + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2 + TableScan: test projection=[user] + "#) + } + + #[test] + fn test_projection_preserves_alias() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![leaf_udf(col("user"), "name").alias("username")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")) AS username + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS username + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: leaf_udf(test.user, Utf8("name")) AS username + TableScan: test projection=[user] + "#) + } + + /// Test: Projection with different field than Filter + /// SELECT id, s['label'] FROM t WHERE s['value'] > 150 + /// Both s['label'] and s['value'] should be in a single extraction projection. + #[test] + fn test_projection_different_field_from_filter() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(leaf_udf(col("user"), "value").gt(lit(150)))? + .project(vec![col("user"), leaf_udf(col("user"), "label")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: test.user, leaf_udf(test.user, Utf8("label")) + Filter: leaf_udf(test.user, Utf8("value")) > Int32(150) + TableScan: test projection=[user] + + ## After Extraction + Projection: test.user, leaf_udf(test.user, Utf8("label")) + Projection: test.user + Filter: __datafusion_extracted_1 > Int32(150) + Projection: leaf_udf(test.user, Utf8("value")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## After Pushdown + Projection: test.user, __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("label")) + Filter: __datafusion_extracted_1 > Int32(150) + Projection: leaf_udf(test.user, Utf8("value")) AS __datafusion_extracted_1, test.user, leaf_udf(test.user, Utf8("label")) AS __datafusion_extracted_2 + TableScan: test projection=[user] + + ## Optimized + (same as after pushdown) + "#) + } + + #[test] + fn test_projection_deduplication() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let field = leaf_udf(col("user"), "name"); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![field.clone(), field.clone().alias("name2")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")), leaf_udf(test.user, Utf8("name")) AS name2 + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("name")), __datafusion_extracted_1 AS name2 + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: leaf_udf(test.user, Utf8("name")), leaf_udf(test.user, Utf8("name")) AS name2 + TableScan: test projection=[user] + "#) + } + + // ========================================================================= + // Additional tests for code coverage + // ========================================================================= + + /// Extractions push through Sort nodes to reach the TableScan. + #[test] + fn test_extract_through_sort() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .sort(vec![col("user").sort(true, true)])? + .project(vec![leaf_udf(col("user"), "name")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")) + Sort: test.user ASC NULLS FIRST + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("name")) + Sort: test.user ASC NULLS FIRST + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + (same as after pushdown) + "#) + } + + /// Extractions push through Limit nodes to reach the TableScan. + #[test] + fn test_extract_through_limit() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .limit(0, Some(10))? + .project(vec![leaf_udf(col("user"), "name")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")) + Limit: skip=0, fetch=10 + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("name")) + Limit: skip=0, fetch=10 + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("name")) + Limit: skip=0, fetch=10 + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1 + TableScan: test projection=[user] + "#) + } + + /// Aliased aggregate functions like count(...).alias("cnt") are handled. + #[test] + fn test_extract_from_aliased_aggregate() -> Result<()> { + use datafusion_expr::test::function_stub::count; + + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("user")], + vec![count(leaf_udf(col("user"), "value")).alias("cnt")], + )? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Aggregate: groupBy=[[test.user]], aggr=[[COUNT(leaf_udf(test.user, Utf8("value"))) AS cnt]] + TableScan: test projection=[user] + + ## After Extraction + Aggregate: groupBy=[[test.user]], aggr=[[COUNT(__datafusion_extracted_1) AS cnt]] + Projection: leaf_udf(test.user, Utf8("value")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + "#) + } + + /// Aggregates with no MoveTowardsLeafNodes expressions return unchanged. + #[test] + fn test_aggregate_no_extraction() -> Result<()> { + use datafusion_expr::test::function_stub::count; + + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("a")], vec![count(col("b"))])? + .build()?; + + assert_stages!(plan, @" + ## Original Plan + Aggregate: groupBy=[[test.a]], aggr=[[COUNT(test.b)]] + TableScan: test projection=[a, b] + + ## After Extraction + (same as original) + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + ") + } + + /// Projections containing extracted expression aliases are skipped (already extracted). + #[test] + fn test_skip_extracted_projection() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![ + leaf_udf(col("user"), "name").alias("__datafusion_extracted_manual"), + col("user"), + ])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_manual, test.user + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + "#) + } + + /// Multiple extractions merge into a single extracted expression projection. + #[test] + fn test_merge_into_existing_extracted_projection() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(leaf_udf(col("user"), "status").eq(lit("active")))? + .filter(leaf_udf(col("user"), "name").is_not_null())? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Filter: leaf_udf(test.user, Utf8("name")) IS NOT NULL + Filter: leaf_udf(test.user, Utf8("status")) = Utf8("active") + TableScan: test projection=[id, user] + + ## After Extraction + Projection: test.id, test.user + Filter: __datafusion_extracted_1 IS NOT NULL + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.id, test.user + Projection: test.id, test.user + Filter: __datafusion_extracted_2 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_2, test.id, test.user + TableScan: test projection=[id, user] + + ## After Pushdown + Projection: test.id, test.user + Filter: __datafusion_extracted_1 IS NOT NULL + Filter: __datafusion_extracted_2 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_2, test.id, test.user, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1 + TableScan: test projection=[id, user] + + ## Optimized + Projection: test.id, test.user + Filter: __datafusion_extracted_1 IS NOT NULL + Projection: test.id, test.user, __datafusion_extracted_1 + Filter: __datafusion_extracted_2 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_2, test.id, test.user, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1 + TableScan: test projection=[id, user] + "#) + } + + /// Extractions push through passthrough projections (columns only). + #[test] + fn test_extract_through_passthrough_projection() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("user")])? + .project(vec![leaf_udf(col("user"), "name")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")) + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("name")) + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: leaf_udf(test.user, Utf8("name")) + TableScan: test projection=[user] + "#) + } + + /// Projections with aliased columns (nothing to extract) return unchanged. + #[test] + fn test_projection_early_return_no_extraction() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").alias("x"), col("b")])? + .build()?; + + assert_stages!(plan, @" + ## Original Plan + Projection: test.a AS x, test.b + TableScan: test projection=[a, b] + + ## After Extraction + (same as original) + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + ") + } + + /// Projections with arithmetic expressions but no MoveTowardsLeafNodes return unchanged. + #[test] + fn test_projection_with_arithmetic_no_extraction() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![(col("a") + col("b")).alias("sum")])? + .build()?; + + assert_stages!(plan, @" + ## Original Plan + Projection: test.a + test.b AS sum + TableScan: test projection=[a, b] + + ## After Extraction + (same as original) + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + ") + } + + /// Aggregate extractions merge into existing extracted projection created by Filter. + #[test] + fn test_aggregate_merge_into_extracted_projection() -> Result<()> { + use datafusion_expr::test::function_stub::count; + + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(leaf_udf(col("user"), "status").eq(lit("active")))? + .aggregate(vec![leaf_udf(col("user"), "name")], vec![count(lit(1))])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Aggregate: groupBy=[[leaf_udf(test.user, Utf8("name"))]], aggr=[[COUNT(Int32(1))]] + Filter: leaf_udf(test.user, Utf8("status")) = Utf8("active") + TableScan: test projection=[user] + + ## After Extraction + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("name")), COUNT(Int32(1)) + Aggregate: groupBy=[[__datafusion_extracted_1]], aggr=[[COUNT(Int32(1))]] + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + Projection: test.user + Filter: __datafusion_extracted_2 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_2, test.user + TableScan: test projection=[user] + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("name")), COUNT(Int32(1)) + Aggregate: groupBy=[[__datafusion_extracted_1]], aggr=[[COUNT(Int32(1))]] + Filter: __datafusion_extracted_2 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_2, test.user, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1 + TableScan: test projection=[user] + + ## Optimized + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("name")), COUNT(Int32(1)) + Aggregate: groupBy=[[__datafusion_extracted_1]], aggr=[[COUNT(Int32(1))]] + Projection: __datafusion_extracted_1 + Filter: __datafusion_extracted_2 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_2, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1 + TableScan: test projection=[user] + "#) + } + + /// Projection containing a MoveTowardsLeafNodes sub-expression above an + /// Aggregate. Aggregate blocks pushdown, so the (None, true) recovery + /// fallback path fires: in-place extraction + recovery projection. + #[test] + fn test_projection_with_leaf_expr_above_aggregate() -> Result<()> { + use datafusion_expr::test::function_stub::count; + + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("user")], vec![count(lit(1))])? + .project(vec![ + leaf_udf(col("user"), "name") + .is_not_null() + .alias("has_name"), + col("COUNT(Int32(1))"), + ])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")) IS NOT NULL AS has_name, COUNT(Int32(1)) + Aggregate: groupBy=[[test.user]], aggr=[[COUNT(Int32(1))]] + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 IS NOT NULL AS has_name, COUNT(Int32(1)) + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user, COUNT(Int32(1)) + Aggregate: groupBy=[[test.user]], aggr=[[COUNT(Int32(1))]] + TableScan: test projection=[user] + + ## Optimized + Projection: leaf_udf(test.user, Utf8("name")) IS NOT NULL AS has_name, COUNT(Int32(1)) + Aggregate: groupBy=[[test.user]], aggr=[[COUNT(Int32(1))]] + TableScan: test projection=[user] + "#) + } + + /// Merging adds new pass-through columns not in the existing extracted projection. + #[test] + fn test_merge_with_new_columns() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(leaf_udf(col("a"), "x").eq(lit(1)))? + .filter(leaf_udf(col("b"), "y").eq(lit(2)))? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Filter: leaf_udf(test.b, Utf8("y")) = Int32(2) + Filter: leaf_udf(test.a, Utf8("x")) = Int32(1) + TableScan: test projection=[a, b, c] + + ## After Extraction + Projection: test.a, test.b, test.c + Filter: __datafusion_extracted_1 = Int32(2) + Projection: leaf_udf(test.b, Utf8("y")) AS __datafusion_extracted_1, test.a, test.b, test.c + Projection: test.a, test.b, test.c + Filter: __datafusion_extracted_2 = Int32(1) + Projection: leaf_udf(test.a, Utf8("x")) AS __datafusion_extracted_2, test.a, test.b, test.c + TableScan: test projection=[a, b, c] + + ## After Pushdown + Projection: test.a, test.b, test.c + Filter: __datafusion_extracted_1 = Int32(2) + Filter: __datafusion_extracted_2 = Int32(1) + Projection: leaf_udf(test.a, Utf8("x")) AS __datafusion_extracted_2, test.a, test.b, test.c, leaf_udf(test.b, Utf8("y")) AS __datafusion_extracted_1 + TableScan: test projection=[a, b, c] + + ## Optimized + Projection: test.a, test.b, test.c + Filter: __datafusion_extracted_1 = Int32(2) + Projection: test.a, test.b, test.c, __datafusion_extracted_1 + Filter: __datafusion_extracted_2 = Int32(1) + Projection: leaf_udf(test.a, Utf8("x")) AS __datafusion_extracted_2, test.a, test.b, test.c, leaf_udf(test.b, Utf8("y")) AS __datafusion_extracted_1 + TableScan: test projection=[a, b, c] + "#) + } + + // ========================================================================= + // Join extraction tests + // ========================================================================= + + /// Create a second table scan with struct field for join tests + fn test_table_scan_with_struct_named(name: &str) -> Result { + use arrow::datatypes::Schema; + let schema = Schema::new(test_table_scan_with_struct_fields()); + datafusion_expr::logical_plan::table_scan(Some(name), &schema, None)?.build() + } + + /// Extraction from equijoin keys (`on` expressions). + #[test] + fn test_extract_from_join_on() -> Result<()> { + use datafusion_expr::JoinType; + + let left = test_table_scan_with_struct()?; + let right = test_table_scan_with_struct_named("right")?; + + let plan = LogicalPlanBuilder::from(left) + .join_with_expr_keys( + right, + JoinType::Inner, + ( + vec![leaf_udf(col("user"), "id")], + vec![leaf_udf(col("user"), "id")], + ), + None, + )? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Inner Join: leaf_udf(test.user, Utf8("id")) = leaf_udf(right.user, Utf8("id")) + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Extraction + Projection: test.id, test.user, right.id, right.user + Inner Join: __datafusion_extracted_1 = __datafusion_extracted_2 + Projection: leaf_udf(test.user, Utf8("id")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("id")) AS __datafusion_extracted_2, right.id, right.user + TableScan: right projection=[id, user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + "#) + } + + /// Extraction from non-equi join filter. + #[test] + fn test_extract_from_join_filter() -> Result<()> { + use datafusion_expr::JoinType; + + let left = test_table_scan_with_struct()?; + let right = test_table_scan_with_struct_named("right")?; + + let plan = LogicalPlanBuilder::from(left) + .join_on( + right, + JoinType::Inner, + vec![ + col("test.user").eq(col("right.user")), + leaf_udf(col("test.user"), "status").eq(lit("active")), + ], + )? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Inner Join: Filter: test.user = right.user AND leaf_udf(test.user, Utf8("status")) = Utf8("active") + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Extraction + Projection: test.id, test.user, right.id, right.user + Inner Join: Filter: test.user = right.user AND __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + "#) + } + + /// Extraction from both left and right sides of a join. + #[test] + fn test_extract_from_join_both_sides() -> Result<()> { + use datafusion_expr::JoinType; + + let left = test_table_scan_with_struct()?; + let right = test_table_scan_with_struct_named("right")?; + + let plan = LogicalPlanBuilder::from(left) + .join_on( + right, + JoinType::Inner, + vec![ + col("test.user").eq(col("right.user")), + leaf_udf(col("test.user"), "status").eq(lit("active")), + leaf_udf(col("right.user"), "role").eq(lit("admin")), + ], + )? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Inner Join: Filter: test.user = right.user AND leaf_udf(test.user, Utf8("status")) = Utf8("active") AND leaf_udf(right.user, Utf8("role")) = Utf8("admin") + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Extraction + Projection: test.id, test.user, right.id, right.user + Inner Join: Filter: test.user = right.user AND __datafusion_extracted_1 = Utf8("active") AND __datafusion_extracted_2 = Utf8("admin") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("role")) AS __datafusion_extracted_2, right.id, right.user + TableScan: right projection=[id, user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + "#) + } + + /// Join with no MoveTowardsLeafNodes expressions returns unchanged. + #[test] + fn test_extract_from_join_no_extraction() -> Result<()> { + use datafusion_expr::JoinType; + + let left = test_table_scan()?; + let right = test_table_scan_with_name("right")?; + + let plan = LogicalPlanBuilder::from(left) + .join(right, JoinType::Inner, (vec!["a"], vec!["a"]), None)? + .build()?; + + assert_stages!(plan, @" + ## Original Plan + Inner Join: test.a = right.a + TableScan: test projection=[a, b, c] + TableScan: right projection=[a, b, c] + + ## After Extraction + (same as original) + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + ") + } + + /// Join followed by filter with extraction. + #[test] + fn test_extract_from_filter_above_join() -> Result<()> { + use datafusion_expr::JoinType; + + let left = test_table_scan_with_struct()?; + let right = test_table_scan_with_struct_named("right")?; + + let plan = LogicalPlanBuilder::from(left) + .join_with_expr_keys( + right, + JoinType::Inner, + ( + vec![leaf_udf(col("user"), "id")], + vec![leaf_udf(col("user"), "id")], + ), + None, + )? + .filter(leaf_udf(col("test.user"), "status").eq(lit("active")))? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Filter: leaf_udf(test.user, Utf8("status")) = Utf8("active") + Inner Join: leaf_udf(test.user, Utf8("id")) = leaf_udf(right.user, Utf8("id")) + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Extraction + Projection: test.id, test.user, right.id, right.user + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user, right.id, right.user + Projection: test.id, test.user, right.id, right.user + Inner Join: __datafusion_extracted_2 = __datafusion_extracted_3 + Projection: leaf_udf(test.user, Utf8("id")) AS __datafusion_extracted_2, test.id, test.user + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("id")) AS __datafusion_extracted_3, right.id, right.user + TableScan: right projection=[id, user] + + ## After Pushdown + Projection: test.id, test.user, right.id, right.user + Filter: __datafusion_extracted_1 = Utf8("active") + Inner Join: __datafusion_extracted_2 = __datafusion_extracted_3 + Projection: leaf_udf(test.user, Utf8("id")) AS __datafusion_extracted_2, test.id, test.user, leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1 + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("id")) AS __datafusion_extracted_3, right.id, right.user + TableScan: right projection=[id, user] + + ## Optimized + Projection: test.id, test.user, right.id, right.user + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: test.id, test.user, __datafusion_extracted_1, right.id, right.user + Inner Join: __datafusion_extracted_2 = __datafusion_extracted_3 + Projection: leaf_udf(test.user, Utf8("id")) AS __datafusion_extracted_2, test.id, test.user, leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1 + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("id")) AS __datafusion_extracted_3, right.id, right.user + TableScan: right projection=[id, user] + "#) + } + + /// Extraction projection (get_field in SELECT) above a Join pushes into + /// the correct input side. + #[test] + fn test_extract_projection_above_join() -> Result<()> { + use datafusion_expr::JoinType; + + let left = test_table_scan_with_struct()?; + let right = test_table_scan_with_struct_named("right")?; + + let plan = LogicalPlanBuilder::from(left) + .join(right, JoinType::Inner, (vec!["id"], vec!["id"]), None)? + .project(vec![ + leaf_udf(col("test.user"), "status"), + leaf_udf(col("right.user"), "role"), + ])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("status")), leaf_udf(right.user, Utf8("role")) + Inner Join: test.id = right.id + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("status")), __datafusion_extracted_2 AS leaf_udf(right.user,Utf8("role")) + Inner Join: test.id = right.id + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("role")) AS __datafusion_extracted_2, right.id, right.user + TableScan: right projection=[id, user] + + ## Optimized + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("status")), __datafusion_extracted_2 AS leaf_udf(right.user,Utf8("role")) + Inner Join: test.id = right.id + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("role")) AS __datafusion_extracted_2, right.id + TableScan: right projection=[id, user] + "#) + } + + /// Join where both sides have same-named columns: a qualified reference + /// to the right side must be routed to the right input, not the left. + #[test] + fn test_extract_from_join_qualified_right_side() -> Result<()> { + use datafusion_expr::JoinType; + + let left = test_table_scan_with_struct()?; + let right = test_table_scan_with_struct_named("right")?; + + // Filter references right.user explicitly — must route to right side + let plan = LogicalPlanBuilder::from(left) + .join_on( + right, + JoinType::Inner, + vec![ + col("test.id").eq(col("right.id")), + leaf_udf(col("right.user"), "status").eq(lit("active")), + ], + )? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Inner Join: Filter: test.id = right.id AND leaf_udf(right.user, Utf8("status")) = Utf8("active") + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Extraction + Projection: test.id, test.user, right.id, right.user + Inner Join: Filter: test.id = right.id AND __datafusion_extracted_1 = Utf8("active") + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("status")) AS __datafusion_extracted_1, right.id, right.user + TableScan: right projection=[id, user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + "#) + } + + /// When both inputs contain the same unqualified column, an unqualified + /// column reference is ambiguous and `find_owning_input` must return + /// `None` rather than always returning 0 (the left side). + #[test] + fn test_find_owning_input_ambiguous_unqualified_column() { + use std::collections::HashSet; + + // Simulate schema_columns output for two sides of a join where both + // have a "user" column — each set contains the qualified and + // unqualified form. + let left_cols: HashSet = [ + Column::new(Some("test"), "user"), + Column::new_unqualified("user"), + ] + .into_iter() + .collect(); + + let right_cols: HashSet = [ + Column::new(Some("right"), "user"), + Column::new_unqualified("user"), + ] + .into_iter() + .collect(); + + let input_column_sets = vec![left_cols, right_cols]; + + // Unqualified "user" matches both sets — must return None (ambiguous) + let unqualified = Expr::Column(Column::new_unqualified("user")); + assert_eq!(find_owning_input(&unqualified, &input_column_sets), None); + + // Qualified "right.user" matches only the right set — must return Some(1) + let qualified_right = Expr::Column(Column::new(Some("right"), "user")); + assert_eq!( + find_owning_input(&qualified_right, &input_column_sets), + Some(1) + ); + + // Qualified "test.user" matches only the left set — must return Some(0) + let qualified_left = Expr::Column(Column::new(Some("test"), "user")); + assert_eq!( + find_owning_input(&qualified_left, &input_column_sets), + Some(0) + ); + } + + /// Two leaf_udf expressions from different sides of a Join in a Filter. + /// Each is routed to its respective input side independently. + #[test] + fn test_extract_from_join_cross_input_expression() -> Result<()> { + let left = test_table_scan_with_struct()?; + let right = test_table_scan_with_struct_named("right")?; + + let plan = LogicalPlanBuilder::from(left) + .join_on( + right, + datafusion_expr::JoinType::Inner, + vec![col("test.id").eq(col("right.id"))], + )? + .filter( + leaf_udf(col("test.user"), "status") + .eq(leaf_udf(col("right.user"), "status")), + )? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Filter: leaf_udf(test.user, Utf8("status")) = leaf_udf(right.user, Utf8("status")) + Inner Join: Filter: test.id = right.id + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Extraction + Projection: test.id, test.user, right.id, right.user + Filter: __datafusion_extracted_1 = __datafusion_extracted_2 + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, leaf_udf(right.user, Utf8("status")) AS __datafusion_extracted_2, test.id, test.user, right.id, right.user + Inner Join: Filter: test.id = right.id + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Pushdown + Projection: test.id, test.user, right.id, right.user + Filter: __datafusion_extracted_1 = __datafusion_extracted_2 + Inner Join: Filter: test.id = right.id + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("status")) AS __datafusion_extracted_2, right.id, right.user + TableScan: right projection=[id, user] + + ## Optimized + (same as after pushdown) + "#) + } + + // ========================================================================= + // Column-rename through intermediate node tests + // ========================================================================= + + /// Projection with leaf expr above Filter above renaming Projection. + #[test] + fn test_extract_through_filter_with_column_rename() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("user").alias("x")])? + .filter(col("x").is_not_null())? + .project(vec![leaf_udf(col("x"), "a")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(x, Utf8("a")) + Filter: x IS NOT NULL + Projection: test.user AS x + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(x,Utf8("a")) + Filter: x IS NOT NULL + Projection: test.user AS x, leaf_udf(test.user, Utf8("a")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: __datafusion_extracted_1 AS leaf_udf(x,Utf8("a")) + Filter: x IS NOT NULL + Projection: test.user AS x, leaf_udf(test.user, Utf8("a")) AS __datafusion_extracted_1 + TableScan: test projection=[user] + "#) + } + + /// Same as above but with a partial extraction (leaf + arithmetic). + #[test] + fn test_extract_partial_through_filter_with_column_rename() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("user").alias("x")])? + .filter(col("x").is_not_null())? + .project(vec![leaf_udf(col("x"), "a").is_not_null()])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(x, Utf8("a")) IS NOT NULL + Filter: x IS NOT NULL + Projection: test.user AS x + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 IS NOT NULL AS leaf_udf(x,Utf8("a")) IS NOT NULL + Filter: x IS NOT NULL + Projection: test.user AS x, leaf_udf(test.user, Utf8("a")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: __datafusion_extracted_1 IS NOT NULL AS leaf_udf(x,Utf8("a")) IS NOT NULL + Filter: x IS NOT NULL + Projection: test.user AS x, leaf_udf(test.user, Utf8("a")) AS __datafusion_extracted_1 + TableScan: test projection=[user] + "#) + } + + /// Tests merge_into_extracted_projection path through a renaming projection. + #[test] + fn test_extract_from_filter_above_renaming_projection() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("user").alias("x")])? + .filter(leaf_udf(col("x"), "a").eq(lit("active")))? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Filter: leaf_udf(x, Utf8("a")) = Utf8("active") + Projection: test.user AS x + TableScan: test projection=[user] + + ## After Extraction + Projection: x + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: test.user AS x, leaf_udf(test.user, Utf8("a")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + Projection: x + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: test.user AS x, leaf_udf(test.user, Utf8("a")) AS __datafusion_extracted_1 + TableScan: test projection=[user] + "#) + } + + // ========================================================================= + // SubqueryAlias extraction tests + // ========================================================================= + + /// Extraction projection pushes through SubqueryAlias. + #[test] + fn test_extract_through_subquery_alias() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .alias("sub")? + .project(vec![leaf_udf(col("sub.user"), "name")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(sub.user, Utf8("name")) + SubqueryAlias: sub + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(sub.user,Utf8("name")) + SubqueryAlias: sub + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: __datafusion_extracted_1 AS leaf_udf(sub.user,Utf8("name")) + SubqueryAlias: sub + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1 + TableScan: test projection=[user] + "#) + } + + /// Extraction projection pushes through SubqueryAlias + Filter. + #[test] + fn test_extract_through_subquery_alias_with_filter() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .alias("sub")? + .filter(leaf_udf(col("sub.user"), "status").eq(lit("active")))? + .project(vec![leaf_udf(col("sub.user"), "name")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(sub.user, Utf8("name")) + Filter: leaf_udf(sub.user, Utf8("status")) = Utf8("active") + SubqueryAlias: sub + TableScan: test projection=[user] + + ## After Extraction + Projection: leaf_udf(sub.user, Utf8("name")) + Projection: sub.user + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(sub.user, Utf8("status")) AS __datafusion_extracted_1, sub.user + SubqueryAlias: sub + TableScan: test projection=[user] + + ## After Pushdown + Projection: __datafusion_extracted_2 AS leaf_udf(sub.user,Utf8("name")) + Filter: __datafusion_extracted_1 = Utf8("active") + SubqueryAlias: sub + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: __datafusion_extracted_2 AS leaf_udf(sub.user,Utf8("name")) + Filter: __datafusion_extracted_1 = Utf8("active") + SubqueryAlias: sub + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2 + TableScan: test projection=[user] + "#) + } + + /// Two layers of SubqueryAlias: extraction pushes through both. + #[test] + fn test_extract_through_nested_subquery_alias() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .alias("inner_sub")? + .alias("outer_sub")? + .project(vec![leaf_udf(col("outer_sub.user"), "name")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(outer_sub.user, Utf8("name")) + SubqueryAlias: outer_sub + SubqueryAlias: inner_sub + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(outer_sub.user,Utf8("name")) + SubqueryAlias: outer_sub + SubqueryAlias: inner_sub + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: __datafusion_extracted_1 AS leaf_udf(outer_sub.user,Utf8("name")) + SubqueryAlias: outer_sub + SubqueryAlias: inner_sub + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1 + TableScan: test projection=[user] + "#) + } + + /// Plain columns through SubqueryAlias -- no extraction needed. + #[test] + fn test_subquery_alias_no_extraction() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .alias("sub")? + .project(vec![col("sub.a"), col("sub.b")])? + .build()?; + + assert_stages!(plan, @" + ## Original Plan + SubqueryAlias: sub + TableScan: test projection=[a, b] + + ## After Extraction + (same as original) + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + ") + } + + /// Two UDFs with the same `name()` but different concrete types should NOT be + /// deduplicated -- they are semantically different expressions that happen to + /// collide on `schema_name()`. + #[test] + fn test_different_udfs_same_schema_name_not_deduplicated() -> Result<()> { + let udf_a = Arc::new(ScalarUDF::new_from_impl( + PlacementTestUDF::new() + .with_placement(ExpressionPlacement::MoveTowardsLeafNodes) + .with_id(1), + )); + let udf_b = Arc::new(ScalarUDF::new_from_impl( + PlacementTestUDF::new() + .with_placement(ExpressionPlacement::MoveTowardsLeafNodes) + .with_id(2), + )); + + let expr_a = Expr::ScalarFunction(ScalarFunction::new_udf( + udf_a, + vec![col("user"), lit("field")], + )); + let expr_b = Expr::ScalarFunction(ScalarFunction::new_udf( + udf_b, + vec![col("user"), lit("field")], + )); + + // Verify preconditions: same schema_name but different Expr + assert_eq!( + expr_a.schema_name().to_string(), + expr_b.schema_name().to_string(), + "Both expressions should have the same schema_name" + ); + assert_ne!( + expr_a, expr_b, + "Expressions should NOT be equal (different UDF instances)" + ); + + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan.clone()) + .filter(expr_a.clone().eq(lit("a")).and(expr_b.clone().eq(lit("b"))))? + .select(vec![ + table_scan + .schema() + .index_of_column_by_name(None, "id") + .unwrap(), + ])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: test.id + Filter: leaf_udf(test.user, Utf8("field")) = Utf8("a") AND leaf_udf(test.user, Utf8("field")) = Utf8("b") + TableScan: test projection=[id, user] + + ## After Extraction + Projection: test.id + Projection: test.id, test.user + Filter: __datafusion_extracted_1 = Utf8("a") AND __datafusion_extracted_2 = Utf8("b") + Projection: leaf_udf(test.user, Utf8("field")) AS __datafusion_extracted_1, leaf_udf(test.user, Utf8("field")) AS __datafusion_extracted_2, test.id, test.user + TableScan: test projection=[id, user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + Projection: test.id + Filter: __datafusion_extracted_1 = Utf8("a") AND __datafusion_extracted_2 = Utf8("b") + Projection: leaf_udf(test.user, Utf8("field")) AS __datafusion_extracted_1, leaf_udf(test.user, Utf8("field")) AS __datafusion_extracted_2, test.id + TableScan: test projection=[id, user] + "#) + } + + // ========================================================================= + // Filter pushdown interaction tests + // ========================================================================= + + /// Extraction pushdown through a filter that already had its own + /// `leaf_udf` extracted. + #[test] + fn test_extraction_pushdown_through_filter_with_extracted_predicate() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(leaf_udf(col("user"), "status").eq(lit("active")))? + .project(vec![col("id"), leaf_udf(col("user"), "name")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: test.id, leaf_udf(test.user, Utf8("name")) + Filter: leaf_udf(test.user, Utf8("status")) = Utf8("active") + TableScan: test projection=[id, user] + + ## After Extraction + Projection: test.id, leaf_udf(test.user, Utf8("name")) + Projection: test.id, test.user + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + + ## After Pushdown + Projection: test.id, __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("name")) + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2 + TableScan: test projection=[id, user] + + ## Optimized + Projection: test.id, __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("name")) + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2 + TableScan: test projection=[id, user] + "#) + } + + /// Same expression in filter predicate and projection output. + #[test] + fn test_extraction_pushdown_same_expr_in_filter_and_projection() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let field_expr = leaf_udf(col("user"), "status"); + let plan = LogicalPlanBuilder::from(table_scan) + .filter(field_expr.clone().gt(lit(5)))? + .project(vec![col("id"), field_expr])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: test.id, leaf_udf(test.user, Utf8("status")) + Filter: leaf_udf(test.user, Utf8("status")) > Int32(5) + TableScan: test projection=[id, user] + + ## After Extraction + Projection: test.id, leaf_udf(test.user, Utf8("status")) + Projection: test.id, test.user + Filter: __datafusion_extracted_1 > Int32(5) + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + + ## After Pushdown + Projection: test.id, __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("status")) + Filter: __datafusion_extracted_1 > Int32(5) + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user, leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_2 + TableScan: test projection=[id, user] + + ## Optimized + Projection: test.id, __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("status")) + Filter: __datafusion_extracted_1 > Int32(5) + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_2 + TableScan: test projection=[id, user] + "#) + } + + /// Left join with a `leaf_udf` filter on the right side AND + /// the projection also selects `leaf_udf` from the right side. + #[test] + fn test_left_join_with_filter_and_projection_extraction() -> Result<()> { + use datafusion_expr::JoinType; + + let left = test_table_scan_with_struct()?; + let right = test_table_scan_with_struct_named("right")?; + + let plan = LogicalPlanBuilder::from(left) + .join_on( + right, + JoinType::Left, + vec![ + col("test.id").eq(col("right.id")), + leaf_udf(col("right.user"), "status").gt(lit(5)), + ], + )? + .project(vec![ + col("test.id"), + leaf_udf(col("test.user"), "name"), + leaf_udf(col("right.user"), "status"), + ])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: test.id, leaf_udf(test.user, Utf8("name")), leaf_udf(right.user, Utf8("status")) + Left Join: Filter: test.id = right.id AND leaf_udf(right.user, Utf8("status")) > Int32(5) + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Extraction + Projection: test.id, leaf_udf(test.user, Utf8("name")), leaf_udf(right.user, Utf8("status")) + Projection: test.id, test.user, right.id, right.user + Left Join: Filter: test.id = right.id AND __datafusion_extracted_1 > Int32(5) + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("status")) AS __datafusion_extracted_1, right.id, right.user + TableScan: right projection=[id, user] + + ## After Pushdown + Projection: test.id, __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("name")), __datafusion_extracted_3 AS leaf_udf(right.user,Utf8("status")) + Left Join: Filter: test.id = right.id AND __datafusion_extracted_1 > Int32(5) + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2, test.id, test.user + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("status")) AS __datafusion_extracted_1, right.id, right.user, leaf_udf(right.user, Utf8("status")) AS __datafusion_extracted_3 + TableScan: right projection=[id, user] + + ## Optimized + Projection: test.id, __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("name")), __datafusion_extracted_3 AS leaf_udf(right.user,Utf8("status")) + Left Join: Filter: test.id = right.id AND __datafusion_extracted_1 > Int32(5) + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2, test.id + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("status")) AS __datafusion_extracted_1, right.id, leaf_udf(right.user, Utf8("status")) AS __datafusion_extracted_3 + TableScan: right projection=[id, user] + "#) + } + + /// Extraction projection pushed through a filter whose predicate + /// references a different extracted expression. + #[test] + fn test_pure_extraction_proj_push_through_filter() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(leaf_udf(col("user"), "status").gt(lit(5)))? + .project(vec![ + col("id"), + leaf_udf(col("user"), "name"), + leaf_udf(col("user"), "status"), + ])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: test.id, leaf_udf(test.user, Utf8("name")), leaf_udf(test.user, Utf8("status")) + Filter: leaf_udf(test.user, Utf8("status")) > Int32(5) + TableScan: test projection=[id, user] + + ## After Extraction + Projection: test.id, leaf_udf(test.user, Utf8("name")), leaf_udf(test.user, Utf8("status")) + Projection: test.id, test.user + Filter: __datafusion_extracted_1 > Int32(5) + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + + ## After Pushdown + Projection: test.id, __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("name")), __datafusion_extracted_3 AS leaf_udf(test.user,Utf8("status")) + Filter: __datafusion_extracted_1 > Int32(5) + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2, leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_3 + TableScan: test projection=[id, user] + + ## Optimized + Projection: test.id, __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("name")), __datafusion_extracted_3 AS leaf_udf(test.user,Utf8("status")) + Filter: __datafusion_extracted_1 > Int32(5) + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2, leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_3 + TableScan: test projection=[id, user] + "#) + } + + /// When an extraction projection's __extracted alias references a column + /// (e.g. `user`) that is NOT a standalone expression in the projection, + /// the merge into the inner projection should still succeed. + #[test] + fn test_merge_extraction_into_projection_with_column_ref_inflation() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + + // Inner projection (simulates a trimmed projection) + let inner = LogicalPlanBuilder::from(table_scan) + .project(vec![col("user"), col("id")])? + .build()?; + + // Outer projection: __extracted alias + id (but NOT user as standalone). + // The alias references `user` internally, inflating columns_needed. + let plan = LogicalPlanBuilder::from(inner) + .project(vec![ + leaf_udf(col("user"), "status") + .alias(format!("{EXTRACTED_EXPR_PREFIX}_1")), + col("id"), + ])? + .build()?; + + // Run only PushDownLeafProjections + let ctx = OptimizerContext::new().with_max_passes(1); + let optimizer = + Optimizer::with_rules(vec![Arc::new(PushDownLeafProjections::new())]); + let result = optimizer.optimize(plan, &ctx, |_, _| {})?; + + // With the fix: merge succeeds → extraction merged into inner projection. + // Without the fix: merge rejected → two separate projections remain. + insta::assert_snapshot!(format!("{result}"), @r#" + Projection: __datafusion_extracted_1, test.id + Projection: test.user, test.id, leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1 + TableScan: test + "#); + + Ok(()) + } +} diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index a1a59cb34887..e61009182409 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -23,7 +23,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] -#![deny(clippy::allow_attributes)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] //! # DataFusion Optimizer @@ -58,6 +57,7 @@ pub mod eliminate_nested_union { } pub mod eliminate_outer_join; pub mod extract_equijoin_predicate; +pub mod extract_leaf_expressions; pub mod filter_null_join_keys; pub mod optimize_projections; pub mod optimize_unions; @@ -66,6 +66,7 @@ pub mod propagate_empty_relation; pub mod push_down_filter; pub mod push_down_limit; pub mod replace_distinct_aggregate; +pub mod rewrite_set_comparison; pub mod scalar_subquery_to_join; pub mod simplify_expressions; pub mod single_distinct_to_groupby; diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 548eadffa242..93df300bb50b 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -268,15 +268,10 @@ fn optimize_projections( Some(projection) => indices.into_mapped_indices(|idx| projection[idx]), None => indices.into_inner(), }; - return TableScan::try_new( - table_name, - source, - Some(projection), - filters, - fetch, - ) - .map(LogicalPlan::TableScan) - .map(Transformed::yes); + let new_scan = + TableScan::try_new(table_name, source, Some(projection), filters, fetch)?; + + return Ok(Transformed::yes(LogicalPlan::TableScan(new_scan))); } // Other node types are handled below _ => {} @@ -530,15 +525,14 @@ fn merge_consecutive_projections(proj: Projection) -> Result 1 - && !is_expr_trivial( - &prev_projection.expr - [prev_projection.schema.index_of_column(col).unwrap()], - ) + && !prev_projection.expr[prev_projection.schema.index_of_column(col).unwrap()] + .placement() + .should_push_to_leaves() }) { // no change return Projection::try_new_with_schema(expr, input, schema).map(Transformed::no); @@ -565,7 +559,19 @@ fn merge_consecutive_projections(proj: Projection) -> Result rewrite_expr(*expr, &prev_projection).map(|result| { result.update_data(|expr| { - Expr::Alias(Alias::new(expr, relation, name).with_metadata(metadata)) + // After substitution, the inner expression may now have the + // same schema_name as the alias (e.g. when an extraction + // alias like `__extracted_1 AS f(x)` is resolved back to + // `f(x)`). Wrapping in a redundant self-alias causes a + // cosmetic `f(x) AS f(x)` due to Display vs schema_name + // formatting differences. Drop the alias when it matches. + if metadata.is_none() && expr.schema_name().to_string() == name { + expr + } else { + Expr::Alias( + Alias::new(expr, relation, name).with_metadata(metadata), + ) + } }) }), e => rewrite_expr(e, &prev_projection), @@ -591,11 +597,6 @@ fn merge_consecutive_projections(proj: Projection) -> Result bool { - matches!(expr, Expr::Column(_) | Expr::Literal(_, _)) -} - /// Rewrites a projection expression using the projection before it (i.e. its input) /// This is a subroutine to the `merge_consecutive_projections` function. /// diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index ededcec0a47c..bdea6a83072c 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -43,6 +43,7 @@ use crate::eliminate_join::EliminateJoin; use crate::eliminate_limit::EliminateLimit; use crate::eliminate_outer_join::EliminateOuterJoin; use crate::extract_equijoin_predicate::ExtractEquijoinPredicate; +use crate::extract_leaf_expressions::{ExtractLeafExpressions, PushDownLeafProjections}; use crate::filter_null_join_keys::FilterNullJoinKeys; use crate::optimize_projections::OptimizeProjections; use crate::optimize_unions::OptimizeUnions; @@ -51,17 +52,18 @@ use crate::propagate_empty_relation::PropagateEmptyRelation; use crate::push_down_filter::PushDownFilter; use crate::push_down_limit::PushDownLimit; use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; +use crate::rewrite_set_comparison::RewriteSetComparison; use crate::scalar_subquery_to_join::ScalarSubqueryToJoin; use crate::simplify_expressions::SimplifyExpressions; use crate::single_distinct_to_groupby::SingleDistinctToGroupBy; use crate::utils::log_plan; -/// `OptimizerRule`s transforms one [`LogicalPlan`] into another which -/// computes the same results, but in a potentially more efficient -/// way. If there are no suitable transformations for the input plan, -/// the optimizer should simply return it unmodified. +/// Transforms one [`LogicalPlan`] into another which computes the same results, +/// but in a potentially more efficient way. /// -/// To change the semantics of a `LogicalPlan`, see [`AnalyzerRule`] +/// See notes on [`Self::rewrite`] for details on how to implement an `OptimizerRule`. +/// +/// To change the semantics of a `LogicalPlan`, see [`AnalyzerRule`]. /// /// Use [`SessionState::add_optimizer_rule`] to register additional /// `OptimizerRule`s. @@ -86,8 +88,40 @@ pub trait OptimizerRule: Debug { true } - /// Try to rewrite `plan` to an optimized form, returning `Transformed::yes` - /// if the plan was rewritten and `Transformed::no` if it was not. + /// Try to rewrite `plan` to an optimized form, returning [`Transformed::yes`] + /// if the plan was rewritten and [`Transformed::no`] if it was not. + /// + /// # Notes for implementations: + /// + /// ## Return the same plan if no changes were made + /// + /// If there are no suitable transformations for the input plan, + /// the optimizer should simply return it unmodified. + /// + /// The optimizer will call `rewrite` several times until a fixed point is + /// reached, so it is important that `rewrite` return [`Transformed::no`] if + /// the output is the same. + /// + /// ## Matching on functions + /// + /// The rule should avoid function-specific transformations, and instead use + /// methods on [`ScalarUDFImpl`] and [`AggregateUDFImpl`]. Specifically, the + /// rule should not check function names as functions can be overridden, and + /// may not have the same semantics as the functions provided with + /// DataFusion. + /// + /// For example, if a rule rewrites a function based on the check + /// `func.name() == "sum"`, it may rewrite the plan incorrectly if the + /// registered `sum` function has different semantics (for example, the + /// `sum` function from the `datafusion-spark` crate). + /// + /// There are still several cases that rely on function name checking in + /// the rules included with DataFusion. Please see [#18643] for more details + /// and to help remove these cases. + /// + /// [`ScalarUDFImpl`]: datafusion_expr::ScalarUDFImpl + /// [`AggregateUDFImpl`]: datafusion_expr::ScalarUDFImpl + /// [#18643]: https://github.com/apache/datafusion/issues/18643 fn rewrite( &self, _plan: LogicalPlan, @@ -100,8 +134,9 @@ pub trait OptimizerRule: Debug { /// Options to control the DataFusion Optimizer. pub trait OptimizerConfig { /// Return the time at which the query execution started. This - /// time is used as the value for now() - fn query_execution_start_time(&self) -> DateTime; + /// time is used as the value for `now()`. If `None`, time-dependent + /// functions like `now()` will not be simplified during optimization. + fn query_execution_start_time(&self) -> Option>; /// Return alias generator used to generate unique aliases for subqueries fn alias_generator(&self) -> &Arc; @@ -118,8 +153,9 @@ pub trait OptimizerConfig { #[derive(Debug)] pub struct OptimizerContext { /// Query execution start time that can be used to rewrite - /// expressions such as `now()` to use a literal value instead - query_execution_start_time: DateTime, + /// expressions such as `now()` to use a literal value instead. + /// If `None`, time-dependent functions will not be simplified. + query_execution_start_time: Option>, /// Alias generator used to generate unique aliases for subqueries alias_generator: Arc, @@ -139,7 +175,7 @@ impl OptimizerContext { /// Create a optimizer config with provided [ConfigOptions]. pub fn new_with_config_options(options: Arc) -> Self { Self { - query_execution_start_time: Utc::now(), + query_execution_start_time: Some(Utc::now()), alias_generator: Arc::new(AliasGenerator::new()), options, } @@ -153,13 +189,19 @@ impl OptimizerContext { self } - /// Specify whether the optimizer should skip rules that produce - /// errors, or fail the query + /// Set the query execution start time pub fn with_query_execution_start_time( mut self, - query_execution_tart_time: DateTime, + query_execution_start_time: DateTime, ) -> Self { - self.query_execution_start_time = query_execution_tart_time; + self.query_execution_start_time = Some(query_execution_start_time); + self + } + + /// Clear the query execution start time. When `None`, time-dependent + /// functions like `now()` will not be simplified during optimization. + pub fn without_query_execution_start_time(mut self) -> Self { + self.query_execution_start_time = None; self } @@ -185,7 +227,7 @@ impl Default for OptimizerContext { } impl OptimizerConfig for OptimizerContext { - fn query_execution_start_time(&self) -> DateTime { + fn query_execution_start_time(&self) -> Option> { self.query_execution_start_time } @@ -226,7 +268,17 @@ impl Default for Optimizer { impl Optimizer { /// Create a new optimizer using the recommended list of rules pub fn new() -> Self { + // NOTEs: + // - The order of rules in this list is important, as it determines the + // order in which they are applied. + // - Adding a new rule here is expensive as it will be applied to all + // queries, and will likely increase the optimization time. Please extend + // existing rules when possible, rather than adding a new rule. + // If you do add a new rule considering having aggressive no-op paths + // (e.g. if the plan doesn't contain any of the nodes you are looking for + // return `Transformed::no`; only works if you control the traversal). let rules: Vec> = vec![ + Arc::new(RewriteSetComparison::new()), Arc::new(OptimizeUnions::new()), Arc::new(SimplifyExpressions::new()), Arc::new(ReplaceDistinctWithAggregate::new()), @@ -250,6 +302,8 @@ impl Optimizer { // that might benefit from the following rules Arc::new(EliminateGroupByConstant::new()), Arc::new(CommonSubexprEliminate::new()), + Arc::new(ExtractLeafExpressions::new()), + Arc::new(PushDownLeafProjections::new()), Arc::new(OptimizeProjections::new()), ]; diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 755ffdbafc86..d9cbe7cea4cd 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -45,6 +45,7 @@ use crate::optimizer::ApplyOrder; use crate::simplify_expressions::simplify_predicates; use crate::utils::{has_all_column_refs, is_restrict_null_predicate}; use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_expr::ExpressionPlacement; /// Optimizer rule for pushing (moving) filter expressions down in a plan so /// they are applied as early as possible. @@ -263,6 +264,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump), Expr::Exists { .. } | Expr::InSubquery(_) + | Expr::SetComparison(_) | Expr::ScalarSubquery(_) | Expr::OuterReferenceColumn(_, _) | Expr::Unnest(_) => { @@ -454,11 +456,11 @@ fn push_down_all_join( } } - // For infer predicates, if they can not push through join, just drop them + // Push predicates inferred from the join expression for predicate in inferred_join_predicates { - if left_preserved && checker.is_left_only(&predicate) { + if checker.is_left_only(&predicate) { left_push.push(predicate); - } else if right_preserved && checker.is_right_only(&predicate) { + } else if checker.is_right_only(&predicate) { right_push.push(predicate); } } @@ -616,7 +618,7 @@ impl InferredPredicates { fn new(join_type: JoinType) -> Self { Self { predicates: vec![], - is_inner_join: matches!(join_type, JoinType::Inner), + is_inner_join: join_type == JoinType::Inner, } } @@ -791,6 +793,13 @@ impl OptimizerRule for PushDownFilter { filter.predicate = new_predicate; } + // If the child has a fetch (limit) or skip (offset), pushing a filter + // below it would change semantics: the limit/offset should apply before + // the filter, not after. + if filter.input.fetch()?.is_some() || filter.input.skip()?.is_some() { + return Ok(Transformed::no(LogicalPlan::Filter(filter))); + } + match Arc::unwrap_or_clone(filter.input) { LogicalPlan::Filter(child_filter) => { let parents_predicates = split_conjunction_owned(filter.predicate); @@ -1294,10 +1303,13 @@ fn rewrite_projection( predicates: Vec, mut projection: Projection, ) -> Result<(Transformed, Option)> { - // A projection is filter-commutable if it do not contain volatile predicates or contain volatile - // predicates that are not used in the filter. However, we should re-writes all predicate expressions. - // collect projection. - let (volatile_map, non_volatile_map): (HashMap<_, _>, HashMap<_, _>) = projection + // Partition projection expressions into non-pushable vs pushable. + // Non-pushable expressions are volatile (must not be duplicated) or + // MoveTowardsLeafNodes (cheap expressions like get_field where re-inlining + // into a filter causes optimizer instability — ExtractLeafExpressions will + // undo the push-down, creating an infinite loop that runs until the + // iteration limit is hit). + let (non_pushable_map, pushable_map): (HashMap<_, _>, HashMap<_, _>) = projection .schema .iter() .zip(projection.expr.iter()) @@ -1307,12 +1319,15 @@ fn rewrite_projection( (qualified_name(qualifier, field.name()), expr) }) - .partition(|(_, value)| value.is_volatile()); + .partition(|(_, value)| { + value.is_volatile() + || value.placement() == ExpressionPlacement::MoveTowardsLeafNodes + }); let mut push_predicates = vec![]; let mut keep_predicates = vec![]; for expr in predicates { - if contain(&expr, &volatile_map) { + if contain(&expr, &non_pushable_map) { keep_predicates.push(expr); } else { push_predicates.push(expr); @@ -1324,7 +1339,7 @@ fn rewrite_projection( // re-write all filters based on this projection // E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1" let new_filter = LogicalPlan::Filter(Filter::try_new( - replace_cols_by_name(expr, &non_volatile_map)?, + replace_cols_by_name(expr, &pushable_map)?, std::mem::take(&mut projection.input), )?); @@ -1335,7 +1350,10 @@ fn rewrite_projection( conjunction(keep_predicates), )) } - None => Ok((Transformed::no(LogicalPlan::Projection(projection)), None)), + None => Ok(( + Transformed::no(LogicalPlan::Projection(projection)), + conjunction(keep_predicates), + )), } } @@ -1445,6 +1463,7 @@ mod tests { use crate::assert_optimized_plan_eq_snapshot; use crate::optimizer::Optimizer; use crate::simplify_expressions::SimplifyExpressions; + use crate::test::udfs::leaf_udf_expr; use crate::test::*; use datafusion_expr::test::function_stub::sum; use insta::assert_snapshot; @@ -2331,7 +2350,7 @@ mod tests { plan, @r" Projection: test.a, test1.d - Cross Join: + Cross Join: Projection: test.a, test.b, test.c TableScan: test, full_filters=[test.a = Int32(1)] Projection: test1.d, test1.e, test1.f @@ -2361,7 +2380,7 @@ mod tests { plan, @r" Projection: test.a, test1.a - Cross Join: + Cross Join: Projection: test.a, test.b, test.c TableScan: test, full_filters=[test.a = Int32(1)] Projection: test1.a, test1.b, test1.c @@ -2720,8 +2739,7 @@ mod tests { ) } - /// post-left-join predicate on a column common to both sides is only pushed to the left side - /// i.e. - not duplicated to the right side + /// post-left-join predicate on a column common to both sides is pushed to both sides #[test] fn filter_using_left_join_on_common() -> Result<()> { let table_scan = test_table_scan()?; @@ -2749,20 +2767,19 @@ mod tests { TableScan: test2 ", ); - // filter sent to left side of the join, not the right + // filter sent to left side of the join and to the right assert_optimized_plan_equal!( plan, @r" Left Join: Using test.a = test2.a TableScan: test, full_filters=[test.a <= Int64(1)] Projection: test2.a - TableScan: test2 + TableScan: test2, full_filters=[test2.a <= Int64(1)] " ) } - /// post-right-join predicate on a column common to both sides is only pushed to the right side - /// i.e. - not duplicated to the left side. + /// post-right-join predicate on a column common to both sides is pushed to both sides #[test] fn filter_using_right_join_on_common() -> Result<()> { let table_scan = test_table_scan()?; @@ -2790,12 +2807,12 @@ mod tests { TableScan: test2 ", ); - // filter sent to right side of join, not duplicated to the left + // filter sent to right side of join, sent to the left as well assert_optimized_plan_equal!( plan, @r" Right Join: Using test.a = test2.a - TableScan: test + TableScan: test, full_filters=[test.a <= Int64(1)] Projection: test2.a TableScan: test2, full_filters=[test2.a <= Int64(1)] " @@ -2977,7 +2994,7 @@ mod tests { Projection: test.a, test.b, test.c TableScan: test Projection: test2.a, test2.b, test2.c - TableScan: test2, full_filters=[test2.c > UInt32(4)] + TableScan: test2, full_filters=[test2.a > UInt32(1), test2.c > UInt32(4)] " ) } @@ -4222,4 +4239,127 @@ mod tests { " ) } + + /// Test that filters are NOT pushed through MoveTowardsLeafNodes projections. + /// These are cheap expressions (like get_field) where re-inlining into a filter + /// has no benefit and causes optimizer instability — ExtractLeafExpressions will + /// undo the push-down, creating an infinite loop that runs until the iteration + /// limit is hit. + #[test] + fn filter_not_pushed_through_move_towards_leaves_projection() -> Result<()> { + let table_scan = test_table_scan()?; + + // Create a projection with a MoveTowardsLeafNodes expression + let proj = LogicalPlanBuilder::from(table_scan) + .project(vec![ + leaf_udf_expr(col("a")).alias("val"), + col("b"), + col("c"), + ])? + .build()?; + + // Put a filter on the MoveTowardsLeafNodes column + let plan = LogicalPlanBuilder::from(proj) + .filter(col("val").gt(lit(150i64)))? + .build()?; + + // Filter should NOT be pushed through — val maps to a MoveTowardsLeafNodes expr + assert_optimized_plan_equal!( + plan, + @r" + Filter: val > Int64(150) + Projection: leaf_udf(test.a) AS val, test.b, test.c + TableScan: test + " + ) + } + + /// Test mixed predicates: Column predicate pushed, MoveTowardsLeafNodes kept. + #[test] + fn filter_mixed_predicates_partial_push() -> Result<()> { + let table_scan = test_table_scan()?; + + // Create a projection with both MoveTowardsLeafNodes and Column expressions + let proj = LogicalPlanBuilder::from(table_scan) + .project(vec![ + leaf_udf_expr(col("a")).alias("val"), + col("b"), + col("c"), + ])? + .build()?; + + // Filter with both: val > 150 (MoveTowardsLeafNodes) AND b > 5 (Column) + let plan = LogicalPlanBuilder::from(proj) + .filter(col("val").gt(lit(150i64)).and(col("b").gt(lit(5i64))))? + .build()?; + + // val > 150 should be kept above, b > 5 should be pushed through + assert_optimized_plan_equal!( + plan, + @r" + Filter: val > Int64(150) + Projection: leaf_udf(test.a) AS val, test.b, test.c + TableScan: test, full_filters=[test.b > Int64(5)] + " + ) + } + + #[test] + fn filter_not_pushed_down_through_table_scan_with_fetch() -> Result<()> { + let scan = test_table_scan()?; + let scan_with_fetch = match scan { + LogicalPlan::TableScan(scan) => LogicalPlan::TableScan(TableScan { + fetch: Some(10), + ..scan + }), + _ => unreachable!(), + }; + let plan = LogicalPlanBuilder::from(scan_with_fetch) + .filter(col("a").gt(lit(10i64)))? + .build()?; + // Filter must NOT be pushed into the table scan when it has a fetch (limit) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.a > Int64(10) + TableScan: test, fetch=10 + " + ) + } + + #[test] + fn filter_push_down_through_sort_without_fetch() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .sort(vec![col("a").sort(true, true)])? + .filter(col("a").gt(lit(10i64)))? + .build()?; + // Filter should be pushed below the sort + assert_optimized_plan_equal!( + plan, + @r" + Sort: test.a ASC NULLS FIRST + TableScan: test, full_filters=[test.a > Int64(10)] + " + ) + } + + #[test] + fn filter_not_pushed_down_through_sort_with_fetch() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .sort_with_limit(vec![col("a").sort(true, true)], Some(5))? + .filter(col("a").gt(lit(10i64)))? + .build()?; + // Filter must NOT be pushed below the sort when it has a fetch (limit), + // because the limit should apply before the filter. + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.a > Int64(10) + Sort: test.a ASC NULLS FIRST, fetch=5 + TableScan: test + " + ) + } } diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 7b302adf22ac..755e192e340d 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -1044,7 +1044,7 @@ mod test { plan, @r" Limit: skip=0, fetch=1000 - Cross Join: + Cross Join: Limit: skip=0, fetch=1000 TableScan: test, fetch=1000 Limit: skip=0, fetch=1000 @@ -1067,7 +1067,7 @@ mod test { plan, @r" Limit: skip=1000, fetch=1000 - Cross Join: + Cross Join: Limit: skip=0, fetch=2000 TableScan: test, fetch=2000 Limit: skip=0, fetch=2000 diff --git a/datafusion/optimizer/src/rewrite_set_comparison.rs b/datafusion/optimizer/src/rewrite_set_comparison.rs new file mode 100644 index 000000000000..c8c35b518743 --- /dev/null +++ b/datafusion/optimizer/src/rewrite_set_comparison.rs @@ -0,0 +1,171 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Optimizer rule rewriting `SetComparison` subqueries (e.g. `= ANY`, +//! `> ALL`) into boolean expressions built from `EXISTS` subqueries +//! that capture SQL three-valued logic. + +use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{Column, DFSchema, ExprSchema, Result, ScalarValue, plan_err}; +use datafusion_expr::expr::{self, Exists, SetComparison, SetQuantifier}; +use datafusion_expr::logical_plan::Subquery; +use datafusion_expr::logical_plan::builder::LogicalPlanBuilder; +use datafusion_expr::{Expr, LogicalPlan, lit}; +use std::sync::Arc; + +use datafusion_expr::utils::merge_schema; + +/// Rewrite `SetComparison` expressions to scalar subqueries that return the +/// correct boolean value (including SQL NULL semantics). After this rule +/// runs, later rules such as `ScalarSubqueryToJoin` can decorrelate and +/// remove the remaining subquery. +#[derive(Debug, Default)] +pub struct RewriteSetComparison; + +impl RewriteSetComparison { + /// Create a new `RewriteSetComparison` optimizer rule. + pub fn new() -> Self { + Self + } + + fn rewrite_plan(&self, plan: LogicalPlan) -> Result> { + let schema = merge_schema(&plan.inputs()); + plan.map_expressions(|expr| { + expr.transform_up(|expr| rewrite_set_comparison(expr, &schema)) + }) + } +} + +impl OptimizerRule for RewriteSetComparison { + fn name(&self) -> &str { + "rewrite_set_comparison" + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + plan.transform_up_with_subqueries(|plan| self.rewrite_plan(plan)) + } +} + +fn rewrite_set_comparison( + expr: Expr, + outer_schema: &DFSchema, +) -> Result> { + match expr { + Expr::SetComparison(set_comparison) => { + let rewritten = build_set_comparison_subquery(set_comparison, outer_schema)?; + Ok(Transformed::yes(rewritten)) + } + _ => Ok(Transformed::no(expr)), + } +} + +fn build_set_comparison_subquery( + set_comparison: SetComparison, + outer_schema: &DFSchema, +) -> Result { + let SetComparison { + expr, + subquery, + op, + quantifier, + } = set_comparison; + + let left_expr = to_outer_reference(*expr, outer_schema)?; + let subquery_schema = subquery.subquery.schema(); + if subquery_schema.fields().is_empty() { + return plan_err!("single expression required."); + } + // avoid `head_output_expr` for aggr/window plan, it will gives group-by expr if exists + let right_expr = Expr::Column(Column::from(subquery_schema.qualified_field(0))); + + let comparison = Expr::BinaryExpr(expr::BinaryExpr::new( + Box::new(left_expr), + op, + Box::new(right_expr), + )); + + let true_exists = + exists_subquery(&subquery, Expr::IsTrue(Box::new(comparison.clone())))?; + let null_exists = + exists_subquery(&subquery, Expr::IsNull(Box::new(comparison.clone())))?; + + let result_expr = match quantifier { + SetQuantifier::Any => Expr::Case(expr::Case { + expr: None, + when_then_expr: vec![ + (Box::new(true_exists), Box::new(lit(true))), + ( + Box::new(null_exists), + Box::new(Expr::Literal(ScalarValue::Boolean(None), None)), + ), + ], + else_expr: Some(Box::new(lit(false))), + }), + SetQuantifier::All => { + let false_exists = + exists_subquery(&subquery, Expr::IsFalse(Box::new(comparison.clone())))?; + Expr::Case(expr::Case { + expr: None, + when_then_expr: vec![ + (Box::new(false_exists), Box::new(lit(false))), + ( + Box::new(null_exists), + Box::new(Expr::Literal(ScalarValue::Boolean(None), None)), + ), + ], + else_expr: Some(Box::new(lit(true))), + }) + } + }; + + Ok(result_expr) +} + +fn exists_subquery(subquery: &Subquery, filter: Expr) -> Result { + let plan = LogicalPlanBuilder::from(subquery.subquery.as_ref().clone()) + .filter(filter)? + .build()?; + let outer_ref_columns = plan.all_out_ref_exprs(); + Ok(Expr::Exists(Exists { + subquery: Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + spans: subquery.spans.clone(), + }, + negated: false, + })) +} + +fn to_outer_reference(expr: Expr, outer_schema: &DFSchema) -> Result { + expr.transform_up(|expr| match expr { + Expr::Column(col) => { + let field = outer_schema.field_from_column(&col)?; + Ok(Transformed::yes(Expr::OuterReferenceColumn( + Arc::clone(field), + col, + ))) + } + Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)), + _ => Ok(Transformed::no(expr)), + }) + .map(|t| t.data) +} diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 01de44cee1f6..c6644e008645 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -18,7 +18,7 @@ //! Expression simplification API use arrow::{ - array::{AsArray, new_null_array}, + array::{Array, AsArray, new_null_array}, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; @@ -27,6 +27,8 @@ use std::collections::HashSet; use std::ops::Not; use std::sync::Arc; +use datafusion_common::config::ConfigOptions; +use datafusion_common::nested_struct::has_one_of_more_common_fields; use datafusion_common::{ DFSchema, DataFusionError, Result, ScalarValue, exec_datafusion_err, internal_err, }; @@ -37,8 +39,8 @@ use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; use datafusion_expr::{ - BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, and, - binary::BinaryTypeCoercer, lit, or, + BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Like, Operator, Volatility, + and, binary::BinaryTypeCoercer, lit, or, preimage::PreimageResult, }; use datafusion_expr::{Cast, TryCast, simplify::ExprSimplifyResult}; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; @@ -50,14 +52,17 @@ use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionP use super::inlist_simplifier::ShortenInListSimplifier; use super::utils::*; -use crate::analyzer::type_coercion::TypeCoercionRewriter; -use crate::simplify_expressions::SimplifyInfo; +use crate::simplify_expressions::SimplifyContext; use crate::simplify_expressions::regex::simplify_regex_expr; use crate::simplify_expressions::unwrap_cast::{ is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary, is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist, unwrap_cast_in_comparison_for_binary, }; +use crate::{ + analyzer::type_coercion::TypeCoercionRewriter, + simplify_expressions::udf_preimage::rewrite_with_preimage, +}; use datafusion_expr::expr_rewriter::rewrite_with_guarantees_map; use datafusion_expr_common::casts::try_cast_literal_to_type; use indexmap::IndexSet; @@ -72,7 +77,6 @@ use regex::Regex; /// ``` /// use arrow::datatypes::{DataType, Field, Schema}; /// use datafusion_common::{DataFusionError, ToDFSchema}; -/// use datafusion_expr::execution_props::ExecutionProps; /// use datafusion_expr::simplify::SimplifyContext; /// use datafusion_expr::{col, lit}; /// use datafusion_optimizer::simplify_expressions::ExprSimplifier; @@ -83,8 +87,7 @@ use regex::Regex; /// .unwrap(); /// /// // Create the simplifier -/// let props = ExecutionProps::new(); -/// let context = SimplifyContext::new(&props).with_schema(schema); +/// let context = SimplifyContext::default().with_schema(schema); /// let simplifier = ExprSimplifier::new(context); /// /// // Use the simplifier @@ -96,8 +99,8 @@ use regex::Regex; /// let simplified = simplifier.simplify(expr).unwrap(); /// assert_eq!(simplified, col("b").lt(lit(2))); /// ``` -pub struct ExprSimplifier { - info: S, +pub struct ExprSimplifier { + info: SimplifyContext, /// Guarantees about the values of columns. This is provided by the user /// in [ExprSimplifier::with_guarantees()]. guarantees: Vec<(Expr, NullableInterval)>, @@ -111,13 +114,12 @@ pub struct ExprSimplifier { pub const THRESHOLD_INLINE_INLIST: usize = 3; pub const DEFAULT_MAX_SIMPLIFIER_CYCLES: u32 = 3; -impl ExprSimplifier { - /// Create a new `ExprSimplifier` with the given `info` such as an - /// instance of [`SimplifyContext`]. See - /// [`simplify`](Self::simplify) for an example. +impl ExprSimplifier { + /// Create a new `ExprSimplifier` with the given [`SimplifyContext`]. + /// See [`simplify`](Self::simplify) for an example. /// /// [`SimplifyContext`]: datafusion_expr::simplify::SimplifyContext - pub fn new(info: S) -> Self { + pub fn new(info: SimplifyContext) -> Self { Self { info, guarantees: vec![], @@ -142,40 +144,21 @@ impl ExprSimplifier { /// `b > 2` /// /// ``` - /// use arrow::datatypes::DataType; - /// use datafusion_common::DFSchema; + /// use arrow::datatypes::{DataType, Field, Schema}; + /// use datafusion_common::{DFSchema, ToDFSchema}; /// use datafusion_common::Result; - /// use datafusion_expr::execution_props::ExecutionProps; /// use datafusion_expr::simplify::SimplifyContext; - /// use datafusion_expr::simplify::SimplifyInfo; /// use datafusion_expr::{col, lit, Expr}; /// use datafusion_optimizer::simplify_expressions::ExprSimplifier; /// use std::sync::Arc; /// - /// /// Simple implementation that provides `Simplifier` the information it needs - /// /// See SimplifyContext for a structure that does this. - /// #[derive(Default)] - /// struct Info { - /// execution_props: ExecutionProps, - /// }; - /// - /// impl SimplifyInfo for Info { - /// fn is_boolean_type(&self, expr: &Expr) -> Result { - /// Ok(false) - /// } - /// fn nullable(&self, expr: &Expr) -> Result { - /// Ok(true) - /// } - /// fn execution_props(&self) -> &ExecutionProps { - /// &self.execution_props - /// } - /// fn get_data_type(&self, expr: &Expr) -> Result { - /// Ok(DataType::Int32) - /// } - /// } - /// + /// // Create a schema and SimplifyContext + /// let schema = Schema::new(vec![Field::new("b", DataType::Int32, true)]) + /// .to_dfschema_ref() + /// .unwrap(); /// // Create the simplifier - /// let simplifier = ExprSimplifier::new(Info::default()); + /// let context = SimplifyContext::default().with_schema(schema); + /// let simplifier = ExprSimplifier::new(context); /// /// // b < 2 /// let b_lt_2 = col("b").gt(lit(2)); @@ -225,7 +208,8 @@ impl ExprSimplifier { mut expr: Expr, ) -> Result<(Transformed, u32)> { let mut simplifier = Simplifier::new(&self.info); - let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?; + let config_options = Some(Arc::clone(self.info.config_options())); + let mut const_evaluator = ConstEvaluator::try_new(config_options)?; let mut shorten_in_list_simplifier = ShortenInListSimplifier::new(); let guarantees_map: HashMap<&Expr, &NullableInterval> = self.guarantees.iter().map(|(k, v)| (k, v)).collect(); @@ -287,7 +271,6 @@ impl ExprSimplifier { /// ```rust /// use arrow::datatypes::{DataType, Field, Schema}; /// use datafusion_common::{Result, ScalarValue, ToDFSchema}; - /// use datafusion_expr::execution_props::ExecutionProps; /// use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; /// use datafusion_expr::simplify::SimplifyContext; /// use datafusion_expr::{col, lit, Expr}; @@ -302,8 +285,7 @@ impl ExprSimplifier { /// .unwrap(); /// /// // Create the simplifier - /// let props = ExecutionProps::new(); - /// let context = SimplifyContext::new(&props).with_schema(schema); + /// let context = SimplifyContext::default().with_schema(schema); /// /// // Expression: (x >= 3) AND (y + 2 < 10) AND (z > 5) /// let expr_x = col("x").gt_eq(lit(3_i64)); @@ -349,7 +331,6 @@ impl ExprSimplifier { /// ```rust /// use arrow::datatypes::{DataType, Field, Schema}; /// use datafusion_common::{Result, ScalarValue, ToDFSchema}; - /// use datafusion_expr::execution_props::ExecutionProps; /// use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; /// use datafusion_expr::simplify::SimplifyContext; /// use datafusion_expr::{col, lit, Expr}; @@ -364,8 +345,7 @@ impl ExprSimplifier { /// .unwrap(); /// /// // Create the simplifier - /// let props = ExecutionProps::new(); - /// let context = SimplifyContext::new(&props).with_schema(schema); + /// let context = SimplifyContext::default().with_schema(schema); /// let simplifier = ExprSimplifier::new(context); /// /// // Expression: a = c AND 1 = b @@ -410,7 +390,6 @@ impl ExprSimplifier { /// use arrow::datatypes::{DataType, Field, Schema}; /// use datafusion_expr::{col, lit, Expr}; /// use datafusion_common::{Result, ScalarValue, ToDFSchema}; - /// use datafusion_expr::execution_props::ExecutionProps; /// use datafusion_expr::simplify::SimplifyContext; /// use datafusion_optimizer::simplify_expressions::ExprSimplifier; /// @@ -420,9 +399,7 @@ impl ExprSimplifier { /// .to_dfschema_ref().unwrap(); /// /// // Create the simplifier - /// let props = ExecutionProps::new(); - /// let context = SimplifyContext::new(&props) - /// .with_schema(schema); + /// let context = SimplifyContext::default().with_schema(schema); /// let simplifier = ExprSimplifier::new(context); /// /// // Expression: a IS NOT NULL @@ -500,7 +477,7 @@ impl TreeNodeRewriter for Canonicalizer { /// /// Note it does not handle algebraic rewrites such as `(a or false)` /// --> `a`, which is handled by [`Simplifier`] -struct ConstEvaluator<'a> { +struct ConstEvaluator { /// `can_evaluate` is used during the depth-first-search of the /// `Expr` tree to track if any siblings (or their descendants) were /// non evaluatable (e.g. had a column reference or volatile @@ -514,8 +491,13 @@ struct ConstEvaluator<'a> { /// means there were no non evaluatable siblings (or their /// descendants) so this `Expr` can be evaluated can_evaluate: Vec, - - execution_props: &'a ExecutionProps, + /// Execution properties needed to call [`create_physical_expr`]. + /// `ConstEvaluator` only evaluates expressions without column references + /// (i.e. constant expressions) and doesn't use the variable binding features + /// of `ExecutionProps` (we explicitly filter out [`Expr::ScalarVariable`]). + /// The `config_options` are passed from the session to allow scalar functions + /// to access configuration like timezone. + execution_props: ExecutionProps, input_schema: DFSchema, input_batch: RecordBatch, } @@ -530,7 +512,7 @@ enum ConstSimplifyResult { SimplifyRuntimeError(DataFusionError, Expr), } -impl TreeNodeRewriter for ConstEvaluator<'_> { +impl TreeNodeRewriter for ConstEvaluator { type Node = Expr; fn f_down(&mut self, expr: Expr) -> Result> { @@ -593,11 +575,17 @@ impl TreeNodeRewriter for ConstEvaluator<'_> { } } -impl<'a> ConstEvaluator<'a> { - /// Create a new `ConstantEvaluator`. Session constants (such as - /// the time for `now()` are taken from the passed - /// `execution_props`. - pub fn try_new(execution_props: &'a ExecutionProps) -> Result { +impl ConstEvaluator { + /// Create a new `ConstantEvaluator`. + /// + /// Note: `ConstEvaluator` filters out expressions with scalar variables + /// (like `$var`) and volatile functions, so it creates its own default + /// `ExecutionProps` internally. The filtered expressions will be evaluated + /// at runtime where proper variable bindings are available. + /// + /// The `config_options` parameter is used to pass session configuration + /// (like timezone) to scalar functions during constant evaluation. + pub fn try_new(config_options: Option>) -> Result { // The dummy column name is unused and doesn't matter as only // expressions without column references can be evaluated static DUMMY_COL_NAME: &str = "."; @@ -611,6 +599,9 @@ impl<'a> ConstEvaluator<'a> { let col = new_null_array(&DataType::Null, 1); let input_batch = RecordBatch::try_new(schema, vec![col])?; + let mut execution_props = ExecutionProps::new(); + execution_props.config_options = config_options; + Ok(Self { can_evaluate: vec![], execution_props, @@ -646,6 +637,7 @@ impl<'a> ConstEvaluator<'a> { | Expr::OuterReferenceColumn(_, _) | Expr::Exists { .. } | Expr::InSubquery(_) + | Expr::SetComparison(_) | Expr::ScalarSubquery(_) | Expr::WindowFunction { .. } | Expr::GroupingSet(_) @@ -654,6 +646,35 @@ impl<'a> ConstEvaluator<'a> { Expr::ScalarFunction(ScalarFunction { func, .. }) => { Self::volatility_ok(func.signature().volatility) } + Expr::Cast(Cast { expr, data_type }) + | Expr::TryCast(TryCast { expr, data_type }) => { + if let ( + Ok(DataType::Struct(source_fields)), + DataType::Struct(target_fields), + ) = (expr.get_type(&DFSchema::empty()), data_type) + { + // Don't const-fold struct casts with different field counts + if source_fields.len() != target_fields.len() { + return false; + } + + // Skip const-folding when there is no field name overlap + if !has_one_of_more_common_fields(&source_fields, target_fields) { + return false; + } + + // Don't const-fold struct casts with empty (0-row) literals + // The simplifier uses a 1-row input batch, which causes dimension mismatches + // when evaluating 0-row struct literals + if let Expr::Literal(ScalarValue::Struct(struct_array), _) = + expr.as_ref() + && struct_array.len() == 0 + { + return false; + } + } + true + } Expr::Literal(_, _) | Expr::Alias(..) | Expr::Unnest(_) @@ -672,8 +693,6 @@ impl<'a> ConstEvaluator<'a> { | Expr::Like { .. } | Expr::SimilarTo { .. } | Expr::Case(_) - | Expr::Cast { .. } - | Expr::TryCast { .. } | Expr::InList { .. } => true, } } @@ -684,11 +703,14 @@ impl<'a> ConstEvaluator<'a> { return ConstSimplifyResult::NotSimplified(s, m); } - let phys_expr = - match create_physical_expr(&expr, &self.input_schema, self.execution_props) { - Ok(e) => e, - Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr), - }; + let phys_expr = match create_physical_expr( + &expr, + &self.input_schema, + &self.execution_props, + ) { + Ok(e) => e, + Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr), + }; let metadata = phys_expr .return_field(self.input_batch.schema_ref()) .ok() @@ -745,17 +767,17 @@ impl<'a> ConstEvaluator<'a> { /// * `false = true` and `true = false` to `false` /// * `!!expr` to `expr` /// * `expr = null` and `expr != null` to `null` -struct Simplifier<'a, S> { - info: &'a S, +struct Simplifier<'a> { + info: &'a SimplifyContext, } -impl<'a, S> Simplifier<'a, S> { - pub fn new(info: &'a S) -> Self { +impl<'a> Simplifier<'a> { + pub fn new(info: &'a SimplifyContext) -> Self { Self { info } } } -impl TreeNodeRewriter for Simplifier<'_, S> { +impl TreeNodeRewriter for Simplifier<'_> { type Node = Expr; /// rewrite the expression simplifying any constant expressions @@ -1055,6 +1077,22 @@ impl TreeNodeRewriter for Simplifier<'_, S> { ); } } + // A = L1 AND A != L2 --> A = L1 (when L1 != L2) + Expr::BinaryExpr(BinaryExpr { + left, + op: And, + right, + }) if is_eq_and_ne_with_different_literal(&left, &right) => { + Transformed::yes(*left) + } + // A != L2 AND A = L1 --> A = L1 (when L1 != L2) + Expr::BinaryExpr(BinaryExpr { + left, + op: And, + right, + }) if is_eq_and_ne_with_different_literal(&right, &left) => { + Transformed::yes(*right) + } // // Rules for Multiply @@ -1962,12 +2000,132 @@ impl TreeNodeRewriter for Simplifier<'_, S> { })) } + // ======================================= + // preimage_in_comparison + // ======================================= + // + // For case: + // date_part('YEAR', expr) op literal + // + // For details see datafusion_expr::ScalarUDFImpl::preimage + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + use datafusion_expr::Operator::*; + let is_preimage_op = matches!( + op, + Eq | NotEq + | Lt + | LtEq + | Gt + | GtEq + | IsDistinctFrom + | IsNotDistinctFrom + ); + if !is_preimage_op || is_null(&right) { + return Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { + left, + op, + right, + }))); + } + + if let PreimageResult::Range { interval, expr } = + get_preimage(left.as_ref(), right.as_ref(), info)? + { + rewrite_with_preimage(*interval, op, expr)? + } else if let Some(swapped) = op.swap() { + if let PreimageResult::Range { interval, expr } = + get_preimage(right.as_ref(), left.as_ref(), info)? + { + rewrite_with_preimage(*interval, swapped, expr)? + } else { + Transformed::no(Expr::BinaryExpr(BinaryExpr { left, op, right })) + } + } else { + Transformed::no(Expr::BinaryExpr(BinaryExpr { left, op, right })) + } + } + // For case: + // date_part('YEAR', expr) IN (literal1, literal2, ...) + Expr::InList(InList { + expr, + list, + negated, + }) => { + if list.len() > THRESHOLD_INLINE_INLIST || list.iter().any(is_null) { + return Ok(Transformed::no(Expr::InList(InList { + expr, + list, + negated, + }))); + } + + let (op, combiner): (Operator, fn(Expr, Expr) -> Expr) = + if negated { (NotEq, and) } else { (Eq, or) }; + + let mut rewritten: Option = None; + for item in &list { + let PreimageResult::Range { interval, expr } = + get_preimage(expr.as_ref(), item, info)? + else { + return Ok(Transformed::no(Expr::InList(InList { + expr, + list, + negated, + }))); + }; + + let range_expr = rewrite_with_preimage(*interval, op, expr)?.data; + rewritten = Some(match rewritten { + None => range_expr, + Some(acc) => combiner(acc, range_expr), + }); + } + + if let Some(rewritten) = rewritten { + Transformed::yes(rewritten) + } else { + Transformed::no(Expr::InList(InList { + expr, + list, + negated, + })) + } + } + // no additional rewrites possible expr => Transformed::no(expr), }) } } +fn get_preimage( + left_expr: &Expr, + right_expr: &Expr, + info: &SimplifyContext, +) -> Result { + let Expr::ScalarFunction(ScalarFunction { func, args }) = left_expr else { + return Ok(PreimageResult::None); + }; + if !is_literal_or_literal_cast(right_expr) { + return Ok(PreimageResult::None); + } + if func.signature().volatility != Volatility::Immutable { + return Ok(PreimageResult::None); + } + func.preimage(args, right_expr, info) +} + +fn is_literal_or_literal_cast(expr: &Expr) -> bool { + match expr { + Expr::Literal(_, _) => true, + Expr::Cast(Cast { expr, .. }) => matches!(expr.as_ref(), Expr::Literal(_, _)), + Expr::TryCast(TryCast { expr, .. }) => { + matches!(expr.as_ref(), Expr::Literal(_, _)) + } + _ => false, + } +} + fn as_string_scalar(expr: &Expr) -> Option<(DataType, &Option)> { match expr { Expr::Literal(ScalarValue::Utf8(s), _) => Some((DataType::Utf8, s)), @@ -2117,7 +2275,7 @@ fn inlist_except(mut l1: InList, l2: &InList) -> Result { } /// Returns expression testing a boolean `expr` for being exactly `true` (not `false` or NULL). -fn is_exactly_true(expr: Expr, info: &impl SimplifyInfo) -> Result { +fn is_exactly_true(expr: Expr, info: &SimplifyContext) -> Result { if !info.nullable(&expr)? { Ok(expr) } else { @@ -2133,8 +2291,8 @@ fn is_exactly_true(expr: Expr, info: &impl SimplifyInfo) -> Result { // A / 1 -> A // // Move this function body out of the large match branch avoid stack overflow -fn simplify_right_is_one_case( - info: &S, +fn simplify_right_is_one_case( + info: &SimplifyContext, left: Box, op: &Operator, right: &Expr, @@ -2160,7 +2318,10 @@ mod tests { use super::*; use crate::simplify_expressions::SimplifyContext; use crate::test::test_table_scan_with_name; - use arrow::datatypes::FieldRef; + use arrow::{ + array::{Int32Array, StructArray}, + datatypes::{FieldRef, Fields}, + }; use datafusion_common::{DFSchemaRef, ToDFSchema, assert_contains}; use datafusion_expr::{ expr::WindowFunction, @@ -2187,9 +2348,8 @@ mod tests { // ------------------------------ #[test] fn api_basic() { - let props = ExecutionProps::new(); let simplifier = - ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema())); + ExprSimplifier::new(SimplifyContext::default().with_schema(test_schema())); let expr = lit(1) + lit(2); let expected = lit(3); @@ -2199,9 +2359,8 @@ mod tests { #[test] fn basic_coercion() { let schema = test_schema(); - let props = ExecutionProps::new(); let simplifier = ExprSimplifier::new( - SimplifyContext::new(&props).with_schema(Arc::clone(&schema)), + SimplifyContext::default().with_schema(Arc::clone(&schema)), ); // Note expr type is int32 (not int64) @@ -2229,9 +2388,8 @@ mod tests { #[test] fn simplify_and_constant_prop() { - let props = ExecutionProps::new(); let simplifier = - ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema())); + ExprSimplifier::new(SimplifyContext::default().with_schema(test_schema())); // should be able to simplify to false // (i * (1 - 2)) > 0 @@ -2242,9 +2400,8 @@ mod tests { #[test] fn simplify_and_constant_prop_with_case() { - let props = ExecutionProps::new(); let simplifier = - ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema())); + ExprSimplifier::new(SimplifyContext::default().with_schema(test_schema())); // CASE // WHEN i>5 AND false THEN i > 5 @@ -2412,6 +2569,27 @@ mod tests { assert_eq!(simplify(expr_b), expected); } + #[test] + fn test_simplify_eq_and_neq_with_different_literals() { + // A = 1 AND A != 0 --> A = 1 (when 1 != 0) + let expr = col("c2").eq(lit(1)).and(col("c2").not_eq(lit(0))); + let expected = col("c2").eq(lit(1)); + assert_eq!(simplify(expr), expected); + + // A != 0 AND A = 1 --> A = 1 (when 1 != 0) + let expr = col("c2").not_eq(lit(0)).and(col("c2").eq(lit(1))); + let expected = col("c2").eq(lit(1)); + assert_eq!(simplify(expr), expected); + + // Should NOT simplify when literals are the same (A = 1 AND A != 1) + // This is a contradiction but handled by other rules + let expr = col("c2").eq(lit(1)).and(col("c2").not_eq(lit(1))); + // Should not be simplified by this rule (left unchanged or handled elsewhere) + let result = simplify(expr.clone()); + // The expression should not have been simplified + assert_eq!(result, expr); + } + #[test] fn test_simplify_multiply_by_one() { let expr_a = col("c2") * lit(1); @@ -3358,18 +3536,15 @@ mod tests { fn try_simplify(expr: Expr) -> Result { let schema = expr_test_schema(); - let execution_props = ExecutionProps::new(); - let simplifier = ExprSimplifier::new( - SimplifyContext::new(&execution_props).with_schema(schema), - ); + let simplifier = + ExprSimplifier::new(SimplifyContext::default().with_schema(schema)); simplifier.simplify(expr) } fn coerce(expr: Expr) -> Expr { let schema = expr_test_schema(); - let execution_props = ExecutionProps::new(); let simplifier = ExprSimplifier::new( - SimplifyContext::new(&execution_props).with_schema(Arc::clone(&schema)), + SimplifyContext::default().with_schema(Arc::clone(&schema)), ); simplifier.coerce(expr, schema.as_ref()).unwrap() } @@ -3380,10 +3555,8 @@ mod tests { fn try_simplify_with_cycle_count(expr: Expr) -> Result<(Expr, u32)> { let schema = expr_test_schema(); - let execution_props = ExecutionProps::new(); - let simplifier = ExprSimplifier::new( - SimplifyContext::new(&execution_props).with_schema(schema), - ); + let simplifier = + ExprSimplifier::new(SimplifyContext::default().with_schema(schema)); let (expr, count) = simplifier.simplify_with_cycle_count_transformed(expr)?; Ok((expr.data, count)) } @@ -3397,11 +3570,9 @@ mod tests { guarantees: Vec<(Expr, NullableInterval)>, ) -> Expr { let schema = expr_test_schema(); - let execution_props = ExecutionProps::new(); - let simplifier = ExprSimplifier::new( - SimplifyContext::new(&execution_props).with_schema(schema), - ) - .with_guarantees(guarantees); + let simplifier = + ExprSimplifier::new(SimplifyContext::default().with_schema(schema)) + .with_guarantees(guarantees); simplifier.simplify(expr).unwrap() } @@ -4303,8 +4474,7 @@ mod tests { fn just_simplifier_simplify_null_in_empty_inlist() { let simplify = |expr: Expr| -> Expr { let schema = expr_test_schema(); - let execution_props = ExecutionProps::new(); - let info = SimplifyContext::new(&execution_props).with_schema(schema); + let info = SimplifyContext::default().with_schema(schema); let simplifier = &mut Simplifier::new(&info); expr.rewrite(simplifier) .expect("Failed to simplify expression") @@ -4670,10 +4840,9 @@ mod tests { #[test] fn simplify_common_factor_conjunction_in_disjunction() { - let props = ExecutionProps::new(); let schema = boolean_test_schema(); let simplifier = - ExprSimplifier::new(SimplifyContext::new(&props).with_schema(schema)); + ExprSimplifier::new(SimplifyContext::default().with_schema(schema)); let a = || col("A"); let b = || col("B"); @@ -5003,9 +5172,8 @@ mod tests { // The simplification should now fail with an error at plan time let schema = test_schema(); - let props = ExecutionProps::new(); let simplifier = - ExprSimplifier::new(SimplifyContext::new(&props).with_schema(schema)); + ExprSimplifier::new(SimplifyContext::default().with_schema(schema)); let result = simplifier.simplify(expr); assert!(result.is_err(), "Expected error for invalid cast"); let err_msg = result.unwrap_err().to_string(); @@ -5019,4 +5187,156 @@ mod tests { else_expr: None, }) } + + // -------------------------------- + // --- Struct Cast Tests ----- + // -------------------------------- + + /// Helper to create a `Struct` literal cast expression from `source_fields` and `target_fields`. + fn make_struct_cast_expr(source_fields: Fields, target_fields: Fields) -> Expr { + // Create 1-row struct array (not 0-row) so it can be evaluated by simplifier + let arrays: Vec> = vec![ + Arc::new(Int32Array::from(vec![Some(1)])), + Arc::new(Int32Array::from(vec![Some(2)])), + ]; + let struct_array = StructArray::try_new(source_fields, arrays, None).unwrap(); + + Expr::Cast(Cast::new( + Box::new(Expr::Literal( + ScalarValue::Struct(Arc::new(struct_array)), + None, + )), + DataType::Struct(target_fields), + )) + } + + #[test] + fn test_struct_cast_different_field_counts_not_foldable() { + // Test that struct casts with different field counts are NOT marked as foldable + // When field counts differ, const-folding should not be attempted + + let source_fields = Fields::from(vec![ + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::new(Field::new("b", DataType::Int32, true)), + ]); + + let target_fields = Fields::from(vec![ + Arc::new(Field::new("x", DataType::Int32, true)), + Arc::new(Field::new("y", DataType::Int32, true)), + Arc::new(Field::new("z", DataType::Int32, true)), + ]); + + let expr = make_struct_cast_expr(source_fields, target_fields); + + let simplifier = + ExprSimplifier::new(SimplifyContext::default().with_schema(test_schema())); + + // The cast should remain unchanged since field counts differ + let result = simplifier.simplify(expr.clone()).unwrap(); + // Ensure const-folding was not attempted (the expression remains exactly the same) + assert_eq!( + result, expr, + "Struct cast with different field counts should remain unchanged (no const-folding)" + ); + } + + #[test] + fn test_struct_cast_same_field_count_foldable() { + // Test that struct casts with same field counts can be considered for const-folding + + let source_fields = Fields::from(vec![ + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::new(Field::new("b", DataType::Int32, true)), + ]); + + let target_fields = Fields::from(vec![ + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::new(Field::new("b", DataType::Int32, true)), + ]); + + let expr = make_struct_cast_expr(source_fields, target_fields); + + let simplifier = + ExprSimplifier::new(SimplifyContext::default().with_schema(test_schema())); + + // The cast should be simplified + let result = simplifier.simplify(expr.clone()).unwrap(); + // Struct casts with same field count should be const-folded to a literal + assert!(matches!(result, Expr::Literal(_, _))); + // Ensure the simplifier made a change (not identical to original) + assert_ne!( + result, expr, + "Struct cast with same field count should be simplified (not identical to input)" + ); + } + + #[test] + fn test_struct_cast_different_names_same_count() { + // Test struct cast with same field count but different names + // Field count matches; simplification should be skipped because names do not overlap + + let source_fields = Fields::from(vec![ + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::new(Field::new("b", DataType::Int32, true)), + ]); + + let target_fields = Fields::from(vec![ + Arc::new(Field::new("x", DataType::Int32, true)), + Arc::new(Field::new("y", DataType::Int32, true)), + ]); + + let expr = make_struct_cast_expr(source_fields, target_fields); + + let simplifier = + ExprSimplifier::new(SimplifyContext::default().with_schema(test_schema())); + + // The cast should remain unchanged because there is no name overlap + let result = simplifier.simplify(expr.clone()).unwrap(); + assert_eq!( + result, expr, + "Struct cast with different names but same field count should not be simplified" + ); + } + + #[test] + fn test_struct_cast_empty_array_not_foldable() { + // Test that struct casts with 0-row (empty) struct arrays are NOT const-folded + // The simplifier uses a 1-row input batch, which causes dimension mismatches + // when evaluating 0-row struct literals + + let source_fields = Fields::from(vec![ + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::new(Field::new("b", DataType::Int32, true)), + ]); + + let target_fields = Fields::from(vec![ + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::new(Field::new("b", DataType::Int32, true)), + ]); + + // Create a 0-row (empty) struct array + let arrays: Vec> = vec![ + Arc::new(Int32Array::new(vec![].into(), None)), + Arc::new(Int32Array::new(vec![].into(), None)), + ]; + let struct_array = StructArray::try_new(source_fields, arrays, None).unwrap(); + + let expr = Expr::Cast(Cast::new( + Box::new(Expr::Literal( + ScalarValue::Struct(Arc::new(struct_array)), + None, + )), + DataType::Struct(target_fields), + )); + + let simplifier = + ExprSimplifier::new(SimplifyContext::default().with_schema(test_schema())); + + // The cast should remain unchanged since the struct array is empty (0-row) + let result = simplifier.simplify(expr.clone()).unwrap(); + assert_eq!( + result, expr, + "Struct cast with empty (0-row) array should remain unchanged" + ); + } } diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs index 58a4eadb5c07..b85b000821ad 100644 --- a/datafusion/optimizer/src/simplify_expressions/mod.rs +++ b/datafusion/optimizer/src/simplify_expressions/mod.rs @@ -24,11 +24,12 @@ mod regex; pub mod simplify_exprs; pub mod simplify_literal; mod simplify_predicates; +mod udf_preimage; mod unwrap_cast; mod utils; // backwards compatibility -pub use datafusion_expr::simplify::{SimplifyContext, SimplifyInfo}; +pub use datafusion_expr::simplify::SimplifyContext; pub use expr_simplifier::*; pub use simplify_exprs::*; diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index 1b25c5ce8a63..f7f100015004 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -22,7 +22,6 @@ use std::sync::Arc; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result}; use datafusion_expr::Expr; -use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::utils::merge_schema; @@ -67,17 +66,14 @@ impl OptimizerRule for SimplifyExpressions { plan: LogicalPlan, config: &dyn OptimizerConfig, ) -> Result, DataFusionError> { - let mut execution_props = ExecutionProps::new(); - execution_props.query_execution_start_time = config.query_execution_start_time(); - execution_props.config_options = Some(config.options()); - Self::optimize_internal(plan, &execution_props) + Self::optimize_internal(plan, config) } } impl SimplifyExpressions { fn optimize_internal( plan: LogicalPlan, - execution_props: &ExecutionProps, + config: &dyn OptimizerConfig, ) -> Result> { let schema = if !plan.inputs().is_empty() { DFSchemaRef::new(merge_schema(&plan.inputs())) @@ -100,7 +96,10 @@ impl SimplifyExpressions { Arc::new(DFSchema::empty()) }; - let info = SimplifyContext::new(execution_props).with_schema(schema); + let info = SimplifyContext::default() + .with_schema(schema) + .with_config_options(config.options()) + .with_query_execution_start_time(config.query_execution_start_time()); // Inputs have already been rewritten (due to bottom-up traversal handled by Optimizer) // Just need to rewrite our own expressions diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_literal.rs b/datafusion/optimizer/src/simplify_expressions/simplify_literal.rs index 168a6ebb461f..b77240fc5343 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_literal.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_literal.rs @@ -28,7 +28,6 @@ use datafusion_common::{ plan_err, }; use datafusion_expr::Expr; -use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::simplify::SimplifyContext; use std::sync::Arc; @@ -52,10 +51,8 @@ where log::debug!("Parsing expr {:?} to type {}", expr, T::DATA_TYPE); - let execution_props = ExecutionProps::new(); - let simplifier = ExprSimplifier::new( - SimplifyContext::new(&execution_props).with_schema(Arc::clone(&schema)), - ); + let simplifier = + ExprSimplifier::new(SimplifyContext::default().with_schema(Arc::clone(&schema))); // Simplify and coerce expression in case of constant arithmetic operations (e.g., 10 + 5) let simplified_expr: Expr = simplifier diff --git a/datafusion/optimizer/src/simplify_expressions/udf_preimage.rs b/datafusion/optimizer/src/simplify_expressions/udf_preimage.rs new file mode 100644 index 000000000000..da2716d13cb4 --- /dev/null +++ b/datafusion/optimizer/src/simplify_expressions/udf_preimage.rs @@ -0,0 +1,404 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::{Result, internal_err, tree_node::Transformed}; +use datafusion_expr::{Expr, Operator, and, lit, or}; +use datafusion_expr_common::interval_arithmetic::Interval; + +/// Rewrites a binary expression using its "preimage" +/// +/// Specifically it rewrites expressions of the form ` OP x` (e.g. ` = +/// x`) where `` is known to have a pre-image (aka the entire single +/// range for which it is valid) and `x` is not `NULL` +/// +/// For details see [`datafusion_expr::ScalarUDFImpl::preimage`] +pub(super) fn rewrite_with_preimage( + preimage_interval: Interval, + op: Operator, + expr: Expr, +) -> Result> { + let (lower, upper) = preimage_interval.into_bounds(); + let (lower, upper) = (lit(lower), lit(upper)); + + let rewritten_expr = match op { + // < x ==> < lower + Operator::Lt => expr.lt(lower), + // >= x ==> >= lower + Operator::GtEq => expr.gt_eq(lower), + // > x ==> >= upper + Operator::Gt => expr.gt_eq(upper), + // <= x ==> < upper + Operator::LtEq => expr.lt(upper), + // = x ==> ( >= lower) and ( < upper) + Operator::Eq => and(expr.clone().gt_eq(lower), expr.lt(upper)), + // != x ==> ( < lower) or ( >= upper) + Operator::NotEq => or(expr.clone().lt(lower), expr.gt_eq(upper)), + // is not distinct from x ==> ( is NULL and x is NULL) or (( >= lower) and ( < upper)) + // but since x is always not NULL => ( is not NULL) and ( >= lower) and ( < upper) + Operator::IsNotDistinctFrom => expr + .clone() + .is_not_null() + .and(expr.clone().gt_eq(lower)) + .and(expr.lt(upper)), + // is distinct from x ==> ( < lower) or ( >= upper) or ( is NULL and x is not NULL) or ( is not NULL and x is NULL) + // but given that x is always not NULL => ( < lower) or ( >= upper) or ( is NULL) + Operator::IsDistinctFrom => expr + .clone() + .lt(lower) + .or(expr.clone().gt_eq(upper)) + .or(expr.is_null()), + _ => return internal_err!("Expect comparison operators"), + }; + Ok(Transformed::yes(rewritten_expr)) +} + +#[cfg(test)] +mod test { + use std::any::Any; + use std::sync::Arc; + + use arrow::datatypes::{DataType, Field}; + use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue}; + use datafusion_expr::{ + ColumnarValue, Expr, Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, + Signature, Volatility, and, binary_expr, col, lit, or, preimage::PreimageResult, + simplify::SimplifyContext, + }; + + use super::Interval; + use crate::simplify_expressions::ExprSimplifier; + + fn is_distinct_from(left: Expr, right: Expr) -> Expr { + binary_expr(left, Operator::IsDistinctFrom, right) + } + + fn is_not_distinct_from(left: Expr, right: Expr) -> Expr { + binary_expr(left, Operator::IsNotDistinctFrom, right) + } + + #[derive(Debug, PartialEq, Eq, Hash)] + struct PreimageUdf { + /// Defaults to an exact signature with one Int32 argument and Immutable volatility + signature: Signature, + /// If true, returns a preimage; otherwise, returns None + enabled: bool, + } + + impl PreimageUdf { + fn new() -> Self { + Self { + signature: Signature::exact(vec![DataType::Int32], Volatility::Immutable), + enabled: true, + } + } + + /// Set the enabled flag + fn with_enabled(mut self, enabled: bool) -> Self { + self.enabled = enabled; + self + } + + /// Set the volatility + fn with_volatility(mut self, volatility: Volatility) -> Self { + self.signature.volatility = volatility; + self + } + } + + impl ScalarUDFImpl for PreimageUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "preimage_func" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(500)))) + } + + fn preimage( + &self, + args: &[Expr], + lit_expr: &Expr, + _info: &SimplifyContext, + ) -> Result { + if !self.enabled { + return Ok(PreimageResult::None); + } + if args.len() != 1 { + return Ok(PreimageResult::None); + } + + let expr = args.first().cloned().expect("Should be column expression"); + match lit_expr { + Expr::Literal(ScalarValue::Int32(Some(500)), _) => { + Ok(PreimageResult::Range { + expr, + interval: Box::new(Interval::try_new( + ScalarValue::Int32(Some(100)), + ScalarValue::Int32(Some(200)), + )?), + }) + } + Expr::Literal(ScalarValue::Int32(Some(600)), _) => { + Ok(PreimageResult::Range { + expr, + interval: Box::new(Interval::try_new( + ScalarValue::Int32(Some(300)), + ScalarValue::Int32(Some(400)), + )?), + }) + } + _ => Ok(PreimageResult::None), + } + } + } + + fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { + let simplify_context = SimplifyContext::default().with_schema(Arc::clone(schema)); + ExprSimplifier::new(simplify_context) + .simplify(expr) + .unwrap() + } + + fn preimage_udf_expr() -> Expr { + ScalarUDF::new_from_impl(PreimageUdf::new()).call(vec![col("x")]) + } + + fn non_immutable_udf_expr() -> Expr { + ScalarUDF::new_from_impl(PreimageUdf::new().with_volatility(Volatility::Volatile)) + .call(vec![col("x")]) + } + + fn no_preimage_udf_expr() -> Expr { + ScalarUDF::new_from_impl(PreimageUdf::new().with_enabled(false)) + .call(vec![col("x")]) + } + + fn test_schema() -> DFSchemaRef { + Arc::new( + DFSchema::from_unqualified_fields( + vec![Field::new("x", DataType::Int32, true)].into(), + Default::default(), + ) + .unwrap(), + ) + } + + fn test_schema_xy() -> DFSchemaRef { + Arc::new( + DFSchema::from_unqualified_fields( + vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Int32, false), + ] + .into(), + Default::default(), + ) + .unwrap(), + ) + } + + #[test] + fn test_preimage_eq_rewrite() { + // Equality rewrite when preimage and column expression are available. + let schema = test_schema(); + let expr = preimage_udf_expr().eq(lit(500)); + let expected = and(col("x").gt_eq(lit(100)), col("x").lt(lit(200))); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_noteq_rewrite() { + // Inequality rewrite expands to disjoint ranges. + let schema = test_schema(); + let expr = preimage_udf_expr().not_eq(lit(500)); + let expected = col("x").lt(lit(100)).or(col("x").gt_eq(lit(200))); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_eq_rewrite_swapped() { + // Equality rewrite works when the literal appears on the left. + let schema = test_schema(); + let expr = lit(500).eq(preimage_udf_expr()); + let expected = and(col("x").gt_eq(lit(100)), col("x").lt(lit(200))); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_lt_rewrite() { + // Less-than comparison rewrites to the lower bound. + let schema = test_schema(); + let expr = preimage_udf_expr().lt(lit(500)); + let expected = col("x").lt(lit(100)); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_lteq_rewrite() { + // Less-than-or-equal comparison rewrites to the upper bound. + let schema = test_schema(); + let expr = preimage_udf_expr().lt_eq(lit(500)); + let expected = col("x").lt(lit(200)); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_gt_rewrite() { + // Greater-than comparison rewrites to the upper bound (inclusive). + let schema = test_schema(); + let expr = preimage_udf_expr().gt(lit(500)); + let expected = col("x").gt_eq(lit(200)); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_gteq_rewrite() { + // Greater-than-or-equal comparison rewrites to the lower bound. + let schema = test_schema(); + let expr = preimage_udf_expr().gt_eq(lit(500)); + let expected = col("x").gt_eq(lit(100)); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_is_not_distinct_from_rewrite() { + // IS NOT DISTINCT FROM rewrites to equality plus expression not-null check + // for non-null literal RHS. + let schema = test_schema(); + let expr = is_not_distinct_from(preimage_udf_expr(), lit(500)); + let expected = col("x") + .is_not_null() + .and(col("x").gt_eq(lit(100))) + .and(col("x").lt(lit(200))); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_is_distinct_from_rewrite() { + // IS DISTINCT FROM adds an explicit NULL branch for the column. + let schema = test_schema(); + let expr = is_distinct_from(preimage_udf_expr(), lit(500)); + let expected = col("x") + .lt(lit(100)) + .or(col("x").gt_eq(lit(200))) + .or(col("x").is_null()); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_in_list_rewrite() { + let schema = test_schema(); + let expr = preimage_udf_expr().in_list(vec![lit(500), lit(600)], false); + let expected = or( + and(col("x").gt_eq(lit(100)), col("x").lt(lit(200))), + and(col("x").gt_eq(lit(300)), col("x").lt(lit(400))), + ); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_not_in_list_rewrite() { + let schema = test_schema(); + let expr = preimage_udf_expr().in_list(vec![lit(500), lit(600)], true); + let expected = and( + or(col("x").lt(lit(100)), col("x").gt_eq(lit(200))), + or(col("x").lt(lit(300)), col("x").gt_eq(lit(400))), + ); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_in_list_long_list_no_rewrite() { + let schema = test_schema(); + let expr = preimage_udf_expr().in_list((1..100).map(lit).collect(), false); + + assert_eq!(optimize_test(expr.clone(), &schema), expr); + } + + #[test] + fn test_preimage_non_literal_rhs_no_rewrite() { + // Non-literal RHS should not be rewritten. + let schema = test_schema_xy(); + let expr = preimage_udf_expr().eq(col("y")); + let expected = expr.clone(); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_null_literal_no_rewrite_distinct_ops() { + // NULL literal RHS should not be rewritten for DISTINCTness operators: + // - `expr IS DISTINCT FROM NULL` <=> `NOT (expr IS NULL)` + // - `expr IS NOT DISTINCT FROM NULL` <=> `expr IS NULL` + // + // For normal comparisons (=, !=, <, <=, >, >=), `expr OP NULL` evaluates to NULL + // under SQL tri-state logic, and DataFusion's simplifier constant-folds it. + // https://docs.rs/datafusion/latest/datafusion/physical_optimizer/pruning/struct.PruningPredicate.html#boolean-tri-state-logic + + let schema = test_schema(); + + let expr = is_distinct_from(preimage_udf_expr(), lit(ScalarValue::Int32(None))); + assert_eq!(optimize_test(expr.clone(), &schema), expr); + + let expr = + is_not_distinct_from(preimage_udf_expr(), lit(ScalarValue::Int32(None))); + assert_eq!(optimize_test(expr.clone(), &schema), expr); + } + + #[test] + fn test_preimage_non_immutable_no_rewrite() { + // Non-immutable UDFs should not participate in preimage rewrites. + let schema = test_schema(); + let expr = non_immutable_udf_expr().eq(lit(500)); + let expected = expr.clone(); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_no_preimage_no_rewrite() { + // If the UDF provides no preimage, the expression should remain unchanged. + let schema = test_schema(); + let expr = no_preimage_udf_expr().eq(lit(500)); + let expected = expr.clone(); + + assert_eq!(optimize_test(expr, &schema), expected); + } +} diff --git a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs index b2349db8c460..acf0f32ab223 100644 --- a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs +++ b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs @@ -58,11 +58,11 @@ use arrow::datatypes::DataType; use datafusion_common::{Result, ScalarValue}; use datafusion_common::{internal_err, tree_node::Transformed}; use datafusion_expr::{BinaryExpr, lit}; -use datafusion_expr::{Cast, Expr, Operator, TryCast, simplify::SimplifyInfo}; +use datafusion_expr::{Cast, Expr, Operator, TryCast, simplify::SimplifyContext}; use datafusion_expr_common::casts::{is_supported_type, try_cast_literal_to_type}; -pub(super) fn unwrap_cast_in_comparison_for_binary( - info: &S, +pub(super) fn unwrap_cast_in_comparison_for_binary( + info: &SimplifyContext, cast_expr: Expr, literal: Expr, op: Operator, @@ -104,10 +104,8 @@ pub(super) fn unwrap_cast_in_comparison_for_binary( } } -pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary< - S: SimplifyInfo, ->( - info: &S, +pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary( + info: &SimplifyContext, expr: &Expr, op: Operator, literal: &Expr, @@ -142,10 +140,8 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary< } } -pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist< - S: SimplifyInfo, ->( - info: &S, +pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist( + info: &SimplifyContext, expr: &Expr, list: &[Expr], ) -> bool { @@ -241,7 +237,6 @@ mod tests { use crate::simplify_expressions::ExprSimplifier; use arrow::datatypes::{Field, TimeUnit}; use datafusion_common::{DFSchema, DFSchemaRef}; - use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::{cast, col, in_list, try_cast}; @@ -592,9 +587,8 @@ mod tests { } fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { - let props = ExecutionProps::new(); let simplifier = ExprSimplifier::new( - SimplifyContext::new(&props).with_schema(Arc::clone(schema)), + SimplifyContext::default().with_schema(Arc::clone(schema)), ); simplifier.simplify(expr).unwrap() diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index 1f214e3d365c..b0908b47602f 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -290,6 +290,54 @@ pub fn is_lit(expr: &Expr) -> bool { matches!(expr, Expr::Literal(_, _)) } +/// Checks if `eq_expr` is `A = L1` and `ne_expr` is `A != L2` where L1 != L2. +/// This pattern can be simplified to just `A = L1` since if A equals L1 +/// and L1 is different from L2, then A is automatically not equal to L2. +pub fn is_eq_and_ne_with_different_literal(eq_expr: &Expr, ne_expr: &Expr) -> bool { + fn extract_var_and_literal(expr: &Expr) -> Option<(&Expr, &Expr)> { + match expr { + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + }) + | Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::NotEq, + right, + }) => match (left.as_ref(), right.as_ref()) { + (Expr::Literal(_, _), var) => Some((var, left)), + (var, Expr::Literal(_, _)) => Some((var, right)), + _ => None, + }, + _ => None, + } + } + match (eq_expr, ne_expr) { + ( + Expr::BinaryExpr(BinaryExpr { + op: Operator::Eq, .. + }), + Expr::BinaryExpr(BinaryExpr { + op: Operator::NotEq, + .. + }), + ) => { + // Check if both compare the same expression against different literals + if let (Some((var1, lit1)), Some((var2, lit2))) = ( + extract_var_and_literal(eq_expr), + extract_var_and_literal(ne_expr), + ) && var1 == var2 + && lit1 != lit2 + { + return true; + } + false + } + _ => false, + } +} + /// negate a Not clause /// input is the clause to be negated.(args of Not clause) /// For BinaryExpr, use the negation of op instead. diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 05edd230dacc..00c8fab22811 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -184,7 +184,11 @@ impl OptimizerRule for SingleDistinctToGroupBy { func, params: AggregateFunctionParams { - mut args, distinct, .. + mut args, + distinct, + filter, + order_by, + null_treatment, }, }) => { if distinct { @@ -204,9 +208,9 @@ impl OptimizerRule for SingleDistinctToGroupBy { func, vec![col(SINGLE_DISTINCT_ALIAS)], false, // intentional to remove distinct here - None, - vec![], - None, + filter, + order_by, + null_treatment, ))) // if the aggregate function is not distinct, we need to rewrite it like two phase aggregation } else { @@ -217,9 +221,9 @@ impl OptimizerRule for SingleDistinctToGroupBy { Arc::clone(&func), args, false, - None, - vec![], - None, + filter, + order_by, + null_treatment, )) .alias(&alias_str), ); diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index a45983950496..2915e77be2e1 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -24,6 +24,7 @@ use datafusion_common::{Result, assert_contains}; use datafusion_expr::{LogicalPlan, LogicalPlanBuilder, logical_plan::table_scan}; use std::sync::Arc; +pub mod udfs; pub mod user_defined; pub fn test_table_scan_fields() -> Vec { @@ -34,6 +35,28 @@ pub fn test_table_scan_fields() -> Vec { ] } +pub fn test_table_scan_with_struct_fields() -> Vec { + vec![ + Field::new("id", DataType::UInt32, false), + Field::new( + "user", + DataType::Struct( + vec![ + Field::new("name", DataType::Utf8, true), + Field::new("status", DataType::Utf8, true), + ] + .into(), + ), + true, + ), + ] +} + +pub fn test_table_scan_with_struct() -> Result { + let schema = Schema::new(test_table_scan_with_struct_fields()); + table_scan(Some("test"), &schema, None)?.build() +} + /// some tests share a common table with different names pub fn test_table_scan_with_name(name: &str) -> Result { let schema = Schema::new(test_table_scan_fields()); diff --git a/datafusion/optimizer/src/test/udfs.rs b/datafusion/optimizer/src/test/udfs.rs new file mode 100644 index 000000000000..9164603dba3d --- /dev/null +++ b/datafusion/optimizer/src/test/udfs.rs @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; + +use arrow::datatypes::DataType; +use datafusion_common::Result; +use datafusion_expr::{ + ColumnarValue, Expr, ExpressionPlacement, ScalarFunctionArgs, ScalarUDF, + ScalarUDFImpl, Signature, TypeSignature, +}; + +/// A configurable test UDF for optimizer tests. +/// Defaults to `MoveTowardsLeafNodes` placement. Use `with_placement()` to override. +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct PlacementTestUDF { + signature: Signature, + placement: ExpressionPlacement, + id: usize, +} + +impl Default for PlacementTestUDF { + fn default() -> Self { + Self::new() + } +} + +impl PlacementTestUDF { + pub fn new() -> Self { + Self { + // Accept any one or two arguments and return UInt32 for testing purposes. + // The actual types don't matter since this UDF is not intended for execution. + signature: Signature::new( + TypeSignature::OneOf(vec![TypeSignature::Any(1), TypeSignature::Any(2)]), + datafusion_expr::Volatility::Immutable, + ), + placement: ExpressionPlacement::MoveTowardsLeafNodes, + id: 0, + } + } + + /// Set the expression placement for this UDF, which is used by optimizer rules to determine where in the plan the expression should be placed. + /// This also resets the name of the UDF to a default based on the placement. + pub fn with_placement(mut self, placement: ExpressionPlacement) -> Self { + self.placement = placement; + self + } + + /// Set the id of the UDF. + /// This is an arbitrary made up field to allow creating multiple distinct UDFs with the same placement. + pub fn with_id(mut self, id: usize) -> Self { + self.id = id; + self + } +} + +impl ScalarUDFImpl for PlacementTestUDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + match self.placement { + ExpressionPlacement::MoveTowardsLeafNodes => "leaf_udf", + ExpressionPlacement::KeepInPlace => "keep_in_place_udf", + ExpressionPlacement::Column => "column_udf", + ExpressionPlacement::Literal => "literal_udf", + } + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::UInt32) + } + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + panic!("PlacementTestUDF: not intended for execution") + } + fn placement(&self, _args: &[ExpressionPlacement]) -> ExpressionPlacement { + self.placement + } +} + +/// Create a `leaf_udf(arg)` expression with `MoveTowardsLeafNodes` placement. +pub fn leaf_udf_expr(arg: Expr) -> Expr { + let udf = ScalarUDF::new_from_impl( + PlacementTestUDF::new().with_placement(ExpressionPlacement::MoveTowardsLeafNodes), + ); + udf.call(vec![arg]) +} diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 36a6df54ddaf..fd4991c24413 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -543,7 +543,7 @@ fn recursive_cte_projection_pushdown() -> Result<()> { RecursiveQuery: is_distinct=false Projection: test.col_int32 AS id TableScan: test projection=[col_int32] - Projection: CAST(CAST(nodes.id AS Int64) + Int64(1) AS Int32) AS id + Projection: CAST(CAST(nodes.id AS Int64) + Int64(1) AS Int32) Filter: nodes.id < Int32(3) TableScan: nodes projection=[id] " @@ -567,7 +567,7 @@ fn recursive_cte_with_aliased_self_reference() -> Result<()> { RecursiveQuery: is_distinct=false Projection: test.col_int32 AS id TableScan: test projection=[col_int32] - Projection: CAST(CAST(child.id AS Int64) + Int64(1) AS Int32) AS id + Projection: CAST(CAST(child.id AS Int64) + Int64(1) AS Int32) SubqueryAlias: child Filter: nodes.id < Int32(3) TableScan: nodes projection=[id] @@ -630,7 +630,7 @@ fn recursive_cte_projection_pushdown_baseline() -> Result<()> { Projection: test.col_int32 AS n Filter: test.col_int32 = Int32(5) TableScan: test projection=[col_int32] - Projection: CAST(CAST(countdown.n AS Int64) - Int64(1) AS Int32) AS n + Projection: CAST(CAST(countdown.n AS Int64) - Int64(1) AS Int32) Filter: countdown.n > Int32(1) TableScan: countdown projection=[n] " diff --git a/datafusion/physical-expr-adapter/Cargo.toml b/datafusion/physical-expr-adapter/Cargo.toml index 03e1b1f06578..453c8bdaacb4 100644 --- a/datafusion/physical-expr-adapter/Cargo.toml +++ b/datafusion/physical-expr-adapter/Cargo.toml @@ -24,4 +24,7 @@ datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } itertools = { workspace = true } +[lints] +workspace = true + [dev-dependencies] diff --git a/datafusion/physical-expr-adapter/LICENSE.txt b/datafusion/physical-expr-adapter/LICENSE.txt new file mode 120000 index 000000000000..1ef648f64b34 --- /dev/null +++ b/datafusion/physical-expr-adapter/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/physical-expr-adapter/NOTICE.txt b/datafusion/physical-expr-adapter/NOTICE.txt new file mode 120000 index 000000000000..fb051c92b10b --- /dev/null +++ b/datafusion/physical-expr-adapter/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/physical-expr-adapter/src/lib.rs b/datafusion/physical-expr-adapter/src/lib.rs index d7c750e4a1a1..ea4db19ee110 100644 --- a/datafusion/physical-expr-adapter/src/lib.rs +++ b/datafusion/physical-expr-adapter/src/lib.rs @@ -21,14 +21,13 @@ html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] #![cfg_attr(docsrs, feature(doc_cfg))] -// https://github.com/apache/datafusion/issues/18881 -#![deny(clippy::allow_attributes)] //! Physical expression schema adaptation utilities for DataFusion pub mod schema_rewriter; pub use schema_rewriter::{ - DefaultPhysicalExprAdapter, DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, - PhysicalExprAdapterFactory, replace_columns_with_literals, + BatchAdapter, BatchAdapterFactory, DefaultPhysicalExprAdapter, + DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, + replace_columns_with_literals, }; diff --git a/datafusion/physical-expr-adapter/src/schema_rewriter.rs b/datafusion/physical-expr-adapter/src/schema_rewriter.rs index 83727ac09204..ec5f9139ed22 100644 --- a/datafusion/physical-expr-adapter/src/schema_rewriter.rs +++ b/datafusion/physical-expr-adapter/src/schema_rewriter.rs @@ -24,20 +24,24 @@ use std::collections::HashMap; use std::hash::Hash; use std::sync::Arc; +use arrow::array::RecordBatch; use arrow::compute::can_cast_types; -use arrow::datatypes::{DataType, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Field, SchemaRef}; use datafusion_common::{ Result, ScalarValue, exec_err, nested_struct::validate_struct_compatibility, tree_node::{Transformed, TransformedResult, TreeNode}, }; use datafusion_functions::core::getfield::GetFieldFunc; +use datafusion_physical_expr::PhysicalExprSimplifier; use datafusion_physical_expr::expressions::CastColumnExpr; +use datafusion_physical_expr::projection::{ProjectionExprs, Projector}; use datafusion_physical_expr::{ ScalarFunctionExpr, expressions::{self, Column}, }; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use itertools::Itertools; /// Replace column references in the given physical expression with literal values. /// @@ -137,11 +141,11 @@ where /// &self, /// logical_file_schema: SchemaRef, /// physical_file_schema: SchemaRef, -/// ) -> Arc { -/// Arc::new(CustomPhysicalExprAdapter { +/// ) -> Result> { +/// Ok(Arc::new(CustomPhysicalExprAdapter { /// logical_file_schema, /// physical_file_schema, -/// }) +/// })) /// } /// } /// ``` @@ -174,7 +178,7 @@ pub trait PhysicalExprAdapterFactory: Send + Sync + std::fmt::Debug { &self, logical_file_schema: SchemaRef, physical_file_schema: SchemaRef, - ) -> Arc; + ) -> Result>; } #[derive(Debug, Clone)] @@ -185,11 +189,11 @@ impl PhysicalExprAdapterFactory for DefaultPhysicalExprAdapterFactory { &self, logical_file_schema: SchemaRef, physical_file_schema: SchemaRef, - ) -> Arc { - Arc::new(DefaultPhysicalExprAdapter { + ) -> Result> { + Ok(Arc::new(DefaultPhysicalExprAdapter { logical_file_schema, physical_file_schema, - }) + })) } } @@ -228,7 +232,8 @@ impl PhysicalExprAdapterFactory for DefaultPhysicalExprAdapterFactory { /// # logical_file_schema: &Schema, /// # ) -> datafusion_common::Result<()> { /// let factory = DefaultPhysicalExprAdapterFactory; -/// let adapter = factory.create(Arc::new(logical_file_schema.clone()), Arc::new(physical_file_schema.clone())); +/// let adapter = +/// factory.create(Arc::new(logical_file_schema.clone()), Arc::new(physical_file_schema.clone()))?; /// let adapted_predicate = adapter.rewrite(predicate)?; /// # Ok(()) /// # } @@ -255,20 +260,20 @@ impl DefaultPhysicalExprAdapter { impl PhysicalExprAdapter for DefaultPhysicalExprAdapter { fn rewrite(&self, expr: Arc) -> Result> { let rewriter = DefaultPhysicalExprAdapterRewriter { - logical_file_schema: &self.logical_file_schema, - physical_file_schema: &self.physical_file_schema, + logical_file_schema: Arc::clone(&self.logical_file_schema), + physical_file_schema: Arc::clone(&self.physical_file_schema), }; expr.transform(|expr| rewriter.rewrite_expr(Arc::clone(&expr))) .data() } } -struct DefaultPhysicalExprAdapterRewriter<'a> { - logical_file_schema: &'a Schema, - physical_file_schema: &'a Schema, +struct DefaultPhysicalExprAdapterRewriter { + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, } -impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { +impl DefaultPhysicalExprAdapterRewriter { fn rewrite_expr( &self, expr: Arc, @@ -416,18 +421,13 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { }; let physical_field = self.physical_file_schema.field(physical_column_index); - let column = match ( - column.index() == physical_column_index, - logical_field.data_type() == physical_field.data_type(), - ) { - // If the column index matches and the data types match, we can use the column as is - (true, true) => return Ok(Transformed::no(expr)), - // If the indexes or data types do not match, we need to create a new column expression - (true, _) => column.clone(), - (false, _) => { - Column::new_with_schema(logical_field.name(), self.physical_file_schema)? - } - }; + if column.index() == physical_column_index + && logical_field.data_type() == physical_field.data_type() + { + return Ok(Transformed::no(expr)); + } + + let column = self.resolve_column(column, physical_column_index)?; if logical_field.data_type() == physical_field.data_type() { // If the data types match, we can use the column as is @@ -438,24 +438,63 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { // TODO: add optimization to move the cast from the column to literal expressions in the case of `col = 123` // since that's much cheaper to evalaute. // See https://github.com/apache/datafusion/issues/15780#issuecomment-2824716928 - // + self.create_cast_column_expr(column, logical_field) + } + + /// Resolves a column expression, handling index and type mismatches. + /// + /// Returns the appropriate Column expression when the column's index or data type + /// don't match the physical schema. Assumes that the early-exit case (both index + /// and type match) has already been checked by the caller. + fn resolve_column( + &self, + column: &Column, + physical_column_index: usize, + ) -> Result { + if column.index() == physical_column_index { + Ok(column.clone()) + } else { + Column::new_with_schema(column.name(), self.physical_file_schema.as_ref()) + } + } + + /// Validates type compatibility and creates a CastColumnExpr if needed. + /// + /// Checks whether the physical field can be cast to the logical field type, + /// handling both struct and scalar types. Returns a CastColumnExpr with the + /// appropriate configuration. + fn create_cast_column_expr( + &self, + column: Column, + logical_field: &Field, + ) -> Result>> { + // Look up the column index in the physical schema by name to ensure correctness. + let physical_column_index = self.physical_file_schema.index_of(column.name())?; + let actual_physical_field = + self.physical_file_schema.field(physical_column_index); + // For struct types, use validate_struct_compatibility which handles: // - Missing fields in source (filled with nulls) // - Extra fields in source (ignored) // - Recursive validation of nested structs // For non-struct types, use Arrow's can_cast_types - match (physical_field.data_type(), logical_field.data_type()) { + match (actual_physical_field.data_type(), logical_field.data_type()) { (DataType::Struct(physical_fields), DataType::Struct(logical_fields)) => { - validate_struct_compatibility(physical_fields, logical_fields)?; + validate_struct_compatibility( + physical_fields.as_ref(), + logical_fields.as_ref(), + )?; } _ => { - let is_compatible = - can_cast_types(physical_field.data_type(), logical_field.data_type()); + let is_compatible = can_cast_types( + actual_physical_field.data_type(), + logical_field.data_type(), + ); if !is_compatible { return exec_err!( "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type)", column.name(), - physical_field.data_type(), + actual_physical_field.data_type(), logical_field.data_type() ); } @@ -464,7 +503,7 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { let cast_expr = Arc::new(CastColumnExpr::new( Arc::new(column), - Arc::new(physical_field.clone()), + Arc::new(actual_physical_field.clone()), Arc::new(logical_field.clone()), None, )); @@ -473,6 +512,141 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { } } +/// Factory for creating [`BatchAdapter`] instances to adapt record batches +/// to a target schema. +/// +/// This binds a target schema and allows creating adapters for different source schemas. +/// It handles: +/// - **Column reordering**: Columns are reordered to match the target schema +/// - **Type casting**: Automatic type conversion (e.g., Int32 to Int64) +/// - **Missing columns**: Nullable columns missing from source are filled with nulls +/// - **Struct field adaptation**: Nested struct fields are recursively adapted +/// +/// ## Examples +/// +/// ```rust +/// use arrow::array::{Int32Array, Int64Array, StringArray, RecordBatch}; +/// use arrow::datatypes::{DataType, Field, Schema}; +/// use datafusion_physical_expr_adapter::BatchAdapterFactory; +/// use std::sync::Arc; +/// +/// // Target schema has different column order and types +/// let target_schema = Arc::new(Schema::new(vec![ +/// Field::new("name", DataType::Utf8, true), +/// Field::new("id", DataType::Int64, false), // Int64 in target +/// Field::new("score", DataType::Float64, true), // Missing from source +/// ])); +/// +/// // Source schema has different column order and Int32 for id +/// let source_schema = Arc::new(Schema::new(vec![ +/// Field::new("id", DataType::Int32, false), // Int32 in source +/// Field::new("name", DataType::Utf8, true), +/// // Note: 'score' column is missing from source +/// ])); +/// +/// // Create factory with target schema +/// let factory = BatchAdapterFactory::new(Arc::clone(&target_schema)); +/// +/// // Create adapter for this specific source schema +/// let adapter = factory.make_adapter(&source_schema).unwrap(); +/// +/// // Create a source batch +/// let source_batch = RecordBatch::try_new( +/// source_schema, +/// vec![ +/// Arc::new(Int32Array::from(vec![1, 2, 3])), +/// Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol"])), +/// ], +/// ).unwrap(); +/// +/// // Adapt the batch to match target schema +/// let adapted = adapter.adapt_batch(&source_batch).unwrap(); +/// +/// assert_eq!(adapted.num_columns(), 3); +/// assert_eq!(adapted.column(0).data_type(), &DataType::Utf8); // name +/// assert_eq!(adapted.column(1).data_type(), &DataType::Int64); // id (cast from Int32) +/// assert_eq!(adapted.column(2).data_type(), &DataType::Float64); // score (filled with nulls) +/// ``` +#[derive(Debug)] +pub struct BatchAdapterFactory { + target_schema: SchemaRef, + expr_adapter_factory: Arc, +} + +impl BatchAdapterFactory { + /// Create a new [`BatchAdapterFactory`] with the given target schema. + pub fn new(target_schema: SchemaRef) -> Self { + let expr_adapter_factory = Arc::new(DefaultPhysicalExprAdapterFactory); + Self { + target_schema, + expr_adapter_factory, + } + } + + /// Set a custom [`PhysicalExprAdapterFactory`] to use when adapting expressions. + /// + /// Use this to customize behavior when adapting batches, e.g. to fill in missing values + /// with defaults instead of nulls. + /// + /// See [`PhysicalExprAdapter`] for more details. + pub fn with_adapter_factory( + self, + factory: Arc, + ) -> Self { + Self { + expr_adapter_factory: factory, + ..self + } + } + + /// Create a new [`BatchAdapter`] for the given source schema. + /// + /// Batches fed into this [`BatchAdapter`] *must* conform to the source schema, + /// no validation is performed at runtime to minimize overheads. + pub fn make_adapter(&self, source_schema: &SchemaRef) -> Result { + let expr_adapter = self + .expr_adapter_factory + .create(Arc::clone(&self.target_schema), Arc::clone(source_schema))?; + + let simplifier = PhysicalExprSimplifier::new(&self.target_schema); + + let projection = ProjectionExprs::from_indices( + &(0..self.target_schema.fields().len()).collect_vec(), + &self.target_schema, + ); + + let adapted = projection + .try_map_exprs(|e| simplifier.simplify(expr_adapter.rewrite(e)?))?; + let projector = adapted.make_projector(source_schema)?; + + Ok(BatchAdapter { projector }) + } +} + +/// Adapter for transforming record batches to match a target schema. +/// +/// Create instances via [`BatchAdapterFactory`]. +/// +/// ## Performance +/// +/// The adapter pre-computes the projection expressions during creation, +/// so the [`adapt_batch`](BatchAdapter::adapt_batch) call is efficient and suitable +/// for use in hot paths like streaming file scans. +#[derive(Debug)] +pub struct BatchAdapter { + projector: Projector, +} + +impl BatchAdapter { + /// Adapt the given record batch to match the target schema. + /// + /// The input batch *must* conform to the source schema used when + /// creating this adapter. + pub fn adapt_batch(&self, batch: &RecordBatch) -> Result { + self.projector.project_batch(batch) + } +} + #[cfg(test)] mod tests { use super::*; @@ -508,7 +682,9 @@ mod tests { let (physical_schema, logical_schema) = create_test_schema(); let factory = DefaultPhysicalExprAdapterFactory; - let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + let adapter = factory + .create(Arc::new(logical_schema), Arc::new(physical_schema)) + .unwrap(); let column_expr = Arc::new(Column::new("a", 0)); let result = adapter.rewrite(column_expr).unwrap(); @@ -521,7 +697,9 @@ mod tests { fn test_rewrite_multi_column_expr_with_type_cast() { let (physical_schema, logical_schema) = create_test_schema(); let factory = DefaultPhysicalExprAdapterFactory; - let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + let adapter = factory + .create(Arc::new(logical_schema), Arc::new(physical_schema)) + .unwrap(); // Create a complex expression: (a + 5) OR (c > 0.0) that tests the recursive case of the rewriter let column_a = Arc::new(Column::new("a", 0)) as Arc; @@ -529,7 +707,7 @@ mod tests { let expr = expressions::BinaryExpr::new( Arc::clone(&column_a), Operator::Plus, - Arc::new(expressions::Literal::new(ScalarValue::Int64(Some(5)))), + Arc::new(Literal::new(ScalarValue::Int64(Some(5)))), ); let expr = expressions::BinaryExpr::new( Arc::new(expr), @@ -537,7 +715,7 @@ mod tests { Arc::new(expressions::BinaryExpr::new( Arc::clone(&column_c), Operator::Gt, - Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.0)))), + Arc::new(Literal::new(ScalarValue::Float64(Some(0.0)))), )), ); @@ -552,7 +730,7 @@ mod tests { None, )), Operator::Plus, - Arc::new(expressions::Literal::new(ScalarValue::Int64(Some(5)))), + Arc::new(Literal::new(ScalarValue::Int64(Some(5)))), ); let expected = Arc::new(expressions::BinaryExpr::new( Arc::new(expected), @@ -560,7 +738,7 @@ mod tests { Arc::new(expressions::BinaryExpr::new( lit(ScalarValue::Float64(None)), // c is missing, so it becomes null Operator::Gt, - Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.0)))), + Arc::new(Literal::new(ScalarValue::Float64(Some(0.0)))), )), )) as Arc; @@ -586,7 +764,9 @@ mod tests { )]); let factory = DefaultPhysicalExprAdapterFactory; - let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + let adapter = factory + .create(Arc::new(logical_schema), Arc::new(physical_schema)) + .unwrap(); let column_expr = Arc::new(Column::new("data", 0)); let error_msg = adapter.rewrite(column_expr).unwrap_err().to_string(); @@ -624,35 +804,39 @@ mod tests { )]); let factory = DefaultPhysicalExprAdapterFactory; - let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + let adapter = factory + .create(Arc::new(logical_schema), Arc::new(physical_schema)) + .unwrap(); let column_expr = Arc::new(Column::new("data", 0)); let result = adapter.rewrite(column_expr).unwrap(); + let physical_struct_fields: Fields = vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ] + .into(); + let physical_field = Arc::new(Field::new( + "data", + DataType::Struct(physical_struct_fields), + false, + )); + + let logical_struct_fields: Fields = vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8View, true), + ] + .into(); + let logical_field = Arc::new(Field::new( + "data", + DataType::Struct(logical_struct_fields), + false, + )); + let expected = Arc::new(CastColumnExpr::new( Arc::new(Column::new("data", 0)), - Arc::new(Field::new( - "data", - DataType::Struct( - vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, true), - ] - .into(), - ), - false, - )), - Arc::new(Field::new( - "data", - DataType::Struct( - vec![ - Field::new("id", DataType::Int64, false), - Field::new("name", DataType::Utf8View, true), - ] - .into(), - ), - false, - )), + physical_field, + logical_field, None, )) as Arc; @@ -664,13 +848,15 @@ mod tests { let (physical_schema, logical_schema) = create_test_schema(); let factory = DefaultPhysicalExprAdapterFactory; - let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + let adapter = factory + .create(Arc::new(logical_schema), Arc::new(physical_schema)) + .unwrap(); let column_expr = Arc::new(Column::new("c", 2)); let result = adapter.rewrite(column_expr)?; // Should be replaced with a literal null - if let Some(literal) = result.as_any().downcast_ref::() { + if let Some(literal) = result.as_any().downcast_ref::() { assert_eq!(*literal.value(), ScalarValue::Float64(None)); } else { panic!("Expected literal expression"); @@ -688,7 +874,9 @@ mod tests { ]); let factory = DefaultPhysicalExprAdapterFactory; - let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + let adapter = factory + .create(Arc::new(logical_schema), Arc::new(physical_schema)) + .unwrap(); let column_expr = Arc::new(Column::new("b", 1)); let error_msg = adapter.rewrite(column_expr).unwrap_err().to_string(); @@ -704,7 +892,9 @@ mod tests { ]); let factory = DefaultPhysicalExprAdapterFactory; - let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + let adapter = factory + .create(Arc::new(logical_schema), Arc::new(physical_schema)) + .unwrap(); let column_expr = Arc::new(Column::new("b", 1)); let result = adapter.rewrite(column_expr).unwrap(); @@ -727,7 +917,7 @@ mod tests { // Should be replaced with the partition value let literal = result .as_any() - .downcast_ref::() + .downcast_ref::() .expect("Expected literal expression"); assert_eq!(*literal.value(), partition_value); @@ -770,7 +960,9 @@ mod tests { let (physical_schema, logical_schema) = create_test_schema(); let factory = DefaultPhysicalExprAdapterFactory; - let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + let adapter = factory + .create(Arc::new(logical_schema), Arc::new(physical_schema)) + .unwrap(); let column_expr = Arc::new(Column::new("b", 1)) as Arc; let result = adapter.rewrite(Arc::clone(&column_expr))?; @@ -794,7 +986,9 @@ mod tests { ]); let factory = DefaultPhysicalExprAdapterFactory; - let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + let adapter = factory + .create(Arc::new(logical_schema), Arc::new(physical_schema)) + .unwrap(); let column_expr = Arc::new(Column::new("b", 1)); let result = adapter.rewrite(column_expr); @@ -852,8 +1046,9 @@ mod tests { ]; let factory = DefaultPhysicalExprAdapterFactory; - let adapter = - factory.create(Arc::clone(&logical_schema), Arc::clone(&physical_schema)); + let adapter = factory + .create(Arc::clone(&logical_schema), Arc::clone(&physical_schema)) + .unwrap(); let adapted_projection = projection .into_iter() @@ -880,7 +1075,7 @@ mod tests { assert_eq!( res.column(0) .as_any() - .downcast_ref::() + .downcast_ref::() .unwrap() .iter() .collect_vec(), @@ -889,7 +1084,7 @@ mod tests { assert_eq!( res.column(1) .as_any() - .downcast_ref::() + .downcast_ref::() .unwrap() .iter() .collect_vec(), @@ -954,8 +1149,9 @@ mod tests { let projection = vec![col("data", &logical_schema).unwrap()]; let factory = DefaultPhysicalExprAdapterFactory; - let adapter = - factory.create(Arc::clone(&logical_schema), Arc::clone(&physical_schema)); + let adapter = factory + .create(Arc::clone(&logical_schema), Arc::clone(&physical_schema)) + .unwrap(); let adapted_projection = projection .into_iter() @@ -1033,8 +1229,8 @@ mod tests { )]); let rewriter = DefaultPhysicalExprAdapterRewriter { - logical_file_schema: &logical_schema, - physical_file_schema: &physical_schema, + logical_file_schema: Arc::new(logical_schema), + physical_file_schema: Arc::new(physical_schema), }; // Test that when a field exists in physical schema, it returns None @@ -1046,4 +1242,295 @@ mod tests { // with ScalarUDF, which is complex to set up in a unit test. The integration tests in // datafusion/core/tests/parquet/schema_adapter.rs provide better coverage for this functionality. } + + // ============================================================================ + // BatchAdapterFactory and BatchAdapter tests + // ============================================================================ + + #[test] + fn test_batch_adapter_factory_basic() { + // Target schema + let target_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Utf8, true), + ])); + + // Source schema with different column order and type + let source_schema = Arc::new(Schema::new(vec![ + Field::new("b", DataType::Utf8, true), + Field::new("a", DataType::Int32, false), // Int32 -> Int64 + ])); + + let factory = BatchAdapterFactory::new(Arc::clone(&target_schema)); + let adapter = factory.make_adapter(&source_schema).unwrap(); + + // Create source batch + let source_batch = RecordBatch::try_new( + Arc::clone(&source_schema), + vec![ + Arc::new(StringArray::from(vec![Some("hello"), None, Some("world")])), + Arc::new(Int32Array::from(vec![1, 2, 3])), + ], + ) + .unwrap(); + + let adapted = adapter.adapt_batch(&source_batch).unwrap(); + + // Verify schema matches target + assert_eq!(adapted.num_columns(), 2); + assert_eq!(adapted.schema().field(0).name(), "a"); + assert_eq!(adapted.schema().field(0).data_type(), &DataType::Int64); + assert_eq!(adapted.schema().field(1).name(), "b"); + assert_eq!(adapted.schema().field(1).data_type(), &DataType::Utf8); + + // Verify data + let col_a = adapted + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col_a.iter().collect_vec(), vec![Some(1), Some(2), Some(3)]); + + let col_b = adapted + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + col_b.iter().collect_vec(), + vec![Some("hello"), None, Some("world")] + ); + } + + #[test] + fn test_batch_adapter_factory_missing_column() { + // Target schema with a column missing from source + let target_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, true), // exists in source + Field::new("c", DataType::Float64, true), // missing from source + ])); + + let source_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, true), + ])); + + let factory = BatchAdapterFactory::new(Arc::clone(&target_schema)); + let adapter = factory.make_adapter(&source_schema).unwrap(); + + let source_batch = RecordBatch::try_new( + Arc::clone(&source_schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(StringArray::from(vec!["x", "y"])), + ], + ) + .unwrap(); + + let adapted = adapter.adapt_batch(&source_batch).unwrap(); + + assert_eq!(adapted.num_columns(), 3); + + // Missing column should be filled with nulls + let col_c = adapted.column(2); + assert_eq!(col_c.data_type(), &DataType::Float64); + assert_eq!(col_c.null_count(), 2); // All nulls + } + + #[test] + fn test_batch_adapter_factory_with_struct() { + // Target has struct with Int64 id + let target_struct_fields: Fields = vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, true), + ] + .into(); + let target_schema = Arc::new(Schema::new(vec![Field::new( + "data", + DataType::Struct(target_struct_fields), + false, + )])); + + // Source has struct with Int32 id + let source_struct_fields: Fields = vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ] + .into(); + let source_schema = Arc::new(Schema::new(vec![Field::new( + "data", + DataType::Struct(source_struct_fields.clone()), + false, + )])); + + let struct_array = StructArray::new( + source_struct_fields, + vec![ + Arc::new(Int32Array::from(vec![10, 20])) as _, + Arc::new(StringArray::from(vec!["a", "b"])) as _, + ], + None, + ); + + let source_batch = RecordBatch::try_new( + Arc::clone(&source_schema), + vec![Arc::new(struct_array)], + ) + .unwrap(); + + let factory = BatchAdapterFactory::new(Arc::clone(&target_schema)); + let adapter = factory.make_adapter(&source_schema).unwrap(); + let adapted = adapter.adapt_batch(&source_batch).unwrap(); + + let result_struct = adapted + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + // Verify id was cast to Int64 + let id_col = result_struct.column_by_name("id").unwrap(); + assert_eq!(id_col.data_type(), &DataType::Int64); + let id_values = id_col.as_any().downcast_ref::().unwrap(); + assert_eq!(id_values.iter().collect_vec(), vec![Some(10), Some(20)]); + } + + #[test] + fn test_batch_adapter_factory_identity() { + // When source and target schemas are identical, should pass through efficiently + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, true), + ])); + + let factory = BatchAdapterFactory::new(Arc::clone(&schema)); + let adapter = factory.make_adapter(&schema).unwrap(); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ], + ) + .unwrap(); + + let adapted = adapter.adapt_batch(&batch).unwrap(); + + assert_eq!(adapted.num_columns(), 2); + assert_eq!(adapted.schema().field(0).data_type(), &DataType::Int32); + assert_eq!(adapted.schema().field(1).data_type(), &DataType::Utf8); + } + + #[test] + fn test_batch_adapter_factory_reuse() { + // Factory can create multiple adapters for different source schemas + let target_schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int64, false), + Field::new("y", DataType::Utf8, true), + ])); + + let factory = BatchAdapterFactory::new(Arc::clone(&target_schema)); + + // First source schema + let source1 = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Utf8, true), + ])); + let adapter1 = factory.make_adapter(&source1).unwrap(); + + // Second source schema (different order) + let source2 = Arc::new(Schema::new(vec![ + Field::new("y", DataType::Utf8, true), + Field::new("x", DataType::Int64, false), + ])); + let adapter2 = factory.make_adapter(&source2).unwrap(); + + // Both should work correctly + assert!(format!("{adapter1:?}").contains("BatchAdapter")); + assert!(format!("{adapter2:?}").contains("BatchAdapter")); + } + + #[test] + fn test_rewrite_column_index_and_type_mismatch() { + let physical_schema = Schema::new(vec![ + Field::new("b", DataType::Utf8, true), + Field::new("a", DataType::Int32, false), // Index 1 + ]); + + let logical_schema = Schema::new(vec![ + Field::new("a", DataType::Int64, false), // Index 0, Different Type + Field::new("b", DataType::Utf8, true), + ]); + + let factory = DefaultPhysicalExprAdapterFactory; + let adapter = factory + .create(Arc::new(logical_schema), Arc::new(physical_schema)) + .unwrap(); + + // Logical column "a" is at index 0 + let column_expr = Arc::new(Column::new("a", 0)); + + let result = adapter.rewrite(column_expr).unwrap(); + + // Should be a CastColumnExpr + let cast_expr = result + .as_any() + .downcast_ref::() + .expect("Expected CastColumnExpr"); + + // Verify the inner column points to the correct physical index (1) + let inner_col = cast_expr + .expr() + .as_any() + .downcast_ref::() + .expect("Expected inner Column"); + assert_eq!(inner_col.name(), "a"); + assert_eq!(inner_col.index(), 1); // Physical index is 1 + + // Verify cast types + assert_eq!( + cast_expr.data_type(&Schema::empty()).unwrap(), + DataType::Int64 + ); + } + + #[test] + fn test_create_cast_column_expr_uses_name_lookup_not_column_index() { + // Physical schema has column `a` at index 1; index 0 is an incompatible type. + let physical_schema = Arc::new(Schema::new(vec![ + Field::new("b", DataType::Binary, true), + Field::new("a", DataType::Int32, false), + ])); + + let logical_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Binary, true), + ])); + + let rewriter = DefaultPhysicalExprAdapterRewriter { + logical_file_schema: Arc::clone(&logical_schema), + physical_file_schema: Arc::clone(&physical_schema), + }; + + // Deliberately provide the wrong index for column `a`. + // Regression: this must still resolve against physical field `a` by name. + let transformed = rewriter + .create_cast_column_expr( + Column::new("a", 0), + logical_schema.field_with_name("a").unwrap(), + ) + .unwrap(); + + let cast_expr = transformed + .data + .as_any() + .downcast_ref::() + .expect("Expected CastColumnExpr"); + + assert_eq!(cast_expr.input_field().name(), "a"); + assert_eq!(cast_expr.input_field().data_type(), &DataType::Int32); + assert_eq!(cast_expr.target_field().data_type(), &DataType::Int64); + } } diff --git a/datafusion/physical-expr-common/src/binary_map.rs b/datafusion/physical-expr-common/src/binary_map.rs index ab95302bbb04..95d085ddfdb6 100644 --- a/datafusion/physical-expr-common/src/binary_map.rs +++ b/datafusion/physical-expr-common/src/binary_map.rs @@ -389,7 +389,7 @@ where // is value is already present in the set? let entry = self.map.find_mut(hash, |header| { // compare value if hashes match - if header.len != value_len { + if header.hash != hash || header.len != value_len { return false; } // value is stored inline so no need to consult buffer @@ -427,7 +427,7 @@ where // Check if the value is already present in the set let entry = self.map.find_mut(hash, |header| { // compare value if hashes match - if header.len != value_len { + if header.hash != hash { return false; } // Need to compare the bytes in the buffer diff --git a/datafusion/physical-expr-common/src/binary_view_map.rs b/datafusion/physical-expr-common/src/binary_view_map.rs index 2de563472c78..aa0d186f9ea0 100644 --- a/datafusion/physical-expr-common/src/binary_view_map.rs +++ b/datafusion/physical-expr-common/src/binary_view_map.rs @@ -17,16 +17,17 @@ //! [`ArrowBytesViewMap`] and [`ArrowBytesViewSet`] for storing maps/sets of values from //! `StringViewArray`/`BinaryViewArray`. -//! Much of the code is from `binary_map.rs`, but with simpler implementation because we directly use the -//! [`GenericByteViewBuilder`]. use crate::binary_map::OutputType; use ahash::RandomState; +use arrow::array::NullBufferBuilder; use arrow::array::cast::AsArray; -use arrow::array::{Array, ArrayBuilder, ArrayRef, GenericByteViewBuilder}; +use arrow::array::{Array, ArrayRef, BinaryViewArray, ByteView, make_view}; +use arrow::buffer::{Buffer, ScalarBuffer}; use arrow::datatypes::{BinaryViewType, ByteViewType, DataType, StringViewType}; use datafusion_common::hash_utils::create_hashes; use datafusion_common::utils::proxy::{HashTableAllocExt, VecAllocExt}; use std::fmt::Debug; +use std::mem::size_of; use std::sync::Arc; /// HashSet optimized for storing string or binary values that can produce that @@ -113,6 +114,9 @@ impl ArrowBytesViewSet { /// This map is used by the special `COUNT DISTINCT` aggregate function to /// store the distinct values, and by the `GROUP BY` operator to store /// group values when they are a single string array. +/// Max size of the in-progress buffer before flushing to completed buffers +const BYTE_VIEW_MAX_BLOCK_SIZE: usize = 2 * 1024 * 1024; + pub struct ArrowBytesViewMap where V: Debug + PartialEq + Eq + Clone + Copy + Default, @@ -124,8 +128,15 @@ where /// Total size of the map in bytes map_size: usize, - /// Builder for output array - builder: GenericByteViewBuilder, + /// Views for all stored values (in insertion order) + views: Vec, + /// In-progress buffer for out-of-line string data + in_progress: Vec, + /// Completed buffers containing string data + completed: Vec, + /// Tracks null values (true = null) + nulls: NullBufferBuilder, + /// random state used to generate hashes random_state: RandomState, /// buffer that stores hash values (reused across batches to save allocations) @@ -148,7 +159,10 @@ where output_type, map: hashbrown::hash_table::HashTable::with_capacity(INITIAL_MAP_CAPACITY), map_size: 0, - builder: GenericByteViewBuilder::new(), + views: Vec::new(), + in_progress: Vec::new(), + completed: Vec::new(), + nulls: NullBufferBuilder::new(0), random_state: RandomState::new(), hashes_buffer: vec![], null: None, @@ -250,53 +264,92 @@ where // step 2: insert each value into the set, if not already present let values = values.as_byte_view::(); + // Get raw views buffer for direct comparison + let input_views = values.views(); + // Ensure lengths are equivalent - assert_eq!(values.len(), batch_hashes.len()); + assert_eq!(values.len(), self.hashes_buffer.len()); + + for i in 0..values.len() { + let view_u128 = input_views[i]; + let hash = self.hashes_buffer[i]; - for (value, &hash) in values.iter().zip(batch_hashes.iter()) { - // handle null value - let Some(value) = value else { + // handle null value via validity bitmap check + if values.is_null(i) { let payload = if let Some(&(payload, _offset)) = self.null.as_ref() { payload } else { let payload = make_payload_fn(None); - let null_index = self.builder.len(); - self.builder.append_null(); + let null_index = self.views.len(); + self.views.push(0); + self.nulls.append_null(); self.null = Some((payload, null_index)); payload }; observe_payload_fn(payload); continue; - }; - - // get the value as bytes - let value: &[u8] = value.as_ref(); + } - let entry = self.map.find_mut(hash, |header| { - let v = self.builder.get_value(header.view_idx); + // Extract length from the view (first 4 bytes of u128 in little-endian) + let len = view_u128 as u32; - if v.len() != value.len() { - return false; - } + // Check if value already exists + let maybe_payload = { + // Borrow completed and in_progress for comparison + let completed = &self.completed; + let in_progress = &self.in_progress; - v == value - }); + self.map + .find(hash, |header| { + if header.hash != hash { + return false; + } + + // Fast path: inline strings can be compared directly + if len <= 12 { + return header.view == view_u128; + } + + // For larger strings: first compare the 4-byte prefix + let stored_prefix = (header.view >> 32) as u32; + let input_prefix = (view_u128 >> 32) as u32; + if stored_prefix != input_prefix { + return false; + } + + // Prefix matched - compare full bytes + let byte_view = ByteView::from(header.view); + let stored_len = byte_view.length as usize; + let buffer_index = byte_view.buffer_index as usize; + let offset = byte_view.offset as usize; + + let stored_value = if buffer_index < completed.len() { + &completed[buffer_index].as_slice() + [offset..offset + stored_len] + } else { + &in_progress[offset..offset + stored_len] + }; + let input_value: &[u8] = values.value(i).as_ref(); + stored_value == input_value + }) + .map(|entry| entry.payload) + }; - let payload = if let Some(entry) = entry { - entry.payload + let payload = if let Some(payload) = maybe_payload { + payload } else { - // no existing value, make a new one. + // no existing value, make a new one + let value: &[u8] = values.value(i).as_ref(); let payload = make_payload_fn(Some(value)); - let inner_view_idx = self.builder.len(); + // Create view pointing to our buffers + let new_view = self.append_value(value); let new_header = Entry { - view_idx: inner_view_idx, + view: new_view, hash, payload, }; - self.builder.append_value(value); - self.map .insert_accounted(new_header, |h| h.hash, &mut self.map_size); payload @@ -311,29 +364,58 @@ where /// /// The values are guaranteed to be returned in the same order in which /// they were first seen. - pub fn into_state(self) -> ArrayRef { - let mut builder = self.builder; - match self.output_type { - OutputType::BinaryView => { - let array = builder.finish(); + pub fn into_state(mut self) -> ArrayRef { + // Flush any remaining in-progress buffer + if !self.in_progress.is_empty() { + let flushed = std::mem::take(&mut self.in_progress); + self.completed.push(Buffer::from_vec(flushed)); + } - Arc::new(array) - } + // Build null buffer if we have any nulls + let null_buffer = self.nulls.finish(); + + let views = ScalarBuffer::from(self.views); + let array = + unsafe { BinaryViewArray::new_unchecked(views, self.completed, null_buffer) }; + + match self.output_type { + OutputType::BinaryView => Arc::new(array), OutputType::Utf8View => { - // SAFETY: - // we asserted the input arrays were all the correct type and - // thus since all the values that went in were valid (e.g. utf8) - // so are all the values that come out - let array = builder.finish(); + // SAFETY: all input was valid utf8 let array = unsafe { array.to_string_view_unchecked() }; Arc::new(array) } - _ => { - unreachable!("Utf8/Binary should use `ArrowBytesMap`") - } + _ => unreachable!("Utf8/Binary should use `ArrowBytesMap`"), } } + /// Append a value to our buffers and return the view pointing to it + fn append_value(&mut self, value: &[u8]) -> u128 { + let len = value.len(); + let view = if len <= 12 { + make_view(value, 0, 0) + } else { + // Ensure buffer is big enough + if self.in_progress.len() + len > BYTE_VIEW_MAX_BLOCK_SIZE { + let flushed = std::mem::replace( + &mut self.in_progress, + Vec::with_capacity(BYTE_VIEW_MAX_BLOCK_SIZE), + ); + self.completed.push(Buffer::from_vec(flushed)); + } + + let buffer_index = self.completed.len() as u32; + let offset = self.in_progress.len() as u32; + self.in_progress.extend_from_slice(value); + + make_view(value, buffer_index, offset) + }; + + self.views.push(view); + self.nulls.append_non_null(); + view + } + /// Total number of entries (including null, if present) pub fn len(&self) -> usize { self.non_null_len() + self.null.map(|_| 1).unwrap_or(0) @@ -352,8 +434,16 @@ where /// Return the total size, in bytes, of memory used to store the data in /// this set, not including `self` pub fn size(&self) -> usize { + let views_size = self.views.len() * size_of::(); + let in_progress_size = self.in_progress.capacity(); + let completed_size: usize = self.completed.iter().map(|b| b.len()).sum(); + let nulls_size = self.nulls.allocated_size(); + self.map_size - + self.builder.allocated_size() + + views_size + + in_progress_size + + completed_size + + nulls_size + self.hashes_buffer.allocated_size() } } @@ -366,7 +456,8 @@ where f.debug_struct("ArrowBytesMap") .field("map", &"") .field("map_size", &self.map_size) - .field("view_builder", &self.builder) + .field("views_len", &self.views.len()) + .field("completed_buffers", &self.completed.len()) .field("random_state", &self.random_state) .field("hashes_buffer", &self.hashes_buffer) .finish() @@ -374,13 +465,20 @@ where } /// Entry in the hash table -- see [`ArrowBytesViewMap`] for more details +/// +/// Stores the view pointing to our internal buffers, eliminating the need +/// for a separate builder index. For inline strings (<=12 bytes), the view +/// contains the entire value. For out-of-line strings, the view contains +/// buffer_index and offset pointing directly to our storage. #[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] struct Entry where V: Debug + PartialEq + Eq + Clone + Copy + Default, { - /// The idx into the views array - view_idx: usize, + /// The u128 view pointing to our internal buffers. For inline strings, + /// this contains the complete value. For larger strings, this contains + /// the buffer_index/offset into our completed/in_progress buffers. + view: u128, hash: u64, diff --git a/datafusion/physical-expr-common/src/datum.rs b/datafusion/physical-expr-common/src/datum.rs index 16ef38b0940b..9efaca0f6b6a 100644 --- a/datafusion/physical-expr-common/src/datum.rs +++ b/datafusion/physical-expr-common/src/datum.rs @@ -189,14 +189,14 @@ pub fn compare_op_for_nested( (false, false) | (true, true) => NullBuffer::union(l.nulls(), r.nulls()), (true, false) => { // When left is null-scalar and right is array, expand left nulls to match result length - match l.nulls().filter(|nulls| !nulls.is_valid(0)) { + match l.nulls().filter(|nulls| nulls.is_null(0)) { Some(_) => Some(NullBuffer::new_null(len)), // Left scalar is null None => r.nulls().cloned(), // Left scalar is non-null } } (false, true) => { // When right is null-scalar and left is array, expand right nulls to match result length - match r.nulls().filter(|nulls| !nulls.is_valid(0)) { + match r.nulls().filter(|nulls| nulls.is_null(0)) { Some(_) => Some(NullBuffer::new_null(len)), // Right scalar is null None => l.nulls().cloned(), // Right scalar is non-null } diff --git a/datafusion/physical-expr-common/src/lib.rs b/datafusion/physical-expr-common/src/lib.rs index 84378a3d26ee..b6eaacdca250 100644 --- a/datafusion/physical-expr-common/src/lib.rs +++ b/datafusion/physical-expr-common/src/lib.rs @@ -24,8 +24,6 @@ // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -// https://github.com/apache/datafusion/issues/18881 -#![deny(clippy::allow_attributes)] //! Physical Expr Common packages for [DataFusion] //! This package contains high level PhysicalExpr trait diff --git a/datafusion/physical-expr-common/src/metrics/value.rs b/datafusion/physical-expr-common/src/metrics/value.rs index 9a14b804a20b..26f68980bad8 100644 --- a/datafusion/physical-expr-common/src/metrics/value.rs +++ b/datafusion/physical-expr-common/src/metrics/value.rs @@ -372,19 +372,31 @@ impl Drop for ScopedTimerGuard<'_> { pub struct PruningMetrics { pruned: Arc, matched: Arc, + fully_matched: Arc, } impl Display for PruningMetrics { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { let matched = self.matched.load(Ordering::Relaxed); let total = self.pruned.load(Ordering::Relaxed) + matched; + let fully_matched = self.fully_matched.load(Ordering::Relaxed); - write!( - f, - "{} total → {} matched", - human_readable_count(total), - human_readable_count(matched) - ) + if fully_matched != 0 { + write!( + f, + "{} total → {} matched -> {} fully matched", + human_readable_count(total), + human_readable_count(matched), + human_readable_count(fully_matched) + ) + } else { + write!( + f, + "{} total → {} matched", + human_readable_count(total), + human_readable_count(matched) + ) + } } } @@ -400,6 +412,7 @@ impl PruningMetrics { Self { pruned: Arc::new(AtomicUsize::new(0)), matched: Arc::new(AtomicUsize::new(0)), + fully_matched: Arc::new(AtomicUsize::new(0)), } } @@ -417,6 +430,13 @@ impl PruningMetrics { self.matched.fetch_add(n, Ordering::Relaxed); } + /// Add `n` to the metric's fully matched value + pub fn add_fully_matched(&self, n: usize) { + // relaxed ordering for operations on `value` poses no issues + // we're purely using atomic ops with no associated memory ops + self.fully_matched.fetch_add(n, Ordering::Relaxed); + } + /// Subtract `n` to the metric's matched value. pub fn subtract_matched(&self, n: usize) { // relaxed ordering for operations on `value` poses no issues @@ -433,6 +453,11 @@ impl PruningMetrics { pub fn matched(&self) -> usize { self.matched.load(Ordering::Relaxed) } + + /// Number of items fully matched + pub fn fully_matched(&self) -> usize { + self.fully_matched.load(Ordering::Relaxed) + } } /// Counters tracking ratio metrics (e.g. matched vs total) @@ -906,8 +931,11 @@ impl MetricValue { ) => { let pruned = other_pruning_metrics.pruned.load(Ordering::Relaxed); let matched = other_pruning_metrics.matched.load(Ordering::Relaxed); + let fully_matched = + other_pruning_metrics.fully_matched.load(Ordering::Relaxed); pruning_metrics.add_pruned(pruned); pruning_metrics.add_matched(matched); + pruning_metrics.add_fully_matched(fully_matched); } ( Self::Ratio { ratio_metrics, .. }, @@ -956,20 +984,21 @@ impl MetricValue { "files_ranges_pruned_statistics" => 4, "row_groups_pruned_statistics" => 5, "row_groups_pruned_bloom_filter" => 6, - "page_index_rows_pruned" => 7, - _ => 8, + "page_index_pages_pruned" => 7, + "page_index_rows_pruned" => 8, + _ => 9, }, - Self::SpillCount(_) => 9, - Self::SpilledBytes(_) => 10, - Self::SpilledRows(_) => 11, - Self::CurrentMemoryUsage(_) => 12, - Self::Count { .. } => 13, - Self::Gauge { .. } => 14, - Self::Time { .. } => 15, - Self::Ratio { .. } => 16, - Self::StartTimestamp(_) => 17, // show timestamps last - Self::EndTimestamp(_) => 18, - Self::Custom { .. } => 19, + Self::SpillCount(_) => 10, + Self::SpilledBytes(_) => 11, + Self::SpilledRows(_) => 12, + Self::CurrentMemoryUsage(_) => 13, + Self::Count { .. } => 14, + Self::Gauge { .. } => 15, + Self::Time { .. } => 16, + Self::Ratio { .. } => 17, + Self::StartTimestamp(_) => 18, // show timestamps last + Self::EndTimestamp(_) => 19, + Self::Custom { .. } => 20, } } diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 2358a2194091..7107b0a9004d 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -35,6 +35,7 @@ use datafusion_common::{ }; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; +use datafusion_expr_common::placement::ExpressionPlacement; use datafusion_expr_common::sort_properties::ExprProperties; use datafusion_expr_common::statistics::Distribution; @@ -430,6 +431,16 @@ pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash { fn is_volatile_node(&self) -> bool { false } + + /// Returns placement information for this expression. + /// + /// This is used by optimizers to make decisions about expression placement, + /// such as whether to push expressions down through projections. + /// + /// The default implementation returns [`ExpressionPlacement::KeepInPlace`]. + fn placement(&self) -> ExpressionPlacement { + ExpressionPlacement::KeepInPlace + } } #[deprecated( diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 1b23beeaa37c..7e61be3a16ae 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -85,5 +85,9 @@ name = "is_null" harness = false name = "binary_op" +[[bench]] +harness = false +name = "simplify" + [package.metadata.cargo-machete] ignored = ["half"] diff --git a/datafusion/physical-expr/benches/case_when.rs b/datafusion/physical-expr/benches/case_when.rs index eb0886a31e8d..33931a2ba98e 100644 --- a/datafusion/physical-expr/benches/case_when.rs +++ b/datafusion/physical-expr/benches/case_when.rs @@ -20,6 +20,7 @@ use arrow::datatypes::{ArrowNativeTypeOp, Field, Schema}; use arrow::record_batch::RecordBatch; use arrow::util::test_util::seedable_rng; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, case, col, lit}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -93,6 +94,7 @@ fn criterion_benchmark(c: &mut Criterion) { run_benchmarks(c, &make_batch(8192, 100)); benchmark_lookup_table_case_when(c, 8192); + benchmark_divide_by_zero_protection(c, 8192); } fn run_benchmarks(c: &mut Criterion, batch: &RecordBatch) { @@ -517,5 +519,83 @@ fn benchmark_lookup_table_case_when(c: &mut Criterion, batch_size: usize) { } } +fn benchmark_divide_by_zero_protection(c: &mut Criterion, batch_size: usize) { + let mut group = c.benchmark_group("divide_by_zero_protection"); + + for zero_percentage in [0.0, 0.1, 0.5, 0.9] { + let rng = &mut seedable_rng(); + + let numerator: Int32Array = + (0..batch_size).map(|_| Some(rng.random::())).collect(); + + let divisor_values: Vec> = (0..batch_size) + .map(|_| { + let roll: f32 = rng.random(); + if roll < zero_percentage { + Some(0) + } else { + let mut val = rng.random::(); + while val == 0 { + val = rng.random::(); + } + Some(val) + } + }) + .collect(); + + let divisor: Int32Array = divisor_values.iter().cloned().collect(); + let divisor_copy: Int32Array = divisor_values.iter().cloned().collect(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("numerator", numerator.data_type().clone(), true), + Field::new("divisor", divisor.data_type().clone(), true), + Field::new("divisor_copy", divisor_copy.data_type().clone(), true), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(numerator), + Arc::new(divisor), + Arc::new(divisor_copy), + ], + ) + .unwrap(); + + let numerator_col = col("numerator", &batch.schema()).unwrap(); + let divisor_col = col("divisor", &batch.schema()).unwrap(); + + // DivideByZeroProtection: WHEN condition checks `divisor_col > 0` and division + // uses `divisor_col` as divisor. Since the checked column matches the divisor, + // this triggers the DivideByZeroProtection optimization. + group.bench_function( + format!( + "{} rows, {}% zeros: DivideByZeroProtection", + batch_size, + (zero_percentage * 100.0) as i32 + ), + |b| { + let when = Arc::new(BinaryExpr::new( + Arc::clone(&divisor_col), + Operator::NotEq, + lit(0i32), + )); + let then = Arc::new(BinaryExpr::new( + Arc::clone(&numerator_col), + Operator::Divide, + Arc::clone(&divisor_col), + )); + let else_null: Arc = lit(ScalarValue::Int32(None)); + let expr = + Arc::new(case(None, vec![(when, then)], Some(else_null)).unwrap()); + + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }, + ); + } + + group.finish(); +} + criterion_group!(benches, criterion_benchmark); criterion_main!(benches); diff --git a/datafusion/physical-expr/benches/in_list.rs b/datafusion/physical-expr/benches/in_list.rs index 954715d0e5a9..021d8259cdfd 100644 --- a/datafusion/physical-expr/benches/in_list.rs +++ b/datafusion/physical-expr/benches/in_list.rs @@ -23,6 +23,7 @@ use arrow::datatypes::{Field, Schema}; use arrow::record_batch::RecordBatch; use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::ScalarValue; +use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr::expressions::{col, in_list, lit}; use rand::distr::Alphanumeric; use rand::prelude::*; @@ -50,7 +51,9 @@ fn random_string(rng: &mut StdRng, len: usize) -> String { } const IN_LIST_LENGTHS: [usize; 4] = [3, 8, 28, 100]; +const LIST_WITH_COLUMNS_LENGTHS: [usize; 3] = [3, 8, 28]; const NULL_PERCENTS: [f64; 2] = [0., 0.2]; +const MATCH_PERCENTS: [f64; 3] = [0.0, 0.5, 1.0]; const STRING_LENGTHS: [usize; 3] = [3, 12, 100]; const ARRAY_LENGTH: usize = 8192; @@ -219,6 +222,165 @@ fn bench_realistic_mixed_strings( } } +/// Benchmarks the column-reference evaluation path (no static filter) by including +/// a column reference in the IN list, which prevents static filter creation. +/// +/// This simulates SQL like: +/// ```sql +/// CREATE TABLE t (a INT, b0 INT, b1 INT, b2 INT); +/// SELECT * FROM t WHERE a IN (b0, b1, b2); +/// ``` +/// +/// - `values`: the "needle" column (`a`) +/// - `list_cols`: the "haystack" columns (`b0`, `b1`, …) +fn do_bench_with_columns( + c: &mut Criterion, + name: &str, + values: ArrayRef, + list_cols: &[ArrayRef], +) { + let mut fields = vec![Field::new("a", values.data_type().clone(), true)]; + let mut columns: Vec = vec![values]; + + // Build list expressions: column refs (forces non-constant evaluation path) + let schema_fields: Vec = list_cols + .iter() + .enumerate() + .map(|(i, col_arr)| { + let name = format!("b{i}"); + fields.push(Field::new(&name, col_arr.data_type().clone(), true)); + columns.push(Arc::clone(col_arr)); + Field::new(&name, col_arr.data_type().clone(), true) + }) + .collect(); + + let schema = Schema::new(fields); + let list_exprs: Vec> = schema_fields + .iter() + .map(|f| col(f.name(), &schema).unwrap()) + .collect(); + + let expr = in_list(col("a", &schema).unwrap(), list_exprs, &false, &schema).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), columns).unwrap(); + + c.bench_function(name, |b| { + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); +} + +/// Benchmarks the IN list path with column references for Int32 arrays. +/// +/// Equivalent SQL: +/// ```sql +/// CREATE TABLE t (a INT, b0 INT, b1 INT, ...); +/// SELECT * FROM t WHERE a IN (b0, b1, ...); +/// ``` +fn bench_with_columns_int32(c: &mut Criterion) { + let mut rng = StdRng::seed_from_u64(42); + + for list_size in LIST_WITH_COLUMNS_LENGTHS { + for match_percent in MATCH_PERCENTS { + for null_percent in NULL_PERCENTS { + // Generate the "needle" column + let values: Int32Array = (0..ARRAY_LENGTH) + .map(|_| { + rng.random_bool(1.0 - null_percent) + .then(|| rng.random_range(0..1000)) + }) + .collect(); + + // Generate list columns with controlled match rate + let list_cols: Vec = (0..list_size) + .map(|_| { + let col: Int32Array = (0..ARRAY_LENGTH) + .map(|row| { + if rng.random_bool(1.0 - null_percent) { + if rng.random_bool(match_percent) { + // Copy from values to create a match + if values.is_null(row) { + Some(rng.random_range(0..1000)) + } else { + Some(values.value(row)) + } + } else { + // Random value (unlikely to match) + Some(rng.random_range(1000..2000)) + } + } else { + None + } + }) + .collect(); + Arc::new(col) as ArrayRef + }) + .collect(); + + do_bench_with_columns( + c, + &format!( + "in_list_cols/Int32/list={}/match={}%/nulls={}%", + list_size, + (match_percent * 100.0) as u32, + (null_percent * 100.0) as u32 + ), + Arc::new(values), + &list_cols, + ); + } + } + } +} + +/// Benchmarks the IN list path with column references for Utf8 arrays. +/// +/// Equivalent SQL: +/// ```sql +/// CREATE TABLE t (a VARCHAR, b0 VARCHAR, b1 VARCHAR, ...); +/// SELECT * FROM t WHERE a IN (b0, b1, ...); +/// ``` +fn bench_with_columns_utf8(c: &mut Criterion) { + let mut rng = StdRng::seed_from_u64(99); + + for list_size in LIST_WITH_COLUMNS_LENGTHS { + for match_percent in MATCH_PERCENTS { + // Generate the "needle" column + let value_strings: Vec> = (0..ARRAY_LENGTH) + .map(|_| rng.random_bool(0.8).then(|| random_string(&mut rng, 12))) + .collect(); + let values: StringArray = + value_strings.iter().map(|s| s.as_deref()).collect(); + + // Generate list columns with controlled match rate + let list_cols: Vec = (0..list_size) + .map(|_| { + let col: StringArray = (0..ARRAY_LENGTH) + .map(|row| { + if rng.random_bool(match_percent) { + // Copy from values to create a match + value_strings[row].as_deref() + } else { + Some("no_match_value_xyz") + } + }) + .collect(); + Arc::new(col) as ArrayRef + }) + .collect(); + + do_bench_with_columns( + c, + &format!( + "in_list_cols/Utf8/list={}/match={}%", + list_size, + (match_percent * 100.0) as u32, + ), + Arc::new(values), + &list_cols, + ); + } + } +} + /// Entry point: registers in_list benchmarks for string and numeric array types. fn criterion_benchmark(c: &mut Criterion) { let mut rng = StdRng::seed_from_u64(120320); @@ -266,6 +428,10 @@ fn criterion_benchmark(c: &mut Criterion) { |rng| rng.random(), |v| ScalarValue::TimestampNanosecond(Some(v), None), ); + + // Column-reference path benchmarks (non-constant list expressions) + bench_with_columns_int32(c); + bench_with_columns_utf8(c); } criterion_group! { diff --git a/datafusion/physical-expr/benches/simplify.rs b/datafusion/physical-expr/benches/simplify.rs new file mode 100644 index 000000000000..cc00c710004e --- /dev/null +++ b/datafusion/physical-expr/benches/simplify.rs @@ -0,0 +1,299 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This is an attempt at reproducing some predicates generated by TPC-DS query #76, +//! and trying to figure out how long it takes to simplify them. + +use arrow::datatypes::{DataType, Field, Schema}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::simplifier::PhysicalExprSimplifier; +use std::hint::black_box; +use std::sync::Arc; + +use datafusion_common::ScalarValue; +use datafusion_expr::Operator; + +use datafusion_physical_expr::expressions::{ + BinaryExpr, CaseExpr, Column, IsNullExpr, Literal, +}; + +fn catalog_sales_schema() -> Schema { + Schema::new(vec![ + Field::new("cs_sold_date_sk", DataType::Int64, true), // 0 + Field::new("cs_sold_time_sk", DataType::Int64, true), // 1 + Field::new("cs_ship_date_sk", DataType::Int64, true), // 2 + Field::new("cs_bill_customer_sk", DataType::Int64, true), // 3 + Field::new("cs_bill_cdemo_sk", DataType::Int64, true), // 4 + Field::new("cs_bill_hdemo_sk", DataType::Int64, true), // 5 + Field::new("cs_bill_addr_sk", DataType::Int64, true), // 6 + Field::new("cs_ship_customer_sk", DataType::Int64, true), // 7 + Field::new("cs_ship_cdemo_sk", DataType::Int64, true), // 8 + Field::new("cs_ship_hdemo_sk", DataType::Int64, true), // 9 + Field::new("cs_ship_addr_sk", DataType::Int64, true), // 10 + Field::new("cs_call_center_sk", DataType::Int64, true), // 11 + Field::new("cs_catalog_page_sk", DataType::Int64, true), // 12 + Field::new("cs_ship_mode_sk", DataType::Int64, true), // 13 + Field::new("cs_warehouse_sk", DataType::Int64, true), // 14 + Field::new("cs_item_sk", DataType::Int64, true), // 15 + Field::new("cs_promo_sk", DataType::Int64, true), // 16 + Field::new("cs_order_number", DataType::Int64, true), // 17 + Field::new("cs_quantity", DataType::Int64, true), // 18 + Field::new("cs_wholesale_cost", DataType::Decimal128(7, 2), true), + Field::new("cs_list_price", DataType::Decimal128(7, 2), true), + Field::new("cs_sales_price", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_discount_amt", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_sales_price", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_wholesale_cost", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_list_price", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_tax", DataType::Decimal128(7, 2), true), + Field::new("cs_coupon_amt", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_ship_cost", DataType::Decimal128(7, 2), true), + Field::new("cs_net_paid", DataType::Decimal128(7, 2), true), + Field::new("cs_net_paid_inc_tax", DataType::Decimal128(7, 2), true), + Field::new("cs_net_paid_inc_ship", DataType::Decimal128(7, 2), true), + Field::new("cs_net_paid_inc_ship_tax", DataType::Decimal128(7, 2), true), + Field::new("cs_net_profit", DataType::Decimal128(7, 2), true), + ]) +} + +fn web_sales_schema() -> Schema { + Schema::new(vec![ + Field::new("ws_sold_date_sk", DataType::Int64, true), + Field::new("ws_sold_time_sk", DataType::Int64, true), + Field::new("ws_ship_date_sk", DataType::Int64, true), + Field::new("ws_item_sk", DataType::Int64, true), + Field::new("ws_bill_customer_sk", DataType::Int64, true), + Field::new("ws_bill_cdemo_sk", DataType::Int64, true), + Field::new("ws_bill_hdemo_sk", DataType::Int64, true), + Field::new("ws_bill_addr_sk", DataType::Int64, true), + Field::new("ws_ship_customer_sk", DataType::Int64, true), + Field::new("ws_ship_cdemo_sk", DataType::Int64, true), + Field::new("ws_ship_hdemo_sk", DataType::Int64, true), + Field::new("ws_ship_addr_sk", DataType::Int64, true), + Field::new("ws_web_page_sk", DataType::Int64, true), + Field::new("ws_web_site_sk", DataType::Int64, true), + Field::new("ws_ship_mode_sk", DataType::Int64, true), + Field::new("ws_warehouse_sk", DataType::Int64, true), + Field::new("ws_promo_sk", DataType::Int64, true), + Field::new("ws_order_number", DataType::Int64, true), + Field::new("ws_quantity", DataType::Int64, true), + Field::new("ws_wholesale_cost", DataType::Decimal128(7, 2), true), + Field::new("ws_list_price", DataType::Decimal128(7, 2), true), + Field::new("ws_sales_price", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_discount_amt", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_sales_price", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_wholesale_cost", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_list_price", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_tax", DataType::Decimal128(7, 2), true), + Field::new("ws_coupon_amt", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_ship_cost", DataType::Decimal128(7, 2), true), + Field::new("ws_net_paid", DataType::Decimal128(7, 2), true), + Field::new("ws_net_paid_inc_tax", DataType::Decimal128(7, 2), true), + Field::new("ws_net_paid_inc_ship", DataType::Decimal128(7, 2), true), + Field::new("ws_net_paid_inc_ship_tax", DataType::Decimal128(7, 2), true), + Field::new("ws_net_profit", DataType::Decimal128(7, 2), true), + ]) +} + +// Helper to create a literal +fn lit_i64(val: i64) -> Arc { + Arc::new(Literal::new(ScalarValue::Int64(Some(val)))) +} + +fn lit_i32(val: i32) -> Arc { + Arc::new(Literal::new(ScalarValue::Int32(Some(val)))) +} + +fn lit_bool(val: bool) -> Arc { + Arc::new(Literal::new(ScalarValue::Boolean(Some(val)))) +} + +// Helper to create binary expressions +fn and( + left: Arc, + right: Arc, +) -> Arc { + Arc::new(BinaryExpr::new(left, Operator::And, right)) +} + +fn gte( + left: Arc, + right: Arc, +) -> Arc { + Arc::new(BinaryExpr::new(left, Operator::GtEq, right)) +} + +fn lte( + left: Arc, + right: Arc, +) -> Arc { + Arc::new(BinaryExpr::new(left, Operator::LtEq, right)) +} + +fn modulo( + left: Arc, + right: Arc, +) -> Arc { + Arc::new(BinaryExpr::new(left, Operator::Modulo, right)) +} + +fn eq( + left: Arc, + right: Arc, +) -> Arc { + Arc::new(BinaryExpr::new(left, Operator::Eq, right)) +} + +/// Build a predicate similar to TPC-DS q76 catalog_sales filter. +/// Uses placeholder columns instead of hash expressions. +pub fn catalog_sales_predicate(num_partitions: usize) -> Arc { + let cs_sold_date_sk: Arc = + Arc::new(Column::new("cs_sold_date_sk", 0)); + let cs_ship_addr_sk: Arc = + Arc::new(Column::new("cs_ship_addr_sk", 10)); + let cs_item_sk: Arc = Arc::new(Column::new("cs_item_sk", 15)); + + // Use a simple modulo expression as placeholder for hash + let item_hash_mod = modulo(cs_item_sk.clone(), lit_i64(num_partitions as i64)); + let date_hash_mod = modulo(cs_sold_date_sk.clone(), lit_i64(num_partitions as i64)); + + // cs_ship_addr_sk IS NULL + let is_null_expr: Arc = Arc::new(IsNullExpr::new(cs_ship_addr_sk)); + + // Build item_sk CASE expression with num_partitions branches + let item_when_then: Vec<(Arc, Arc)> = (0 + ..num_partitions) + .map(|partition| { + let when_expr = eq(item_hash_mod.clone(), lit_i32(partition as i32)); + let then_expr = and( + gte(cs_item_sk.clone(), lit_i64(partition as i64)), + lte(cs_item_sk.clone(), lit_i64(18000)), + ); + (when_expr, then_expr) + }) + .collect(); + + let item_case_expr: Arc = + Arc::new(CaseExpr::try_new(None, item_when_then, Some(lit_bool(false))).unwrap()); + + // Build sold_date_sk CASE expression with num_partitions branches + let date_when_then: Vec<(Arc, Arc)> = (0 + ..num_partitions) + .map(|partition| { + let when_expr = eq(date_hash_mod.clone(), lit_i32(partition as i32)); + let then_expr = and( + gte(cs_sold_date_sk.clone(), lit_i64(2415000 + partition as i64)), + lte(cs_sold_date_sk.clone(), lit_i64(2488070)), + ); + (when_expr, then_expr) + }) + .collect(); + + let date_case_expr: Arc = + Arc::new(CaseExpr::try_new(None, date_when_then, Some(lit_bool(false))).unwrap()); + + // Final: is_null AND item_case AND date_case + and(and(is_null_expr, item_case_expr), date_case_expr) +} +/// Build a predicate similar to TPC-DS q76 web_sales filter. +/// Uses placeholder columns instead of hash expressions. +fn web_sales_predicate(num_partitions: usize) -> Arc { + let ws_sold_date_sk: Arc = + Arc::new(Column::new("ws_sold_date_sk", 0)); + let ws_item_sk: Arc = Arc::new(Column::new("ws_item_sk", 3)); + let ws_ship_customer_sk: Arc = + Arc::new(Column::new("ws_ship_customer_sk", 8)); + + // Use simple modulo expression as placeholder for hash + let item_hash_mod = modulo(ws_item_sk.clone(), lit_i64(num_partitions as i64)); + let date_hash_mod = modulo(ws_sold_date_sk.clone(), lit_i64(num_partitions as i64)); + + // ws_ship_customer_sk IS NULL + let is_null_expr: Arc = + Arc::new(IsNullExpr::new(ws_ship_customer_sk)); + + // Build item_sk CASE expression with num_partitions branches + let item_when_then: Vec<(Arc, Arc)> = (0 + ..num_partitions) + .map(|partition| { + let when_expr = eq(item_hash_mod.clone(), lit_i32(partition as i32)); + let then_expr = and( + gte(ws_item_sk.clone(), lit_i64(partition as i64)), + lte(ws_item_sk.clone(), lit_i64(18000)), + ); + (when_expr, then_expr) + }) + .collect(); + + let item_case_expr: Arc = + Arc::new(CaseExpr::try_new(None, item_when_then, Some(lit_bool(false))).unwrap()); + + // Build sold_date_sk CASE expression with num_partitions branches + let date_when_then: Vec<(Arc, Arc)> = (0 + ..num_partitions) + .map(|partition| { + let when_expr = eq(date_hash_mod.clone(), lit_i32(partition as i32)); + let then_expr = and( + gte(ws_sold_date_sk.clone(), lit_i64(2415000 + partition as i64)), + lte(ws_sold_date_sk.clone(), lit_i64(2488070)), + ); + (when_expr, then_expr) + }) + .collect(); + + let date_case_expr: Arc = + Arc::new(CaseExpr::try_new(None, date_when_then, Some(lit_bool(false))).unwrap()); + + and(and(is_null_expr, item_case_expr), date_case_expr) +} + +/// Measures how long `PhysicalExprSimplifier::simplify` takes for a given expression. +fn bench_simplify( + c: &mut Criterion, + name: &str, + schema: &Schema, + expr: &Arc, +) { + let simplifier = PhysicalExprSimplifier::new(schema); + c.bench_function(name, |b| { + b.iter(|| black_box(simplifier.simplify(black_box(Arc::clone(expr))).unwrap())) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + let cs_schema = catalog_sales_schema(); + let ws_schema = web_sales_schema(); + + for num_partitions in [16, 128] { + bench_simplify( + c, + &format!("tpc-ds/q76/cs/{num_partitions}"), + &cs_schema, + &catalog_sales_predicate(num_partitions), + ); + bench_simplify( + c, + &format!("tpc-ds/q76/ws/{num_partitions}"), + &ws_schema, + &web_sales_predicate(num_partitions), + ); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/physical-expr/src/analysis.rs b/datafusion/physical-expr/src/analysis.rs index d734c86726f1..11a60afc90a1 100644 --- a/datafusion/physical-expr/src/analysis.rs +++ b/datafusion/physical-expr/src/analysis.rs @@ -178,7 +178,7 @@ pub fn analyze( "ExprBoundaries has a non-zero distinct count although it represents an empty table" ); assert_or_internal_err!( - context.selectivity == Some(0.0), + context.selectivity.unwrap_or(0.0) == 0.0, "AnalysisContext has a non-zero selectivity although it represents an empty table" ); Ok(context) diff --git a/datafusion/physical-expr/src/equivalence/properties/mod.rs b/datafusion/physical-expr/src/equivalence/properties/mod.rs index 70f97139f8af..a98341b10765 100644 --- a/datafusion/physical-expr/src/equivalence/properties/mod.rs +++ b/datafusion/physical-expr/src/equivalence/properties/mod.rs @@ -207,8 +207,13 @@ impl EquivalenceProperties { } /// Adds constraints to the properties. - pub fn with_constraints(mut self, constraints: Constraints) -> Self { + pub fn set_constraints(&mut self, constraints: Constraints) { self.constraints = constraints; + } + + /// Adds constraints to the properties. + pub fn with_constraints(mut self, constraints: Constraints) -> Self { + self.set_constraints(constraints); self } @@ -1277,7 +1282,7 @@ impl EquivalenceProperties { // Rewriting equivalence properties in terms of new schema is not // safe when schemas are not aligned: return plan_err!( - "Schemas have to be aligned to rewrite equivalences:\n Old schema: {:?}\n New schema: {:?}", + "Schemas have to be aligned to rewrite equivalences:\n Old schema: {}\n New schema: {}", self.schema, schema ); diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 8df09c22bbd8..02628b405ec6 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -30,6 +30,7 @@ use arrow::datatypes::*; use arrow::error::ArrowError; use datafusion_common::cast::as_boolean_array; use datafusion_common::{Result, ScalarValue, internal_err, not_impl_err}; + use datafusion_expr::binary::BinaryTypeCoercer; use datafusion_expr::interval_arithmetic::{Interval, apply_operator}; use datafusion_expr::sort_properties::ExprProperties; @@ -162,6 +163,94 @@ fn boolean_op( op(ll, rr).map(|t| Arc::new(t) as _) } +/// Returns true if both operands are Date types (Date32 or Date64) +/// Used to detect Date - Date operations which should return Int64 (days difference) +fn is_date_minus_date(lhs: &DataType, rhs: &DataType) -> bool { + matches!( + (lhs, rhs), + (DataType::Date32, DataType::Date32) | (DataType::Date64, DataType::Date64) + ) +} + +/// Computes the difference between two dates and returns the result as Int64 (days) +/// This aligns with PostgreSQL, DuckDB, and MySQL behavior where date - date returns an integer +/// +/// Implementation: Uses Arrow's sub_wrapping to get Duration, then converts to Int64 days +fn apply_date_subtraction( + lhs: &ColumnarValue, + rhs: &ColumnarValue, +) -> Result { + use arrow::compute::kernels::numeric::sub_wrapping; + + // Use Arrow's sub_wrapping to compute the Duration result + let duration_result = apply(lhs, rhs, sub_wrapping)?; + + // Convert Duration to Int64 (days) + match duration_result { + ColumnarValue::Array(array) => { + let int64_array = duration_to_days(&array)?; + Ok(ColumnarValue::Array(int64_array)) + } + ColumnarValue::Scalar(scalar) => { + // Convert scalar Duration to Int64 days + let array = scalar.to_array_of_size(1)?; + let int64_array = duration_to_days(&array)?; + let int64_scalar = ScalarValue::try_from_array(int64_array.as_ref(), 0)?; + Ok(ColumnarValue::Scalar(int64_scalar)) + } + } +} + +/// Converts a Duration array to Int64 days +/// Handles different Duration time units (Second, Millisecond, Microsecond, Nanosecond) +fn duration_to_days(array: &ArrayRef) -> Result { + use datafusion_common::cast::{ + as_duration_microsecond_array, as_duration_millisecond_array, + as_duration_nanosecond_array, as_duration_second_array, + }; + + const SECONDS_PER_DAY: i64 = 86_400; + const MILLIS_PER_DAY: i64 = 86_400_000; + const MICROS_PER_DAY: i64 = 86_400_000_000; + const NANOS_PER_DAY: i64 = 86_400_000_000_000; + + match array.data_type() { + DataType::Duration(TimeUnit::Second) => { + let duration_array = as_duration_second_array(array)?; + let result: Int64Array = duration_array + .iter() + .map(|v| v.map(|val| val / SECONDS_PER_DAY)) + .collect(); + Ok(Arc::new(result)) + } + DataType::Duration(TimeUnit::Millisecond) => { + let duration_array = as_duration_millisecond_array(array)?; + let result: Int64Array = duration_array + .iter() + .map(|v| v.map(|val| val / MILLIS_PER_DAY)) + .collect(); + Ok(Arc::new(result)) + } + DataType::Duration(TimeUnit::Microsecond) => { + let duration_array = as_duration_microsecond_array(array)?; + let result: Int64Array = duration_array + .iter() + .map(|v| v.map(|val| val / MICROS_PER_DAY)) + .collect(); + Ok(Arc::new(result)) + } + DataType::Duration(TimeUnit::Nanosecond) => { + let duration_array = as_duration_nanosecond_array(array)?; + let result: Int64Array = duration_array + .iter() + .map(|v| v.map(|val| val / NANOS_PER_DAY)) + .collect(); + Ok(Arc::new(result)) + } + other => internal_err!("duration_to_days expected Duration type, got: {}", other), + } +} + impl PhysicalExpr for BinaryExpr { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { @@ -251,6 +340,11 @@ impl PhysicalExpr for BinaryExpr { match self.op { Operator::Plus if self.fail_on_overflow => return apply(&lhs, &rhs, add), Operator::Plus => return apply(&lhs, &rhs, add_wrapping), + // Special case: Date - Date returns Int64 (days difference) + // This aligns with PostgreSQL, DuckDB, and MySQL behavior + Operator::Minus if is_date_minus_date(&left_data_type, &right_data_type) => { + return apply_date_subtraction(&lhs, &rhs); + } Operator::Minus if self.fail_on_overflow => return apply(&lhs, &rhs, sub), Operator::Minus => return apply(&lhs, &rhs, sub_wrapping), Operator::Multiply if self.fail_on_overflow => return apply(&lhs, &rhs, mul), @@ -621,7 +715,7 @@ impl BinaryExpr { StringConcat => concat_elements(&left, &right), AtArrow | ArrowAt | Arrow | LongArrow | HashArrow | HashLongArrow | AtAt | HashMinus | AtQuestion | Question | QuestionAnd | QuestionPipe - | IntegerDivide => { + | IntegerDivide | Colon => { not_impl_err!( "Binary operator '{:?}' is not supported in the physical expr", self.op diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 758317d3d279..f1d867dddf36 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -42,6 +42,7 @@ use crate::expressions::case::literal_lookup_table::LiteralLookupTable; use arrow::compute::kernels::merge::{MergeIndex, merge, merge_n}; use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_physical_expr_common::datum::compare_with_eq; +use datafusion_physical_expr_common::utils::scatter; use itertools::Itertools; use std::fmt::{Debug, Formatter}; @@ -64,7 +65,7 @@ enum EvalMethod { /// for expressions that are infallible and can be cheaply computed for the entire /// record batch rather than just for the rows where the predicate is true. /// - /// CASE WHEN condition THEN column [ELSE NULL] END + /// CASE WHEN condition THEN infallible_expression [ELSE NULL] END InfallibleExprOrNull, /// This is a specialization for a specific use case where we can take a fast path /// if there is just one when/then pair and both the `then` and `else` expressions @@ -72,9 +73,13 @@ enum EvalMethod { /// CASE WHEN condition THEN literal ELSE literal END ScalarOrScalar, /// This is a specialization for a specific use case where we can take a fast path - /// if there is just one when/then pair and both the `then` and `else` are expressions + /// if there is just one when/then pair, the `then` is an expression, and `else` is either + /// an expression, literal NULL or absent. /// - /// CASE WHEN condition THEN expression ELSE expression END + /// In contrast to [`EvalMethod::InfallibleExprOrNull`], this specialization can handle fallible + /// `then` expressions. + /// + /// CASE WHEN condition THEN expression [ELSE expression] END ExpressionOrExpression(ProjectedCaseBody), /// This is a specialization for [`EvalMethod::WithExpression`] when the value and results are literals @@ -659,7 +664,7 @@ impl CaseExpr { && body.else_expr.as_ref().unwrap().as_any().is::() { EvalMethod::ScalarOrScalar - } else if body.when_then_expr.len() == 1 && body.else_expr.is_some() { + } else if body.when_then_expr.len() == 1 { EvalMethod::ExpressionOrExpression(body.project()?) } else { EvalMethod::NoExpression(body.project()?) @@ -961,32 +966,40 @@ impl CaseBody { let then_batch = filter_record_batch(batch, &when_filter)?; let then_value = self.when_then_expr[0].1.evaluate(&then_batch)?; - let else_selection = not(&when_value)?; - let else_filter = create_filter(&else_selection, optimize_filter); - let else_batch = filter_record_batch(batch, &else_filter)?; - - // keep `else_expr`'s data type and return type consistent - let e = self.else_expr.as_ref().unwrap(); - let return_type = self.data_type(&batch.schema())?; - let else_expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) - .unwrap_or_else(|_| Arc::clone(e)); - - let else_value = else_expr.evaluate(&else_batch)?; - - Ok(ColumnarValue::Array(match (then_value, else_value) { - (ColumnarValue::Array(t), ColumnarValue::Array(e)) => { - merge(&when_value, &t, &e) - } - (ColumnarValue::Scalar(t), ColumnarValue::Array(e)) => { - merge(&when_value, &t.to_scalar()?, &e) - } - (ColumnarValue::Array(t), ColumnarValue::Scalar(e)) => { - merge(&when_value, &t, &e.to_scalar()?) + match &self.else_expr { + None => { + let then_array = then_value.to_array(when_value.true_count())?; + scatter(&when_value, then_array.as_ref()).map(ColumnarValue::Array) } - (ColumnarValue::Scalar(t), ColumnarValue::Scalar(e)) => { - merge(&when_value, &t.to_scalar()?, &e.to_scalar()?) + Some(else_expr) => { + let else_selection = not(&when_value)?; + let else_filter = create_filter(&else_selection, optimize_filter); + let else_batch = filter_record_batch(batch, &else_filter)?; + + // keep `else_expr`'s data type and return type consistent + let return_type = self.data_type(&batch.schema())?; + let else_expr = + try_cast(Arc::clone(else_expr), &batch.schema(), return_type.clone()) + .unwrap_or_else(|_| Arc::clone(else_expr)); + + let else_value = else_expr.evaluate(&else_batch)?; + + Ok(ColumnarValue::Array(match (then_value, else_value) { + (ColumnarValue::Array(t), ColumnarValue::Array(e)) => { + merge(&when_value, &t, &e) + } + (ColumnarValue::Scalar(t), ColumnarValue::Array(e)) => { + merge(&when_value, &t.to_scalar()?, &e) + } + (ColumnarValue::Array(t), ColumnarValue::Scalar(e)) => { + merge(&when_value, &t, &e.to_scalar()?) + } + (ColumnarValue::Scalar(t), ColumnarValue::Scalar(e)) => { + merge(&when_value, &t.to_scalar()?, &e.to_scalar()?) + } + }?)) } - }?)) + } } } @@ -1137,7 +1150,15 @@ impl CaseExpr { self.body.when_then_expr[0].1.evaluate(batch) } else if true_count == 0 { // All input rows are false/null, just call the 'else' expression - self.body.else_expr.as_ref().unwrap().evaluate(batch) + match &self.body.else_expr { + Some(else_expr) => else_expr.evaluate(batch), + None => { + let return_type = self.data_type(&batch.schema())?; + Ok(ColumnarValue::Scalar(ScalarValue::try_new_null( + &return_type, + )?)) + } + } } else if projected.projection.len() < batch.num_columns() { // The case expressions do not use all the columns of the input batch. // Project first to reduce time spent filtering. @@ -2258,7 +2279,7 @@ mod tests { make_lit_i32(250), )); let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?; - assert!(matches!(expr.eval_method, EvalMethod::InfallibleExprOrNull)); + assert_eq!(expr.eval_method, EvalMethod::InfallibleExprOrNull); match expr.evaluate(&batch)? { ColumnarValue::Array(array) => { assert_eq!(1000, array.len()); diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index bd5c63a69979..2d44215cf2d5 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -26,6 +26,7 @@ use arrow::compute::{CastOptions, can_cast_types}; use arrow::datatypes::{DataType, DataType::*, FieldRef, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; +use datafusion_common::nested_struct::validate_struct_compatibility; use datafusion_common::{Result, not_impl_err}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; @@ -41,6 +42,22 @@ const DEFAULT_SAFE_CAST_OPTIONS: CastOptions<'static> = CastOptions { format_options: DEFAULT_FORMAT_OPTIONS, }; +/// Check if struct-to-struct casting is allowed by validating field compatibility. +/// +/// This function applies the same validation rules as execution time to ensure +/// planning-time validation matches runtime validation, enabling fail-fast behavior +/// instead of deferring errors to execution. +fn can_cast_struct_types(source: &DataType, target: &DataType) -> bool { + match (source, target) { + (Struct(source_fields), Struct(target_fields)) => { + // Apply the same struct compatibility rules as at execution time. + // This ensures planning-time validation matches execution-time validation. + validate_struct_compatibility(source_fields, target_fields).is_ok() + } + _ => false, + } +} + /// CAST expression casts an expression to a specific data type and returns a runtime error on invalid cast #[derive(Debug, Clone, Eq)] pub struct CastExpr { @@ -129,7 +146,7 @@ impl CastExpr { impl fmt::Display for CastExpr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "CAST({} AS {:?})", self.expr, self.cast_type) + write!(f, "CAST({} AS {})", self.expr, self.cast_type) } } @@ -237,6 +254,12 @@ pub fn cast_with_options( Ok(Arc::clone(&expr)) } else if can_cast_types(&expr_type, &cast_type) { Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options))) + } else if can_cast_struct_types(&expr_type, &cast_type) { + // Allow struct-to-struct casts that pass name-based compatibility validation. + // This validation is applied at planning time (now) to fail fast, rather than + // deferring errors to execution time. The name-based casting logic will be + // executed at runtime via ColumnarValue::cast_to. + Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options))) } else { not_impl_err!("Unsupported CAST from {expr_type} to {cast_type}") } @@ -289,10 +312,7 @@ mod tests { cast_with_options(col("a", &schema)?, &schema, $TYPE, $CAST_OPTIONS)?; // verify that its display is correct - assert_eq!( - format!("CAST(a@0 AS {:?})", $TYPE), - format!("{}", expression) - ); + assert_eq!(format!("CAST(a@0 AS {})", $TYPE), format!("{}", expression)); // verify that the expression's type is correct assert_eq!(expression.data_type(&schema)?, $TYPE); @@ -316,7 +336,7 @@ mod tests { for (i, x) in $VEC.iter().enumerate() { match x { Some(x) => assert_eq!(result.value(i), *x), - None => assert!(!result.is_valid(i)), + None => assert!(result.is_null(i)), } } }}; @@ -341,10 +361,7 @@ mod tests { cast_with_options(col("a", &schema)?, &schema, $TYPE, $CAST_OPTIONS)?; // verify that its display is correct - assert_eq!( - format!("CAST(a@0 AS {:?})", $TYPE), - format!("{}", expression) - ); + assert_eq!(format!("CAST(a@0 AS {})", $TYPE), format!("{}", expression)); // verify that the expression's type is correct assert_eq!(expression.data_type(&schema)?, $TYPE); @@ -371,7 +388,7 @@ mod tests { for (i, x) in $VEC.iter().enumerate() { match x { Some(x) => assert_eq!(result.value(i), *x), - None => assert!(!result.is_valid(i)), + None => assert!(result.is_null(i)), } } }}; diff --git a/datafusion/physical-expr/src/expressions/cast_column.rs b/datafusion/physical-expr/src/expressions/cast_column.rs index 3dc0293da83d..d80b6f4a588a 100644 --- a/datafusion/physical-expr/src/expressions/cast_column.rs +++ b/datafusion/physical-expr/src/expressions/cast_column.rs @@ -114,7 +114,7 @@ impl Display for CastColumnExpr { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "CAST_COLUMN({} AS {:?})", + "CAST_COLUMN({} AS {})", self.expr, self.target_field.data_type() ) diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index 8c7e8c319fff..cf844790a002 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -30,6 +30,7 @@ use arrow::{ use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{Result, internal_err, plan_err}; use datafusion_expr::ColumnarValue; +use datafusion_expr_common::placement::ExpressionPlacement; /// Represents the column at a given index in a RecordBatch /// @@ -146,6 +147,10 @@ impl PhysicalExpr for Column { fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.name) } + + fn placement(&self) -> ExpressionPlacement { + ExpressionPlacement::Column + } } impl Column { diff --git a/datafusion/physical-expr/src/expressions/dynamic_filters.rs b/datafusion/physical-expr/src/expressions/dynamic_filters.rs index fd8b2667259f..d285f8b377ec 100644 --- a/datafusion/physical-expr/src/expressions/dynamic_filters.rs +++ b/datafusion/physical-expr/src/expressions/dynamic_filters.rs @@ -26,7 +26,7 @@ use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode}, }; use datafusion_expr::ColumnarValue; -use datafusion_physical_expr_common::physical_expr::{DynEq, DynHash}; +use datafusion_physical_expr_common::physical_expr::DynHash; /// State of a dynamic filter, tracking both updates and completion. #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -51,6 +51,10 @@ impl FilterState { /// Any `ExecutionPlan` that uses this expression and holds a reference to it internally should probably also /// implement `ExecutionPlan::reset_state` to remain compatible with recursive queries and other situations where /// the same `ExecutionPlan` is reused with different data. +/// +/// For more background, please also see the [Dynamic Filters: Passing Information Between Operators During Execution for 25x Faster Queries blog] +/// +/// [Dynamic Filters: Passing Information Between Operators During Execution for 25x Faster Queries blog]: https://datafusion.apache.org/blog/2025/09/10/dynamic-filters #[derive(Debug)] pub struct DynamicFilterPhysicalExpr { /// The original children of this PhysicalExpr, if any. @@ -103,8 +107,11 @@ impl Inner { impl Hash for DynamicFilterPhysicalExpr { fn hash(&self, state: &mut H) { - let inner = self.current().expect("Failed to get current expression"); - inner.dyn_hash(state); + // Use pointer identity of the inner Arc for stable hashing. + // This is stable across update() calls and consistent with Eq. + // See issue #19641 for details on why content-based hashing violates + // the Hash/Eq contract when the underlying expression can change. + Arc::as_ptr(&self.inner).hash(state); self.children.dyn_hash(state); self.remapped_children.dyn_hash(state); } @@ -112,11 +119,13 @@ impl Hash for DynamicFilterPhysicalExpr { impl PartialEq for DynamicFilterPhysicalExpr { fn eq(&self, other: &Self) -> bool { - let inner = self.current().expect("Failed to get current expression"); - let our_children = self.remapped_children.as_ref().unwrap_or(&self.children); - let other_children = other.remapped_children.as_ref().unwrap_or(&other.children); - let other = other.current().expect("Failed to get current expression"); - inner.dyn_eq(other.as_any()) && our_children == other_children + // Two dynamic filters are equal if they share the same inner source + // AND have the same children configuration. + // This is consistent with Hash using Arc::as_ptr. + // See issue #19641 for details on the Hash/Eq contract violation fix. + Arc::ptr_eq(&self.inner, &other.inner) + && self.children == other.children + && self.remapped_children == other.remapped_children } } @@ -267,6 +276,10 @@ impl DynamicFilterPhysicalExpr { /// /// This method will return when [`Self::update`] is called and the generation increases. /// It does not guarantee that the filter is complete. + /// + /// Producers (e.g.) HashJoinExec may never update the expression or mark it as completed if there are no consumers. + /// If you call this method on a dynamic filter created by such a producer and there are no consumers registered this method would wait indefinitely. + /// This should not happen under normal operation and would indicate a programming error either in your producer or in DataFusion if the producer is a built in node. pub async fn wait_update(&self) { let mut rx = self.state_watch.subscribe(); // Get the current generation @@ -278,17 +291,16 @@ impl DynamicFilterPhysicalExpr { /// Wait asynchronously until this dynamic filter is marked as complete. /// - /// This method returns immediately if the filter is already complete or if the filter - /// is not being used by any consumers. + /// This method returns immediately if the filter is already complete. /// Otherwise, it waits until [`Self::mark_complete`] is called. /// /// Unlike [`Self::wait_update`], this method guarantees that when it returns, /// the filter is fully complete with no more updates expected. - pub async fn wait_complete(self: &Arc) { - if !self.is_used() { - return; - } - + /// + /// Producers (e.g.) HashJoinExec may never update the expression or mark it as completed if there are no consumers. + /// If you call this method on a dynamic filter created by such a producer and there are no consumers registered this method would wait indefinitely. + /// This should not happen under normal operation and would indicate a programming error either in your producer or in DataFusion if the producer is a built in node. + pub async fn wait_complete(&self) { if self.inner.read().is_complete { return; } @@ -305,14 +317,14 @@ impl DynamicFilterPhysicalExpr { /// that created the filter). This is useful to avoid computing expensive filter /// expressions when no consumer will actually use them. /// - /// Note: We check the inner Arc's strong_count, not the outer Arc's count, because - /// when filters are transformed (e.g., via reassign_expr_columns during filter pushdown), - /// new outer Arc instances are created via with_new_children(), but they all share the - /// same inner `Arc>`. This is what allows filter updates to propagate to - /// consumers even after transformation. + /// # Implementation Details + /// + /// We check both Arc counts to handle two cases: + /// - Transformed filters (via `with_new_children`) share the inner Arc (inner count > 1) + /// - Direct clones (via `Arc::clone`) increment the outer count (outer count > 1) pub fn is_used(self: &Arc) -> bool { // Strong count > 1 means at least one consumer is holding a reference beyond the producer. - Arc::strong_count(&self.inner) > 1 + Arc::strong_count(self) > 1 || Arc::strong_count(&self.inner) > 1 } fn render( @@ -753,4 +765,106 @@ mod test { "Filter should still be used with multiple consumers" ); } + + /// Test that verifies the Hash/Eq contract is now satisfied (issue #19641 fix). + /// + /// After the fix, Hash uses Arc::as_ptr(&self.inner) which is stable across + /// update() calls, fixing the HashMap key instability issue. + #[test] + fn test_hash_stable_after_update() { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + // Create filter with initial value + let filter = + DynamicFilterPhysicalExpr::new(vec![], lit(true) as Arc); + + // Compute hash BEFORE update + let mut hasher_before = DefaultHasher::new(); + filter.hash(&mut hasher_before); + let hash_before = hasher_before.finish(); + + // Update changes the underlying expression + filter + .update(lit(false) as Arc) + .expect("Update should succeed"); + + // Compute hash AFTER update + let mut hasher_after = DefaultHasher::new(); + filter.hash(&mut hasher_after); + let hash_after = hasher_after.finish(); + + // FIXED: Hash should now be STABLE after update() because we use + // Arc::as_ptr for identity-based hashing instead of expression content. + assert_eq!( + hash_before, hash_after, + "Hash should be stable after update() - fix for issue #19641" + ); + + // Self-equality should still hold + assert!(filter.eq(&filter), "Self-equality should hold"); + } + + /// Test that verifies separate DynamicFilterPhysicalExpr instances + /// with the same expression are NOT equal (identity-based comparison). + #[test] + fn test_identity_based_equality() { + // Create two separate filters with identical initial expressions + let filter1 = + DynamicFilterPhysicalExpr::new(vec![], lit(true) as Arc); + let filter2 = + DynamicFilterPhysicalExpr::new(vec![], lit(true) as Arc); + + // Different instances should NOT be equal even with same expression + // because they have independent inner Arcs (different update lifecycles) + assert!( + !filter1.eq(&filter2), + "Different instances should not be equal (identity-based)" + ); + + // Self-equality should hold + assert!(filter1.eq(&filter1), "Self-equality should hold"); + } + + /// Test that hash is stable for the same filter instance. + /// After the fix, hash uses Arc::as_ptr which is pointer-based. + #[test] + fn test_hash_stable_for_same_instance() { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let filter = + DynamicFilterPhysicalExpr::new(vec![], lit(true) as Arc); + + // Compute hash twice for the same instance + let hash1 = { + let mut h = DefaultHasher::new(); + filter.hash(&mut h); + h.finish() + }; + let hash2 = { + let mut h = DefaultHasher::new(); + filter.hash(&mut h); + h.finish() + }; + + assert_eq!(hash1, hash2, "Same instance should have stable hash"); + + // Update the expression + filter + .update(lit(false) as Arc) + .expect("Update should succeed"); + + // Hash should STILL be the same (identity-based) + let hash3 = { + let mut h = DefaultHasher::new(); + filter.hash(&mut h); + h.finish() + }; + + assert_eq!( + hash1, hash3, + "Hash should be stable after update (identity-based)" + ); + } } diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 5c2f1adcd0cf..6c81fcc11c6c 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -28,6 +28,7 @@ use crate::physical_expr::physical_exprs_bag_equal; use arrow::array::*; use arrow::buffer::{BooleanBuffer, NullBuffer}; use arrow::compute::kernels::boolean::{not, or_kleene}; +use arrow::compute::kernels::cmp::eq as arrow_eq; use arrow::compute::{SortOptions, take}; use arrow::datatypes::*; use arrow::util::bit_iterator::BitIndexIterator; @@ -98,11 +99,18 @@ impl StaticFilter for ArrayStaticFilter { )); } + // Unwrap dictionary-encoded needles when the value type matches + // in_array, evaluating against the dictionary values and mapping + // back via keys. downcast_dictionary_array! { v => { - let values_contains = self.contains(v.values().as_ref(), negated)?; - let result = take(&values_contains, v.keys(), None)?; - return Ok(downcast_array(result.as_ref())) + // Only unwrap when the haystack (in_array) type matches + // the dictionary value type + if v.values().data_type() == self.in_array.data_type() { + let values_contains = self.contains(v.values().as_ref(), negated)?; + let result = take(&values_contains, v.keys(), None)?; + return Ok(downcast_array(result.as_ref())); + } } _ => {} } @@ -138,6 +146,21 @@ impl StaticFilter for ArrayStaticFilter { } } +/// Returns true if Arrow's vectorized `eq` kernel supports this data type. +/// +/// Supported: primitives, boolean, strings (Utf8/LargeUtf8/Utf8View), +/// binary (Binary/LargeBinary/BinaryView/FixedSizeBinary), Null, and +/// Dictionary-encoded variants of the above. +/// Unsupported: nested types (Struct, List, Map, Union) and RunEndEncoded. +fn supports_arrow_eq(dt: &DataType) -> bool { + use DataType::*; + match dt { + Boolean | Binary | LargeBinary | BinaryView | FixedSizeBinary(_) => true, + Dictionary(_, v) => supports_arrow_eq(v.as_ref()), + _ => dt.is_primitive() || dt.is_null() || dt.is_string(), + } +} + fn instantiate_static_filter( in_array: ArrayRef, ) -> Result> { @@ -771,32 +794,45 @@ impl PhysicalExpr for InListExpr { } } None => { - // No static filter: iterate through each expression, compare, and OR results + // No static filter: iterate through each expression, compare, and OR results. + // Use Arrow's vectorized eq kernel for types it supports (primitive, + // boolean, string, binary, dictionary), falling back to row-by-row + // comparator for unsupported types (nested, RunEndEncoded, etc.). let value = value.into_array(num_rows)?; + let lhs_supports_arrow_eq = supports_arrow_eq(value.data_type()); let found = self.list.iter().map(|expr| expr.evaluate(batch)).try_fold( BooleanArray::new(BooleanBuffer::new_unset(num_rows), None), |result, expr| -> Result { let rhs = match expr? { ColumnarValue::Array(array) => { - let cmp = make_comparator( - value.as_ref(), - array.as_ref(), - SortOptions::default(), - )?; - (0..num_rows) - .map(|i| { - if value.is_null(i) || array.is_null(i) { - return None; - } - Some(cmp(i, i).is_eq()) - }) - .collect::() + if lhs_supports_arrow_eq + && supports_arrow_eq(array.data_type()) + { + arrow_eq(&value, &array)? + } else { + let cmp = make_comparator( + value.as_ref(), + array.as_ref(), + SortOptions::default(), + )?; + (0..num_rows) + .map(|i| { + if value.is_null(i) || array.is_null(i) { + return None; + } + Some(cmp(i, i).is_eq()) + }) + .collect::() + } } ColumnarValue::Scalar(scalar) => { // Check if scalar is null once, before the loop if scalar.is_null() { // If scalar is null, all comparisons return null BooleanArray::from(vec![None; num_rows]) + } else if lhs_supports_arrow_eq { + let scalar_datum = scalar.to_scalar()?; + arrow_eq(&value, &scalar_datum)? } else { // Convert scalar to 1-element array let array = scalar.to_array()?; @@ -3507,4 +3543,536 @@ mod tests { Ok(()) } + + /// Helper: creates an InListExpr with `static_filter = None` + /// to force the column-reference evaluation path. + fn make_in_list_with_columns( + expr: Arc, + list: Vec>, + negated: bool, + ) -> Arc { + Arc::new(InListExpr::new(expr, list, negated, None)) + } + + #[test] + fn test_in_list_with_columns_int32_scalars() -> Result<()> { + // Column-reference path with scalar literals (bypassing static filter) + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + ]))], + )?; + + let list = vec![ + lit(ScalarValue::Int32(Some(1))), + lit(ScalarValue::Int32(Some(3))), + ]; + let expr = make_in_list_with_columns(col_a, list, false); + + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + assert_eq!( + result, + &BooleanArray::from(vec![Some(true), Some(false), Some(true), None,]) + ); + Ok(()) + } + + #[test] + fn test_in_list_with_columns_int32_column_refs() -> Result<()> { + // IN list with column references + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3), None])), + Arc::new(Int32Array::from(vec![ + Some(1), + Some(99), + Some(99), + Some(99), + ])), + Arc::new(Int32Array::from(vec![Some(99), Some(99), Some(3), None])), + ], + )?; + + let col_a = col("a", &schema)?; + let list = vec![col("b", &schema)?, col("c", &schema)?]; + let expr = make_in_list_with_columns(col_a, list, false); + + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + // row 0: 1 IN (1, 99) → true + // row 1: 2 IN (99, 99) → false + // row 2: 3 IN (99, 3) → true + // row 3: NULL IN (99, NULL) → NULL + assert_eq!( + result, + &BooleanArray::from(vec![Some(true), Some(false), Some(true), None,]) + ); + Ok(()) + } + + #[test] + fn test_in_list_with_columns_utf8_column_refs() -> Result<()> { + // IN list with Utf8 column references + let schema = Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + ]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(StringArray::from(vec!["x", "y", "z"])), + Arc::new(StringArray::from(vec!["x", "x", "z"])), + ], + )?; + + let col_a = col("a", &schema)?; + let list = vec![col("b", &schema)?]; + let expr = make_in_list_with_columns(col_a, list, false); + + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + // row 0: "x" IN ("x") → true + // row 1: "y" IN ("x") → false + // row 2: "z" IN ("z") → true + assert_eq!(result, &BooleanArray::from(vec![true, false, true])); + Ok(()) + } + + #[test] + fn test_in_list_with_columns_negated() -> Result<()> { + // NOT IN with column references + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Int32Array::from(vec![1, 99, 3])), + ], + )?; + + let col_a = col("a", &schema)?; + let list = vec![col("b", &schema)?]; + let expr = make_in_list_with_columns(col_a, list, true); + + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + // row 0: 1 NOT IN (1) → false + // row 1: 2 NOT IN (99) → true + // row 2: 3 NOT IN (3) → false + assert_eq!(result, &BooleanArray::from(vec![false, true, false])); + Ok(()) + } + + #[test] + fn test_in_list_with_columns_null_in_list() -> Result<()> { + // IN list with NULL scalar (column-reference path) + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(Int32Array::from(vec![1, 2]))], + )?; + + let list = vec![ + lit(ScalarValue::Int32(None)), + lit(ScalarValue::Int32(Some(1))), + ]; + let expr = make_in_list_with_columns(col_a, list, false); + + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + // row 0: 1 IN (NULL, 1) → true (true OR null = true) + // row 1: 2 IN (NULL, 1) → NULL (false OR null = null) + assert_eq!(result, &BooleanArray::from(vec![Some(true), None])); + Ok(()) + } + + #[test] + fn test_in_list_with_columns_float_nan() -> Result<()> { + // Verify NaN == NaN is true in the column-reference path + // (consistent with Arrow's totalOrder semantics) + let schema = Schema::new(vec![ + Field::new("a", DataType::Float64, false), + Field::new("b", DataType::Float64, false), + ]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(Float64Array::from(vec![f64::NAN, 1.0, f64::NAN])), + Arc::new(Float64Array::from(vec![f64::NAN, 2.0, 0.0])), + ], + )?; + + let col_a = col("a", &schema)?; + let list = vec![col("b", &schema)?]; + let expr = make_in_list_with_columns(col_a, list, false); + + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + // row 0: NaN IN (NaN) → true + // row 1: 1.0 IN (2.0) → false + // row 2: NaN IN (0.0) → false + assert_eq!(result, &BooleanArray::from(vec![true, false, false])); + Ok(()) + } + /// Tests that short-circuit evaluation produces correct results. + /// When all rows match after the first list item, remaining items + /// should be skipped without affecting correctness. + #[test] + fn test_in_list_with_columns_short_circuit() -> Result<()> { + // a IN (b, c) where b already matches every row of a + // The short-circuit should skip evaluating c + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Int32Array::from(vec![1, 2, 3])), // b == a for all rows + Arc::new(Int32Array::from(vec![99, 99, 99])), + ], + )?; + + let col_a = col("a", &schema)?; + let list = vec![col("b", &schema)?, col("c", &schema)?]; + let expr = make_in_list_with_columns(col_a, list, false); + + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + assert_eq!(result, &BooleanArray::from(vec![true, true, true])); + Ok(()) + } + + /// Short-circuit must NOT skip when nulls are present (three-valued logic). + /// Even if all non-null values are true, null rows keep the result as null. + #[test] + fn test_in_list_with_columns_short_circuit_with_nulls() -> Result<()> { + // a IN (b, c) where a has nulls + // Even if b matches all non-null rows, result should preserve nulls + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])), + Arc::new(Int32Array::from(vec![1, 2, 3])), // matches non-null rows + Arc::new(Int32Array::from(vec![99, 99, 99])), + ], + )?; + + let col_a = col("a", &schema)?; + let list = vec![col("b", &schema)?, col("c", &schema)?]; + let expr = make_in_list_with_columns(col_a, list, false); + + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + // row 0: 1 IN (1, 99) → true + // row 1: NULL IN (2, 99) → NULL + // row 2: 3 IN (3, 99) → true + assert_eq!( + result, + &BooleanArray::from(vec![Some(true), None, Some(true)]) + ); + Ok(()) + } + + /// Tests the make_comparator + collect_bool fallback path using + /// struct column references (nested types don't support arrow_eq). + #[test] + fn test_in_list_with_columns_struct() -> Result<()> { + let struct_fields = Fields::from(vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Utf8, false), + ]); + let struct_dt = DataType::Struct(struct_fields.clone()); + + let schema = Schema::new(vec![ + Field::new("a", struct_dt.clone(), true), + Field::new("b", struct_dt.clone(), false), + Field::new("c", struct_dt.clone(), false), + ]); + + // a: [{1,"a"}, {2,"b"}, NULL, {4,"d"}] + // b: [{1,"a"}, {9,"z"}, {3,"c"}, {4,"d"}] + // c: [{9,"z"}, {2,"b"}, {9,"z"}, {9,"z"}] + let a = Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4])), + Arc::new(StringArray::from(vec!["a", "b", "c", "d"])), + ], + Some(vec![true, true, false, true].into()), + )); + let b = Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 9, 3, 4])), + Arc::new(StringArray::from(vec!["a", "z", "c", "d"])), + ], + None, + )); + let c = Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![9, 2, 9, 9])), + Arc::new(StringArray::from(vec!["z", "b", "z", "z"])), + ], + None, + )); + + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a, b, c])?; + + let col_a = col("a", &schema)?; + let list = vec![col("b", &schema)?, col("c", &schema)?]; + let expr = make_in_list_with_columns(col_a, list, false); + + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + // row 0: {1,"a"} IN ({1,"a"}, {9,"z"}) → true (matches b) + // row 1: {2,"b"} IN ({9,"z"}, {2,"b"}) → true (matches c) + // row 2: NULL IN ({3,"c"}, {9,"z"}) → NULL + // row 3: {4,"d"} IN ({4,"d"}, {9,"z"}) → true (matches b) + assert_eq!( + result, + &BooleanArray::from(vec![Some(true), Some(true), None, Some(true)]) + ); + + // Also test NOT IN + let col_a = col("a", &schema)?; + let list = vec![col("b", &schema)?, col("c", &schema)?]; + let expr = make_in_list_with_columns(col_a, list, true); + + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + // row 0: {1,"a"} NOT IN ({1,"a"}, {9,"z"}) → false + // row 1: {2,"b"} NOT IN ({9,"z"}, {2,"b"}) → false + // row 2: NULL NOT IN ({3,"c"}, {9,"z"}) → NULL + // row 3: {4,"d"} NOT IN ({4,"d"}, {9,"z"}) → false + assert_eq!( + result, + &BooleanArray::from(vec![Some(false), Some(false), None, Some(false)]) + ); + Ok(()) + } + + // ----------------------------------------------------------------------- + // Tests for try_new_from_array: evaluates `needle IN in_array`. + // + // This exercises the code path used by HashJoin dynamic filter pushdown, + // where in_array is built directly from the join's build-side arrays. + // Unlike try_new (used by SQL IN expressions), which always produces a + // non-Dictionary in_array because evaluate_list() flattens Dictionary + // scalars, try_new_from_array passes the array directly and can produce + // a Dictionary in_array. + // ----------------------------------------------------------------------- + + fn wrap_in_dict(array: ArrayRef) -> ArrayRef { + let keys = Int32Array::from((0..array.len() as i32).collect::>()); + Arc::new(DictionaryArray::new(keys, array)) + } + + /// Evaluates `needle IN in_array` via try_new_from_array, the same + /// path used by HashJoin dynamic filter pushdown (not the SQL literal + /// IN path which goes through try_new). + fn eval_in_list_from_array( + needle: ArrayRef, + in_array: ArrayRef, + ) -> Result { + let schema = + Schema::new(vec![Field::new("a", needle.data_type().clone(), false)]); + let col_a = col("a", &schema)?; + let expr = Arc::new(InListExpr::try_new_from_array(col_a, in_array, false)?) + as Arc; + let batch = RecordBatch::try_new(Arc::new(schema), vec![needle])?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + Ok(as_boolean_array(&result).clone()) + } + + #[test] + fn test_in_list_from_array_type_combinations() -> Result<()> { + use arrow::compute::cast; + + // All cases: needle[0] and needle[2] match, needle[1] does not. + let expected = BooleanArray::from(vec![Some(true), Some(false), Some(true)]); + + // Base arrays cast to each target type + let base_in = Arc::new(Int64Array::from(vec![1i64, 2, 3])) as ArrayRef; + let base_needle = Arc::new(Int64Array::from(vec![1i64, 4, 2])) as ArrayRef; + + // Test all specializations in instantiate_static_filter + let primitive_types = vec![ + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + DataType::Float32, + DataType::Float64, + ]; + + for dt in &primitive_types { + let in_array = cast(&base_in, dt)?; + let needle = cast(&base_needle, dt)?; + + // T in_array, T needle + assert_eq!( + expected, + eval_in_list_from_array(Arc::clone(&needle), Arc::clone(&in_array))?, + "same-type failed for {dt:?}" + ); + + // T in_array, Dict(Int32, T) needle + assert_eq!( + expected, + eval_in_list_from_array(wrap_in_dict(needle), in_array)?, + "dict-needle failed for {dt:?}" + ); + } + + // Utf8 (falls through to ArrayStaticFilter) + let utf8_in = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef; + let utf8_needle = Arc::new(StringArray::from(vec!["a", "d", "b"])) as ArrayRef; + + // Utf8 in_array, Utf8 needle + assert_eq!( + expected, + eval_in_list_from_array(Arc::clone(&utf8_needle), Arc::clone(&utf8_in),)? + ); + + // Utf8 in_array, Dict(Utf8) needle + assert_eq!( + expected, + eval_in_list_from_array( + wrap_in_dict(Arc::clone(&utf8_needle)), + Arc::clone(&utf8_in), + )? + ); + + // Dict(Utf8) in_array, Dict(Utf8) needle: the #20937 bug + assert_eq!( + expected, + eval_in_list_from_array( + wrap_in_dict(Arc::clone(&utf8_needle)), + wrap_in_dict(Arc::clone(&utf8_in)), + )? + ); + + // Struct in_array, Struct needle: multi-column join + let struct_fields = Fields::from(vec![ + Field::new("c0", DataType::Utf8, true), + Field::new("c1", DataType::Int64, true), + ]); + let make_struct = |c0: ArrayRef, c1: ArrayRef| -> ArrayRef { + let pairs: Vec<(FieldRef, ArrayRef)> = + struct_fields.iter().cloned().zip([c0, c1]).collect(); + Arc::new(StructArray::from(pairs)) + }; + assert_eq!( + expected, + eval_in_list_from_array( + make_struct( + Arc::clone(&utf8_needle), + Arc::new(Int64Array::from(vec![1, 4, 2])), + ), + make_struct( + Arc::clone(&utf8_in), + Arc::new(Int64Array::from(vec![1, 2, 3])), + ), + )? + ); + + // Struct with Dict fields: multi-column Dict join + let dict_struct_fields = Fields::from(vec![ + Field::new( + "c0", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + ), + Field::new("c1", DataType::Int64, true), + ]); + let make_dict_struct = |c0: ArrayRef, c1: ArrayRef| -> ArrayRef { + let pairs: Vec<(FieldRef, ArrayRef)> = + dict_struct_fields.iter().cloned().zip([c0, c1]).collect(); + Arc::new(StructArray::from(pairs)) + }; + assert_eq!( + expected, + eval_in_list_from_array( + make_dict_struct( + wrap_in_dict(Arc::clone(&utf8_needle)), + Arc::new(Int64Array::from(vec![1, 4, 2])), + ), + make_dict_struct( + wrap_in_dict(Arc::clone(&utf8_in)), + Arc::new(Int64Array::from(vec![1, 2, 3])), + ), + )? + ); + + Ok(()) + } + + #[test] + fn test_in_list_from_array_type_mismatch_errors() -> Result<()> { + // Utf8 needle, Dict(Utf8) in_array + let err = eval_in_list_from_array( + Arc::new(StringArray::from(vec!["a", "d", "b"])), + wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))), + ) + .unwrap_err() + .to_string(); + assert!( + err.contains("Can't compare arrays of different types"), + "{err}" + ); + + // Dict(Utf8) needle, Int64 in_array: specialized Int64StaticFilter + // rejects the Utf8 dictionary values at construction time + let err = eval_in_list_from_array( + wrap_in_dict(Arc::new(StringArray::from(vec!["a", "d", "b"]))), + Arc::new(Int64Array::from(vec![1, 2, 3])), + ) + .unwrap_err() + .to_string(); + assert!(err.contains("Failed to downcast"), "{err}"); + + // Dict(Int64) needle, Dict(Utf8) in_array: both Dict but different + // value types, make_comparator rejects the comparison + let err = eval_in_list_from_array( + wrap_in_dict(Arc::new(Int64Array::from(vec![1, 4, 2]))), + wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))), + ) + .unwrap_err() + .to_string(); + assert!( + err.contains("Can't compare arrays of different types"), + "{err}" + ); + Ok(()) + } } diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index 1f3fefc60b7a..9105297c96d6 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -33,6 +33,7 @@ use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Expr; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; +use datafusion_expr_common::placement::ExpressionPlacement; use datafusion_expr_common::sort_properties::{ExprProperties, SortProperties}; /// Represents a literal value @@ -134,6 +135,10 @@ impl PhysicalExpr for Literal { fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(self, f) } + + fn placement(&self) -> ExpressionPlacement { + ExpressionPlacement::Literal + } } /// Create a literal expression diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index 0c9476bebaaf..c727c8fa5f77 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -37,7 +37,7 @@ use datafusion_expr::statistics::Distribution::{ }; use datafusion_expr::{ ColumnarValue, - type_coercion::{is_interval, is_null, is_signed_numeric, is_timestamp}, + type_coercion::{is_interval, is_signed_numeric, is_timestamp}, }; /// Negative expression @@ -190,7 +190,7 @@ pub fn negative( input_schema: &Schema, ) -> Result> { let data_type = arg.data_type(input_schema)?; - if is_null(&data_type) { + if data_type.is_null() { Ok(arg) } else if !is_signed_numeric(&data_type) && !is_interval(&data_type) diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index c9ace3239c64..306f14b48fa3 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -72,7 +72,7 @@ impl TryCastExpr { impl fmt::Display for TryCastExpr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "TRY_CAST({} AS {:?})", self.expr, self.cast_type) + write!(f, "TRY_CAST({} AS {})", self.expr, self.cast_type) } } @@ -180,7 +180,7 @@ mod tests { // verify that its display is correct assert_eq!( - format!("TRY_CAST(a@0 AS {:?})", $TYPE), + format!("TRY_CAST(a@0 AS {})", $TYPE), format!("{}", expression) ); @@ -206,7 +206,7 @@ mod tests { for (i, x) in $VEC.iter().enumerate() { match x { Some(x) => assert_eq!(result.value(i), *x), - None => assert!(!result.is_valid(i)), + None => assert!(result.is_null(i)), } } }}; @@ -231,7 +231,7 @@ mod tests { // verify that its display is correct assert_eq!( - format!("TRY_CAST(a@0 AS {:?})", $TYPE), + format!("TRY_CAST(a@0 AS {})", $TYPE), format!("{}", expression) ); @@ -260,7 +260,7 @@ mod tests { for (i, x) in $VEC.iter().enumerate() { match x { Some(x) => assert_eq!(result.value(i), *x), - None => assert!(!result.is_valid(i)), + None => assert!(result.is_null(i)), } } }}; diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 988e14c28e17..bedd348dab92 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -24,8 +24,6 @@ // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -// https://github.com/apache/datafusion/issues/18881 -#![deny(clippy::allow_attributes)] // Backward compatibility pub mod aggregate; diff --git a/datafusion/physical-expr/src/partitioning.rs b/datafusion/physical-expr/src/partitioning.rs index 54e1cd3675d1..d24c60b63e6b 100644 --- a/datafusion/physical-expr/src/partitioning.rs +++ b/datafusion/physical-expr/src/partitioning.rs @@ -157,7 +157,7 @@ impl PartitioningSatisfaction { } pub fn is_subset(&self) -> bool { - matches!(self, Self::Subset) + *self == Self::Subset } } diff --git a/datafusion/physical-expr/src/projection.rs b/datafusion/physical-expr/src/projection.rs index 540fd620c92c..dbbd28941527 100644 --- a/datafusion/physical-expr/src/projection.rs +++ b/datafusion/physical-expr/src/projection.rs @@ -29,7 +29,8 @@ use arrow::datatypes::{Field, Schema, SchemaRef}; use datafusion_common::stats::{ColumnStatistics, Precision}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{ - Result, ScalarValue, assert_or_internal_err, internal_datafusion_err, plan_err, + Result, ScalarValue, Statistics, assert_or_internal_err, internal_datafusion_err, + plan_err, }; use datafusion_physical_expr_common::metrics::ExecutionPlanMetricsSet; @@ -125,7 +126,8 @@ impl From for (Arc, String) { /// indices. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ProjectionExprs { - exprs: Vec, + /// [`Arc`] used for a cheap clone, which improves physical plan optimization performance. + exprs: Arc<[ProjectionExpr]>, } impl std::fmt::Display for ProjectionExprs { @@ -137,14 +139,16 @@ impl std::fmt::Display for ProjectionExprs { impl From> for ProjectionExprs { fn from(value: Vec) -> Self { - Self { exprs: value } + Self { + exprs: value.into(), + } } } impl From<&[ProjectionExpr]> for ProjectionExprs { fn from(value: &[ProjectionExpr]) -> Self { Self { - exprs: value.to_vec(), + exprs: value.iter().cloned().collect(), } } } @@ -152,7 +156,7 @@ impl From<&[ProjectionExpr]> for ProjectionExprs { impl FromIterator for ProjectionExprs { fn from_iter>(exprs: T) -> Self { Self { - exprs: exprs.into_iter().collect::>(), + exprs: exprs.into_iter().collect(), } } } @@ -164,12 +168,17 @@ impl AsRef<[ProjectionExpr]> for ProjectionExprs { } impl ProjectionExprs { - pub fn new(exprs: I) -> Self - where - I: IntoIterator, - { + /// Make a new [`ProjectionExprs`] from expressions iterator. + pub fn new(exprs: impl IntoIterator) -> Self { Self { - exprs: exprs.into_iter().collect::>(), + exprs: exprs.into_iter().collect(), + } + } + + /// Make a new [`ProjectionExprs`] from expressions. + pub fn from_expressions(exprs: impl Into>) -> Self { + Self { + exprs: exprs.into(), } } @@ -285,13 +294,14 @@ impl ProjectionExprs { { let exprs = self .exprs - .into_iter() + .iter() + .cloned() .map(|mut proj| { proj.expr = f(proj.expr)?; Ok(proj) }) - .collect::>>()?; - Ok(Self::new(exprs)) + .collect::>>()?; + Ok(Self::from_expressions(exprs)) } /// Apply another projection on top of this projection, returning the combined projection. @@ -361,17 +371,9 @@ impl ProjectionExprs { /// applied on top of this projection. pub fn try_merge(&self, other: &ProjectionExprs) -> Result { let mut new_exprs = Vec::with_capacity(other.exprs.len()); - for proj_expr in &other.exprs { - let new_expr = update_expr(&proj_expr.expr, &self.exprs, true)? - .ok_or_else(|| { - internal_datafusion_err!( - "Failed to combine projections: expression {} could not be applied on top of existing projections {}", - proj_expr.expr, - self.exprs.iter().map(|e| format!("{e}")).join(", ") - ) - })?; + for proj_expr in other.exprs.iter() { new_exprs.push(ProjectionExpr { - expr: new_expr, + expr: self.unproject_expr(&proj_expr.expr)?, alias: proj_expr.alias.clone(), }); } @@ -440,9 +442,16 @@ impl ProjectionExprs { } /// Project a schema according to this projection. - /// For example, for a projection `SELECT a AS x, b + 1 AS y`, where `a` is at index 0 and `b` is at index 1, - /// if the input schema is `[a: Int32, b: Int32, c: Int32]`, the output schema would be `[x: Int32, y: Int32]`. - /// Fields' metadata are preserved from the input schema. + /// + /// For example, given a projection: + /// * `SELECT a AS x, b + 1 AS y` + /// * where `a` is at index 0 + /// * `b` is at index 1 + /// + /// If the input schema is `[a: Int32, b: Int32, c: Int32]`, the output + /// schema would be `[x: Int32, y: Int32]`. + /// + /// Note that [`Field`] metadata are preserved from the input schema. pub fn project_schema(&self, input_schema: &Schema) -> Result { let fields: Result> = self .exprs @@ -471,6 +480,48 @@ impl ProjectionExprs { )) } + /// "unproject" an expression by applying this projection in reverse, + /// returning a new set of expressions that reference the original input + /// columns. + /// + /// For example, consider + /// * an expression `c1_c2 > 5`, and a schema `[c1, c2]` + /// * a projection `c1 + c2 as c1_c2` + /// + /// This method would rewrite the expression to `c1 + c2 > 5` + pub fn unproject_expr( + &self, + expr: &Arc, + ) -> Result> { + update_expr(expr, &self.exprs, true)?.ok_or_else(|| { + internal_datafusion_err!( + "Failed to unproject an expression {} with ProjectionExprs {}", + expr, + self.exprs.iter().map(|e| format!("{e}")).join(", ") + ) + }) + } + + /// "project" an expression using these projection's expressions + /// + /// For example, consider + /// * an expression `c1 + c2 > 5`, and a schema `[c1, c2]` + /// * a projection `c1 + c2 as c1_c2` + /// + /// * This method would rewrite the expression to `c1_c2 > 5` + pub fn project_expr( + &self, + expr: &Arc, + ) -> Result> { + update_expr(expr, &self.exprs, false)?.ok_or_else(|| { + internal_datafusion_err!( + "Failed to project an expression {} with ProjectionExprs {}", + expr, + self.exprs.iter().map(|e| format!("{e}")).join(", ") + ) + }) + } + /// Create a new [`Projector`] from this projection and an input schema. /// /// A [`Projector`] can be used to apply this projection to record batches. @@ -602,12 +653,12 @@ impl ProjectionExprs { /// ``` pub fn project_statistics( &self, - mut stats: datafusion_common::Statistics, + mut stats: Statistics, output_schema: &Schema, - ) -> Result { + ) -> Result { let mut column_statistics = vec![]; - for proj_expr in &self.exprs { + for proj_expr in self.exprs.iter() { let expr = &proj_expr.expr; let col_stats = if let Some(col) = expr.as_any().downcast_ref::() { std::mem::take(&mut stats.column_statistics[col.index()]) @@ -754,35 +805,92 @@ impl Projector { } } -impl IntoIterator for ProjectionExprs { - type Item = ProjectionExpr; - type IntoIter = std::vec::IntoIter; +/// Describes an immutable reference counted projection. +/// +/// This structure represents projecting a set of columns by index. +/// [`Arc`] is used to make it cheap to clone. +pub type ProjectionRef = Arc<[usize]>; - fn into_iter(self) -> Self::IntoIter { - self.exprs.into_iter() - } +/// Combine two projections. +/// +/// If `p1` is [`None`] then there are no changes. +/// Otherwise, if passed `p2` is not [`None`] then it is remapped +/// according to the `p1`. Otherwise, there are no changes. +/// +/// # Example +/// +/// If stored projection is [0, 2] and we call `apply_projection([0, 2, 3])`, +/// then the resulting projection will be [0, 3]. +/// +/// # Error +/// +/// Returns an internal error if `p1` contains index that is greater than `p2` len. +/// +pub fn combine_projections( + p1: Option<&ProjectionRef>, + p2: Option<&ProjectionRef>, +) -> Result> { + let Some(p1) = p1 else { + return Ok(None); + }; + let Some(p2) = p2 else { + return Ok(Some(Arc::clone(p1))); + }; + + Ok(Some( + p1.iter() + .map(|i| { + let idx = *i; + assert_or_internal_err!( + idx < p2.len(), + "unable to apply projection: index {} is greater than new projection len {}", + idx, + p2.len(), + ); + Ok(p2[*i]) + }) + .collect::>>()?, + )) } -/// The function operates in two modes: +/// The function projects / unprojects an expression with respect to set of +/// projection expressions. +/// +/// See also [`ProjectionExprs::unproject_expr`] and [`ProjectionExprs::project_expr`] +/// +/// 1) When `unproject` is `true`: +/// +/// Rewrites an expression with respect to the projection expressions, +/// effectively "unprojecting" it to reference the original input columns. +/// +/// For example, given +/// * the expressions `a@1 + b@2` and `c@0` +/// * and projection expressions `c@2, a@0, b@1` +/// +/// Then +/// * `a@1 + b@2` becomes `a@0 + b@1` +/// * `c@0` becomes `c@2` +/// +/// 2) When `unproject` is `false`: /// -/// 1) When `sync_with_child` is `true`: +/// Rewrites the expression to reference the projected expressions, +/// effectively "projecting" it. The resulting expression will reference the +/// indices as they appear in the projection. /// -/// The function updates the indices of `expr` if the expression resides -/// in the input plan. For instance, given the expressions `a@1 + b@2` -/// and `c@0` with the input schema `c@2, a@0, b@1`, the expressions are -/// updated to `a@0 + b@1` and `c@2`. +/// If the expression cannot be rewritten after the projection, it returns +/// `None`. /// -/// 2) When `sync_with_child` is `false`: +/// For example, given +/// * the expressions `c@0`, `a@1` and `b@2` +/// * the projection `a@1 as a, c@0 as c_new`, /// -/// The function determines how the expression would be updated if a projection -/// was placed before the plan associated with the expression. If the expression -/// cannot be rewritten after the projection, it returns `None`. For example, -/// given the expressions `c@0`, `a@1` and `b@2`, and the projection with -/// an output schema of `a, c_new`, then `c@0` becomes `c_new@1`, `a@1` becomes -/// `a@0`, but `b@2` results in `None` since the projection does not include `b`. +/// Then +/// * `c@0` becomes `c_new@1` +/// * `a@1` becomes `a@0` +/// * `b@2` results in `None` since the projection does not include `b`. /// /// # Errors -/// This function returns an error if `sync_with_child` is `true` and if any expression references +/// This function returns an error if `unproject` is `true` and if any expression references /// an index that is out of bounds for `projected_exprs`. /// For example: /// @@ -793,7 +901,7 @@ impl IntoIterator for ProjectionExprs { pub fn update_expr( expr: &Arc, projected_exprs: &[ProjectionExpr], - sync_with_child: bool, + unproject: bool, ) -> Result>> { #[derive(Debug, PartialEq)] enum RewriteState { @@ -817,7 +925,7 @@ pub fn update_expr( let Some(column) = expr.as_any().downcast_ref::() else { return Ok(Transformed::no(expr)); }; - if sync_with_child { + if unproject { state = RewriteState::RewrittenValid; // Update the index of `column`: let projected_expr = projected_exprs.get(column.index()).ok_or_else(|| { diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index aa090743ad44..dab4153fa682 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -45,8 +45,8 @@ use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::functions::fields_with_udf; use datafusion_expr::{ - ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, Volatility, - expr_vec_fmt, + ColumnarValue, ExpressionPlacement, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, + Volatility, expr_vec_fmt, }; /// Physical expression of a scalar function @@ -362,6 +362,12 @@ impl PhysicalExpr for ScalarFunctionExpr { fn is_volatile_node(&self) -> bool { self.fun.signature().volatility == Volatility::Volatile } + + fn placement(&self) -> ExpressionPlacement { + let arg_placements: Vec<_> = + self.args.iter().map(|arg| arg.placement()).collect(); + self.fun.placement(&arg_placements) + } } #[cfg(test)] diff --git a/datafusion/physical-expr/src/simplifier/const_evaluator.rs b/datafusion/physical-expr/src/simplifier/const_evaluator.rs index 65111b291165..1f3781c537dd 100644 --- a/datafusion/physical-expr/src/simplifier/const_evaluator.rs +++ b/datafusion/physical-expr/src/simplifier/const_evaluator.rs @@ -25,7 +25,6 @@ use arrow::record_batch::RecordBatch; use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr_common::columnar_value::ColumnarValue; -use datafusion_physical_expr_common::physical_expr::is_volatile; use crate::PhysicalExpr; use crate::expressions::{Column, Literal}; @@ -40,15 +39,18 @@ use crate::expressions::{Column, Literal}; /// - `1 + 2` -> `3` /// - `(1 + 2) * 3` -> `9` (with bottom-up traversal) /// - `'hello' || ' world'` -> `'hello world'` +#[deprecated( + since = "53.0.0", + note = "This function will be removed in a future release in favor of a private implementation that depends on other implementation details. Please open an issue if you have a use case for keeping it." +)] pub fn simplify_const_expr( - expr: &Arc, + expr: Arc, ) -> Result>> { - if is_volatile(expr) || has_column_references(expr) { - return Ok(Transformed::no(Arc::clone(expr))); - } - - // Create a 1-row dummy batch for evaluation let batch = create_dummy_batch()?; + // If expr is already a const literal or can't be evaluated into one. + if expr.as_any().is::() || (!can_evaluate_as_constant(&expr)) { + return Ok(Transformed::no(expr)); + } // Evaluate the expression match expr.evaluate(&batch) { @@ -62,13 +64,77 @@ pub fn simplify_const_expr( } Ok(_) => { // Unexpected result - keep original expression - Ok(Transformed::no(Arc::clone(expr))) + Ok(Transformed::no(expr)) + } + Err(_) => { + // On error, keep original expression + // The expression might succeed at runtime due to short-circuit evaluation + // or other runtime conditions + Ok(Transformed::no(expr)) + } + } +} + +/// Simplify expressions whose immediate children are all literals. +/// +/// This function only checks the direct children of the expression, +/// not the entire subtree. It is designed to be used with bottom-up tree +/// traversal, where children are simplified before parents. +/// +/// # Example transformations +/// - `1 + 2` -> `3` +/// - `(1 + 2) * 3` -> `9` (with bottom-up traversal, inner expr simplified first) +/// - `'hello' || ' world'` -> `'hello world'` +pub(crate) fn simplify_const_expr_immediate( + expr: Arc, + batch: &RecordBatch, +) -> Result>> { + // Already a literal - nothing to do + if expr.as_any().is::() { + return Ok(Transformed::no(expr)); + } + + // Column references cannot be evaluated at plan time + if expr.as_any().is::() { + return Ok(Transformed::no(expr)); + } + + // Volatile nodes cannot be evaluated at plan time + if expr.is_volatile_node() { + return Ok(Transformed::no(expr)); + } + + // Since transform visits bottom-up, children have already been simplified. + // If all children are now Literals, this node can be const-evaluated. + // This is O(k) where k = number of children, instead of O(subtree). + let all_children_literal = expr + .children() + .iter() + .all(|child| child.as_any().is::()); + + if !all_children_literal { + return Ok(Transformed::no(expr)); + } + + // Evaluate the expression + match expr.evaluate(batch) { + Ok(ColumnarValue::Scalar(scalar)) => { + Ok(Transformed::yes(Arc::new(Literal::new(scalar)))) + } + Ok(ColumnarValue::Array(arr)) if arr.len() == 1 => { + // Some operations return an array even for scalar inputs + let scalar = ScalarValue::try_from_array(&arr, 0)?; + Ok(Transformed::yes(Arc::new(Literal::new(scalar)))) + } + Ok(_) => { + // Unexpected result - keep original expression + Ok(Transformed::no(expr)) } Err(_) => { // On error, keep original expression // The expression might succeed at runtime due to short-circuit evaluation // or other runtime conditions - Ok(Transformed::no(Arc::clone(expr))) + Ok(Transformed::no(expr)) } } } @@ -80,14 +146,34 @@ pub fn simplify_const_expr( /// that only contain literals, the batch content is irrelevant. /// /// This is the same approach used in the logical expression `ConstEvaluator`. -fn create_dummy_batch() -> Result { +pub(crate) fn create_dummy_batch() -> Result { // RecordBatch requires at least one column let dummy_schema = Arc::new(Schema::new(vec![Field::new("_", DataType::Null, true)])); let col = new_null_array(&DataType::Null, 1); Ok(RecordBatch::try_new(dummy_schema, vec![col])?) } +fn can_evaluate_as_constant(expr: &Arc) -> bool { + let mut can_evaluate = true; + + expr.apply(|e| { + if e.as_any().is::() || e.is_volatile_node() { + can_evaluate = false; + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + }) + .expect("apply should not fail"); + + can_evaluate +} + /// Check if this expression has any column references. +#[deprecated( + since = "53.0.0", + note = "This function isn't used internally and is trivial to implement, therefore it will be removed in a future release." +)] pub fn has_column_references(expr: &Arc) -> bool { let mut has_columns = false; expr.apply(|expr| { diff --git a/datafusion/physical-expr/src/simplifier/mod.rs b/datafusion/physical-expr/src/simplifier/mod.rs index 97395f4fe8a2..3f3f8573449e 100644 --- a/datafusion/physical-expr/src/simplifier/mod.rs +++ b/datafusion/physical-expr/src/simplifier/mod.rs @@ -21,7 +21,12 @@ use arrow::datatypes::Schema; use datafusion_common::{Result, tree_node::TreeNode}; use std::sync::Arc; -use crate::{PhysicalExpr, simplifier::not::simplify_not_expr}; +use crate::{ + PhysicalExpr, + simplifier::{ + const_evaluator::create_dummy_batch, unwrap_cast::unwrap_cast_in_comparison, + }, +}; pub mod const_evaluator; pub mod not; @@ -50,21 +55,24 @@ impl<'a> PhysicalExprSimplifier<'a> { let mut count = 0; let schema = self.schema; + let batch = create_dummy_batch()?; + while count < MAX_LOOP_COUNT { count += 1; let result = current_expr.transform(|node| { - #[cfg(test)] + #[cfg(debug_assertions)] let original_type = node.data_type(schema).unwrap(); // Apply NOT expression simplification first, then unwrap cast optimization, // then constant expression evaluation - let rewritten = simplify_not_expr(&node, schema)? + #[expect(deprecated, reason = "`simplify_not_expr` is marked as deprecated until it's made private.")] + let rewritten = not::simplify_not_expr(node, schema)? + .transform_data(|node| unwrap_cast_in_comparison(node, schema))? .transform_data(|node| { - unwrap_cast::unwrap_cast_in_comparison(node, schema) - })? - .transform_data(|node| const_evaluator::simplify_const_expr(&node))?; + const_evaluator::simplify_const_expr_immediate(node, &batch) + })?; - #[cfg(test)] + #[cfg(debug_assertions)] assert_eq!( rewritten.data.data_type(schema).unwrap(), original_type, diff --git a/datafusion/physical-expr/src/simplifier/not.rs b/datafusion/physical-expr/src/simplifier/not.rs index 9b65d5cba95a..709260aa4879 100644 --- a/datafusion/physical-expr/src/simplifier/not.rs +++ b/datafusion/physical-expr/src/simplifier/not.rs @@ -43,14 +43,18 @@ use crate::expressions::{BinaryExpr, InListExpr, Literal, NotExpr, in_list, lit} /// This function applies a single simplification rule and returns. When used with /// TreeNodeRewriter, multiple passes will automatically be applied until no more /// transformations are possible. +#[deprecated( + since = "53.0.0", + note = "This function will be made private in a future release, please file an issue if you have a reason for keeping it public." +)] pub fn simplify_not_expr( - expr: &Arc, + expr: Arc, schema: &Schema, ) -> Result>> { // Check if this is a NOT expression let not_expr = match expr.as_any().downcast_ref::() { Some(not_expr) => not_expr, - None => return Ok(Transformed::no(Arc::clone(expr))), + None => return Ok(Transformed::no(expr)), }; let inner_expr = not_expr.arg(); @@ -120,5 +124,5 @@ pub fn simplify_not_expr( } // If no simplification possible, return the original expression - Ok(Transformed::no(Arc::clone(expr))) + Ok(Transformed::no(expr)) } diff --git a/datafusion/physical-expr/src/simplifier/unwrap_cast.rs b/datafusion/physical-expr/src/simplifier/unwrap_cast.rs index ae6da9c5e0dc..0de517cd36c8 100644 --- a/datafusion/physical-expr/src/simplifier/unwrap_cast.rs +++ b/datafusion/physical-expr/src/simplifier/unwrap_cast.rs @@ -34,10 +34,7 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Schema}; -use datafusion_common::{ - Result, ScalarValue, - tree_node::{Transformed, TreeNode}, -}; +use datafusion_common::{Result, ScalarValue, tree_node::Transformed}; use datafusion_expr::Operator; use datafusion_expr_common::casts::try_cast_literal_to_type; @@ -49,14 +46,12 @@ pub(crate) fn unwrap_cast_in_comparison( expr: Arc, schema: &Schema, ) -> Result>> { - expr.transform_down(|e| { - if let Some(binary) = e.as_any().downcast_ref::() - && let Some(unwrapped) = try_unwrap_cast_binary(binary, schema)? - { - return Ok(Transformed::yes(unwrapped)); - } - Ok(Transformed::no(e)) - }) + if let Some(binary) = expr.as_any().downcast_ref::() + && let Some(unwrapped) = try_unwrap_cast_binary(binary, schema)? + { + return Ok(Transformed::yes(unwrapped)); + } + Ok(Transformed::no(expr)) } /// Try to unwrap casts in binary expressions @@ -144,7 +139,7 @@ mod tests { use super::*; use crate::expressions::{col, lit}; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::ScalarValue; + use datafusion_common::{ScalarValue, tree_node::TreeNode}; use datafusion_expr::Operator; /// Check if an expression is a cast expression @@ -484,8 +479,10 @@ mod tests { let and_expr = Arc::new(BinaryExpr::new(compare1, Operator::And, compare2)); - // Apply unwrap cast optimization - let result = unwrap_cast_in_comparison(and_expr, &schema).unwrap(); + // Apply unwrap cast optimization recursively + let result = (and_expr as Arc) + .transform_down(|node| unwrap_cast_in_comparison(node, &schema)) + .unwrap(); // Should be transformed assert!(result.transformed); @@ -602,8 +599,10 @@ mod tests { // Create AND expression let and_expr = Arc::new(BinaryExpr::new(c1_binary, Operator::And, c2_binary)); - // Apply unwrap cast optimization - let result = unwrap_cast_in_comparison(and_expr, &schema).unwrap(); + // Apply unwrap cast optimization recursively + let result = (and_expr as Arc) + .transform_down(|node| unwrap_cast_in_comparison(node, &schema)) + .unwrap(); // Should be transformed assert!(result.transformed); diff --git a/datafusion/physical-optimizer/src/aggregate_statistics.rs b/datafusion/physical-optimizer/src/aggregate_statistics.rs index cf3c15509c29..5caee8b047d8 100644 --- a/datafusion/physical-optimizer/src/aggregate_statistics.rs +++ b/datafusion/physical-optimizer/src/aggregate_statistics.rs @@ -20,7 +20,7 @@ use datafusion_common::Result; use datafusion_common::config::ConfigOptions; use datafusion_common::scalar::ScalarValue; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_physical_plan::aggregates::AggregateExec; +use datafusion_physical_plan::aggregates::{AggregateExec, AggregateInputMode}; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion_physical_plan::udaf::{AggregateFunctionExpr, StatisticsArgs}; @@ -116,13 +116,13 @@ impl PhysicalOptimizerRule for AggregateStatistics { /// the `ExecutionPlan.children()` method that returns an owned reference. fn take_optimizable(node: &dyn ExecutionPlan) -> Option> { if let Some(final_agg_exec) = node.as_any().downcast_ref::() - && !final_agg_exec.mode().is_first_stage() + && final_agg_exec.mode().input_mode() == AggregateInputMode::Partial && final_agg_exec.group_expr().is_empty() { let mut child = Arc::clone(final_agg_exec.input()); loop { if let Some(partial_agg_exec) = child.as_any().downcast_ref::() - && partial_agg_exec.mode().is_first_stage() + && partial_agg_exec.mode().input_mode() == AggregateInputMode::Raw && partial_agg_exec.group_expr().is_empty() && partial_agg_exec.filter_expr().iter().all(|e| e.is_none()) { diff --git a/datafusion/physical-optimizer/src/coalesce_batches.rs b/datafusion/physical-optimizer/src/coalesce_batches.rs deleted file mode 100644 index bedb7f6be049..000000000000 --- a/datafusion/physical-optimizer/src/coalesce_batches.rs +++ /dev/null @@ -1,87 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! CoalesceBatches optimizer that groups batches together rows -//! in bigger batches to avoid overhead with small batches - -use crate::PhysicalOptimizerRule; - -use std::sync::Arc; - -use datafusion_common::assert_eq_or_internal_err; -use datafusion_common::config::ConfigOptions; -use datafusion_common::error::Result; -use datafusion_physical_plan::{ - ExecutionPlan, async_func::AsyncFuncExec, coalesce_batches::CoalesceBatchesExec, -}; - -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; - -/// Optimizer rule that introduces CoalesceBatchesExec to avoid overhead with small batches that -/// are produced by highly selective filters -#[derive(Default, Debug)] -pub struct CoalesceBatches {} - -impl CoalesceBatches { - #[expect(missing_docs)] - pub fn new() -> Self { - Self::default() - } -} -impl PhysicalOptimizerRule for CoalesceBatches { - fn optimize( - &self, - plan: Arc, - config: &ConfigOptions, - ) -> Result> { - if !config.execution.coalesce_batches { - return Ok(plan); - } - - let target_batch_size = config.execution.batch_size; - plan.transform_up(|plan| { - let plan_any = plan.as_any(); - if let Some(async_exec) = plan_any.downcast_ref::() { - // Coalesce inputs to async functions to reduce number of async function invocations - let children = async_exec.children(); - assert_eq_or_internal_err!( - children.len(), - 1, - "Expected AsyncFuncExec to have exactly one child" - ); - - let coalesce_exec = Arc::new(CoalesceBatchesExec::new( - Arc::clone(children[0]), - target_batch_size, - )); - let new_plan = plan.with_new_children(vec![coalesce_exec])?; - Ok(Transformed::yes(new_plan)) - } else { - Ok(Transformed::no(plan)) - } - }) - .data() - } - - fn name(&self) -> &str { - "coalesce_batches" - } - - fn schema_check(&self) -> bool { - true - } -} diff --git a/datafusion/physical-optimizer/src/combine_partial_final_agg.rs b/datafusion/physical-optimizer/src/combine_partial_final_agg.rs index 782e0754b7d2..860406118c1b 100644 --- a/datafusion/physical-optimizer/src/combine_partial_final_agg.rs +++ b/datafusion/physical-optimizer/src/combine_partial_final_agg.rs @@ -72,7 +72,7 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate { return Ok(Transformed::no(plan)); }; - let transformed = if matches!(input_agg_exec.mode(), AggregateMode::Partial) + let transformed = if *input_agg_exec.mode() == AggregateMode::Partial && can_combine( ( agg_exec.group_expr(), @@ -98,7 +98,9 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate { Arc::clone(input_agg_exec.input()), input_agg_exec.input_schema(), ) - .map(|combined_agg| combined_agg.with_limit(agg_exec.limit())) + .map(|combined_agg| { + combined_agg.with_limit_options(agg_exec.limit_options()) + }) .ok() .map(Arc::new) } else { diff --git a/datafusion/physical-optimizer/src/enforce_distribution.rs b/datafusion/physical-optimizer/src/enforce_distribution.rs index 6120e1f3b582..d23a699f715d 100644 --- a/datafusion/physical-optimizer/src/enforce_distribution.rs +++ b/datafusion/physical-optimizer/src/enforce_distribution.rs @@ -36,7 +36,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::error::Result; use datafusion_common::stats::Precision; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_expr::logical_plan::JoinType; +use datafusion_expr::logical_plan::{Aggregate, JoinType}; use datafusion_physical_expr::expressions::{Column, NoOp}; use datafusion_physical_expr::utils::map_columns_before_projection; use datafusion_physical_expr::{ @@ -286,17 +286,15 @@ pub fn adjust_input_keys_ordering( ) -> Result> { let plan = Arc::clone(&requirements.plan); - if let Some(HashJoinExec { - left, - right, - on, - filter, - join_type, - projection, - mode, - null_equality, - .. - }) = plan.as_any().downcast_ref::() + if let Some( + exec @ HashJoinExec { + left, + on, + join_type, + mode, + .. + }, + ) = plan.as_any().downcast_ref::() { match mode { PartitionMode::Partitioned => { @@ -304,18 +302,10 @@ pub fn adjust_input_keys_ordering( Vec<(PhysicalExprRef, PhysicalExprRef)>, Vec, )| { - HashJoinExec::try_new( - Arc::clone(left), - Arc::clone(right), - new_conditions.0, - filter.clone(), - join_type, - // TODO: although projection is not used in the join here, because projection pushdown is after enforce_distribution. Maybe we need to handle it later. Same as filter. - projection.clone(), - PartitionMode::Partitioned, - *null_equality, - ) - .map(|e| Arc::new(e) as _) + exec.builder() + .with_partition_mode(PartitionMode::Partitioned) + .with_on(new_conditions.0) + .build_exec() }; return reorder_partitioned_join_keys( requirements, @@ -495,7 +485,7 @@ pub fn reorder_aggregate_keys( && !physical_exprs_equal(&output_exprs, parent_required) && let Some(positions) = expected_expr_positions(&output_exprs, parent_required) && let Some(agg_exec) = agg_exec.input().as_any().downcast_ref::() - && matches!(agg_exec.mode(), &AggregateMode::Partial) + && *agg_exec.mode() == AggregateMode::Partial { let group_exprs = agg_exec.group_expr().expr(); let new_group_exprs = positions @@ -609,19 +599,17 @@ pub fn reorder_join_keys_to_inputs( plan: Arc, ) -> Result> { let plan_any = plan.as_any(); - if let Some(HashJoinExec { - left, - right, - on, - filter, - join_type, - projection, - mode, - null_equality, - .. - }) = plan_any.downcast_ref::() + if let Some( + exec @ HashJoinExec { + left, + right, + on, + mode, + .. + }, + ) = plan_any.downcast_ref::() { - if matches!(mode, PartitionMode::Partitioned) { + if *mode == PartitionMode::Partitioned { let (join_keys, positions) = reorder_current_join_keys( extract_join_keys(on), Some(left.output_partitioning()), @@ -635,16 +623,11 @@ pub fn reorder_join_keys_to_inputs( right_keys, } = join_keys; let new_join_on = new_join_conditions(&left_keys, &right_keys); - return Ok(Arc::new(HashJoinExec::try_new( - Arc::clone(left), - Arc::clone(right), - new_join_on, - filter.clone(), - join_type, - projection.clone(), - PartitionMode::Partitioned, - *null_equality, - )?)); + return exec + .builder() + .with_partition_mode(PartitionMode::Partitioned) + .with_on(new_join_on) + .build_exec(); } } } else if let Some(SortMergeJoinExec { @@ -1256,7 +1239,7 @@ pub fn ensure_distribution( let is_partitioned_join = plan .as_any() .downcast_ref::() - .is_some_and(|join| matches!(join.mode, PartitionMode::Partitioned)) + .is_some_and(|join| join.mode == PartitionMode::Partitioned) || plan.as_any().is::(); let repartition_status_flags = @@ -1297,10 +1280,25 @@ pub fn ensure_distribution( // Allow subset satisfaction when: // 1. Current partition count >= threshold // 2. Not a partitioned join since must use exact hash matching for joins + // 3. Not a grouping set aggregate (requires exact hash including __grouping_id) let current_partitions = child.plan.output_partitioning().partition_count(); + + // Check if the hash partitioning requirement includes __grouping_id column. + // Grouping set aggregates (ROLLUP, CUBE, GROUPING SETS) require exact hash + // partitioning on all group columns including __grouping_id to ensure partial + // aggregates from different partitions are correctly combined. + let requires_grouping_id = matches!(&requirement, Distribution::HashPartitioned(exprs) + if exprs.iter().any(|expr| { + expr.as_any() + .downcast_ref::() + .is_some_and(|col| col.name() == Aggregate::INTERNAL_GROUPING_ID) + }) + ); + let allow_subset_satisfy_partitioning = current_partitions >= subset_satisfaction_threshold - && !is_partitioned_join; + && !is_partitioned_join + && !requires_grouping_id; // When `repartition_file_scans` is set, attempt to increase // parallelism at the source. diff --git a/datafusion/physical-optimizer/src/enforce_sorting/mod.rs b/datafusion/physical-optimizer/src/enforce_sorting/mod.rs index a5fafb9e87e1..247ebb2785dd 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/mod.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/mod.rs @@ -581,11 +581,17 @@ fn analyze_immediate_sort_removal( // Remove the sort: node.children = node.children.swap_remove(0).children; if let Some(fetch) = sort_exec.fetch() { + let required_ordering = sort_exec.properties().output_ordering().cloned(); // If the sort has a fetch, we need to add a limit: if properties.output_partitioning().partition_count() == 1 { - Arc::new(GlobalLimitExec::new(Arc::clone(sort_input), 0, Some(fetch))) + let mut global_limit = + GlobalLimitExec::new(Arc::clone(sort_input), 0, Some(fetch)); + global_limit.set_required_ordering(required_ordering); + Arc::new(global_limit) } else { - Arc::new(LocalLimitExec::new(Arc::clone(sort_input), fetch)) + let mut local_limit = LocalLimitExec::new(Arc::clone(sort_input), fetch); + local_limit.set_required_ordering(required_ordering); + Arc::new(local_limit) } } else { Arc::clone(sort_input) diff --git a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs index 698fdea8e766..2d9bfe217f40 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs @@ -35,6 +35,7 @@ use datafusion_physical_expr_common::sort_expr::{ LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortExpr, PhysicalSortRequirement, }; +use datafusion_physical_plan::aggregates::AggregateExec; use datafusion_physical_plan::execution_plan::CardinalityEffect; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::utils::{ @@ -353,6 +354,8 @@ fn pushdown_requirement_to_children( Ok(None) } } + } else if let Some(aggregate_exec) = plan.as_any().downcast_ref::() { + handle_aggregate_pushdown(aggregate_exec, parent_required) } else if maintains_input_order.is_empty() || !maintains_input_order.iter().any(|o| *o) || plan.as_any().is::() @@ -388,6 +391,77 @@ fn pushdown_requirement_to_children( // TODO: Add support for Projection push down } +/// Try to push sorting through [`AggregateExec`] +/// +/// `AggregateExec` only preserves the input order of its group by columns +/// (not aggregates in general, which are formed from arbitrary expressions over +/// input) +/// +/// Thus function rewrites the parent required ordering in terms of the +/// aggregate input if possible. This rewritten requirement represents the +/// ordering of the `AggregateExec`'s **input** that would also satisfy the +/// **parent** ordering. +/// +/// If no such mapping is possible (e.g. because the sort references aggregate +/// columns), returns None. +fn handle_aggregate_pushdown( + aggregate_exec: &AggregateExec, + parent_required: OrderingRequirements, +) -> Result>>> { + if !aggregate_exec + .maintains_input_order() + .into_iter() + .any(|o| o) + { + return Ok(None); + } + + let group_expr = aggregate_exec.group_expr(); + // GROUPING SETS introduce additional output columns and NULL substitutions; + // skip pushdown until we can map those cases safely. + if group_expr.has_grouping_set() { + return Ok(None); + } + + let group_input_exprs = group_expr.input_exprs(); + let parent_requirement = parent_required.into_single(); + let mut child_requirement = Vec::with_capacity(parent_requirement.len()); + + for req in parent_requirement { + // Sort above AggregateExec should reference its output columns. Map each + // output group-by column to its original input expression. + let Some(column) = req.expr.as_any().downcast_ref::() else { + return Ok(None); + }; + if column.index() >= group_input_exprs.len() { + // AggregateExec does not produce output that is sorted on aggregate + // columns so those can not be pushed through. + return Ok(None); + } + child_requirement.push(PhysicalSortRequirement::new( + Arc::clone(&group_input_exprs[column.index()]), + req.options, + )); + } + + let Some(child_requirement) = LexRequirement::new(child_requirement) else { + return Ok(None); + }; + + // Keep sort above aggregate unless input ordering already satisfies the + // mapped requirement. + if aggregate_exec + .input() + .equivalence_properties() + .ordering_satisfy_requirement(child_requirement.iter().cloned())? + { + let child_requirements = OrderingRequirements::new(child_requirement); + Ok(Some(vec![Some(child_requirements)])) + } else { + Ok(None) + } +} + /// Return true if pushing the sort requirements through a node would violate /// the input sorting requirements for the plan fn pushdown_would_violate_requirements( @@ -723,7 +797,7 @@ fn handle_hash_join( .collect(); let column_indices = build_join_column_index(plan); - let projected_indices: Vec<_> = if let Some(projection) = &plan.projection { + let projected_indices: Vec<_> = if let Some(projection) = plan.projection.as_ref() { projection.iter().map(|&i| &column_indices[i]).collect() } else { column_indices.iter().collect() diff --git a/datafusion/physical-optimizer/src/ensure_coop.rs b/datafusion/physical-optimizer/src/ensure_coop.rs index dfa97fc84033..ef8946f9a49d 100644 --- a/datafusion/physical-optimizer/src/ensure_coop.rs +++ b/datafusion/physical-optimizer/src/ensure_coop.rs @@ -27,7 +27,7 @@ use crate::PhysicalOptimizerRule; use datafusion_common::Result; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_physical_plan::ExecutionPlan; use datafusion_physical_plan::coop::CooperativeExec; use datafusion_physical_plan::execution_plan::{EvaluationType, SchedulingType}; @@ -67,23 +67,57 @@ impl PhysicalOptimizerRule for EnsureCooperative { plan: Arc, _config: &ConfigOptions, ) -> Result> { - plan.transform_up(|plan| { - let is_leaf = plan.children().is_empty(); - let is_exchange = plan.properties().evaluation_type == EvaluationType::Eager; - if (is_leaf || is_exchange) - && plan.properties().scheduling_type != SchedulingType::Cooperative - { - // Wrap non-cooperative leaves or eager evaluation roots in a cooperative exec to - // ensure the plans they participate in are properly cooperative. - Ok(Transformed::new( - Arc::new(CooperativeExec::new(Arc::clone(&plan))), - true, - TreeNodeRecursion::Continue, - )) - } else { + use std::cell::RefCell; + + let ancestry_stack = RefCell::new(Vec::<(SchedulingType, EvaluationType)>::new()); + + plan.transform_down_up( + // Down phase: Push parent properties into the stack + |plan| { + let props = plan.properties(); + ancestry_stack + .borrow_mut() + .push((props.scheduling_type, props.evaluation_type)); Ok(Transformed::no(plan)) - } - }) + }, + // Up phase: Wrap nodes with CooperativeExec if needed + |plan| { + ancestry_stack.borrow_mut().pop(); + + let props = plan.properties(); + let is_cooperative = props.scheduling_type == SchedulingType::Cooperative; + let is_leaf = plan.children().is_empty(); + let is_exchange = props.evaluation_type == EvaluationType::Eager; + + let mut is_under_cooperative_context = false; + for (scheduling_type, evaluation_type) in + ancestry_stack.borrow().iter().rev() + { + // If nearest ancestor is cooperative, we are under a cooperative context + if *scheduling_type == SchedulingType::Cooperative { + is_under_cooperative_context = true; + break; + // If nearest ancestor is eager, the cooperative context will be reset + } else if *evaluation_type == EvaluationType::Eager { + is_under_cooperative_context = false; + break; + } + } + + // Wrap if: + // 1. Node is a leaf or exchange point + // 2. Node is not already cooperative + // 3. Not under any Cooperative context + if (is_leaf || is_exchange) + && !is_cooperative + && !is_under_cooperative_context + { + return Ok(Transformed::yes(Arc::new(CooperativeExec::new(plan)))); + } + + Ok(Transformed::no(plan)) + }, + ) .map(|t| t.data) } @@ -115,4 +149,264 @@ mod tests { DataSourceExec: partitions=1, partition_sizes=[1] "); } + + #[tokio::test] + async fn test_optimizer_is_idempotent() { + // Comprehensive idempotency test: verify f(f(...f(x))) = f(x) + // This test covers: + // 1. Multiple runs on unwrapped plan + // 2. Multiple runs on already-wrapped plan + // 3. No accumulation of CooperativeExec nodes + + let config = ConfigOptions::new(); + let rule = EnsureCooperative::new(); + + // Test 1: Start with unwrapped plan, run multiple times + let unwrapped_plan = scan_partitioned(1); + let mut current = unwrapped_plan; + let mut stable_result = String::new(); + + for run in 1..=5 { + current = rule.optimize(current, &config).unwrap(); + let display = displayable(current.as_ref()).indent(true).to_string(); + + if run == 1 { + stable_result = display.clone(); + assert_eq!(display.matches("CooperativeExec").count(), 1); + } else { + assert_eq!( + display, stable_result, + "Run {run} should match run 1 (idempotent)" + ); + assert_eq!( + display.matches("CooperativeExec").count(), + 1, + "Should always have exactly 1 CooperativeExec, not accumulate" + ); + } + } + + // Test 2: Start with already-wrapped plan, verify no double wrapping + let pre_wrapped = Arc::new(CooperativeExec::new(scan_partitioned(1))); + let result = rule.optimize(pre_wrapped, &config).unwrap(); + let display = displayable(result.as_ref()).indent(true).to_string(); + + assert_eq!( + display.matches("CooperativeExec").count(), + 1, + "Should not double-wrap already cooperative plans" + ); + assert_eq!( + display, stable_result, + "Pre-wrapped plan should produce same result as unwrapped after optimization" + ); + } + + #[tokio::test] + async fn test_selective_wrapping() { + // Test that wrapping is selective: only leaf/eager nodes, not intermediate nodes + // Also verify depth tracking prevents double wrapping in subtrees + use datafusion_physical_expr::expressions::lit; + use datafusion_physical_plan::filter::FilterExec; + + let config = ConfigOptions::new(); + let rule = EnsureCooperative::new(); + + // Case 1: Filter -> Scan (middle node should not be wrapped) + let scan = scan_partitioned(1); + let filter = Arc::new(FilterExec::try_new(lit(true), scan).unwrap()); + let optimized = rule.optimize(filter, &config).unwrap(); + let display = displayable(optimized.as_ref()).indent(true).to_string(); + + assert_eq!(display.matches("CooperativeExec").count(), 1); + assert!(display.contains("FilterExec")); + + // Case 2: Filter -> CoopExec -> Scan (depth tracking prevents double wrap) + let scan2 = scan_partitioned(1); + let wrapped_scan = Arc::new(CooperativeExec::new(scan2)); + let filter2 = Arc::new(FilterExec::try_new(lit(true), wrapped_scan).unwrap()); + let optimized2 = rule.optimize(filter2, &config).unwrap(); + let display2 = displayable(optimized2.as_ref()).indent(true).to_string(); + + assert_eq!(display2.matches("CooperativeExec").count(), 1); + } + + #[tokio::test] + async fn test_multiple_leaf_nodes() { + // When there are multiple leaf nodes, each should be wrapped separately + use datafusion_physical_plan::union::UnionExec; + + let scan1 = scan_partitioned(1); + let scan2 = scan_partitioned(1); + let union = UnionExec::try_new(vec![scan1, scan2]).unwrap(); + + let config = ConfigOptions::new(); + let optimized = EnsureCooperative::new() + .optimize(union as Arc, &config) + .unwrap(); + + let display = displayable(optimized.as_ref()).indent(true).to_string(); + + // Each leaf should have its own CooperativeExec + assert_eq!( + display.matches("CooperativeExec").count(), + 2, + "Each leaf node should be wrapped separately" + ); + assert_eq!( + display.matches("DataSourceExec").count(), + 2, + "Both data sources should be present" + ); + } + + #[tokio::test] + async fn test_eager_evaluation_resets_cooperative_context() { + // Test that cooperative context is reset when encountering an eager evaluation boundary. + use arrow::datatypes::Schema; + use datafusion_common::{Result, internal_err}; + use datafusion_execution::TaskContext; + use datafusion_physical_expr::EquivalenceProperties; + use datafusion_physical_plan::{ + DisplayAs, DisplayFormatType, Partitioning, PlanProperties, + SendableRecordBatchStream, + execution_plan::{Boundedness, EmissionType}, + }; + use std::any::Any; + use std::fmt::Formatter; + + #[derive(Debug)] + struct DummyExec { + name: String, + input: Arc, + scheduling_type: SchedulingType, + evaluation_type: EvaluationType, + properties: Arc, + } + + impl DummyExec { + fn new( + name: &str, + input: Arc, + scheduling_type: SchedulingType, + evaluation_type: EvaluationType, + ) -> Self { + let properties = PlanProperties::new( + EquivalenceProperties::new(Arc::new(Schema::empty())), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ) + .with_scheduling_type(scheduling_type) + .with_evaluation_type(evaluation_type); + + Self { + name: name.to_string(), + input, + scheduling_type, + evaluation_type, + properties: Arc::new(properties), + } + } + } + + impl DisplayAs for DummyExec { + fn fmt_as( + &self, + _: DisplayFormatType, + f: &mut Formatter, + ) -> std::fmt::Result { + write!(f, "{}", self.name) + } + } + + impl ExecutionPlan for DummyExec { + fn name(&self) -> &str { + &self.name + } + fn as_any(&self) -> &dyn Any { + self + } + fn properties(&self) -> &Arc { + &self.properties + } + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(DummyExec::new( + &self.name, + Arc::clone(&children[0]), + self.scheduling_type, + self.evaluation_type, + ))) + } + fn execute( + &self, + _: usize, + _: Arc, + ) -> Result { + internal_err!("DummyExec does not support execution") + } + } + + // Build a plan similar to the original test: + // scan -> exch1(NonCoop,Eager) -> CoopExec -> filter -> exch2(Coop,Eager) -> filter + let scan = scan_partitioned(1); + let exch1 = Arc::new(DummyExec::new( + "exch1", + scan, + SchedulingType::NonCooperative, + EvaluationType::Eager, + )); + let coop = Arc::new(CooperativeExec::new(exch1)); + let filter1 = Arc::new(DummyExec::new( + "filter1", + coop, + SchedulingType::NonCooperative, + EvaluationType::Lazy, + )); + let exch2 = Arc::new(DummyExec::new( + "exch2", + filter1, + SchedulingType::Cooperative, + EvaluationType::Eager, + )); + let filter2 = Arc::new(DummyExec::new( + "filter2", + exch2, + SchedulingType::NonCooperative, + EvaluationType::Lazy, + )); + + let config = ConfigOptions::new(); + let optimized = EnsureCooperative::new().optimize(filter2, &config).unwrap(); + + let display = displayable(optimized.as_ref()).indent(true).to_string(); + + // Expected wrapping: + // - Scan (leaf) gets wrapped + // - exch1 (eager+noncoop) keeps its manual CooperativeExec wrapper + // - filter1 is protected by exch2's cooperative context, no extra wrap + // - exch2 (already Cooperative) does NOT get wrapped + // - filter2 (not leaf or eager) does NOT get wrapped + assert_eq!( + display.matches("CooperativeExec").count(), + 2, + "Should have 2 CooperativeExec: one wrapping scan, one wrapping exch1" + ); + + assert_snapshot!(display, @r" + filter2 + exch2 + filter1 + CooperativeExec + exch1 + CooperativeExec + DataSourceExec: partitions=1, partition_sizes=[1] + "); + } } diff --git a/datafusion/physical-optimizer/src/join_selection.rs b/datafusion/physical-optimizer/src/join_selection.rs index f837c79a4e39..29bbc8e10888 100644 --- a/datafusion/physical-optimizer/src/join_selection.rs +++ b/datafusion/physical-optimizer/src/join_selection.rs @@ -184,35 +184,30 @@ pub(crate) fn try_collect_left( match (left_can_collect, right_can_collect) { (true, true) => { + // Don't swap null-aware anti joins as they have specific side requirements if hash_join.join_type().supports_swap() + && !hash_join.null_aware && should_swap_join_order(&**left, &**right)? { Ok(Some(hash_join.swap_inputs(PartitionMode::CollectLeft)?)) } else { - Ok(Some(Arc::new(HashJoinExec::try_new( - Arc::clone(left), - Arc::clone(right), - hash_join.on().to_vec(), - hash_join.filter().cloned(), - hash_join.join_type(), - hash_join.projection.clone(), - PartitionMode::CollectLeft, - hash_join.null_equality(), - )?))) + Ok(Some(Arc::new( + hash_join + .builder() + .with_partition_mode(PartitionMode::CollectLeft) + .build()?, + ))) } } - (true, false) => Ok(Some(Arc::new(HashJoinExec::try_new( - Arc::clone(left), - Arc::clone(right), - hash_join.on().to_vec(), - hash_join.filter().cloned(), - hash_join.join_type(), - hash_join.projection.clone(), - PartitionMode::CollectLeft, - hash_join.null_equality(), - )?))), + (true, false) => Ok(Some(Arc::new( + hash_join + .builder() + .with_partition_mode(PartitionMode::CollectLeft) + .build()?, + ))), (false, true) => { - if hash_join.join_type().supports_swap() { + // Don't swap null-aware anti joins as they have specific side requirements + if hash_join.join_type().supports_swap() && !hash_join.null_aware { hash_join.swap_inputs(PartitionMode::CollectLeft).map(Some) } else { Ok(None) @@ -232,20 +227,29 @@ pub(crate) fn partitioned_hash_join( ) -> Result> { let left = hash_join.left(); let right = hash_join.right(); - if hash_join.join_type().supports_swap() && should_swap_join_order(&**left, &**right)? + // Don't swap null-aware anti joins as they have specific side requirements + if hash_join.join_type().supports_swap() + && !hash_join.null_aware + && should_swap_join_order(&**left, &**right)? { hash_join.swap_inputs(PartitionMode::Partitioned) } else { - Ok(Arc::new(HashJoinExec::try_new( - Arc::clone(left), - Arc::clone(right), - hash_join.on().to_vec(), - hash_join.filter().cloned(), - hash_join.join_type(), - hash_join.projection.clone(), - PartitionMode::Partitioned, - hash_join.null_equality(), - )?)) + // Null-aware anti joins must use CollectLeft mode because they track probe-side state + // (probe_side_non_empty, probe_side_has_null) per-partition, but need global knowledge + // for correct null handling. With partitioning, a partition might not see probe rows + // even if the probe side is globally non-empty, leading to incorrect NULL row handling. + let partition_mode = if hash_join.null_aware { + PartitionMode::CollectLeft + } else { + PartitionMode::Partitioned + }; + + Ok(Arc::new( + hash_join + .builder() + .with_partition_mode(partition_mode) + .build()?, + )) } } @@ -277,7 +281,9 @@ fn statistical_join_selection_subrule( PartitionMode::Partitioned => { let left = hash_join.left(); let right = hash_join.right(); + // Don't swap null-aware anti joins as they have specific side requirements if hash_join.join_type().supports_swap() + && !hash_join.null_aware && should_swap_join_order(&**left, &**right)? { hash_join @@ -484,6 +490,7 @@ pub fn hash_join_swap_subrule( if let Some(hash_join) = input.as_any().downcast_ref::() && hash_join.left.boundedness().is_unbounded() && !hash_join.right.boundedness().is_unbounded() + && !hash_join.null_aware // Don't swap null-aware anti joins && matches!( *hash_join.join_type(), JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti diff --git a/datafusion/physical-optimizer/src/lib.rs b/datafusion/physical-optimizer/src/lib.rs index 1b45f02ebd51..3a0d79ae2d23 100644 --- a/datafusion/physical-optimizer/src/lib.rs +++ b/datafusion/physical-optimizer/src/lib.rs @@ -24,11 +24,8 @@ // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -// https://github.com/apache/datafusion/issues/18881 -#![deny(clippy::allow_attributes)] pub mod aggregate_statistics; -pub mod coalesce_batches; pub mod combine_partial_final_agg; pub mod enforce_distribution; pub mod enforce_sorting; diff --git a/datafusion/physical-optimizer/src/limit_pushdown.rs b/datafusion/physical-optimizer/src/limit_pushdown.rs index 4cb3abe30bae..e7bede494da9 100644 --- a/datafusion/physical-optimizer/src/limit_pushdown.rs +++ b/datafusion/physical-optimizer/src/limit_pushdown.rs @@ -50,6 +50,7 @@ pub struct GlobalRequirements { fetch: Option, skip: usize, satisfied: bool, + preserve_order: bool, } impl LimitPushdown { @@ -69,6 +70,7 @@ impl PhysicalOptimizerRule for LimitPushdown { fetch: None, skip: 0, satisfied: false, + preserve_order: false, }; pushdown_limits(plan, global_state) } @@ -111,6 +113,13 @@ impl LimitExec { Self::Local(_) => 0, } } + + fn preserve_order(&self) -> bool { + match self { + Self::Global(global) => global.required_ordering().is_some(), + Self::Local(local) => local.required_ordering().is_some(), + } + } } impl From for Arc { @@ -145,6 +154,8 @@ pub fn pushdown_limit_helper( ); global_state.skip = skip; global_state.fetch = fetch; + global_state.preserve_order = limit_exec.preserve_order(); + global_state.satisfied = false; // Now the global state has the most recent information, we can remove // the `LimitExec` plan. We will decide later if we should add it again @@ -162,7 +173,7 @@ pub fn pushdown_limit_helper( // If we have a non-limit operator with fetch capability, update global // state as necessary: if pushdown_plan.fetch().is_some() { - if global_state.fetch.is_none() { + if global_state.skip == 0 { global_state.satisfied = true; } (global_state.skip, global_state.fetch) = combine_limit( @@ -241,17 +252,28 @@ pub fn pushdown_limit_helper( let maybe_fetchable = pushdown_plan.with_fetch(skip_and_fetch); if global_state.satisfied { if let Some(plan_with_fetch) = maybe_fetchable { - Ok((Transformed::yes(plan_with_fetch), global_state)) + let plan_with_preserve_order = plan_with_fetch + .with_preserve_order(global_state.preserve_order) + .unwrap_or(plan_with_fetch); + Ok((Transformed::yes(plan_with_preserve_order), global_state)) } else { Ok((Transformed::no(pushdown_plan), global_state)) } } else { global_state.satisfied = true; pushdown_plan = if let Some(plan_with_fetch) = maybe_fetchable { + let plan_with_preserve_order = plan_with_fetch + .with_preserve_order(global_state.preserve_order) + .unwrap_or(plan_with_fetch); + if global_skip > 0 { - add_global_limit(plan_with_fetch, global_skip, Some(global_fetch)) + add_global_limit( + plan_with_preserve_order, + global_skip, + Some(global_fetch), + ) } else { - plan_with_fetch + plan_with_preserve_order } } else { add_limit(pushdown_plan, global_skip, global_fetch) diff --git a/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs b/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs index 671d247cf36a..fe9636f67619 100644 --- a/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs +++ b/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs @@ -20,7 +20,7 @@ use std::sync::Arc; -use datafusion_physical_plan::aggregates::AggregateExec; +use datafusion_physical_plan::aggregates::{AggregateExec, LimitOptions}; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; @@ -63,7 +63,7 @@ impl LimitedDistinctAggregation { aggr.input_schema(), ) .expect("Unable to copy Aggregate!") - .with_limit(Some(limit)); + .with_limit_options(Some(LimitOptions::new(limit))); Some(Arc::new(new_aggr)) } diff --git a/datafusion/physical-optimizer/src/optimizer.rs b/datafusion/physical-optimizer/src/optimizer.rs index aa1975d98d48..49225db03ac4 100644 --- a/datafusion/physical-optimizer/src/optimizer.rs +++ b/datafusion/physical-optimizer/src/optimizer.rs @@ -21,7 +21,6 @@ use std::fmt::Debug; use std::sync::Arc; use crate::aggregate_statistics::AggregateStatistics; -use crate::coalesce_batches::CoalesceBatches; use crate::combine_partial_final_agg::CombinePartialFinalAggregate; use crate::enforce_distribution::EnforceDistribution; use crate::enforce_sorting::EnforceSorting; @@ -83,6 +82,12 @@ impl Default for PhysicalOptimizer { impl PhysicalOptimizer { /// Create a new optimizer using the recommended list of rules pub fn new() -> Self { + // NOTEs: + // - The order of rules in this list is important, as it determines the + // order in which they are applied. + // - Adding a new rule here is expensive as it will be applied to all + // queries, and will likely increase the optimization time. Please extend + // existing rules when possible, rather than adding a new rule. let rules: Vec> = vec![ // If there is a output requirement of the query, make sure that // this information is not lost across different rules during optimization. @@ -120,9 +125,6 @@ impl PhysicalOptimizer { Arc::new(OptimizeAggregateOrder::new()), // TODO: `try_embed_to_hash_join` in the ProjectionPushdown rule would be block by the CoalesceBatches, so add it before CoalesceBatches. Maybe optimize it in the future. Arc::new(ProjectionPushdown::new()), - // The CoalesceBatches rule will not influence the distribution and ordering of the - // whole plan tree. Therefore, to avoid influencing other rules, it should run last. - Arc::new(CoalesceBatches::new()), // Remove the ancillary output requirement operator since we are done with the planning // phase. Arc::new(OutputRequirements::new_remove_mode()), diff --git a/datafusion/physical-optimizer/src/output_requirements.rs b/datafusion/physical-optimizer/src/output_requirements.rs index 0dc6a25fbc0b..75721951f8d8 100644 --- a/datafusion/physical-optimizer/src/output_requirements.rs +++ b/datafusion/physical-optimizer/src/output_requirements.rs @@ -98,7 +98,7 @@ pub struct OutputRequirementExec { input: Arc, order_requirement: Option, dist_requirement: Distribution, - cache: PlanProperties, + cache: Arc, fetch: Option, } @@ -114,7 +114,7 @@ impl OutputRequirementExec { input, order_requirement: requirements, dist_requirement, - cache, + cache: Arc::new(cache), fetch, } } @@ -200,7 +200,7 @@ impl ExecutionPlan for OutputRequirementExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -244,10 +244,6 @@ impl ExecutionPlan for OutputRequirementExec { unreachable!(); } - fn statistics(&self) -> Result { - self.input.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { self.input.partition_statistics(partition) } diff --git a/datafusion/physical-optimizer/src/projection_pushdown.rs b/datafusion/physical-optimizer/src/projection_pushdown.rs index 281d61aecf53..44d0926a8b25 100644 --- a/datafusion/physical-optimizer/src/projection_pushdown.rs +++ b/datafusion/physical-optimizer/src/projection_pushdown.rs @@ -32,7 +32,7 @@ use datafusion_common::tree_node::{ }; use datafusion_common::{JoinSide, JoinType, Result}; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::physical_expr::{PhysicalExpr, is_volatile}; use datafusion_physical_plan::ExecutionPlan; use datafusion_physical_plan::joins::NestedLoopJoinExec; use datafusion_physical_plan::joins::utils::{ColumnIndex, JoinFilter}; @@ -135,7 +135,7 @@ fn try_push_down_join_filter( ); let new_lhs_length = lhs_rewrite.data.0.schema().fields.len(); - let projections = match projections { + let projections = match projections.as_ref() { None => match join.join_type() { JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { // Build projections that ignore the newly projected columns. @@ -349,8 +349,7 @@ impl<'a> JoinFilterRewriter<'a> { // Recurse if there is a dependency to both sides or if the entire expression is volatile. let depends_on_other_side = self.depends_on_join_side(&expr, self.join_side.negate())?; - let is_volatile = is_volatile_expression_tree(expr.as_ref()); - if depends_on_other_side || is_volatile { + if depends_on_other_side || is_volatile(&expr) { return expr.map_children(|expr| self.rewrite(expr)); } @@ -431,18 +430,6 @@ impl<'a> JoinFilterRewriter<'a> { } } -fn is_volatile_expression_tree(expr: &dyn PhysicalExpr) -> bool { - if expr.is_volatile_node() { - return true; - } - - expr.children() - .iter() - .map(|expr| is_volatile_expression_tree(expr.as_ref())) - .reduce(|lhs, rhs| lhs || rhs) - .unwrap_or(false) -} - #[cfg(test)] mod test { use super::*; diff --git a/datafusion/physical-optimizer/src/topk_aggregation.rs b/datafusion/physical-optimizer/src/topk_aggregation.rs index 7eb9e6a76211..cec6bd70a208 100644 --- a/datafusion/physical-optimizer/src/topk_aggregation.rs +++ b/datafusion/physical-optimizer/src/topk_aggregation.rs @@ -20,13 +20,13 @@ use std::sync::Arc; use crate::PhysicalOptimizerRule; -use arrow::datatypes::DataType; use datafusion_common::Result; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::ExecutionPlan; -use datafusion_physical_plan::aggregates::AggregateExec; +use datafusion_physical_plan::aggregates::LimitOptions; +use datafusion_physical_plan::aggregates::{AggregateExec, topk_types_supported}; use datafusion_physical_plan::execution_plan::CardinalityEffect; use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::sorts::sort::SortExec; @@ -48,40 +48,47 @@ impl TopKAggregation { order_desc: bool, limit: usize, ) -> Option> { - // ensure the sort direction matches aggregate function - let (field, desc) = aggr.get_minmax_desc()?; - if desc != order_desc { - return None; - } - let group_key = aggr.group_expr().expr().iter().exactly_one().ok()?; - let kt = group_key.0.data_type(&aggr.input().schema()).ok()?; - if !kt.is_primitive() - && kt != DataType::Utf8 - && kt != DataType::Utf8View - && kt != DataType::LargeUtf8 - { + // Current only support single group key + let (group_key, group_key_alias) = + aggr.group_expr().expr().iter().exactly_one().ok()?; + let kt = group_key.data_type(&aggr.input().schema()).ok()?; + let vt = if let Some((field, _)) = aggr.get_minmax_desc() { + field.data_type().clone() + } else { + kt.clone() + }; + if !topk_types_supported(&kt, &vt) { return None; } if aggr.filter_expr().iter().any(|e| e.is_some()) { return None; } - // ensure the sort is on the same field as the aggregate output - if order_by != field.name() { + // Check if this is ordering by an aggregate function (MIN/MAX) + if let Some((field, desc)) = aggr.get_minmax_desc() { + // ensure the sort direction matches aggregate function + if desc != order_desc { + return None; + } + // ensure the sort is on the same field as the aggregate output + if order_by != field.name() { + return None; + } + } else if aggr.aggr_expr().is_empty() { + // This is a GROUP BY without aggregates, check if ordering is on the group key itself + if order_by != group_key_alias { + return None; + } + } else { + // Has aggregates but not MIN/MAX, or doesn't DISTINCT return None; } // We found what we want: clone, copy the limit down, and return modified node - let new_aggr = AggregateExec::try_new( - *aggr.mode(), - aggr.group_expr().clone(), - aggr.aggr_expr().to_vec(), - aggr.filter_expr().to_vec(), - Arc::clone(aggr.input()), - aggr.input_schema(), - ) - .expect("Unable to copy Aggregate!") - .with_limit(Some(limit)); + let new_aggr = AggregateExec::with_new_limit_options( + aggr, + Some(LimitOptions::new_with_order(limit, order_desc)), + ); Some(Arc::new(new_aggr)) } diff --git a/datafusion/physical-optimizer/src/update_aggr_exprs.rs b/datafusion/physical-optimizer/src/update_aggr_exprs.rs index c0aab4080da7..67127c2a238f 100644 --- a/datafusion/physical-optimizer/src/update_aggr_exprs.rs +++ b/datafusion/physical-optimizer/src/update_aggr_exprs.rs @@ -25,7 +25,9 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{Result, plan_datafusion_err}; use datafusion_physical_expr::aggregate::AggregateFunctionExpr; use datafusion_physical_expr::{EquivalenceProperties, PhysicalSortRequirement}; -use datafusion_physical_plan::aggregates::{AggregateExec, concat_slices}; +use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateInputMode, concat_slices, +}; use datafusion_physical_plan::windows::get_ordered_partition_by_indices; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; @@ -81,7 +83,7 @@ impl PhysicalOptimizerRule for OptimizeAggregateOrder { // ordering fields may be pruned out by first stage aggregates. // Hence, necessary information for proper merge is added during // the first stage to the state field, which the final stage uses. - if !aggr_exec.mode().is_first_stage() { + if aggr_exec.mode().input_mode() == AggregateInputMode::Partial { return Ok(Transformed::no(plan)); } let input = aggr_exec.input(); diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index 68e67fa018f0..6a28486cca5d 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -67,6 +67,7 @@ hashbrown = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true, features = ["use_std"] } log = { workspace = true } +num-traits = { workspace = true } parking_lot = { workspace = true } pin-project-lite = "^0.2.7" tokio = { workspace = true } @@ -97,6 +98,11 @@ name = "spill_io" harness = false name = "sort_preserving_merge" +[[bench]] +harness = false +name = "sort_merge_join" +required-features = ["test_utils"] + [[bench]] harness = false name = "aggregate_vectorized" diff --git a/datafusion/physical-plan/benches/sort_merge_join.rs b/datafusion/physical-plan/benches/sort_merge_join.rs new file mode 100644 index 000000000000..82610b2a54c2 --- /dev/null +++ b/datafusion/physical-plan/benches/sort_merge_join.rs @@ -0,0 +1,204 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Criterion benchmarks for Sort Merge Join +//! +//! These benchmarks measure the join kernel in isolation by feeding +//! pre-sorted RecordBatches directly into SortMergeJoinExec, avoiding +//! sort / scan overhead. + +use std::sync::Arc; + +use arrow::array::{Int64Array, RecordBatch, StringArray}; +use arrow::compute::SortOptions; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::NullEquality; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::expressions::col; +use datafusion_physical_plan::collect; +use datafusion_physical_plan::joins::{SortMergeJoinExec, utils::JoinOn}; +use datafusion_physical_plan::test::TestMemoryExec; +use tokio::runtime::Runtime; + +/// Build pre-sorted RecordBatches (split into ~8192-row chunks). +/// +/// Schema: (key: Int64, data: Int64, payload: Utf8) +/// +/// `key_mod` controls distinct key count: key = row_index % key_mod. +fn build_sorted_batches( + num_rows: usize, + key_mod: usize, + schema: &SchemaRef, +) -> Vec { + let mut rows: Vec<(i64, i64)> = (0..num_rows) + .map(|i| ((i % key_mod) as i64, i as i64)) + .collect(); + rows.sort(); + + let keys: Vec = rows.iter().map(|(k, _)| *k).collect(); + let data: Vec = rows.iter().map(|(_, d)| *d).collect(); + let payload: Vec = data.iter().map(|d| format!("val_{d}")).collect(); + + let batch = RecordBatch::try_new( + Arc::clone(schema), + vec![ + Arc::new(Int64Array::from(keys)), + Arc::new(Int64Array::from(data)), + Arc::new(StringArray::from(payload)), + ], + ) + .unwrap(); + + let batch_size = 8192; + let mut batches = Vec::new(); + let mut offset = 0; + while offset < batch.num_rows() { + let len = (batch.num_rows() - offset).min(batch_size); + batches.push(batch.slice(offset, len)); + offset += len; + } + batches +} + +fn make_exec( + batches: &[RecordBatch], + schema: &SchemaRef, +) -> Arc { + TestMemoryExec::try_new_exec(&[batches.to_vec()], Arc::clone(schema), None).unwrap() +} + +fn schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("key", DataType::Int64, false), + Field::new("data", DataType::Int64, false), + Field::new("payload", DataType::Utf8, false), + ])) +} + +fn do_join( + left: Arc, + right: Arc, + join_type: datafusion_common::JoinType, + rt: &Runtime, +) -> usize { + let on: JoinOn = vec![( + col("key", &left.schema()).unwrap(), + col("key", &right.schema()).unwrap(), + )]; + let join = SortMergeJoinExec::try_new( + left, + right, + on, + None, + join_type, + vec![SortOptions::default()], + NullEquality::NullEqualsNothing, + ) + .unwrap(); + + let task_ctx = Arc::new(TaskContext::default()); + rt.block_on(async { + let batches = collect(Arc::new(join), task_ctx).await.unwrap(); + batches.iter().map(|b| b.num_rows()).sum() + }) +} + +fn bench_smj(c: &mut Criterion) { + let rt = Runtime::new().unwrap(); + let s = schema(); + + let mut group = c.benchmark_group("sort_merge_join"); + + // 1:1 Inner Join — 100K rows each, unique keys + // Best case for contiguous-range optimization: every index array is [0,1,2,...]. + { + let n = 100_000; + let left_batches = build_sorted_batches(n, n, &s); + let right_batches = build_sorted_batches(n, n, &s); + group.bench_function(BenchmarkId::new("inner_1to1", n), |b| { + b.iter(|| { + let left = make_exec(&left_batches, &s); + let right = make_exec(&right_batches, &s); + do_join(left, right, datafusion_common::JoinType::Inner, &rt) + }) + }); + } + + // 1:10 Inner Join — 100K left, 100K right, 10K distinct keys + { + let n = 100_000; + let key_mod = 10_000; + let left_batches = build_sorted_batches(n, key_mod, &s); + let right_batches = build_sorted_batches(n, key_mod, &s); + group.bench_function(BenchmarkId::new("inner_1to10", n), |b| { + b.iter(|| { + let left = make_exec(&left_batches, &s); + let right = make_exec(&right_batches, &s); + do_join(left, right, datafusion_common::JoinType::Inner, &rt) + }) + }); + } + + // Left Join — 100K each, ~5% unmatched on left + { + let n = 100_000; + let left_batches = build_sorted_batches(n, n + n / 20, &s); + let right_batches = build_sorted_batches(n, n, &s); + group.bench_function(BenchmarkId::new("left_1to1_unmatched", n), |b| { + b.iter(|| { + let left = make_exec(&left_batches, &s); + let right = make_exec(&right_batches, &s); + do_join(left, right, datafusion_common::JoinType::Left, &rt) + }) + }); + } + + // Left Semi Join — 100K left, 100K right, 10K keys + { + let n = 100_000; + let key_mod = 10_000; + let left_batches = build_sorted_batches(n, key_mod, &s); + let right_batches = build_sorted_batches(n, key_mod, &s); + group.bench_function(BenchmarkId::new("left_semi_1to10", n), |b| { + b.iter(|| { + let left = make_exec(&left_batches, &s); + let right = make_exec(&right_batches, &s); + do_join(left, right, datafusion_common::JoinType::LeftSemi, &rt) + }) + }); + } + + // Left Anti Join — 100K left, 100K right, partial match + { + let n = 100_000; + let left_batches = build_sorted_batches(n, n + n / 5, &s); + let right_batches = build_sorted_batches(n, n, &s); + group.bench_function(BenchmarkId::new("left_anti_partial", n), |b| { + b.iter(|| { + let left = make_exec(&left_batches, &s); + let right = make_exec(&right_batches, &s); + do_join(left, right, datafusion_common::JoinType::LeftAnti, &rt) + }) + }); + } + + group.finish(); +} + +criterion_group!(benches, bench_smj); +criterion_main!(benches); diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs index c46cde8786eb..2b8a2cfa6889 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs @@ -128,7 +128,9 @@ where let hash = key.hash(state); let insert = self.map.entry( hash, - |&(g, _)| unsafe { self.values.get_unchecked(g).is_eq(key) }, + |&(g, h)| unsafe { + hash == h && self.values.get_unchecked(g).is_eq(key) + }, |&(_, h)| h, ); diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 06f12a90195d..85999938510b 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -33,7 +33,7 @@ use crate::filter_pushdown::{ use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::{ DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode, - SendableRecordBatchStream, Statistics, + SendableRecordBatchStream, Statistics, check_if_same_properties, }; use datafusion_common::config::ConfigOptions; use datafusion_physical_expr::utils::collect_columns; @@ -41,7 +41,7 @@ use parking_lot::Mutex; use std::collections::HashSet; use arrow::array::{ArrayRef, UInt8Array, UInt16Array, UInt32Array, UInt64Array}; -use arrow::datatypes::{Field, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow_schema::FieldRef; use datafusion_common::stats::Precision; @@ -64,6 +64,8 @@ use datafusion_physical_expr_common::sort_expr::{ use datafusion_expr::utils::AggregateOrderSensitivity; use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays; use itertools::Itertools; +use topk::hash_table::is_supported_hash_key_type; +use topk::heap::is_supported_heap_type; pub mod group_values; mod no_grouping; @@ -72,14 +74,69 @@ mod row_hash; mod topk; mod topk_stream; +/// Returns true if TopK aggregation data structures support the provided key and value types. +/// +/// This function checks whether both the key type (used for grouping) and value type +/// (used in min/max aggregation) can be handled by the TopK aggregation heap and hash table. +/// Supported types include Arrow primitives (integers, floats, decimals, intervals) and +/// UTF-8 strings (`Utf8`, `LargeUtf8`, `Utf8View`). +/// ```text +pub fn topk_types_supported(key_type: &DataType, value_type: &DataType) -> bool { + is_supported_hash_key_type(key_type) && is_supported_heap_type(value_type) +} + /// Hard-coded seed for aggregations to ensure hash values differ from `RepartitionExec`, avoiding collisions. const AGGREGATION_HASH_SEED: ahash::RandomState = ahash::RandomState::with_seeds('A' as u64, 'G' as u64, 'G' as u64, 'R' as u64); +/// Whether an aggregate stage consumes raw input data or intermediate +/// accumulator state from a previous aggregation stage. +/// +/// See the [table on `AggregateMode`](AggregateMode#variants-and-their-inputoutput-modes) +/// for how this relates to aggregate modes. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum AggregateInputMode { + /// The stage consumes raw, unaggregated input data and calls + /// [`Accumulator::update_batch`]. + Raw, + /// The stage consumes intermediate accumulator state from a previous + /// aggregation stage and calls [`Accumulator::merge_batch`]. + Partial, +} + +/// Whether an aggregate stage produces intermediate accumulator state +/// or final output values. +/// +/// See the [table on `AggregateMode`](AggregateMode#variants-and-their-inputoutput-modes) +/// for how this relates to aggregate modes. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum AggregateOutputMode { + /// The stage produces intermediate accumulator state, serialized via + /// [`Accumulator::state`]. + Partial, + /// The stage produces final output values via + /// [`Accumulator::evaluate`]. + Final, +} + /// Aggregation modes /// /// See [`Accumulator::state`] for background information on multi-phase /// aggregation and how these modes are used. +/// +/// # Variants and their input/output modes +/// +/// Each variant can be characterized by its [`AggregateInputMode`] and +/// [`AggregateOutputMode`]: +/// +/// ```text +/// | Input: Raw data | Input: Partial state +/// Output: Final values | Single, SinglePartitioned | Final, FinalPartitioned +/// Output: Partial state | Partial | PartialReduce +/// ``` +/// +/// Use [`AggregateMode::input_mode`] and [`AggregateMode::output_mode`] +/// to query these properties. #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum AggregateMode { /// One of multiple layers of aggregation, any input partitioning @@ -131,18 +188,56 @@ pub enum AggregateMode { /// This mode requires that the input has more than one partition, and is /// partitioned by group key (like FinalPartitioned). SinglePartitioned, + /// Combine multiple partial aggregations to produce a new partial + /// aggregation. + /// + /// Input is intermediate accumulator state (like Final), but output is + /// also intermediate accumulator state (like Partial). This enables + /// tree-reduce aggregation strategies where partial results from + /// multiple workers are combined in multiple stages before a final + /// evaluation. + /// + /// ```text + /// Final + /// / \ + /// PartialReduce PartialReduce + /// / \ / \ + /// Partial Partial Partial Partial + /// ``` + PartialReduce, } impl AggregateMode { - /// Checks whether this aggregation step describes a "first stage" calculation. - /// In other words, its input is not another aggregation result and the - /// `merge_batch` method will not be called for these modes. - pub fn is_first_stage(&self) -> bool { + /// Returns the [`AggregateInputMode`] for this mode: whether this + /// stage consumes raw input data or intermediate accumulator state. + /// + /// See the [table above](AggregateMode#variants-and-their-inputoutput-modes) + /// for details. + pub fn input_mode(&self) -> AggregateInputMode { match self { AggregateMode::Partial | AggregateMode::Single - | AggregateMode::SinglePartitioned => true, - AggregateMode::Final | AggregateMode::FinalPartitioned => false, + | AggregateMode::SinglePartitioned => AggregateInputMode::Raw, + AggregateMode::Final + | AggregateMode::FinalPartitioned + | AggregateMode::PartialReduce => AggregateInputMode::Partial, + } + } + + /// Returns the [`AggregateOutputMode`] for this mode: whether this + /// stage produces intermediate accumulator state or final output values. + /// + /// See the [table above](AggregateMode#variants-and-their-inputoutput-modes) + /// for details. + pub fn output_mode(&self) -> AggregateOutputMode { + match self { + AggregateMode::Final + | AggregateMode::FinalPartitioned + | AggregateMode::Single + | AggregateMode::SinglePartitioned => AggregateOutputMode::Final, + AggregateMode::Partial | AggregateMode::PartialReduce => { + AggregateOutputMode::Partial + } } } } @@ -489,19 +584,58 @@ enum DynamicFilterAggregateType { Max, } +/// Configuration for limit-based optimizations in aggregation +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct LimitOptions { + /// The maximum number of rows to return + pub limit: usize, + /// Optional ordering direction (true = descending, false = ascending) + /// This is used for TopK aggregation to maintain a priority queue with the correct ordering + pub descending: Option, +} + +impl LimitOptions { + /// Create a new LimitOptions with a limit and no specific ordering + pub fn new(limit: usize) -> Self { + Self { + limit, + descending: None, + } + } + + /// Create a new LimitOptions with a limit and ordering direction + pub fn new_with_order(limit: usize, descending: bool) -> Self { + Self { + limit, + descending: Some(descending), + } + } + + pub fn limit(&self) -> usize { + self.limit + } + + pub fn descending(&self) -> Option { + self.descending + } +} + /// Hash aggregate execution plan #[derive(Debug, Clone)] pub struct AggregateExec { /// Aggregation mode (full, partial) mode: AggregateMode, /// Group by expressions - group_by: PhysicalGroupBy, + /// [`Arc`] used for a cheap clone, which improves physical plan optimization performance. + group_by: Arc, /// Aggregate expressions - aggr_expr: Vec>, + /// The same reason to [`Arc`] it as for [`Self::group_by`]. + aggr_expr: Arc<[Arc]>, /// FILTER (WHERE clause) expression for each aggregate expression - filter_expr: Vec>>, - /// Set if the output of this aggregation is truncated by a upstream sort/limit clause - limit: Option, + /// The same reason to [`Arc`] it as for [`Self::group_by`]. + filter_expr: Arc<[Option>]>, + /// Configuration for limit-based optimizations + limit_options: Option, /// Input plan, could be a partial aggregate or the input to the aggregate pub input: Arc, /// Schema after the aggregate is applied @@ -517,7 +651,7 @@ pub struct AggregateExec { required_input_ordering: Option, /// Describes how the input is ordered relative to the group by columns input_order_mode: InputOrderMode, - cache: PlanProperties, + cache: Arc, /// During initialization, if the plan supports dynamic filtering (see [`AggrDynFilter`]), /// it is set to `Some(..)` regardless of whether it can be pushed down to a child node. /// @@ -533,19 +667,39 @@ impl AggregateExec { /// Rewrites aggregate exec with new aggregate expressions. pub fn with_new_aggr_exprs( &self, - aggr_expr: Vec>, + aggr_expr: impl Into]>>, ) -> Self { Self { - aggr_expr, + aggr_expr: aggr_expr.into(), + // clone the rest of the fields + required_input_ordering: self.required_input_ordering.clone(), + metrics: ExecutionPlanMetricsSet::new(), + input_order_mode: self.input_order_mode.clone(), + cache: Arc::clone(&self.cache), + mode: self.mode, + group_by: Arc::clone(&self.group_by), + filter_expr: Arc::clone(&self.filter_expr), + limit_options: self.limit_options, + input: Arc::clone(&self.input), + schema: Arc::clone(&self.schema), + input_schema: Arc::clone(&self.input_schema), + dynamic_filter: self.dynamic_filter.clone(), + } + } + + /// Clone this exec, overriding only the limit hint. + pub fn with_new_limit_options(&self, limit_options: Option) -> Self { + Self { + limit_options, // clone the rest of the fields required_input_ordering: self.required_input_ordering.clone(), metrics: ExecutionPlanMetricsSet::new(), input_order_mode: self.input_order_mode.clone(), - cache: self.cache.clone(), + cache: Arc::clone(&self.cache), mode: self.mode, - group_by: self.group_by.clone(), - filter_expr: self.filter_expr.clone(), - limit: self.limit, + group_by: Arc::clone(&self.group_by), + aggr_expr: Arc::clone(&self.aggr_expr), + filter_expr: Arc::clone(&self.filter_expr), input: Arc::clone(&self.input), schema: Arc::clone(&self.schema), input_schema: Arc::clone(&self.input_schema), @@ -560,12 +714,13 @@ impl AggregateExec { /// Create a new hash aggregate execution plan pub fn try_new( mode: AggregateMode, - group_by: PhysicalGroupBy, + group_by: impl Into>, aggr_expr: Vec>, filter_expr: Vec>>, input: Arc, input_schema: SchemaRef, ) -> Result { + let group_by = group_by.into(); let schema = create_schema(&input.schema(), &group_by, &aggr_expr, mode)?; let schema = Arc::new(schema); @@ -590,13 +745,16 @@ impl AggregateExec { /// the schema in such cases. fn try_new_with_schema( mode: AggregateMode, - group_by: PhysicalGroupBy, + group_by: impl Into>, mut aggr_expr: Vec>, - filter_expr: Vec>>, + filter_expr: impl Into>]>>, input: Arc, input_schema: SchemaRef, schema: SchemaRef, ) -> Result { + let group_by = group_by.into(); + let filter_expr = filter_expr.into(); + // Make sure arguments are consistent in size assert_eq_or_internal_err!( aggr_expr.len(), @@ -663,22 +821,22 @@ impl AggregateExec { &group_expr_mapping, &mode, &input_order_mode, - aggr_expr.as_slice(), + aggr_expr.as_ref(), )?; let mut exec = AggregateExec { mode, group_by, - aggr_expr, + aggr_expr: aggr_expr.into(), filter_expr, input, schema, input_schema, metrics: ExecutionPlanMetricsSet::new(), required_input_ordering, - limit: None, + limit_options: None, input_order_mode, - cache, + cache: Arc::new(cache), dynamic_filter: None, }; @@ -692,11 +850,17 @@ impl AggregateExec { &self.mode } - /// Set the `limit` of this AggExec - pub fn with_limit(mut self, limit: Option) -> Self { - self.limit = limit; + /// Set the limit options for this AggExec + pub fn with_limit_options(mut self, limit_options: Option) -> Self { + self.limit_options = limit_options; self } + + /// Get the limit options (if set) + pub fn limit_options(&self) -> Option { + self.limit_options + } + /// Grouping expressions pub fn group_expr(&self) -> &PhysicalGroupBy { &self.group_by @@ -727,11 +891,6 @@ impl AggregateExec { Arc::clone(&self.input_schema) } - /// number of rows soft limit of the AggregateExec - pub fn limit(&self) -> Option { - self.limit - } - fn execute_typed( &self, partition: usize, @@ -744,11 +903,11 @@ impl AggregateExec { } // grouping by an expression that has a sort/limit upstream - if let Some(limit) = self.limit + if let Some(config) = self.limit_options && !self.is_unordered_unfiltered_group_by_distinct() { return Ok(StreamType::GroupedPriorityQueue( - GroupedTopKAggregateStream::new(self, context, partition, limit)?, + GroupedTopKAggregateStream::new(self, context, partition, config.limit)?, )); } @@ -769,6 +928,13 @@ impl AggregateExec { /// This method qualifies the use of the LimitedDistinctAggregation rewrite rule /// on an AggregateExec. pub fn is_unordered_unfiltered_group_by_distinct(&self) -> bool { + if self + .limit_options() + .and_then(|config| config.descending) + .is_some() + { + return false; + } // ensure there is a group by if self.group_expr().is_empty() && !self.group_expr().has_grouping_set() { return false; @@ -840,14 +1006,15 @@ impl AggregateExec { // Get output partitioning: let input_partitioning = input.output_partitioning().clone(); - let output_partitioning = if mode.is_first_stage() { - // First stage aggregation will not change the output partitioning, - // but needs to respect aliases (e.g. mapping in the GROUP BY - // expression). - let input_eq_properties = input.equivalence_properties(); - input_partitioning.project(group_expr_mapping, input_eq_properties) - } else { - input_partitioning.clone() + let output_partitioning = match mode.input_mode() { + AggregateInputMode::Raw => { + // First stage aggregation will not change the output partitioning, + // but needs to respect aliases (e.g. mapping in the GROUP BY + // expression). + let input_eq_properties = input.equivalence_properties(); + input_partitioning.project(group_expr_mapping, input_eq_properties) + } + AggregateInputMode::Partial => input_partitioning.clone(), }; // TODO: Emission type and boundedness information can be enhanced here @@ -949,7 +1116,7 @@ impl AggregateExec { /// - If yes, init one inside `AggregateExec`'s `dynamic_filter` field. /// - If not supported, `self.dynamic_filter` should be kept `None` fn init_dynamic_filter(&mut self) { - if (!self.group_by.is_empty()) || (!matches!(self.mode, AggregateMode::Partial)) { + if (!self.group_by.is_empty()) || (self.mode != AggregateMode::Partial) { debug_assert!( self.dynamic_filter.is_none(), "The current operator node does not support dynamic filter" @@ -980,7 +1147,7 @@ impl AggregateExec { } else if fun_name.eq_ignore_ascii_case("max") { DynamicFilterAggregateType::Max } else { - continue; + return; }; // 2. arg should be only 1 column reference @@ -1027,6 +1194,17 @@ impl AggregateExec { _ => Precision::Absent, } } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for AggregateExec { @@ -1086,8 +1264,8 @@ impl DisplayAs for AggregateExec { .map(|agg| agg.name().to_string()) .collect(); write!(f, ", aggr=[{}]", a.join(", "))?; - if let Some(limit) = self.limit { - write!(f, ", lim=[{limit}]")?; + if let Some(config) = self.limit_options { + write!(f, ", lim=[{}]", config.limit)?; } if self.input_order_mode != InputOrderMode::Linear { @@ -1146,6 +1324,9 @@ impl DisplayAs for AggregateExec { if !a.is_empty() { writeln!(f, "aggr={}", a.join(", "))?; } + if let Some(config) = self.limit_options { + writeln!(f, "limit={}", config.limit)?; + } } } Ok(()) @@ -1162,13 +1343,13 @@ impl ExecutionPlan for AggregateExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } fn required_input_distribution(&self) -> Vec { match &self.mode { - AggregateMode::Partial => { + AggregateMode::Partial | AggregateMode::PartialReduce => { vec![Distribution::UnspecifiedDistribution] } AggregateMode::FinalPartitioned | AggregateMode::SinglePartitioned => { @@ -1205,16 +1386,18 @@ impl ExecutionPlan for AggregateExec { self: Arc, children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); + let mut me = AggregateExec::try_new_with_schema( self.mode, - self.group_by.clone(), - self.aggr_expr.clone(), - self.filter_expr.clone(), + Arc::clone(&self.group_by), + self.aggr_expr.to_vec(), + Arc::clone(&self.filter_expr), Arc::clone(&children[0]), Arc::clone(&self.input_schema), Arc::clone(&self.schema), )?; - me.limit = self.limit; + me.limit_options = self.limit_options; me.dynamic_filter = self.dynamic_filter.clone(); Ok(Arc::new(me)) @@ -1233,10 +1416,6 @@ impl ExecutionPlan for AggregateExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { let child_statistics = self.input().partition_statistics(partition)?; self.statistics_inner(&child_statistics) @@ -1326,7 +1505,7 @@ impl ExecutionPlan for AggregateExec { ); // Include self dynamic filter when it's possible - if matches!(phase, FilterPushdownPhase::Post) + if phase == FilterPushdownPhase::Post && config.optimizer.enable_aggregate_dynamic_filter_pushdown && let Some(self_dyn_filter) = &self.dynamic_filter { @@ -1349,7 +1528,9 @@ impl ExecutionPlan for AggregateExec { // If this node tried to pushdown some dynamic filter before, now we check // if the child accept the filter - if matches!(phase, FilterPushdownPhase::Post) && self.dynamic_filter.is_some() { + if phase == FilterPushdownPhase::Post + && let Some(dyn_filter) = &self.dynamic_filter + { // let child_accepts_dyn_filter = child_pushdown_result // .self_filters // .first() @@ -1370,7 +1551,6 @@ impl ExecutionPlan for AggregateExec { // So here, we try to use ref count to determine if the dynamic filter // has actually be pushed down. // Issue: - let dyn_filter = self.dynamic_filter.as_ref().unwrap(); let child_accepts_dyn_filter = Arc::strong_count(dyn_filter) > 1; if !child_accepts_dyn_filter { @@ -1397,20 +1577,17 @@ fn create_schema( let mut fields = Vec::with_capacity(group_by.num_output_exprs() + aggr_expr.len()); fields.extend(group_by.output_fields(input_schema)?); - match mode { - AggregateMode::Partial => { - // in partial mode, the fields of the accumulator's state + match mode.output_mode() { + AggregateOutputMode::Final => { + // in final mode, the field with the final result of the accumulator for expr in aggr_expr { - fields.extend(expr.state_fields()?.iter().cloned()); + fields.push(expr.field()) } } - AggregateMode::Final - | AggregateMode::FinalPartitioned - | AggregateMode::Single - | AggregateMode::SinglePartitioned => { - // in final mode, the field with the final result of the accumulator + AggregateOutputMode::Partial => { + // in partial mode, the fields of the accumulator's state for expr in aggr_expr { - fields.push(expr.field()) + fields.extend(expr.state_fields()?.iter().cloned()); } } } @@ -1450,7 +1627,7 @@ fn get_aggregate_expr_req( // If the aggregation is performing a "second stage" calculation, // then ignore the ordering requirement. Ordering requirement applies // only to the aggregation input data. - if !agg_mode.is_first_stage() { + if agg_mode.input_mode() == AggregateInputMode::Partial { return None; } @@ -1616,10 +1793,8 @@ pub fn aggregate_expressions( mode: &AggregateMode, col_idx_base: usize, ) -> Result>>> { - match mode { - AggregateMode::Partial - | AggregateMode::Single - | AggregateMode::SinglePartitioned => Ok(aggr_expr + match mode.input_mode() { + AggregateInputMode::Raw => Ok(aggr_expr .iter() .map(|agg| { let mut result = agg.expressions(); @@ -1630,8 +1805,8 @@ pub fn aggregate_expressions( result }) .collect()), - // In this mode, we build the merge expressions of the aggregation. - AggregateMode::Final | AggregateMode::FinalPartitioned => { + AggregateInputMode::Partial => { + // In merge mode, we build the merge expressions of the aggregation. let mut col_idx_base = col_idx_base; aggr_expr .iter() @@ -1679,8 +1854,15 @@ pub fn finalize_aggregation( accumulators: &mut [AccumulatorItem], mode: &AggregateMode, ) -> Result> { - match mode { - AggregateMode::Partial => { + match mode.output_mode() { + AggregateOutputMode::Final => { + // Merge the state to the final value + accumulators + .iter_mut() + .map(|accumulator| accumulator.evaluate().and_then(|v| v.to_array())) + .collect() + } + AggregateOutputMode::Partial => { // Build the vector of states accumulators .iter_mut() @@ -1694,16 +1876,6 @@ pub fn finalize_aggregation( .flatten_ok() .collect() } - AggregateMode::Final - | AggregateMode::FinalPartitioned - | AggregateMode::Single - | AggregateMode::SinglePartitioned => { - // Merge the state to the final value - accumulators - .iter_mut() - .map(|accumulator| accumulator.evaluate().and_then(|v| v.to_array())) - .collect() - } } } @@ -1803,7 +1975,6 @@ mod tests { use super::*; use crate::RecordBatchStream; - use crate::coalesce_batches::CoalesceBatchesExec; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::common; use crate::common::collect; @@ -2249,14 +2420,17 @@ mod tests { struct TestYieldingExec { /// True if this exec should yield back to runtime the first time it is polled pub yield_first: bool, - cache: PlanProperties, + cache: Arc, } impl TestYieldingExec { fn new(yield_first: bool) -> Self { let schema = some_data().0; let cache = Self::compute_properties(schema); - Self { yield_first, cache } + Self { + yield_first, + cache: Arc::new(cache), + } } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -2297,7 +2471,7 @@ mod tests { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -2326,10 +2500,6 @@ mod tests { Ok(Box::pin(stream)) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { if partition.is_some() { return Ok(Statistics::new_unknown(self.schema().as_ref())); @@ -2601,17 +2771,9 @@ mod tests { #[tokio::test] async fn run_first_last_multi_partitions() -> Result<()> { - for use_coalesce_batches in [false, true] { - for is_first_acc in [false, true] { - for spill in [false, true] { - first_last_multi_partitions( - use_coalesce_batches, - is_first_acc, - spill, - 4200, - ) - .await? - } + for is_first_acc in [false, true] { + for spill in [false, true] { + first_last_multi_partitions(is_first_acc, spill, 4200).await? } } Ok(()) @@ -2654,15 +2816,7 @@ mod tests { .map(Arc::new) } - // This function either constructs the physical plan below, - // - // "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[FIRST_VALUE(b)]", - // " CoalesceBatchesExec: target_batch_size=1024", - // " CoalescePartitionsExec", - // " AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[FIRST_VALUE(b)], ordering_mode=None", - // " DataSourceExec: partitions=4, partition_sizes=[1, 1, 1, 1]", - // - // or + // This function constructs the physical plan below, // // "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[FIRST_VALUE(b)]", // " CoalescePartitionsExec", @@ -2672,7 +2826,6 @@ mod tests { // and checks whether the function `merge_batch` works correctly for // FIRST_VALUE and LAST_VALUE functions. async fn first_last_multi_partitions( - use_coalesce_batches: bool, is_first_acc: bool, spill: bool, max_memory: usize, @@ -2720,13 +2873,8 @@ mod tests { memory_exec, Arc::clone(&schema), )?); - let coalesce = if use_coalesce_batches { - let coalesce = Arc::new(CoalescePartitionsExec::new(aggregate_exec)); - Arc::new(CoalesceBatchesExec::new(coalesce, 1024)) as Arc - } else { - Arc::new(CoalescePartitionsExec::new(aggregate_exec)) - as Arc - }; + let coalesce = Arc::new(CoalescePartitionsExec::new(aggregate_exec)) + as Arc; let aggregate_final = Arc::new(AggregateExec::try_new( AggregateMode::Final, groups, @@ -3688,4 +3836,271 @@ mod tests { } Ok(()) } + + /// Tests that when the memory pool is too small to accommodate the sort + /// reservation during spill, the error is properly propagated as + /// ResourcesExhausted rather than silently exceeding memory limits. + #[tokio::test] + async fn test_sort_reservation_fails_during_spill() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("g", DataType::Int64, false), + Field::new("a", DataType::Float64, false), + Field::new("b", DataType::Float64, false), + Field::new("c", DataType::Float64, false), + Field::new("d", DataType::Float64, false), + Field::new("e", DataType::Float64, false), + ])); + + let batches = vec![vec![ + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int64Array::from(vec![1])), + Arc::new(Float64Array::from(vec![10.0])), + Arc::new(Float64Array::from(vec![20.0])), + Arc::new(Float64Array::from(vec![30.0])), + Arc::new(Float64Array::from(vec![40.0])), + Arc::new(Float64Array::from(vec![50.0])), + ], + )?, + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int64Array::from(vec![2])), + Arc::new(Float64Array::from(vec![11.0])), + Arc::new(Float64Array::from(vec![21.0])), + Arc::new(Float64Array::from(vec![31.0])), + Arc::new(Float64Array::from(vec![41.0])), + Arc::new(Float64Array::from(vec![51.0])), + ], + )?, + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int64Array::from(vec![3])), + Arc::new(Float64Array::from(vec![12.0])), + Arc::new(Float64Array::from(vec![22.0])), + Arc::new(Float64Array::from(vec![32.0])), + Arc::new(Float64Array::from(vec![42.0])), + Arc::new(Float64Array::from(vec![52.0])), + ], + )?, + ]]; + + let scan = TestMemoryExec::try_new(&batches, Arc::clone(&schema), None)?; + + let aggr = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new( + vec![(col("g", schema.as_ref())?, "g".to_string())], + vec![], + vec![vec![false]], + false, + ), + vec![ + Arc::new( + AggregateExprBuilder::new( + avg_udaf(), + vec![col("a", schema.as_ref())?], + ) + .schema(Arc::clone(&schema)) + .alias("AVG(a)") + .build()?, + ), + Arc::new( + AggregateExprBuilder::new( + avg_udaf(), + vec![col("b", schema.as_ref())?], + ) + .schema(Arc::clone(&schema)) + .alias("AVG(b)") + .build()?, + ), + Arc::new( + AggregateExprBuilder::new( + avg_udaf(), + vec![col("c", schema.as_ref())?], + ) + .schema(Arc::clone(&schema)) + .alias("AVG(c)") + .build()?, + ), + Arc::new( + AggregateExprBuilder::new( + avg_udaf(), + vec![col("d", schema.as_ref())?], + ) + .schema(Arc::clone(&schema)) + .alias("AVG(d)") + .build()?, + ), + Arc::new( + AggregateExprBuilder::new( + avg_udaf(), + vec![col("e", schema.as_ref())?], + ) + .schema(Arc::clone(&schema)) + .alias("AVG(e)") + .build()?, + ), + ], + vec![None, None, None, None, None], + Arc::new(scan) as Arc, + Arc::clone(&schema), + )?); + + // Pool must be large enough for accumulation to start but too small for + // sort_memory after clearing. + let task_ctx = new_spill_ctx(1, 500); + let result = collect(aggr.execute(0, Arc::clone(&task_ctx))?).await; + + match &result { + Ok(_) => panic!("Expected ResourcesExhausted error but query succeeded"), + Err(e) => { + let root = e.find_root(); + assert!( + matches!(root, DataFusionError::ResourcesExhausted(_)), + "Expected ResourcesExhausted, got: {root}", + ); + let msg = root.to_string(); + assert!( + msg.contains("Failed to reserve memory for sort during spill"), + "Expected sort reservation error, got: {msg}", + ); + } + } + + Ok(()) + } + + /// Tests that PartialReduce mode: + /// 1. Accepts state as input (like Final) + /// 2. Produces state as output (like Partial) + /// 3. Can be followed by a Final stage to get the correct result + /// + /// This simulates a tree-reduce pattern: + /// Partial -> PartialReduce -> Final + #[tokio::test] + async fn test_partial_reduce_mode() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::UInt32, false), + Field::new("b", DataType::Float64, false), + ])); + + // Produce two partitions of input data + let batch1 = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(UInt32Array::from(vec![1, 2, 3])), + Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])), + ], + )?; + let batch2 = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(UInt32Array::from(vec![1, 2, 3])), + Arc::new(Float64Array::from(vec![40.0, 50.0, 60.0])), + ], + )?; + + let groups = + PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]); + let aggregates: Vec> = vec![Arc::new( + AggregateExprBuilder::new(sum_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("SUM(b)") + .build()?, + )]; + + // Step 1: Partial aggregation on partition 1 + let input1 = + TestMemoryExec::try_new_exec(&[vec![batch1]], Arc::clone(&schema), None)?; + let partial1 = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + groups.clone(), + aggregates.clone(), + vec![None], + input1, + Arc::clone(&schema), + )?); + + // Step 2: Partial aggregation on partition 2 + let input2 = + TestMemoryExec::try_new_exec(&[vec![batch2]], Arc::clone(&schema), None)?; + let partial2 = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + groups.clone(), + aggregates.clone(), + vec![None], + input2, + Arc::clone(&schema), + )?); + + // Collect partial results + let task_ctx = Arc::new(TaskContext::default()); + let partial_result1 = + crate::collect(Arc::clone(&partial1) as _, Arc::clone(&task_ctx)).await?; + let partial_result2 = + crate::collect(Arc::clone(&partial2) as _, Arc::clone(&task_ctx)).await?; + + // The partial results have state schema (group cols + accumulator state) + let partial_schema = partial1.schema(); + + // Step 3: PartialReduce — combine partial results, still producing state + let combined_input = TestMemoryExec::try_new_exec( + &[partial_result1, partial_result2], + Arc::clone(&partial_schema), + None, + )?; + // Coalesce into a single partition for the PartialReduce + let coalesced = Arc::new(CoalescePartitionsExec::new(combined_input)); + + let partial_reduce = Arc::new(AggregateExec::try_new( + AggregateMode::PartialReduce, + groups.clone(), + aggregates.clone(), + vec![None], + coalesced, + Arc::clone(&partial_schema), + )?); + + // Verify PartialReduce output schema matches Partial output schema + // (both produce state, not final values) + assert_eq!(partial_reduce.schema(), partial_schema); + + // Collect PartialReduce results + let reduce_result = + crate::collect(Arc::clone(&partial_reduce) as _, Arc::clone(&task_ctx)) + .await?; + + // Step 4: Final aggregation on the PartialReduce output + let final_input = TestMemoryExec::try_new_exec( + &[reduce_result], + Arc::clone(&partial_schema), + None, + )?; + let final_agg = Arc::new(AggregateExec::try_new( + AggregateMode::Final, + groups.clone(), + aggregates.clone(), + vec![None], + final_input, + Arc::clone(&partial_schema), + )?); + + let result = crate::collect(final_agg, Arc::clone(&task_ctx)).await?; + + // Expected: group 1 -> 10+40=50, group 2 -> 20+50=70, group 3 -> 30+60=90 + assert_snapshot!(batches_to_sort_string(&result), @r" + +---+--------+ + | a | SUM(b) | + +---+--------+ + | 1 | 50.0 | + | 2 | 70.0 | + | 3 | 90.0 | + +---+--------+ + "); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/aggregates/no_grouping.rs b/datafusion/physical-plan/src/aggregates/no_grouping.rs index a55d70ca6fb2..a7dd7c9a66cb 100644 --- a/datafusion/physical-plan/src/aggregates/no_grouping.rs +++ b/datafusion/physical-plan/src/aggregates/no_grouping.rs @@ -18,8 +18,9 @@ //! Aggregate without grouping columns use crate::aggregates::{ - AccumulatorItem, AggrDynFilter, AggregateMode, DynamicFilterAggregateType, - aggregate_expressions, create_accumulators, finalize_aggregation, + AccumulatorItem, AggrDynFilter, AggregateInputMode, AggregateMode, + DynamicFilterAggregateType, aggregate_expressions, create_accumulators, + finalize_aggregation, }; use crate::metrics::{BaselineMetrics, RecordOutput}; use crate::{RecordBatchStream, SendableRecordBatchStream}; @@ -61,7 +62,7 @@ struct AggregateStreamInner { mode: AggregateMode, input: SendableRecordBatchStream, aggregate_expressions: Vec>>, - filter_expressions: Vec>>, + filter_expressions: Arc<[Option>]>, // ==== Runtime States/Buffers ==== accumulators: Vec, @@ -160,6 +161,8 @@ impl AggregateStreamInner { return Ok(()); }; + let mut bounds_changed = false; + for acc_info in &filter_state.supported_accumulators_info { let acc = self.accumulators @@ -175,20 +178,27 @@ impl AggregateStreamInner { let current_bound = acc.evaluate()?; { let mut bound = acc_info.shared_bound.lock(); - match acc_info.aggr_type { + let new_bound = match acc_info.aggr_type { DynamicFilterAggregateType::Max => { - *bound = scalar_max(&bound, ¤t_bound)?; + scalar_max(&bound, ¤t_bound)? } DynamicFilterAggregateType::Min => { - *bound = scalar_min(&bound, ¤t_bound)?; + scalar_min(&bound, ¤t_bound)? } + }; + if new_bound != *bound { + *bound = new_bound; + bounds_changed = true; } } } - // Step 2: Sync the dynamic filter physical expression with reader - let predicate = self.build_dynamic_filter_from_accumulator_bounds()?; - filter_state.filter.update(predicate)?; + // Step 2: Sync the dynamic filter physical expression with reader, + // but only if any bound actually changed. + if bounds_changed { + let predicate = self.build_dynamic_filter_from_accumulator_bounds()?; + filter_state.filter.update(predicate)?; + } Ok(()) } @@ -276,19 +286,15 @@ impl AggregateStream { partition: usize, ) -> Result { let agg_schema = Arc::clone(&agg.schema); - let agg_filter_expr = agg.filter_expr.clone(); + let agg_filter_expr = Arc::clone(&agg.filter_expr); let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition); let input = agg.input.execute(partition, Arc::clone(context))?; let aggregate_expressions = aggregate_expressions(&agg.aggr_expr, &agg.mode, 0)?; - let filter_expressions = match agg.mode { - AggregateMode::Partial - | AggregateMode::Single - | AggregateMode::SinglePartitioned => agg_filter_expr, - AggregateMode::Final | AggregateMode::FinalPartitioned => { - vec![None; agg.aggr_expr.len()] - } + let filter_expressions = match agg.mode.input_mode() { + AggregateInputMode::Raw => agg_filter_expr, + AggregateInputMode::Partial => vec![None; agg.aggr_expr.len()].into(), }; let accumulators = create_accumulators(&agg.aggr_expr)?; @@ -455,13 +461,9 @@ fn aggregate_batch( // 1.4 let size_pre = accum.size(); - let res = match mode { - AggregateMode::Partial - | AggregateMode::Single - | AggregateMode::SinglePartitioned => accum.update_batch(&values), - AggregateMode::Final | AggregateMode::FinalPartitioned => { - accum.merge_batch(&values) - } + let res = match mode.input_mode() { + AggregateInputMode::Raw => accum.update_batch(&values), + AggregateInputMode::Partial => accum.merge_batch(&values), }; let size_post = accum.size(); allocated += size_post.saturating_sub(size_pre); diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 1ae720271111..8a45e4b503d5 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -26,13 +26,12 @@ use super::order::GroupOrdering; use crate::aggregates::group_values::{GroupByMetrics, GroupValues, new_group_values}; use crate::aggregates::order::GroupOrderingFull; use crate::aggregates::{ - AggregateMode, PhysicalGroupBy, create_schema, evaluate_group_by, evaluate_many, - evaluate_optional, + AggregateInputMode, AggregateMode, AggregateOutputMode, PhysicalGroupBy, + create_schema, evaluate_group_by, evaluate_many, evaluate_optional, }; use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput}; -use crate::sorts::sort::sort_batch; use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; -use crate::spill::spill_manager::SpillManager; +use crate::spill::spill_manager::{GetSlicedSize, SpillManager}; use crate::{PhysicalExpr, aggregates, metrics}; use crate::{RecordBatchStream, SendableRecordBatchStream}; @@ -40,7 +39,7 @@ use arrow::array::*; use arrow::datatypes::SchemaRef; use datafusion_common::{ DataFusionError, Result, assert_eq_or_internal_err, assert_or_internal_err, - internal_err, + internal_err, resources_datafusion_err, }; use datafusion_execution::TaskContext; use datafusion_execution::memory_pool::proxy::VecAllocExt; @@ -51,7 +50,9 @@ use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{GroupsAccumulatorAdapter, PhysicalSortExpr}; use datafusion_physical_expr_common::sort_expr::LexOrdering; +use crate::sorts::IncrementalSortIterator; use datafusion_common::instant::Instant; +use datafusion_common::utils::memory::get_record_batch_memory_size; use futures::ready; use futures::stream::{Stream, StreamExt}; use log::debug; @@ -377,10 +378,10 @@ pub(crate) struct GroupedHashAggregateStream { /// /// For example, for an aggregate like `SUM(x) FILTER (WHERE x >= 100)`, /// the filter expression is `x > 100`. - filter_expressions: Vec>>, + filter_expressions: Arc<[Option>]>, /// GROUP BY expressions - group_by: PhysicalGroupBy, + group_by: Arc, /// max rows in output RecordBatches batch_size: usize, @@ -465,8 +466,8 @@ impl GroupedHashAggregateStream { ) -> Result { debug!("Creating GroupedHashAggregateStream"); let agg_schema = Arc::clone(&agg.schema); - let agg_group_by = agg.group_by.clone(); - let agg_filter_expr = agg.filter_expr.clone(); + let agg_group_by = Arc::clone(&agg.group_by); + let agg_filter_expr = Arc::clone(&agg.filter_expr); let batch_size = context.session_config().batch_size(); let input = agg.input.execute(partition, Arc::clone(context))?; @@ -475,7 +476,7 @@ impl GroupedHashAggregateStream { let timer = baseline_metrics.elapsed_compute().timer(); - let aggregate_exprs = agg.aggr_expr.clone(); + let aggregate_exprs = Arc::clone(&agg.aggr_expr); // arguments for each aggregate, one vec of expressions per // aggregate @@ -491,13 +492,9 @@ impl GroupedHashAggregateStream { agg_group_by.num_group_exprs(), )?; - let filter_expressions = match agg.mode { - AggregateMode::Partial - | AggregateMode::Single - | AggregateMode::SinglePartitioned => agg_filter_expr, - AggregateMode::Final | AggregateMode::FinalPartitioned => { - vec![None; agg.aggr_expr.len()] - } + let filter_expressions = match agg.mode.input_mode() { + AggregateInputMode::Raw => agg_filter_expr, + AggregateInputMode::Partial => vec![None; agg.aggr_expr.len()].into(), }; // Instantiate the accumulators @@ -679,7 +676,7 @@ impl GroupedHashAggregateStream { group_ordering, input_done: false, spill_state, - group_values_soft_limit: agg.limit, + group_values_soft_limit: agg.limit_options().map(|config| config.limit()), skip_aggregation_probe, reduction_factor, }) @@ -982,29 +979,24 @@ impl GroupedHashAggregateStream { // Call the appropriate method on each aggregator with // the entire input row and the relevant group indexes - match self.mode { - AggregateMode::Partial - | AggregateMode::Single - | AggregateMode::SinglePartitioned - if !self.spill_state.is_stream_merging => - { - acc.update_batch( - values, - group_indices, - opt_filter, - total_num_groups, - )?; - } - _ => { - assert_or_internal_err!( - opt_filter.is_none(), - "aggregate filter should be applied in partial stage, there should be no filter in final stage" - ); - - // if aggregation is over intermediate states, - // use merge - acc.merge_batch(values, group_indices, None, total_num_groups)?; - } + if self.mode.input_mode() == AggregateInputMode::Raw + && !self.spill_state.is_stream_merging + { + acc.update_batch( + values, + group_indices, + opt_filter, + total_num_groups, + )?; + } else { + assert_or_internal_err!( + opt_filter.is_none(), + "aggregate filter should be applied in partial stage, there should be no filter in final stage" + ); + + // if aggregation is over intermediate states, + // use merge + acc.merge_batch(values, group_indices, None, total_num_groups)?; } self.group_by_metrics .aggregation_time @@ -1045,7 +1037,19 @@ impl GroupedHashAggregateStream { self.group_values.len() }; - if let Some(batch) = self.emit(EmitTo::First(n), false)? { + // Clamp to the sort boundary when using partial group ordering, + // otherwise remove_groups panics (#20445). + let n = match &self.group_ordering { + GroupOrdering::None => n, + _ => match self.group_ordering.emit_to() { + Some(EmitTo::First(max)) => n.min(max), + _ => 0, + }, + }; + + if n > 0 + && let Some(batch) = self.emit(EmitTo::First(n), false)? + { Ok(Some(ExecutionState::ProducingOutput(batch))) } else { Err(oom) @@ -1057,10 +1061,27 @@ impl GroupedHashAggregateStream { fn update_memory_reservation(&mut self) -> Result<()> { let acc = self.accumulators.iter().map(|x| x.size()).sum::(); - let new_size = acc + let groups_and_acc_size = acc + self.group_values.size() + self.group_ordering.size() + self.current_group_indices.allocated_size(); + + // Reserve extra headroom for sorting during potential spill. + // When OOM triggers, group_aggregate_batch has already processed the + // latest input batch, so the internal state may have grown well beyond + // the last successful reservation. The emit batch reflects this larger + // actual state, and the sort needs memory proportional to it. + // By reserving headroom equal to the data size, we trigger OOM earlier + // (before too much data accumulates), ensuring the freed reservation + // after clear_shrink is sufficient to cover the sort memory. + let sort_headroom = + if self.oom_mode == OutOfMemoryMode::Spill && !self.group_values.is_empty() { + acc + self.group_values.size() + } else { + 0 + }; + + let new_size = groups_and_acc_size + sort_headroom; let reservation_result = self.reservation.try_resize(new_size); if reservation_result.is_ok() { @@ -1092,17 +1113,12 @@ impl GroupedHashAggregateStream { // Next output each aggregate value for acc in self.accumulators.iter_mut() { - match self.mode { - AggregateMode::Partial => output.extend(acc.state(emit_to)?), - _ if spilling => { - // If spilling, output partial state because the spilled data will be - // merged and re-evaluated later. - output.extend(acc.state(emit_to)?) - } - AggregateMode::Final - | AggregateMode::FinalPartitioned - | AggregateMode::Single - | AggregateMode::SinglePartitioned => output.push(acc.evaluate(emit_to)?), + if self.mode.output_mode() == AggregateOutputMode::Final && !spilling { + output.push(acc.evaluate(emit_to)?) + } else { + // Output partial state: either because we're in a non-final mode, + // or because we're spilling and will merge/re-evaluate later. + output.extend(acc.state(emit_to)?) } } drop(timer); @@ -1124,17 +1140,47 @@ impl GroupedHashAggregateStream { let Some(emit) = self.emit(EmitTo::All, true)? else { return Ok(()); }; - let sorted = sort_batch(&emit, &self.spill_state.spill_expr, None)?; - // Spill sorted state to disk + // Free accumulated state now that data has been emitted into `emit`. + // This must happen before reserving sort memory so the pool has room. + // Use 0 to minimize allocated capacity and maximize memory available for sorting. + self.clear_shrink(0); + self.update_memory_reservation()?; + + let batch_size_ratio = self.batch_size as f32 / emit.num_rows() as f32; + let batch_memory = get_record_batch_memory_size(&emit); + // The maximum worst case for a sort is 2X the original underlying buffers(regardless of slicing) + // First we get the underlying buffers' size, then we get the sliced("actual") size of the batch, + // and multiply it by the ratio of batch_size to actual size to get the estimated memory needed for sorting the batch. + // If something goes wrong in get_sliced_size()(double counting or something), + // we fall back to the worst case. + let sort_memory = (batch_memory + + (emit.get_sliced_size()? as f32 * batch_size_ratio) as usize) + .min(batch_memory * 2); + + // If we can't grow even that, we have no choice but to return an error since we can't spill to disk without sorting the data first. + self.reservation.try_grow(sort_memory).map_err(|err| { + resources_datafusion_err!( + "Failed to reserve memory for sort during spill: {err}" + ) + })?; + + let sorted_iter = IncrementalSortIterator::new( + emit, + self.spill_state.spill_expr.clone(), + self.batch_size, + ); let spillfile = self .spill_state .spill_manager - .spill_record_batch_by_size_and_return_max_batch_memory( - &sorted, + .spill_record_batch_iter_and_return_max_batch_memory( + sorted_iter, "HashAggSpill", - self.batch_size, )?; + + // Shrink the memory we allocated for sorting as the sorting is fully done at this point. + self.reservation.shrink(sort_memory); + match spillfile { Some((spillfile, max_record_batch_memory)) => { self.spill_state.spills.push(SortedSpillFile { @@ -1152,14 +1198,14 @@ impl GroupedHashAggregateStream { Ok(()) } - /// Clear memory and shirk capacities to the size of the batch. + /// Clear memory and shrink capacities to the given number of rows. fn clear_shrink(&mut self, num_rows: usize) { self.group_values.clear_shrink(num_rows); self.current_group_indices.clear(); self.current_group_indices.shrink_to(num_rows); } - /// Clear memory and shirk capacities to zero. + /// Clear memory and shrink capacities to zero. fn clear_all(&mut self) { self.clear_shrink(0); } @@ -1198,7 +1244,7 @@ impl GroupedHashAggregateStream { // instead. // Spilling to disk and reading back also ensures batch size is consistent // rather than potentially having one significantly larger last batch. - self.spill()?; // TODO: use sort_batch_chunked instead? + self.spill()?; // Mark that we're switching to stream merging mode. self.spill_state.is_stream_merging = true; @@ -1221,6 +1267,18 @@ impl GroupedHashAggregateStream { // on the grouping columns. self.group_ordering = GroupOrdering::Full(GroupOrderingFull::new()); + // Recreate group_values to use streaming mode (GroupValuesColumn + // with scalarized_intern) which preserves input row order, as required + // by GroupOrderingFull. This is only needed for multi-column group by, + // since single-column uses GroupValuesPrimitive which is always safe. + let group_schema = self + .spill_state + .merging_group_by + .group_schema(&self.spill_state.spill_schema)?; + if group_schema.fields().len() > 1 { + self.group_values = new_group_values(group_schema, &self.group_ordering)?; + } + // Use `OutOfMemoryMode::ReportError` from this point on // to ensure we don't spill the spilled data to disk again. self.oom_mode = OutOfMemoryMode::ReportError; @@ -1305,6 +1363,7 @@ impl GroupedHashAggregateStream { #[cfg(test)] mod tests { use super::*; + use crate::InputOrderMode; use crate::execution_plan::ExecutionPlan; use crate::test::TestMemoryExec; use arrow::array::{Int32Array, Int64Array}; @@ -1567,4 +1626,88 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_emit_early_with_partially_sorted() -> Result<()> { + // Reproducer for #20445: EmitEarly with PartiallySorted panics in + // remove_groups because it emits more groups than the sort boundary. + let schema = Arc::new(Schema::new(vec![ + Field::new("sort_col", DataType::Int32, false), + Field::new("group_col", DataType::Int32, false), + Field::new("value_col", DataType::Int64, false), + ])); + + // All rows share sort_col=1 (no sort boundary), with unique group_col + // values to create many groups and trigger memory pressure. + let n = 256; + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1; n])), + Arc::new(Int32Array::from((0..n as i32).collect::>())), + Arc::new(Int64Array::from(vec![1; n])), + ], + )?; + + let runtime = RuntimeEnvBuilder::default() + .with_memory_limit(4096, 1.0) + .build_arc()?; + let mut task_ctx = TaskContext::default().with_runtime(runtime); + let mut cfg = task_ctx.session_config().clone(); + cfg = cfg.set( + "datafusion.execution.batch_size", + &datafusion_common::ScalarValue::UInt64(Some(128)), + ); + cfg = cfg.set( + "datafusion.execution.skip_partial_aggregation_probe_rows_threshold", + &datafusion_common::ScalarValue::UInt64(Some(u64::MAX)), + ); + task_ctx = task_ctx.with_session_config(cfg); + let task_ctx = Arc::new(task_ctx); + + let ordering = LexOrdering::new(vec![PhysicalSortExpr::new_default(Arc::new( + Column::new("sort_col", 0), + ) + as _)]) + .unwrap(); + let exec = TestMemoryExec::try_new(&[vec![batch]], Arc::clone(&schema), None)? + .try_with_sort_information(vec![ordering])?; + let exec = Arc::new(TestMemoryExec::update_cache(&Arc::new(exec))); + + // GROUP BY sort_col, group_col with input sorted on sort_col + // gives PartiallySorted([0]) + let aggregate_exec = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(vec![ + (col("sort_col", &schema)?, "sort_col".to_string()), + (col("group_col", &schema)?, "group_col".to_string()), + ]), + vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("value_col", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("count_value") + .build()?, + )], + vec![None], + exec, + Arc::clone(&schema), + )?; + assert!(matches!( + aggregate_exec.input_order_mode(), + InputOrderMode::PartiallySorted(_) + )); + + // Must not panic with "assertion failed: *current_sort >= n" + let mut stream = GroupedHashAggregateStream::new(&aggregate_exec, &task_ctx, 0)?; + while let Some(result) = stream.next().await { + if let Err(e) = result { + if e.to_string().contains("Resources exhausted") { + break; + } + return Err(e); + } + } + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs index 4a3f3ac258f9..418ec49ddd71 100644 --- a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs +++ b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs @@ -72,6 +72,19 @@ pub trait ArrowHashTable { fn find_or_insert(&mut self, row_idx: usize, replace_idx: usize) -> (usize, bool); } +/// Returns true if the given data type can be used as a top-K aggregation hash key. +/// +/// Supported types include Arrow primitives (integers, floats, decimals, intervals) +/// and UTF-8 strings (`Utf8`, `LargeUtf8`, `Utf8View`). This is used internally by +/// `PriorityMap::supports()` to validate grouping key type compatibility. +pub fn is_supported_hash_key_type(kt: &DataType) -> bool { + kt.is_primitive() + || matches!( + kt, + DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 + ) +} + // An implementation of ArrowHashTable for String keys pub struct StringHashTable { owned: ArrayRef, @@ -108,6 +121,34 @@ impl StringHashTable { data_type, } } + + /// Extracts the string value at the given row index, handling nulls and different string types. + /// + /// Returns `None` if the value is null, otherwise `Some(value.to_string())`. + fn extract_string_value(&self, row_idx: usize) -> Option { + let is_null_and_value = match self.data_type { + DataType::Utf8 => { + let arr = self.owned.as_string::(); + (arr.is_null(row_idx), arr.value(row_idx)) + } + DataType::LargeUtf8 => { + let arr = self.owned.as_string::(); + (arr.is_null(row_idx), arr.value(row_idx)) + } + DataType::Utf8View => { + let arr = self.owned.as_string_view(); + (arr.is_null(row_idx), arr.value(row_idx)) + } + _ => panic!("Unsupported data type"), + }; + + let (is_null, value) = is_null_and_value; + if is_null { + None + } else { + Some(value.to_string()) + } + } } impl ArrowHashTable for StringHashTable { @@ -138,63 +179,15 @@ impl ArrowHashTable for StringHashTable { } fn find_or_insert(&mut self, row_idx: usize, replace_idx: usize) -> (usize, bool) { - let id = match self.data_type { - DataType::Utf8 => { - let ids = self - .owned - .as_any() - .downcast_ref::() - .expect("Expected StringArray for DataType::Utf8"); - if ids.is_null(row_idx) { - None - } else { - Some(ids.value(row_idx)) - } - } - DataType::LargeUtf8 => { - let ids = self - .owned - .as_any() - .downcast_ref::() - .expect("Expected LargeStringArray for DataType::LargeUtf8"); - if ids.is_null(row_idx) { - None - } else { - Some(ids.value(row_idx)) - } - } - DataType::Utf8View => { - let ids = self - .owned - .as_any() - .downcast_ref::() - .expect("Expected StringViewArray for DataType::Utf8View"); - if ids.is_null(row_idx) { - None - } else { - Some(ids.value(row_idx)) - } - } - _ => panic!("Unsupported data type"), - }; - - // TODO: avoid double lookup by using entry API - - let hash = self.rnd.hash_one(id); - if let Some(map_idx) = self - .map - .find(hash, |mi| id == mi.as_ref().map(|id| id.as_str())) - { - return (map_idx, false); - } + let id = self.extract_string_value(row_idx); - // we're full and this is a better value, so remove the worst - let heap_idx = self.map.remove_if_full(replace_idx); + // Compute hash and create equality closure for hash table lookup. + let hash = self.rnd.hash_one(id.as_deref()); + let id_for_eq = id.clone(); + let eq = move |mi: &Option| id_for_eq.as_deref() == mi.as_deref(); - // add the new group - let id = id.map(|id| id.to_string()); - let map_idx = self.map.insert(hash, &id, heap_idx); - (map_idx, true) + // Use entry API to avoid double lookup + self.map.find_or_insert(hash, id, replace_idx, eq) } } @@ -260,19 +253,12 @@ where } else { Some(ids.value(row_idx)) }; - + // Compute hash and create equality closure for hash table lookup. let hash: u64 = id.hash(&self.rnd); - // TODO: avoid double lookup by using entry API - if let Some(map_idx) = self.map.find(hash, |mi| id == *mi) { - return (map_idx, false); - } - - // we're full and this is a better value, so remove the worst - let heap_idx = self.map.remove_if_full(replace_idx); + let eq = |mi: &Option| id == *mi; - // add the new group - let map_idx = self.map.insert(hash, &id, heap_idx); - (map_idx, true) + // Use entry API to avoid double lookup + self.map.find_or_insert(hash, id, replace_idx, eq) } } @@ -287,11 +273,6 @@ impl TopKHashTable { } } - pub fn find(&self, hash: u64, mut eq: impl FnMut(&ID) -> bool) -> Option { - let eq = |&idx: &usize| eq(&self.store[idx].as_ref().unwrap().id); - self.map.find(hash, eq).copied() - } - pub fn heap_idx_at(&self, map_idx: usize) -> usize { self.store[map_idx].as_ref().unwrap().heap_idx } @@ -324,8 +305,27 @@ impl TopKHashTable { } } - pub fn insert(&mut self, hash: u64, id: &ID, heap_idx: usize) -> usize { - let mi = HashTableItem::new(hash, id.clone(), heap_idx); + /// Find an existing entry or insert a new one, avoiding double hash table lookup. + /// Returns (map_idx, is_new) where is_new indicates if this was a new insertion. + /// If inserting a new entry and the table is full, replaces the entry at replace_idx. + pub fn find_or_insert( + &mut self, + hash: u64, + id: ID, + replace_idx: usize, + mut eq: impl FnMut(&ID) -> bool, + ) -> (usize, bool) { + // Check if entry exists - this is the only hash table lookup + { + let eq_fn = |idx: &usize| eq(&self.store[*idx].as_ref().unwrap().id); + if let Some(&map_idx) = self.map.find(hash, eq_fn) { + return (map_idx, false); + } + } + + // Entry doesn't exist - compute heap_idx and prepare item + let heap_idx = self.remove_if_full(replace_idx); + let mi = HashTableItem::new(hash, id, heap_idx); let store_idx = if let Some(idx) = self.free_index.take() { self.store[idx] = Some(mi); idx @@ -334,19 +334,15 @@ impl TopKHashTable { self.store.len() - 1 }; + // Reserve space if needed let hasher = |idx: &usize| self.store[*idx].as_ref().unwrap().hash; if self.map.len() == self.map.capacity() { self.map.reserve(self.limit, hasher); } - let eq_fn = |idx: &usize| self.store[*idx].as_ref().unwrap().id == *id; - match self.map.entry(hash, eq_fn, hasher) { - Entry::Occupied(_) => unreachable!("Item should not exist"), - Entry::Vacant(vacant) => { - vacant.insert(store_idx); - } - } - store_idx + // Insert without checking again since we already confirmed it doesn't exist + self.map.insert_unique(hash, store_idx, hasher); + (store_idx, true) } pub fn len(&self) -> usize { @@ -449,15 +445,29 @@ mod tests { #[test] fn should_resize_properly() -> Result<()> { let mut heap_to_map = BTreeMap::::new(); + // Create TopKHashTable with limit=5 and capacity=3 to force resizing let mut map = TopKHashTable::>::new(5, 3); - for (heap_idx, id) in vec!["1", "2", "3", "4", "5"].into_iter().enumerate() { + + // Insert 5 entries, tracking the heap-to-map index mapping + for (heap_idx, id) in ["1", "2", "3", "4", "5"].iter().enumerate() { + let value = Some(id.to_string()); let hash = heap_idx as u64; - let map_idx = map.insert(hash, &Some(id.to_string()), heap_idx); - let _ = heap_to_map.insert(heap_idx, map_idx); + let (map_idx, is_new) = + map.find_or_insert(hash, value.clone(), heap_idx, |v| *v == value); + assert!(is_new, "Entry should be new"); + heap_to_map.insert(heap_idx, map_idx); } + // Verify all 5 entries are present + assert_eq!(map.len(), 5); + + // Verify that the hash table resized properly (capacity should have grown beyond 3) + // This is implicit - if it didn't resize, insertions would have failed or been slow + + // Drain all values in heap order let (_heap_idxs, map_idxs): (Vec<_>, Vec<_>) = heap_to_map.into_iter().unzip(); let ids = map.take_all(map_idxs); + assert_eq!( format!("{ids:?}"), r#"[Some("1"), Some("2"), Some("3"), Some("4"), Some("5")]"# diff --git a/datafusion/physical-plan/src/aggregates/topk/heap.rs b/datafusion/physical-plan/src/aggregates/topk/heap.rs index b4569c3d0811..9f0b697ccabe 100644 --- a/datafusion/physical-plan/src/aggregates/topk/heap.rs +++ b/datafusion/physical-plan/src/aggregates/topk/heap.rs @@ -15,10 +15,18 @@ // specific language governing permissions and limitations // under the License. -//! A custom binary heap implementation for performant top K aggregation +//! A custom binary heap implementation for performant top K aggregation. +//! +//! the `new_heap` //! factory function selects an appropriate heap implementation +//! based on the Arrow data type. +//! +//! Supported value types include Arrow primitives (integers, floats, decimals, intervals) +//! and UTF-8 strings (`Utf8`, `LargeUtf8`, `Utf8View`) using lexicographic ordering. use arrow::array::{ArrayRef, ArrowPrimitiveType, PrimitiveArray, downcast_primitive}; +use arrow::array::{LargeStringBuilder, StringBuilder, StringViewBuilder}; use arrow::array::{ + StringArray, cast::AsArray, types::{IntervalDayTime, IntervalMonthDayNano}, }; @@ -156,6 +164,164 @@ where } } +/// An implementation of `ArrowHeap` that deals with string values. +/// +/// Supports all three UTF-8 string types: `Utf8`, `LargeUtf8`, and `Utf8View`. +/// String values are compared lexicographically using the compare-first pattern: +/// borrowed strings are compared before allocation, and only allocated when the +/// heap confirms they improve the top-K set. +/// +pub struct StringHeap { + batch: ArrayRef, + heap: TopKHeap>, + desc: bool, + data_type: DataType, +} + +impl StringHeap { + pub fn new(limit: usize, desc: bool, data_type: DataType) -> Self { + let batch: ArrayRef = Arc::new(StringArray::from(Vec::<&str>::new())); + Self { + batch, + heap: TopKHeap::new(limit, desc), + desc, + data_type, + } + } + + /// Extracts a string value from the current batch at the given row index. + /// + /// Panics if the row index is out of bounds or if the data type is not one of + /// the supported UTF-8 string types. + /// + /// Note: Null values should not appear in the input; the aggregation layer + /// ensures nulls are filtered before reaching this code. + fn value(&self, row_idx: usize) -> &str { + extract_string_value(&self.batch, &self.data_type, row_idx) + } +} + +/// Helper to extract a string value from an ArrayRef at a given index. +/// +/// Supports `Utf8`, `LargeUtf8`, and `Utf8View` data types. +/// +/// # Panics +/// Panics if the index is out of bounds or if the data type is unsupported. +fn extract_string_value<'a>( + batch: &'a ArrayRef, + data_type: &DataType, + idx: usize, +) -> &'a str { + match data_type { + DataType::Utf8 => batch.as_string::().value(idx), + DataType::LargeUtf8 => batch.as_string::().value(idx), + DataType::Utf8View => batch.as_string_view().value(idx), + _ => unreachable!("Unsupported string type: {:?}", data_type), + } +} + +impl ArrowHeap for StringHeap { + fn set_batch(&mut self, vals: ArrayRef) { + self.batch = vals; + } + + fn is_worse(&self, row_idx: usize) -> bool { + if !self.heap.is_full() { + return false; + } + // Compare borrowed `&str` against the worst heap value first to avoid + // allocating a `String` unless this row would actually replace an + // existing heap entry. + let new_val = self.value(row_idx); + let worst_val = self.heap.worst_val().expect("Missing root"); + match worst_val { + None => false, + Some(worst_str) => { + (!self.desc && new_val > worst_str.as_str()) + || (self.desc && new_val < worst_str.as_str()) + } + } + } + + fn worst_map_idx(&self) -> usize { + self.heap.worst_map_idx() + } + + fn insert(&mut self, row_idx: usize, map_idx: usize, map: &mut Vec<(usize, usize)>) { + // When appending (heap not full) we must allocate to own the string + // because it will be stored in the heap. For replacements we avoid + // allocation until `replace_if_better` confirms a replacement is + // necessary. + let new_str = self.value(row_idx).to_string(); + let new_val = Some(new_str); + self.heap.append_or_replace(new_val, map_idx, map); + } + + fn replace_if_better( + &mut self, + heap_idx: usize, + row_idx: usize, + map: &mut Vec<(usize, usize)>, + ) { + let new_str = self.value(row_idx); + let existing = self.heap.heap[heap_idx] + .as_ref() + .expect("Missing heap item"); + + // Compare borrowed reference first—no allocation yet. + // We compare the borrowed `&str` with the stored `Option` and + // only allocate (`to_string()`) when a replacement is required. + match &existing.val { + None => { + // Existing is null; new value always wins + let new_val = Some(new_str.to_string()); + self.heap.replace_if_better(heap_idx, new_val, map); + } + Some(existing_str) => { + // Compare borrowed strings first + if (!self.desc && new_str < existing_str.as_str()) + || (self.desc && new_str > existing_str.as_str()) + { + let new_val = Some(new_str.to_string()); + self.heap.replace_if_better(heap_idx, new_val, map); + } + // Else: no improvement, no allocation + } + } + } + + fn drain(&mut self) -> (ArrayRef, Vec) { + let (vals, map_idxs) = self.heap.drain(); + // Use Arrow builders to safely construct arrays from the owned + // `Option` values. Builders avoid needing to maintain + // references to temporary storage. + + // Macro to eliminate duplication across string builder types. + // All three builders share the same interface for append_value, + // append_null, and finish, differing only in their concrete types. + macro_rules! build_string_array { + ($builder_type:ty) => {{ + let mut builder = <$builder_type>::new(); + for val in vals { + match val { + Some(s) => builder.append_value(&s), + None => builder.append_null(), + } + } + Arc::new(builder.finish()) + }}; + } + + let arr: ArrayRef = match self.data_type { + DataType::Utf8 => build_string_array!(StringBuilder), + DataType::LargeUtf8 => build_string_array!(LargeStringBuilder), + DataType::Utf8View => build_string_array!(StringViewBuilder), + _ => unreachable!("Unsupported string type: {:?}", self.data_type), + }; + (arr, map_idxs) + } +} + impl TopKHeap { pub fn new(limit: usize, desc: bool) -> Self { Self { @@ -438,11 +604,31 @@ compare_integer!(u8, u16, u32, u64); compare_integer!(IntervalDayTime, IntervalMonthDayNano); compare_float!(f16, f32, f64); +/// Returns true if the given data type can be stored in a top-K aggregation heap. +/// +/// Supported types include Arrow primitives (integers, floats, decimals, intervals) +/// and UTF-8 strings (`Utf8`, `LargeUtf8`, `Utf8View`). This is used internally by +/// `PriorityMap::supports()` to validate aggregate value type compatibility. +pub fn is_supported_heap_type(vt: &DataType) -> bool { + vt.is_primitive() + || matches!( + vt, + DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 + ) +} + pub fn new_heap( limit: usize, desc: bool, vt: DataType, ) -> Result> { + if matches!( + vt, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View + ) { + return Ok(Box::new(StringHeap::new(limit, desc, vt))); + } + macro_rules! downcast_helper { ($vt:ty, $d:ident) => { return Ok(Box::new(PrimitiveHeap::<$vt>::new(limit, desc, vt))) @@ -454,7 +640,9 @@ pub fn new_heap( _ => {} } - Err(exec_datafusion_err!("Can't group type: {vt:?}")) + Err(exec_datafusion_err!( + "Unsupported TopK aggregate value type: {vt:?}" + )) } #[cfg(test)] diff --git a/datafusion/physical-plan/src/aggregates/topk/priority_map.rs b/datafusion/physical-plan/src/aggregates/topk/priority_map.rs index 8e093d213e78..c74b648d373c 100644 --- a/datafusion/physical-plan/src/aggregates/topk/priority_map.rs +++ b/datafusion/physical-plan/src/aggregates/topk/priority_map.rs @@ -373,6 +373,102 @@ mod tests { Ok(()) } + #[test] + fn should_track_lexicographic_min_utf8_value() -> Result<()> { + let ids: ArrayRef = Arc::new(Int64Array::from(vec![1, 1])); + let vals: ArrayRef = Arc::new(StringArray::from(vec!["zulu", "alpha"])); + let mut agg = PriorityMap::new(DataType::Int64, DataType::Utf8, 1, false)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + agg.insert(1)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema_value(DataType::Utf8), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + + assert_snapshot!(actual, @r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | alpha | ++----------+--------------+ + "#); + + Ok(()) + } + + #[test] + fn should_track_lexicographic_max_utf8_value_desc() -> Result<()> { + let ids: ArrayRef = Arc::new(Int64Array::from(vec![1, 1])); + let vals: ArrayRef = Arc::new(StringArray::from(vec!["alpha", "zulu"])); + let mut agg = PriorityMap::new(DataType::Int64, DataType::Utf8, 1, true)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + agg.insert(1)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema_value(DataType::Utf8), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + + assert_snapshot!(actual, @r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | zulu | ++----------+--------------+ + "#); + + Ok(()) + } + + #[test] + fn should_track_large_utf8_values() -> Result<()> { + let ids: ArrayRef = Arc::new(Int64Array::from(vec![1, 1])); + let vals: ArrayRef = Arc::new(LargeStringArray::from(vec!["zulu", "alpha"])); + let mut agg = PriorityMap::new(DataType::Int64, DataType::LargeUtf8, 1, false)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + agg.insert(1)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema_value(DataType::LargeUtf8), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + + assert_snapshot!(actual, @r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | alpha | ++----------+--------------+ + "#); + + Ok(()) + } + + #[test] + fn should_track_utf8_view_values() -> Result<()> { + let ids: ArrayRef = Arc::new(Int64Array::from(vec![1, 1])); + let vals: ArrayRef = Arc::new(StringViewArray::from(vec!["alpha", "zulu"])); + let mut agg = PriorityMap::new(DataType::Int64, DataType::Utf8View, 1, true)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + agg.insert(1)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema_value(DataType::Utf8View), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + + assert_snapshot!(actual, @r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | zulu | ++----------+--------------+ + "#); + + Ok(()) + } + #[test] fn should_handle_null_ids() -> Result<()> { let ids: ArrayRef = Arc::new(StringArray::from(vec![Some("1"), None, None])); @@ -419,4 +515,11 @@ mod tests { Field::new("timestamp_ms", DataType::Int64, true), ])) } + + fn test_schema_value(value_type: DataType) -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("trace_id", DataType::Int64, true), + Field::new("timestamp_ms", value_type, true), + ])) + } } diff --git a/datafusion/physical-plan/src/aggregates/topk_stream.rs b/datafusion/physical-plan/src/aggregates/topk_stream.rs index 1096eb64d3ae..4aa566ccfcd0 100644 --- a/datafusion/physical-plan/src/aggregates/topk_stream.rs +++ b/datafusion/physical-plan/src/aggregates/topk_stream.rs @@ -19,6 +19,8 @@ use crate::aggregates::group_values::GroupByMetrics; use crate::aggregates::topk::priority_map::PriorityMap; +#[cfg(debug_assertions)] +use crate::aggregates::topk_types_supported; use crate::aggregates::{ AggregateExec, PhysicalGroupBy, aggregate_expressions, evaluate_group_by, evaluate_many, @@ -32,6 +34,7 @@ use datafusion_common::Result; use datafusion_common::internal_datafusion_err; use datafusion_execution::TaskContext; use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::metrics::RecordOutput; use futures::stream::{Stream, StreamExt}; use log::{Level, trace}; use std::pin::Pin; @@ -47,7 +50,7 @@ pub struct GroupedTopKAggregateStream { baseline_metrics: BaselineMetrics, group_by_metrics: GroupByMetrics, aggregate_arguments: Vec>>, - group_by: PhysicalGroupBy, + group_by: Arc, priority_map: PriorityMap, } @@ -59,20 +62,47 @@ impl GroupedTopKAggregateStream { limit: usize, ) -> Result { let agg_schema = Arc::clone(&aggr.schema); - let group_by = aggr.group_by.clone(); + let group_by = Arc::clone(&aggr.group_by); let input = aggr.input.execute(partition, Arc::clone(context))?; let baseline_metrics = BaselineMetrics::new(&aggr.metrics, partition); let group_by_metrics = GroupByMetrics::new(&aggr.metrics, partition); let aggregate_arguments = aggregate_expressions(&aggr.aggr_expr, &aggr.mode, group_by.expr.len())?; - let (val_field, desc) = aggr - .get_minmax_desc() - .ok_or_else(|| internal_datafusion_err!("Min/max required"))?; let (expr, _) = &aggr.group_expr().expr()[0]; let kt = expr.data_type(&aggr.input().schema())?; - let vt = val_field.data_type().clone(); + // Check if this is a MIN/MAX aggregate or a DISTINCT-like operation + let (vt, desc) = if let Some((val_field, desc)) = aggr.get_minmax_desc() { + // MIN/MAX case: use the aggregate output type + (val_field.data_type().clone(), desc) + } else { + // DISTINCT case: use the group key type and get ordering from limit_order_descending + // The ordering direction is set by the optimizer when it pushes down the limit + let desc = aggr + .limit_options() + .and_then(|config| config.descending) + .ok_or_else(|| { + internal_datafusion_err!( + "Ordering direction required for DISTINCT with limit" + ) + })?; + (kt.clone(), desc) + }; + + // Type validation is performed by the optimizer and can_use_topk() check. + // This debug assertion documents the contract without runtime overhead in release builds. + #[cfg(debug_assertions)] + { + debug_assert!( + topk_types_supported(&kt, &vt), + "TopK type validation should have been performed by optimizer and can_use_topk(). \ + Found unsupported types: key={kt:?}, value={vt:?}" + ); + } + + // Note: Null values in aggregate columns are filtered by the aggregation layer + // before reaching the heap, so the heap implementations don't need explicit null handling. let priority_map = PriorityMap::new(kt, vt, limit, desc)?; Ok(GroupedTopKAggregateStream { @@ -154,18 +184,21 @@ impl Stream for GroupedTopKAggregateStream { "Exactly 1 group value required" ); let group_by_values = Arc::clone(&group_by_values[0][0]); - let input_values = { - let _timer = (!self.aggregate_arguments.is_empty()).then(|| { - self.group_by_metrics.aggregate_arguments_time.timer() - }); - evaluate_many( + let input_values = if self.aggregate_arguments.is_empty() { + // DISTINCT case: use group key as both key and value + Arc::clone(&group_by_values) + } else { + // MIN/MAX case: evaluate aggregate expressions + let _timer = + self.group_by_metrics.aggregate_arguments_time.timer(); + let input_values = evaluate_many( &self.aggregate_arguments, batches.first().unwrap(), - )? + )?; + assert_eq!(input_values.len(), 1, "Exactly 1 input required"); + assert_eq!(input_values[0].len(), 1, "Exactly 1 input required"); + Arc::clone(&input_values[0][0]) }; - assert_eq!(input_values.len(), 1, "Exactly 1 input required"); - assert_eq!(input_values[0].len(), 1, "Exactly 1 input required"); - let input_values = Arc::clone(&input_values[0][0]); // iterate over each column of group_by values (*self).intern(&group_by_values, &input_values)?; @@ -178,9 +211,15 @@ impl Stream for GroupedTopKAggregateStream { } let batch = { let _timer = emitting_time.timer(); - let cols = self.priority_map.emit()?; + let mut cols = self.priority_map.emit()?; + // For DISTINCT case (no aggregate expressions), only use the group key column + // since the schema only has one field and key/value are the same + if self.aggregate_arguments.is_empty() { + cols.truncate(1); + } RecordBatch::try_new(Arc::clone(&self.schema), cols)? }; + let batch = batch.record_output(&self.baseline_metrics); trace!( "partition {} emit batch with {} rows", self.partition, diff --git a/datafusion/physical-plan/src/analyze.rs b/datafusion/physical-plan/src/analyze.rs index 1fb8f93a3878..eca31ea0e194 100644 --- a/datafusion/physical-plan/src/analyze.rs +++ b/datafusion/physical-plan/src/analyze.rs @@ -51,7 +51,7 @@ pub struct AnalyzeExec { pub(crate) input: Arc, /// The output schema for RecordBatches of this exec node schema: SchemaRef, - cache: PlanProperties, + cache: Arc, } impl AnalyzeExec { @@ -70,7 +70,7 @@ impl AnalyzeExec { metric_types, input, schema, - cache, + cache: Arc::new(cache), } } @@ -131,7 +131,7 @@ impl ExecutionPlan for AnalyzeExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } diff --git a/datafusion/physical-plan/src/async_func.rs b/datafusion/physical-plan/src/async_func.rs index 7393116b5ef3..72741f4314e7 100644 --- a/datafusion/physical-plan/src/async_func.rs +++ b/datafusion/physical-plan/src/async_func.rs @@ -20,6 +20,7 @@ use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::stream::RecordBatchStreamAdapter; use crate::{ DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, + check_if_same_properties, }; use arrow::array::RecordBatch; use arrow_schema::{Fields, Schema, SchemaRef}; @@ -30,6 +31,7 @@ use datafusion_physical_expr::ScalarFunctionExpr; use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr_common::metrics::{BaselineMetrics, RecordOutput}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use futures::Stream; use futures::stream::StreamExt; @@ -44,12 +46,12 @@ use std::task::{Context, Poll, ready}; /// /// The schema of the output of the AsyncFuncExec is: /// Input columns followed by one column for each async expression -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct AsyncFuncExec { /// The async expressions to evaluate async_exprs: Vec>, input: Arc, - cache: PlanProperties, + cache: Arc, metrics: ExecutionPlanMetricsSet, } @@ -83,7 +85,7 @@ impl AsyncFuncExec { Ok(Self { input, async_exprs, - cache, + cache: Arc::new(cache), metrics: ExecutionPlanMetricsSet::new(), }) } @@ -112,6 +114,17 @@ impl AsyncFuncExec { pub fn input(&self) -> &Arc { &self.input } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for AsyncFuncExec { @@ -148,7 +161,7 @@ impl ExecutionPlan for AsyncFuncExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -158,16 +171,17 @@ impl ExecutionPlan for AsyncFuncExec { fn with_new_children( self: Arc, - children: Vec>, + mut children: Vec>, ) -> Result> { assert_eq_or_internal_err!( children.len(), 1, "AsyncFuncExec wrong number of children" ); + check_if_same_properties!(self, children); Ok(Arc::new(AsyncFuncExec::try_new( self.async_exprs.clone(), - Arc::clone(&children[0]), + children.swap_remove(0), )?)) } @@ -182,11 +196,14 @@ impl ExecutionPlan for AsyncFuncExec { context.session_id(), context.task_id() ); - // TODO figure out how to record metrics // first execute the input stream let input_stream = self.input.execute(partition, Arc::clone(&context))?; + // TODO: Track `elapsed_compute` in `BaselineMetrics` + // Issue: + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + // now, for each record batch, evaluate the async expressions and add the columns to the result let async_exprs_captured = Arc::new(self.async_exprs.clone()); let schema_captured = self.schema(); @@ -207,6 +224,7 @@ impl ExecutionPlan for AsyncFuncExec { let async_exprs_captured = Arc::clone(&async_exprs_captured); let schema_captured = Arc::clone(&schema_captured); let config_options = Arc::clone(&config_options_ref); + let baseline_metrics_captured = baseline_metrics.clone(); async move { let batch = batch?; @@ -219,7 +237,8 @@ impl ExecutionPlan for AsyncFuncExec { output_arrays.push(output.to_array(batch.num_rows())?); } let batch = RecordBatch::try_new(schema_captured, output_arrays)?; - Ok(batch) + + Ok(batch.record_output(&baseline_metrics_captured)) } }); @@ -386,7 +405,7 @@ mod tests { vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6]))], )?; - let batches: Vec = (0..50).map(|_| batch.clone()).collect(); + let batches: Vec = std::iter::repeat_n(batch, 50).collect(); let session_config = SessionConfig::new().with_batch_size(200); let task_ctx = TaskContext::default().with_session_config(session_config); diff --git a/datafusion/physical-plan/src/buffer.rs b/datafusion/physical-plan/src/buffer.rs new file mode 100644 index 000000000000..a59d06292997 --- /dev/null +++ b/datafusion/physical-plan/src/buffer.rs @@ -0,0 +1,640 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`BufferExec`] decouples production and consumption on messages by buffering the input in the +//! background up to a certain capacity. + +use crate::execution_plan::{CardinalityEffect, SchedulingType}; +use crate::filter_pushdown::{ + ChildPushdownResult, FilterDescription, FilterPushdownPhase, + FilterPushdownPropagation, +}; +use crate::projection::ProjectionExec; +use crate::stream::RecordBatchStreamAdapter; +use crate::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SortOrderPushdownResult, + check_if_same_properties, +}; +use arrow::array::RecordBatch; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{Result, Statistics, internal_err, plan_err}; +use datafusion_common_runtime::SpawnedTask; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr_common::metrics::{ + ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, +}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; +use futures::{Stream, StreamExt, TryStreamExt}; +use pin_project_lite::pin_project; +use std::any::Any; +use std::fmt; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::task::{Context, Poll}; +use tokio::sync::mpsc::UnboundedReceiver; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; + +/// WARNING: EXPERIMENTAL +/// +/// Decouples production and consumption of record batches with an internal queue per partition, +/// eagerly filling up the capacity of the queues even before any message is requested. +/// +/// ```text +/// ┌───────────────────────────┐ +/// │ BufferExec │ +/// │ │ +/// │┌────── Partition 0 ──────┐│ +/// ││ ┌────┐ ┌────┐││ ┌────┐ +/// ──background poll────────▶│ │ │ ├┼┼───────▶ │ +/// ││ └────┘ └────┘││ └────┘ +/// │└─────────────────────────┘│ +/// │┌────── Partition 1 ──────┐│ +/// ││ ┌────┐ ┌────┐ ┌────┐││ ┌────┐ +/// ──background poll─▶│ │ │ │ │ ├┼┼───────▶ │ +/// ││ └────┘ └────┘ └────┘││ └────┘ +/// │└─────────────────────────┘│ +/// │ │ +/// │ ... │ +/// │ │ +/// │┌────── Partition N ──────┐│ +/// ││ ┌────┐││ ┌────┐ +/// ──background poll───────────────▶│ ├┼┼───────▶ │ +/// ││ └────┘││ └────┘ +/// │└─────────────────────────┘│ +/// └───────────────────────────┘ +/// ``` +/// +/// The capacity is provided in bytes, and for each buffered record batch it will take into account +/// the size reported by [RecordBatch::get_array_memory_size]. +/// +/// If a single record batch exceeds the maximum capacity set in the `capacity` argument, it's still +/// allowed to pass in order to not deadlock the buffer. +/// +/// This is useful for operators that conditionally start polling one of their children only after +/// other child has finished, allowing to perform some early work and accumulating batches in +/// memory so that they can be served immediately when requested. +#[derive(Debug, Clone)] +pub struct BufferExec { + input: Arc, + properties: Arc, + capacity: usize, + metrics: ExecutionPlanMetricsSet, +} + +impl BufferExec { + /// Builds a new [BufferExec] with the provided capacity in bytes. + pub fn new(input: Arc, capacity: usize) -> Self { + let properties = PlanProperties::clone(input.properties()) + .with_scheduling_type(SchedulingType::Cooperative); + + Self { + input, + properties: Arc::new(properties), + capacity, + metrics: ExecutionPlanMetricsSet::new(), + } + } + + /// Returns the input [ExecutionPlan] of this [BufferExec]. + pub fn input(&self) -> &Arc { + &self.input + } + + /// Returns the per-partition capacity in bytes for this [BufferExec]. + pub fn capacity(&self) -> usize { + self.capacity + } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } +} + +impl DisplayAs for BufferExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "BufferExec: capacity={}", self.capacity) + } + DisplayFormatType::TreeRender => { + writeln!(f, "target_batch_size={}", self.capacity) + } + } + } +} + +impl ExecutionPlan for BufferExec { + fn name(&self) -> &str { + "BufferExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &Arc { + &self.properties + } + + fn maintains_input_order(&self) -> Vec { + vec![true] + } + + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false] + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + check_if_same_properties!(self, children); + if children.len() != 1 { + return plan_err!("BufferExec can only have one child"); + } + Ok(Arc::new(Self::new(children.swap_remove(0), self.capacity))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let mem_reservation = MemoryConsumer::new(format!("BufferExec[{partition}]")) + .register(context.memory_pool()); + let in_stream = self.input.execute(partition, context)?; + + // Set up the metrics for the stream. + let curr_mem_in = Arc::new(AtomicUsize::new(0)); + let curr_mem_out = Arc::clone(&curr_mem_in); + let mut max_mem_in = 0; + let max_mem = MetricBuilder::new(&self.metrics).gauge("max_mem_used", partition); + + let curr_queued_in = Arc::new(AtomicUsize::new(0)); + let curr_queued_out = Arc::clone(&curr_queued_in); + let mut max_queued_in = 0; + let max_queued = MetricBuilder::new(&self.metrics).gauge("max_queued", partition); + + // Capture metrics when an element is queued on the stream. + let in_stream = in_stream.inspect_ok(move |v| { + let size = v.get_array_memory_size(); + let curr_size = curr_mem_in.fetch_add(size, Ordering::Relaxed) + size; + if curr_size > max_mem_in { + max_mem_in = curr_size; + max_mem.set(max_mem_in); + } + + let curr_queued = curr_queued_in.fetch_add(1, Ordering::Relaxed) + 1; + if curr_queued > max_queued_in { + max_queued_in = curr_queued; + max_queued.set(max_queued_in); + } + }); + // Buffer the input. + let out_stream = + MemoryBufferedStream::new(in_stream, self.capacity, mem_reservation); + // Update in the metrics that when an element gets out, some memory gets freed. + let out_stream = out_stream.inspect_ok(move |v| { + curr_mem_out.fetch_sub(v.get_array_memory_size(), Ordering::Relaxed); + curr_queued_out.fetch_sub(1, Ordering::Relaxed); + }); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + out_stream, + ))) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn partition_statistics(&self, partition: Option) -> Result { + self.input.partition_statistics(partition) + } + + fn supports_limit_pushdown(&self) -> bool { + self.input.supports_limit_pushdown() + } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Equal + } + + fn try_swapping_with_projection( + &self, + projection: &ProjectionExec, + ) -> Result>> { + match self.input.try_swapping_with_projection(projection)? { + Some(new_input) => Ok(Some( + Arc::new(self.clone()).with_new_children(vec![new_input])?, + )), + None => Ok(None), + } + } + + fn gather_filters_for_pushdown( + &self, + _phase: FilterPushdownPhase, + parent_filters: Vec>, + _config: &ConfigOptions, + ) -> Result { + FilterDescription::from_children(parent_filters, &self.children()) + } + + fn handle_child_pushdown_result( + &self, + _phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + Ok(FilterPushdownPropagation::if_all(child_pushdown_result)) + } + + fn try_pushdown_sort( + &self, + order: &[PhysicalSortExpr], + ) -> Result>> { + // CoalesceBatchesExec is transparent for sort ordering - it preserves order + // Delegate to the child and wrap with a new CoalesceBatchesExec + self.input.try_pushdown_sort(order)?.try_map(|new_input| { + Ok(Arc::new(Self::new(new_input, self.capacity)) as Arc) + }) + } +} + +/// Represents anything that occupies a capacity in a [MemoryBufferedStream]. +pub trait SizedMessage { + fn size(&self) -> usize; +} + +impl SizedMessage for RecordBatch { + fn size(&self) -> usize { + self.get_array_memory_size() + } +} + +pin_project! { +/// Decouples production and consumption of messages in a stream with an internal queue, eagerly +/// filling it up to the specified maximum capacity even before any message is requested. +/// +/// Allows each message to have a different size, which is taken into account for determining if +/// the queue is full or not. +pub struct MemoryBufferedStream { + task: SpawnedTask<()>, + batch_rx: UnboundedReceiver>, + memory_reservation: Arc, +}} + +impl MemoryBufferedStream { + /// Builds a new [MemoryBufferedStream] with the provided capacity and event handler. + /// + /// This immediately spawns a Tokio task that will start consumption of the input stream. + pub fn new( + mut input: impl Stream> + Unpin + Send + 'static, + capacity: usize, + memory_reservation: MemoryReservation, + ) -> Self { + let semaphore = Arc::new(Semaphore::new(capacity)); + let (batch_tx, batch_rx) = tokio::sync::mpsc::unbounded_channel(); + + let memory_reservation = Arc::new(memory_reservation); + let memory_reservation_clone = Arc::clone(&memory_reservation); + let task = SpawnedTask::spawn(async move { + loop { + // Select on both the input stream and the channel being closed. + // By down this, we abort polling the input as soon as the consumer channel is + // closed. Otherwise, we would need to wait for a full new message to be available + // in order to consider aborting the stream + let item_or_err = tokio::select! { + biased; + _ = batch_tx.closed() => break, + item_or_err = input.next() => { + let Some(item_or_err) = item_or_err else { + break; // stream finished + }; + item_or_err + } + }; + + let item = match item_or_err { + Ok(batch) => batch, + Err(err) => { + let _ = batch_tx.send(Err(err)); // If there's an error it means the channel was closed, which is fine. + break; + } + }; + + let size = item.size(); + if let Err(err) = memory_reservation.try_grow(size) { + let _ = batch_tx.send(Err(err)); // If there's an error it means the channel was closed, which is fine. + break; + } + + // We need to cap the minimum between amount of permits and the actual size of the + // message. If at any point we try to acquire more permits than the capacity of the + // semaphore, the stream will deadlock. + let capped_size = size.min(capacity) as u32; + + let semaphore = Arc::clone(&semaphore); + let Ok(permit) = semaphore.acquire_many_owned(capped_size).await else { + let _ = batch_tx.send(internal_err!("Closed semaphore in MemoryBufferedStream. This is a bug in DataFusion, please report it!")); + break; + }; + + if batch_tx.send(Ok((item, permit))).is_err() { + break; // stream was closed + }; + } + }); + + Self { + task, + batch_rx, + memory_reservation: memory_reservation_clone, + } + } + + /// Returns the number of queued messages. + pub fn messages_queued(&self) -> usize { + self.batch_rx.len() + } +} + +impl Stream for MemoryBufferedStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let self_project = self.project(); + match self_project.batch_rx.poll_recv(cx) { + Poll::Ready(Some(Ok((item, _semaphore_permit)))) => { + self_project.memory_reservation.shrink(item.size()); + Poll::Ready(Some(Ok(item))) + } + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } + + fn size_hint(&self) -> (usize, Option) { + if self.batch_rx.is_closed() { + let len = self.batch_rx.len(); + (len, Some(len)) + } else { + (self.batch_rx.len(), None) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_common::{DataFusionError, assert_contains}; + use datafusion_execution::memory_pool::{ + GreedyMemoryPool, MemoryPool, UnboundedMemoryPool, + }; + use std::error::Error; + use std::fmt::Debug; + use std::sync::Arc; + use std::time::Duration; + use tokio::time::timeout; + + #[tokio::test] + async fn buffers_only_some_messages() -> Result<(), Box> { + let input = futures::stream::iter([1, 2, 3, 4]).map(Ok); + let (_, res) = memory_pool_and_reservation(); + + let buffered = MemoryBufferedStream::new(input, 4, res); + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 2); + Ok(()) + } + + #[tokio::test] + async fn yields_all_messages() -> Result<(), Box> { + let input = futures::stream::iter([1, 2, 3, 4]).map(Ok); + let (_, res) = memory_pool_and_reservation(); + + let mut buffered = MemoryBufferedStream::new(input, 10, res); + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 4); + + pull_ok_msg(&mut buffered).await?; + pull_ok_msg(&mut buffered).await?; + pull_ok_msg(&mut buffered).await?; + pull_ok_msg(&mut buffered).await?; + finished(&mut buffered).await?; + Ok(()) + } + + #[tokio::test] + async fn yields_first_msg_even_if_big() -> Result<(), Box> { + let input = futures::stream::iter([25, 1, 2, 3]).map(Ok); + let (_, res) = memory_pool_and_reservation(); + + let mut buffered = MemoryBufferedStream::new(input, 10, res); + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 1); + pull_ok_msg(&mut buffered).await?; + Ok(()) + } + + #[tokio::test] + async fn memory_pool_kills_stream() -> Result<(), Box> { + let input = futures::stream::iter([1, 2, 3, 4]).map(Ok); + let (_, res) = bounded_memory_pool_and_reservation(7); + + let mut buffered = MemoryBufferedStream::new(input, 10, res); + wait_for_buffering().await; + + pull_ok_msg(&mut buffered).await?; + pull_ok_msg(&mut buffered).await?; + pull_ok_msg(&mut buffered).await?; + let msg = pull_err_msg(&mut buffered).await?; + + assert_contains!(msg.to_string(), "Failed to allocate additional 4.0 B"); + Ok(()) + } + + #[tokio::test] + async fn memory_pool_does_not_kill_stream() -> Result<(), Box> { + let input = futures::stream::iter([1, 2, 3, 4]).map(Ok); + let (_, res) = bounded_memory_pool_and_reservation(7); + + let mut buffered = MemoryBufferedStream::new(input, 3, res); + wait_for_buffering().await; + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + finished(&mut buffered).await?; + Ok(()) + } + + #[tokio::test] + async fn messages_pass_even_if_all_exceed_limit() -> Result<(), Box> { + let input = futures::stream::iter([3, 3, 3, 3]).map(Ok); + let (_, res) = memory_pool_and_reservation(); + + let mut buffered = MemoryBufferedStream::new(input, 2, res); + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 1); + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 1); + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 1); + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 1); + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + finished(&mut buffered).await?; + Ok(()) + } + + #[tokio::test] + async fn errors_get_propagated() -> Result<(), Box> { + let input = futures::stream::iter([1, 2, 3, 4]).map(|v| { + if v == 3 { + return internal_err!("Error on 3"); + } + Ok(v) + }); + let (_, res) = memory_pool_and_reservation(); + + let mut buffered = MemoryBufferedStream::new(input, 10, res); + wait_for_buffering().await; + + pull_ok_msg(&mut buffered).await?; + pull_ok_msg(&mut buffered).await?; + pull_err_msg(&mut buffered).await?; + + Ok(()) + } + + #[tokio::test] + async fn memory_gets_released_if_stream_drops() -> Result<(), Box> { + let input = futures::stream::iter([1, 2, 3, 4]).map(Ok); + let (pool, res) = memory_pool_and_reservation(); + + let mut buffered = MemoryBufferedStream::new(input, 10, res); + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 4); + assert_eq!(pool.reserved(), 10); + + pull_ok_msg(&mut buffered).await?; + assert_eq!(buffered.messages_queued(), 3); + assert_eq!(pool.reserved(), 9); + + pull_ok_msg(&mut buffered).await?; + assert_eq!(buffered.messages_queued(), 2); + assert_eq!(pool.reserved(), 7); + + drop(buffered); + assert_eq!(pool.reserved(), 0); + Ok(()) + } + + fn memory_pool_and_reservation() -> (Arc, MemoryReservation) { + let pool = Arc::new(UnboundedMemoryPool::default()) as _; + let reservation = MemoryConsumer::new("test").register(&pool); + (pool, reservation) + } + + fn bounded_memory_pool_and_reservation( + size: usize, + ) -> (Arc, MemoryReservation) { + let pool = Arc::new(GreedyMemoryPool::new(size)) as _; + let reservation = MemoryConsumer::new("test").register(&pool); + (pool, reservation) + } + + async fn wait_for_buffering() { + // We do not have control over the spawned task, so the best we can do is to yield some + // cycles to the tokio runtime and let the task make progress on its own. + tokio::time::sleep(Duration::from_millis(1)).await; + } + + async fn pull_ok_msg( + buffered: &mut MemoryBufferedStream, + ) -> Result> { + Ok(timeout(Duration::from_millis(1), buffered.next()) + .await? + .unwrap_or_else(|| internal_err!("Stream should not have finished"))?) + } + + async fn pull_err_msg( + buffered: &mut MemoryBufferedStream, + ) -> Result> { + Ok(timeout(Duration::from_millis(1), buffered.next()) + .await? + .map(|v| match v { + Ok(v) => internal_err!( + "Stream should not have failed, but succeeded with {v:?}" + ), + Err(err) => Ok(err), + }) + .unwrap_or_else(|| internal_err!("Stream should not have finished"))?) + } + + async fn finished( + buffered: &mut MemoryBufferedStream, + ) -> Result<(), Box> { + match timeout(Duration::from_millis(1), buffered.next()) + .await? + .is_none() + { + true => Ok(()), + false => internal_err!("Stream should have finished")?, + } + } + + impl SizedMessage for usize { + fn size(&self) -> usize { + *self + } + } +} diff --git a/datafusion/physical-plan/src/coalesce/mod.rs b/datafusion/physical-plan/src/coalesce/mod.rs index b3947170d9e4..ea1a87d09148 100644 --- a/datafusion/physical-plan/src/coalesce/mod.rs +++ b/datafusion/physical-plan/src/coalesce/mod.rs @@ -134,6 +134,10 @@ impl LimitedBatchCoalescer { Ok(()) } + pub(crate) fn is_finished(&self) -> bool { + self.finished + } + /// Return the next completed batch, if any pub fn next_completed_batch(&mut self) -> Option { self.inner.next_completed_batch() diff --git a/datafusion/physical-plan/src/coalesce_batches.rs b/datafusion/physical-plan/src/coalesce_batches.rs index 13bb862ab937..663b0b51ea59 100644 --- a/datafusion/physical-plan/src/coalesce_batches.rs +++ b/datafusion/physical-plan/src/coalesce_batches.rs @@ -27,6 +27,7 @@ use super::{DisplayAs, ExecutionPlanProperties, PlanProperties, Statistics}; use crate::projection::ProjectionExec; use crate::{ DisplayFormatType, ExecutionPlan, RecordBatchStream, SendableRecordBatchStream, + check_if_same_properties, }; use arrow::datatypes::SchemaRef; @@ -57,6 +58,10 @@ use futures::stream::{Stream, StreamExt}; /// reaches the `fetch` value. /// /// See [`LimitedBatchCoalescer`] for more information +#[deprecated( + since = "52.0.0", + note = "We now use BatchCoalescer from arrow-rs instead of a dedicated operator" +)] #[derive(Debug, Clone)] pub struct CoalesceBatchesExec { /// The input plan @@ -67,9 +72,10 @@ pub struct CoalesceBatchesExec { fetch: Option, /// Execution metrics metrics: ExecutionPlanMetricsSet, - cache: PlanProperties, + cache: Arc, } +#[expect(deprecated)] impl CoalesceBatchesExec { /// Create a new CoalesceBatchesExec pub fn new(input: Arc, target_batch_size: usize) -> Self { @@ -79,7 +85,7 @@ impl CoalesceBatchesExec { target_batch_size, fetch: None, metrics: ExecutionPlanMetricsSet::new(), - cache, + cache: Arc::new(cache), } } @@ -110,8 +116,20 @@ impl CoalesceBatchesExec { input.boundedness(), ) } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } +#[expect(deprecated)] impl DisplayAs for CoalesceBatchesExec { fn fmt_as( &self, @@ -142,6 +160,7 @@ impl DisplayAs for CoalesceBatchesExec { } } +#[expect(deprecated)] impl ExecutionPlan for CoalesceBatchesExec { fn name(&self) -> &'static str { "CoalesceBatchesExec" @@ -152,7 +171,7 @@ impl ExecutionPlan for CoalesceBatchesExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -170,10 +189,11 @@ impl ExecutionPlan for CoalesceBatchesExec { fn with_new_children( self: Arc, - children: Vec>, + mut children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); Ok(Arc::new( - CoalesceBatchesExec::new(Arc::clone(&children[0]), self.target_batch_size) + CoalesceBatchesExec::new(children.swap_remove(0), self.target_batch_size) .with_fetch(self.fetch), )) } @@ -199,10 +219,6 @@ impl ExecutionPlan for CoalesceBatchesExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { self.input .partition_statistics(partition)? @@ -215,7 +231,7 @@ impl ExecutionPlan for CoalesceBatchesExec { target_batch_size: self.target_batch_size, fetch: limit, metrics: self.metrics.clone(), - cache: self.cache.clone(), + cache: Arc::clone(&self.cache), })) } diff --git a/datafusion/physical-plan/src/coalesce_partitions.rs b/datafusion/physical-plan/src/coalesce_partitions.rs index d83f90eb3d8c..39906d3680a4 100644 --- a/datafusion/physical-plan/src/coalesce_partitions.rs +++ b/datafusion/physical-plan/src/coalesce_partitions.rs @@ -31,7 +31,7 @@ use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType}; use crate::filter_pushdown::{FilterDescription, FilterPushdownPhase}; use crate::projection::{ProjectionExec, make_with_child}; use crate::sort_pushdown::SortOrderPushdownResult; -use crate::{DisplayFormatType, ExecutionPlan, Partitioning}; +use crate::{DisplayFormatType, ExecutionPlan, Partitioning, check_if_same_properties}; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use datafusion_common::config::ConfigOptions; @@ -47,7 +47,7 @@ pub struct CoalescePartitionsExec { input: Arc, /// Execution metrics metrics: ExecutionPlanMetricsSet, - cache: PlanProperties, + cache: Arc, /// Optional number of rows to fetch. Stops producing rows after this fetch pub(crate) fetch: Option, } @@ -59,7 +59,7 @@ impl CoalescePartitionsExec { CoalescePartitionsExec { input, metrics: ExecutionPlanMetricsSet::new(), - cache, + cache: Arc::new(cache), fetch: None, } } @@ -100,6 +100,17 @@ impl CoalescePartitionsExec { .with_evaluation_type(drive) .with_scheduling_type(scheduling) } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for CoalescePartitionsExec { @@ -135,7 +146,7 @@ impl ExecutionPlan for CoalescePartitionsExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -149,9 +160,10 @@ impl ExecutionPlan for CoalescePartitionsExec { fn with_new_children( self: Arc, - children: Vec>, + mut children: Vec>, ) -> Result> { - let mut plan = CoalescePartitionsExec::new(Arc::clone(&children[0])); + check_if_same_properties!(self, children); + let mut plan = CoalescePartitionsExec::new(children.swap_remove(0)); plan.fetch = self.fetch; Ok(Arc::new(plan)) } @@ -224,10 +236,6 @@ impl ExecutionPlan for CoalescePartitionsExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, _partition: Option) -> Result { self.input .partition_statistics(None)? @@ -274,10 +282,23 @@ impl ExecutionPlan for CoalescePartitionsExec { input: Arc::clone(&self.input), fetch: limit, metrics: self.metrics.clone(), - cache: self.cache.clone(), + cache: Arc::clone(&self.cache), })) } + fn with_preserve_order( + &self, + preserve_order: bool, + ) -> Option> { + self.input + .with_preserve_order(preserve_order) + .and_then(|new_input| { + Arc::new(self.clone()) + .with_new_children(vec![new_input]) + .ok() + }) + } + fn gather_filters_for_pushdown( &self, _phase: FilterPushdownPhase, diff --git a/datafusion/physical-plan/src/column_rewriter.rs b/datafusion/physical-plan/src/column_rewriter.rs new file mode 100644 index 000000000000..7cd865630455 --- /dev/null +++ b/datafusion/physical-plan/src/column_rewriter.rs @@ -0,0 +1,383 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use datafusion_common::{ + DataFusionError, HashMap, + tree_node::{Transformed, TreeNodeRecursion, TreeNodeRewriter}, +}; +use datafusion_physical_expr::{PhysicalExpr, expressions::Column}; + +/// Rewrite column references in a physical expr according to a mapping. +/// +/// This rewriter traverses the expression tree and replaces [`Column`] nodes +/// with the corresponding expression found in the `column_map`. +/// +/// If a column is found in the map, it is replaced by the mapped expression. +/// If a column is NOT found in the map, a `DataFusionError::Internal` is +/// returned. +pub struct PhysicalColumnRewriter<'a> { + /// Mapping from original column to new column. + pub column_map: &'a HashMap>, +} + +impl<'a> PhysicalColumnRewriter<'a> { + /// Create a new PhysicalColumnRewriter with the given column mapping. + pub fn new(column_map: &'a HashMap>) -> Self { + Self { column_map } + } +} + +impl<'a> TreeNodeRewriter for PhysicalColumnRewriter<'a> { + type Node = Arc; + + fn f_down( + &mut self, + node: Self::Node, + ) -> datafusion_common::Result> { + if let Some(column) = node.as_any().downcast_ref::() { + if let Some(new_column) = self.column_map.get(column) { + // jump to prevent rewriting the new sub-expression again + return Ok(Transformed::new( + Arc::clone(new_column), + true, + TreeNodeRecursion::Jump, + )); + } else { + // Column not found in mapping + return Err(DataFusionError::Internal(format!( + "Column {column:?} not found in column mapping {:?}", + self.column_map + ))); + } + } + Ok(Transformed::no(node)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{DataFusionError, Result, tree_node::TreeNode}; + use datafusion_physical_expr::{ + PhysicalExpr, + expressions::{Column, binary, col, lit}, + }; + use std::sync::Arc; + + /// Helper function to create a test schema + fn create_test_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + Field::new("new_col", DataType::Int32, true), + Field::new("inner_col", DataType::Int32, true), + Field::new("another_col", DataType::Int32, true), + ])) + } + + /// Helper function to create a complex nested expression with multiple columns + /// Create: (col_a + col_b) * (col_c - col_d) + col_e + fn create_complex_expression(schema: &Schema) -> Arc { + let col_a = col("a", schema).unwrap(); + let col_b = col("b", schema).unwrap(); + let col_c = col("c", schema).unwrap(); + let col_d = col("d", schema).unwrap(); + let col_e = col("e", schema).unwrap(); + + let add_expr = + binary(col_a, datafusion_expr::Operator::Plus, col_b, schema).unwrap(); + let sub_expr = + binary(col_c, datafusion_expr::Operator::Minus, col_d, schema).unwrap(); + let mul_expr = binary( + add_expr, + datafusion_expr::Operator::Multiply, + sub_expr, + schema, + ) + .unwrap(); + binary(mul_expr, datafusion_expr::Operator::Plus, col_e, schema).unwrap() + } + + /// Helper function to create a deeply nested expression + /// Create: col_a + (col_b + (col_c + (col_d + col_e))) + fn create_deeply_nested_expression(schema: &Schema) -> Arc { + let col_a = col("a", schema).unwrap(); + let col_b = col("b", schema).unwrap(); + let col_c = col("c", schema).unwrap(); + let col_d = col("d", schema).unwrap(); + let col_e = col("e", schema).unwrap(); + + let inner1 = + binary(col_d, datafusion_expr::Operator::Plus, col_e, schema).unwrap(); + let inner2 = + binary(col_c, datafusion_expr::Operator::Plus, inner1, schema).unwrap(); + let inner3 = + binary(col_b, datafusion_expr::Operator::Plus, inner2, schema).unwrap(); + binary(col_a, datafusion_expr::Operator::Plus, inner3, schema).unwrap() + } + + #[test] + fn test_simple_column_replacement_with_jump() -> Result<()> { + let schema = create_test_schema(); + + // Test that Jump prevents re-processing of replaced columns + let mut column_map = HashMap::new(); + column_map.insert(Column::new_with_schema("a", &schema).unwrap(), lit(42i32)); + column_map.insert( + Column::new_with_schema("b", &schema).unwrap(), + lit("replaced_b"), + ); + column_map.insert( + Column::new_with_schema("c", &schema).unwrap(), + col("c", &schema).unwrap(), + ); + column_map.insert( + Column::new_with_schema("d", &schema).unwrap(), + col("d", &schema).unwrap(), + ); + column_map.insert( + Column::new_with_schema("e", &schema).unwrap(), + col("e", &schema).unwrap(), + ); + + let mut rewriter = PhysicalColumnRewriter::new(&column_map); + let expr = create_complex_expression(&schema); + + let result = expr.rewrite(&mut rewriter)?; + + // Verify the transformation occurred + assert!(result.transformed); + + assert_eq!( + format!("{}", result.data), + "(42 + replaced_b) * (c@2 - d@3) + e@4" + ); + + Ok(()) + } + + #[test] + fn test_nested_column_replacement_with_jump() -> Result<()> { + let schema = create_test_schema(); + // Test Jump behavior with deeply nested expressions + let mut column_map = HashMap::new(); + // Replace col_c with a complex expression containing new columns + let replacement_expr = binary( + lit(100i32), + datafusion_expr::Operator::Plus, + col("new_col", &schema).unwrap(), + &schema, + ) + .unwrap(); + column_map.insert( + Column::new_with_schema("c", &schema).unwrap(), + replacement_expr, + ); + column_map.insert( + Column::new_with_schema("a", &schema).unwrap(), + col("a", &schema).unwrap(), + ); + column_map.insert( + Column::new_with_schema("b", &schema).unwrap(), + col("b", &schema).unwrap(), + ); + column_map.insert( + Column::new_with_schema("d", &schema).unwrap(), + col("d", &schema).unwrap(), + ); + column_map.insert( + Column::new_with_schema("e", &schema).unwrap(), + col("e", &schema).unwrap(), + ); + + let mut rewriter = PhysicalColumnRewriter::new(&column_map); + let expr = create_deeply_nested_expression(&schema); + + let result = expr.rewrite(&mut rewriter)?; + + // Verify transformation occurred + assert!(result.transformed); + + assert_eq!( + format!("{}", result.data), + "a@0 + b@1 + 100 + new_col@5 + d@3 + e@4" + ); + + Ok(()) + } + + #[test] + fn test_circular_reference_prevention() -> Result<()> { + let schema = create_test_schema(); + // Test that Jump prevents infinite recursion with circular references + let mut column_map = HashMap::new(); + + // Create a circular reference: col_a -> col_b -> col_a (but Jump should prevent the second visit) + column_map.insert( + Column::new_with_schema("a", &schema).unwrap(), + col("b", &schema).unwrap(), + ); + column_map.insert( + Column::new_with_schema("b", &schema).unwrap(), + col("a", &schema).unwrap(), + ); + + let mut rewriter = PhysicalColumnRewriter::new(&column_map); + + // Start with an expression containing col_a + let expr = binary( + col("a", &schema).unwrap(), + datafusion_expr::Operator::Plus, + col("b", &schema).unwrap(), + &schema, + ) + .unwrap(); + + let result = expr.rewrite(&mut rewriter)?; + + // Verify transformation occurred + assert!(result.transformed); + + assert_eq!(format!("{}", result.data), "b@1 + a@0"); + + Ok(()) + } + + #[test] + fn test_multiple_replacements_in_same_expression() -> Result<()> { + let schema = create_test_schema(); + // Test multiple column replacements in the same complex expression + let mut column_map = HashMap::new(); + + // Replace multiple columns with literals + column_map.insert(Column::new_with_schema("a", &schema).unwrap(), lit(10i32)); + column_map.insert(Column::new_with_schema("c", &schema).unwrap(), lit(20i32)); + column_map.insert(Column::new_with_schema("e", &schema).unwrap(), lit(30i32)); + column_map.insert( + Column::new_with_schema("b", &schema).unwrap(), + col("b", &schema).unwrap(), + ); + column_map.insert( + Column::new_with_schema("d", &schema).unwrap(), + col("d", &schema).unwrap(), + ); + + let mut rewriter = PhysicalColumnRewriter::new(&column_map); + let expr = create_complex_expression(&schema); // (col_a + col_b) * (col_c - col_d) + col_e + + let result = expr.rewrite(&mut rewriter)?; + + // Verify transformation occurred + assert!(result.transformed); + assert_eq!(format!("{}", result.data), "(10 + b@1) * (20 - d@3) + 30"); + + Ok(()) + } + + #[test] + fn test_jump_with_complex_replacement_expression() -> Result<()> { + let schema = create_test_schema(); + // Test Jump behavior when replacing with very complex expressions + let mut column_map = HashMap::new(); + + // Replace col_a with a complex nested expression + let inner_expr = binary( + lit(5i32), + datafusion_expr::Operator::Multiply, + col("a", &schema).unwrap(), + &schema, + ) + .unwrap(); + let middle_expr = binary( + inner_expr, + datafusion_expr::Operator::Plus, + lit(3i32), + &schema, + ) + .unwrap(); + let complex_replacement = binary( + middle_expr, + datafusion_expr::Operator::Minus, + col("another_col", &schema).unwrap(), + &schema, + ) + .unwrap(); + + column_map.insert( + Column::new_with_schema("a", &schema).unwrap(), + complex_replacement, + ); + column_map.insert( + Column::new_with_schema("b", &schema).unwrap(), + col("b", &schema).unwrap(), + ); + + let mut rewriter = PhysicalColumnRewriter::new(&column_map); + + // Create expression: col_a + col_b + let expr = binary( + col("a", &schema).unwrap(), + datafusion_expr::Operator::Plus, + col("b", &schema).unwrap(), + &schema, + ) + .unwrap(); + + let result = expr.rewrite(&mut rewriter)?; + + assert_eq!( + format!("{}", result.data), + "5 * a@0 + 3 - another_col@7 + b@1" + ); + + // Verify transformation occurred + assert!(result.transformed); + + Ok(()) + } + + #[test] + fn test_unmapped_columns_detection() -> Result<()> { + let schema = create_test_schema(); + let mut column_map = HashMap::new(); + + // Only map col_a, leave col_b unmapped + column_map.insert(Column::new_with_schema("a", &schema).unwrap(), lit(42i32)); + + let mut rewriter = PhysicalColumnRewriter::new(&column_map); + + // Create expression: col_a + col_b + let expr = binary( + col("a", &schema).unwrap(), + datafusion_expr::Operator::Plus, + col("b", &schema).unwrap(), + &schema, + ) + .unwrap(); + + let err = expr.rewrite(&mut rewriter).unwrap_err(); + assert!(matches!(err, DataFusionError::Internal(_))); + + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index 32dc60b56ad4..590f6f09e8b9 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -181,7 +181,7 @@ pub fn compute_record_batch_statistics( /// Checks if the given projection is valid for the given schema. pub fn can_project( schema: &arrow::datatypes::SchemaRef, - projection: Option<&Vec>, + projection: Option<&[usize]>, ) -> Result<()> { match projection { Some(columns) => { diff --git a/datafusion/physical-plan/src/coop.rs b/datafusion/physical-plan/src/coop.rs index a1fad8677740..5f0040b3ddce 100644 --- a/datafusion/physical-plan/src/coop.rs +++ b/datafusion/physical-plan/src/coop.rs @@ -22,10 +22,15 @@ //! A single call to `poll_next` on a top-level [`Stream`] may potentially perform a lot of work //! before it returns a `Poll::Pending`. Think for instance of calculating an aggregation over a //! large dataset. +//! //! If a `Stream` runs for a long period of time without yielding back to the Tokio executor, //! it can starve other tasks waiting on that executor to execute them. //! Additionally, this prevents the query execution from being cancelled. //! +//! For more background, please also see the [Using Rust async for Query Execution and Cancelling Long-Running Queries blog] +//! +//! [Using Rust async for Query Execution and Cancelling Long-Running Queries blog]: https://datafusion.apache.org/blog/2025/06/30/cancellation +//! //! To ensure that `Stream` implementations yield regularly, operators can insert explicit yield //! points using the utilities in this module. For most operators this is **not** necessary. The //! `Stream`s of the built-in DataFusion operators that generate (rather than manipulate) @@ -82,7 +87,7 @@ use crate::filter_pushdown::{ use crate::projection::ProjectionExec; use crate::{ DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream, - SendableRecordBatchStream, SortOrderPushdownResult, + SendableRecordBatchStream, SortOrderPushdownResult, check_if_same_properties, }; use arrow::record_batch::RecordBatch; use arrow_schema::Schema; @@ -212,16 +217,15 @@ where #[derive(Debug, Clone)] pub struct CooperativeExec { input: Arc, - properties: PlanProperties, + properties: Arc, } impl CooperativeExec { /// Creates a new `CooperativeExec` operator that wraps the given input execution plan. pub fn new(input: Arc) -> Self { - let properties = input - .properties() - .clone() - .with_scheduling_type(SchedulingType::Cooperative); + let properties = PlanProperties::clone(input.properties()) + .with_scheduling_type(SchedulingType::Cooperative) + .into(); Self { input, properties } } @@ -230,6 +234,16 @@ impl CooperativeExec { pub fn input(&self) -> &Arc { &self.input } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + ..Self::clone(self) + } + } } impl DisplayAs for CooperativeExec { @@ -255,7 +269,7 @@ impl ExecutionPlan for CooperativeExec { self.input.schema() } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.properties } @@ -276,6 +290,7 @@ impl ExecutionPlan for CooperativeExec { 1, "CooperativeExec requires exactly one child" ); + check_if_same_properties!(self, children); Ok(Arc::new(CooperativeExec::new(children.swap_remove(0)))) } diff --git a/datafusion/physical-plan/src/display.rs b/datafusion/physical-plan/src/display.rs index 52c37a106b39..44148f2d0e88 100644 --- a/datafusion/physical-plan/src/display.rs +++ b/datafusion/physical-plan/src/display.rs @@ -1153,7 +1153,7 @@ mod tests { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { unimplemented!() } @@ -1176,10 +1176,6 @@ mod tests { todo!() } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { if partition.is_some() { return Ok(Statistics::new_unknown(self.schema().as_ref())); diff --git a/datafusion/physical-plan/src/empty.rs b/datafusion/physical-plan/src/empty.rs index fcfbcfa3e827..e4d4da4e88fc 100644 --- a/datafusion/physical-plan/src/empty.rs +++ b/datafusion/physical-plan/src/empty.rs @@ -21,7 +21,7 @@ use std::any::Any; use std::sync::Arc; use crate::memory::MemoryStream; -use crate::{DisplayAs, PlanProperties, SendableRecordBatchStream, Statistics, common}; +use crate::{DisplayAs, PlanProperties, SendableRecordBatchStream, Statistics}; use crate::{ DisplayFormatType, ExecutionPlan, Partitioning, execution_plan::{Boundedness, EmissionType}, @@ -29,7 +29,8 @@ use crate::{ use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use datafusion_common::{Result, assert_or_internal_err}; +use datafusion_common::stats::Precision; +use datafusion_common::{ColumnStatistics, Result, ScalarValue, assert_or_internal_err}; use datafusion_execution::TaskContext; use datafusion_physical_expr::EquivalenceProperties; @@ -43,7 +44,7 @@ pub struct EmptyExec { schema: SchemaRef, /// Number of partitions partitions: usize, - cache: PlanProperties, + cache: Arc, } impl EmptyExec { @@ -53,7 +54,7 @@ impl EmptyExec { EmptyExec { schema, partitions: 1, - cache, + cache: Arc::new(cache), } } @@ -62,7 +63,7 @@ impl EmptyExec { self.partitions = partitions; // Changing partitions may invalidate output partitioning, so update it: let output_partitioning = Self::output_partitioning_helper(self.partitions); - self.cache = self.cache.with_partitioning(output_partitioning); + Arc::make_mut(&mut self.cache).partitioning = output_partitioning; self } @@ -114,7 +115,7 @@ impl ExecutionPlan for EmptyExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -155,10 +156,6 @@ impl ExecutionPlan for EmptyExec { )?)) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { if let Some(partition) = partition { assert_or_internal_err!( @@ -169,20 +166,31 @@ impl ExecutionPlan for EmptyExec { ); } - let batch = self - .data() - .expect("Create empty RecordBatch should not fail"); - Ok(common::compute_record_batch_statistics( - &[batch], - &self.schema, - None, - )) + // Build explicit stats: exact zero rows and bytes, with explicit known column stats + let mut stats = Statistics::default() + .with_num_rows(Precision::Exact(0)) + .with_total_byte_size(Precision::Exact(0)); + + // Add explicit column stats for each field in schema + for _ in self.schema.fields() { + stats = stats.add_column_statistics(ColumnStatistics { + null_count: Precision::Exact(0), + distinct_count: Precision::Exact(0), + min_value: Precision::::Absent, + max_value: Precision::::Absent, + sum_value: Precision::::Absent, + byte_size: Precision::Exact(0), + }); + } + + Ok(stats) } } #[cfg(test)] mod tests { use super::*; + use crate::common; use crate::test; use crate::with_new_children_if_necessary; diff --git a/datafusion/physical-plan/src/execution_plan.rs b/datafusion/physical-plan/src/execution_plan.rs index 06da0b8933c1..d1d7b62b5389 100644 --- a/datafusion/physical-plan/src/execution_plan.rs +++ b/datafusion/physical-plan/src/execution_plan.rs @@ -25,7 +25,9 @@ pub use crate::ordering::InputOrderMode; use crate::sort_pushdown::SortOrderPushdownResult; pub use crate::stream::EmptyRecordBatchStream; +use arrow_schema::Schema; pub use datafusion_common::hash_utils; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; pub use datafusion_common::utils::project_schema; pub use datafusion_common::{ColumnStatistics, Statistics, internal_err}; pub use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; @@ -37,7 +39,7 @@ pub use datafusion_physical_expr::{ use std::any::Any; use std::fmt::Debug; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::display::DisplayableExecutionPlan; @@ -127,7 +129,7 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// /// This information is available via methods on [`ExecutionPlanProperties`] /// trait, which is implemented for all `ExecutionPlan`s. - fn properties(&self) -> &PlanProperties; + fn properties(&self) -> &Arc; /// Returns an error if this individual node does not conform to its invariants. /// These invariants are typically only checked in debug mode. @@ -471,17 +473,6 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { None } - /// Returns statistics for this `ExecutionPlan` node. If statistics are not - /// available, should return [`Statistics::new_unknown`] (the default), not - /// an error. - /// - /// For TableScan executors, which supports filter pushdown, special attention - /// needs to be paid to whether the stats returned by this method are exact or not - #[deprecated(since = "48.0.0", note = "Use `partition_statistics` method instead")] - fn statistics(&self) -> Result { - Ok(Statistics::new_unknown(&self.schema())) - } - /// Returns statistics for a specific partition of this `ExecutionPlan` node. /// If statistics are not available, should return [`Statistics::new_unknown`] /// (the default), not an error. @@ -576,6 +567,7 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { } /// Handle the result of a child pushdown. + /// /// This method is called as we recurse back up the plan tree after pushing /// filters down to child nodes via [`ExecutionPlan::gather_filters_for_pushdown`]. /// It allows the current node to process the results of filter pushdown from @@ -708,6 +700,19 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { ) -> Result>> { Ok(SortOrderPushdownResult::Unsupported) } + + /// Returns a variant of this `ExecutionPlan` that is aware of order-sensitivity. + /// + /// This is used to signal to data sources that the output ordering must be + /// preserved, even if it might be more efficient to ignore it (e.g. by + /// skipping some row groups in Parquet). + /// + fn with_preserve_order( + &self, + _preserve_order: bool, + ) -> Option> { + None + } } /// [`ExecutionPlan`] Invariant Level @@ -1046,12 +1051,17 @@ impl PlanProperties { self } - /// Overwrite equivalence properties with its new value. - pub fn with_eq_properties(mut self, eq_properties: EquivalenceProperties) -> Self { + /// Set equivalence properties having mut reference. + pub fn set_eq_properties(&mut self, eq_properties: EquivalenceProperties) { // Changing equivalence properties also changes output ordering, so // make sure to overwrite it: self.output_ordering = eq_properties.output_ordering(); self.eq_properties = eq_properties; + } + + /// Overwrite equivalence properties with its new value. + pub fn with_eq_properties(mut self, eq_properties: EquivalenceProperties) -> Self { + self.set_eq_properties(eq_properties); self } @@ -1083,9 +1093,14 @@ impl PlanProperties { self } + /// Set constraints having mut reference. + pub fn set_constraints(&mut self, constraints: Constraints) { + self.eq_properties.set_constraints(constraints); + } + /// Overwrite constraints with its new value. pub fn with_constraints(mut self, constraints: Constraints) -> Self { - self.eq_properties = self.eq_properties.with_constraints(constraints); + self.set_constraints(constraints); self } @@ -1384,6 +1399,68 @@ pub fn check_not_null_constraints( Ok(batch) } +/// Make plan ready to be re-executed returning its clone with state reset for all nodes. +/// +/// Some plans will change their internal states after execution, making them unable to be executed again. +/// This function uses [`ExecutionPlan::reset_state`] to reset any internal state within the plan. +/// +/// An example is `CrossJoinExec`, which loads the left table into memory and stores it in the plan. +/// However, if the data of the left table is derived from the work table, it will become outdated +/// as the work table changes. When the next iteration executes this plan again, we must clear the left table. +/// +/// # Limitations +/// +/// While this function enables plan reuse, it does not allow the same plan to be executed if it (OR): +/// +/// * uses dynamic filters, +/// * represents a recursive query. +/// +pub fn reset_plan_states(plan: Arc) -> Result> { + plan.transform_up(|plan| { + let new_plan = Arc::clone(&plan).reset_state()?; + Ok(Transformed::yes(new_plan)) + }) + .data() +} + +/// Check if the `plan` children has the same properties as passed `children`. +/// In this case plan can avoid self properties re-computation when its children +/// replace is requested. +/// The size of `children` must be equal to the size of `ExecutionPlan::children()`. +pub fn has_same_children_properties( + plan: &impl ExecutionPlan, + children: &[Arc], +) -> Result { + let old_children = plan.children(); + assert_eq_or_internal_err!( + children.len(), + old_children.len(), + "Wrong number of children" + ); + for (lhs, rhs) in old_children.iter().zip(children.iter()) { + if !Arc::ptr_eq(lhs.properties(), rhs.properties()) { + return Ok(false); + } + } + Ok(true) +} + +/// Helper macro to avoid properties re-computation if passed children properties +/// the same as plan already has. Could be used to implement fast-path for method +/// [`ExecutionPlan::with_new_children`]. +#[macro_export] +macro_rules! check_if_same_properties { + ($plan: expr, $children: expr) => { + if $crate::execution_plan::has_same_children_properties( + $plan.as_ref(), + &$children, + )? { + let plan = $plan.with_new_children_and_same_properties($children); + return Ok(::std::sync::Arc::new(plan)); + } + }; +} + /// Utility function yielding a string representation of the given [`ExecutionPlan`]. pub fn get_plan_string(plan: &Arc) -> Vec { let formatted = displayable(plan.as_ref()).indent(true).to_string(); @@ -1405,6 +1482,20 @@ pub enum CardinalityEffect { GreaterEqual, } +/// Can be used in contexts where properties have not yet been initialized properly. +pub(crate) fn stub_properties() -> Arc { + static STUB_PROPERTIES: LazyLock> = LazyLock::new(|| { + Arc::new(PlanProperties::new( + EquivalenceProperties::new(Arc::new(Schema::empty())), + Partitioning::UnknownPartitioning(1), + EmissionType::Final, + Boundedness::Bounded, + )) + }); + + Arc::clone(&STUB_PROPERTIES) +} + #[cfg(test)] mod tests { use std::any::Any; @@ -1446,7 +1537,7 @@ mod tests { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { unimplemented!() } @@ -1469,10 +1560,6 @@ mod tests { unimplemented!() } - fn statistics(&self) -> Result { - unimplemented!() - } - fn partition_statistics(&self, _partition: Option) -> Result { unimplemented!() } @@ -1513,7 +1600,7 @@ mod tests { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { unimplemented!() } @@ -1536,10 +1623,6 @@ mod tests { unimplemented!() } - fn statistics(&self) -> Result { - unimplemented!() - } - fn partition_statistics(&self, _partition: Option) -> Result { unimplemented!() } diff --git a/datafusion/physical-plan/src/explain.rs b/datafusion/physical-plan/src/explain.rs index aa3c0afefe8b..bf21b0484689 100644 --- a/datafusion/physical-plan/src/explain.rs +++ b/datafusion/physical-plan/src/explain.rs @@ -44,7 +44,7 @@ pub struct ExplainExec { stringified_plans: Vec, /// control which plans to print verbose: bool, - cache: PlanProperties, + cache: Arc, } impl ExplainExec { @@ -59,7 +59,7 @@ impl ExplainExec { schema, stringified_plans, verbose, - cache, + cache: Arc::new(cache), } } @@ -112,7 +112,7 @@ impl ExecutionPlan for ExplainExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 674fe6692adf..6bd779e3d502 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -20,19 +20,20 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll, ready}; +use datafusion_physical_expr::projection::{ProjectionRef, combine_projections}; use itertools::Itertools; use super::{ ColumnStatistics, DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, }; -use crate::coalesce::LimitedBatchCoalescer; -use crate::coalesce::PushBatchStatus::LimitReached; +use crate::check_if_same_properties; +use crate::coalesce::{LimitedBatchCoalescer, PushBatchStatus}; use crate::common::can_project; use crate::execution_plan::CardinalityEffect; use crate::filter_pushdown::{ ChildFilterDescription, ChildPushdownResult, FilterDescription, FilterPushdownPhase, - FilterPushdownPropagation, PushedDown, PushedDownPredicate, + FilterPushdownPropagation, PushedDown, }; use crate::metrics::{MetricBuilder, MetricType}; use crate::projection::{ @@ -56,12 +57,12 @@ use datafusion_common::{ use datafusion_execution::TaskContext; use datafusion_expr::Operator; use datafusion_physical_expr::equivalence::ProjectionMapping; -use datafusion_physical_expr::expressions::{BinaryExpr, Column, lit}; +use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal, lit}; use datafusion_physical_expr::intervals::utils::check_support; -use datafusion_physical_expr::utils::collect_columns; +use datafusion_physical_expr::utils::{collect_columns, reassign_expr_columns}; use datafusion_physical_expr::{ - AcrossPartitions, AnalysisContext, ConstExpr, ExprBoundaries, PhysicalExpr, analyze, - conjunction, split_conjunction, + AcrossPartitions, AnalysisContext, ConstExpr, EquivalenceProperties, ExprBoundaries, + PhysicalExpr, analyze, conjunction, split_conjunction, }; use datafusion_physical_expr_common::physical_expr::fmt_sql; @@ -84,48 +85,168 @@ pub struct FilterExec { /// Selectivity for statistics. 0 = no rows, 100 = all rows default_selectivity: u8, /// Properties equivalence properties, partitioning, etc. - cache: PlanProperties, + cache: Arc, /// The projection indices of the columns in the output schema of join - projection: Option>, + projection: Option, /// Target batch size for output batches batch_size: usize, /// Number of rows to fetch fetch: Option, } +/// Builder for [`FilterExec`] to set optional parameters +pub struct FilterExecBuilder { + predicate: Arc, + input: Arc, + projection: Option, + default_selectivity: u8, + batch_size: usize, + fetch: Option, +} + +impl FilterExecBuilder { + /// Create a new builder with required parameters (predicate and input) + pub fn new(predicate: Arc, input: Arc) -> Self { + Self { + predicate, + input, + projection: None, + default_selectivity: FILTER_EXEC_DEFAULT_SELECTIVITY, + batch_size: FILTER_EXEC_DEFAULT_BATCH_SIZE, + fetch: None, + } + } + + /// Set the input execution plan + pub fn with_input(mut self, input: Arc) -> Self { + self.input = input; + self + } + + /// Set the predicate expression + pub fn with_predicate(mut self, predicate: Arc) -> Self { + self.predicate = predicate; + self + } + + /// Set the projection, composing with any existing projection. + /// + /// If a projection is already set, the new projection indices are mapped + /// through the existing projection. For example, if the current projection + /// is `[0, 2, 3]` and `apply_projection(Some(vec![0, 2]))` is called, the + /// resulting projection will be `[0, 3]` (indices 0 and 2 of `[0, 2, 3]`). + /// + /// If no projection is currently set, the new projection is used directly. + /// If `None` is passed, the projection is cleared. + pub fn apply_projection(self, projection: Option>) -> Result { + let projection = projection.map(Into::into); + self.apply_projection_by_ref(projection.as_ref()) + } + + /// The same as [`Self::apply_projection`] but takes projection shared reference. + pub fn apply_projection_by_ref( + mut self, + projection: Option<&ProjectionRef>, + ) -> Result { + // Check if the projection is valid against current output schema + can_project(&self.input.schema(), projection.map(AsRef::as_ref))?; + self.projection = combine_projections(projection, self.projection.as_ref())?; + Ok(self) + } + + /// Set the default selectivity + pub fn with_default_selectivity(mut self, default_selectivity: u8) -> Self { + self.default_selectivity = default_selectivity; + self + } + + /// Set the batch size + pub fn with_batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = batch_size; + self + } + + /// Set the fetch limit + pub fn with_fetch(mut self, fetch: Option) -> Self { + self.fetch = fetch; + self + } + + /// Build the FilterExec, computing properties once with all configured parameters + pub fn build(self) -> Result { + // Validate predicate type + match self.predicate.data_type(self.input.schema().as_ref())? { + DataType::Boolean => {} + other => { + return plan_err!( + "Filter predicate must return BOOLEAN values, got {other:?}" + ); + } + } + + // Validate selectivity + if self.default_selectivity > 100 { + return plan_err!( + "Default filter selectivity value needs to be less than or equal to 100" + ); + } + + // Validate projection if provided + can_project(&self.input.schema(), self.projection.as_deref())?; + + // Compute properties once with all parameters + let cache = FilterExec::compute_properties( + &self.input, + &self.predicate, + self.default_selectivity, + self.projection.as_deref(), + )?; + + Ok(FilterExec { + predicate: self.predicate, + input: self.input, + metrics: ExecutionPlanMetricsSet::new(), + default_selectivity: self.default_selectivity, + cache: Arc::new(cache), + projection: self.projection, + batch_size: self.batch_size, + fetch: self.fetch, + }) + } +} + +impl From<&FilterExec> for FilterExecBuilder { + fn from(exec: &FilterExec) -> Self { + Self { + predicate: Arc::clone(&exec.predicate), + input: Arc::clone(&exec.input), + projection: exec.projection.clone(), + default_selectivity: exec.default_selectivity, + batch_size: exec.batch_size, + fetch: exec.fetch, + // We could cache / copy over PlanProperties + // here but that would require invalidating them in FilterExecBuilder::apply_projection, etc. + // and currently every call to this method ends up invalidating them anyway. + // If useful this can be added in the future as a non-breaking change. + } + } +} + impl FilterExec { - /// Create a FilterExec on an input - #[expect(clippy::needless_pass_by_value)] + /// Create a FilterExec on an input using the builder pattern pub fn try_new( predicate: Arc, input: Arc, ) -> Result { - match predicate.data_type(input.schema().as_ref())? { - DataType::Boolean => { - let default_selectivity = FILTER_EXEC_DEFAULT_SELECTIVITY; - let cache = Self::compute_properties( - &input, - &predicate, - default_selectivity, - None, - )?; - Ok(Self { - predicate, - input: Arc::clone(&input), - metrics: ExecutionPlanMetricsSet::new(), - default_selectivity, - cache, - projection: None, - batch_size: FILTER_EXEC_DEFAULT_BATCH_SIZE, - fetch: None, - }) - } - other => { - plan_err!("Filter predicate must return BOOLEAN values, got {other:?}") - } - } + FilterExecBuilder::new(predicate, input).build() + } + + /// Get a batch size + pub fn batch_size(&self) -> usize { + self.batch_size } + /// Set the default selectivity pub fn with_default_selectivity( mut self, default_selectivity: u8, @@ -140,43 +261,26 @@ impl FilterExec { } /// Return new instance of [FilterExec] with the given projection. + /// + /// # Deprecated + /// Use [`FilterExecBuilder::apply_projection`] instead + #[deprecated( + since = "52.0.0", + note = "Use FilterExecBuilder::apply_projection instead" + )] pub fn with_projection(&self, projection: Option>) -> Result { - // Check if the projection is valid - can_project(&self.schema(), projection.as_ref())?; - - let projection = match projection { - Some(projection) => match &self.projection { - Some(p) => Some(projection.iter().map(|i| p[*i]).collect()), - None => Some(projection), - }, - None => None, - }; - - let cache = Self::compute_properties( - &self.input, - &self.predicate, - self.default_selectivity, - projection.as_ref(), - )?; - Ok(Self { - predicate: Arc::clone(&self.predicate), - input: Arc::clone(&self.input), - metrics: self.metrics.clone(), - default_selectivity: self.default_selectivity, - cache, - projection, - batch_size: self.batch_size, - fetch: self.fetch, - }) + let builder = FilterExecBuilder::from(self); + builder.apply_projection(projection)?.build() } + /// Set the batch size pub fn with_batch_size(&self, batch_size: usize) -> Result { Ok(Self { predicate: Arc::clone(&self.predicate), input: Arc::clone(&self.input), metrics: self.metrics.clone(), default_selectivity: self.default_selectivity, - cache: self.cache.clone(), + cache: Arc::clone(&self.cache), projection: self.projection.clone(), batch_size, fetch: self.fetch, @@ -199,8 +303,8 @@ impl FilterExec { } /// Projection - pub fn projection(&self) -> Option<&Vec> { - self.projection.as_ref() + pub fn projection(&self) -> &Option { + &self.projection } /// Calculates `Statistics` for `FilterExec`, by applying selectivity (either default, or estimated) to input statistics. @@ -233,6 +337,7 @@ impl FilterExec { let total_byte_size = total_byte_size.with_estimated_selectivity(selectivity); let column_statistics = collect_new_statistics( + schema, &input_stats.column_statistics, analysis_ctx.boundaries, ); @@ -243,6 +348,20 @@ impl FilterExec { }) } + /// Returns the `AcrossPartitions` value for `expr` if it is constant: + /// either already known constant in `input_eqs`, or a `Literal` + /// (which is inherently constant across all partitions). + fn expr_constant_or_literal( + expr: &Arc, + input_eqs: &EquivalenceProperties, + ) -> Option { + input_eqs.is_expr_constant(expr).or_else(|| { + expr.as_any() + .downcast_ref::() + .map(|l| AcrossPartitions::Uniform(Some(l.value().clone()))) + }) + } + fn extend_constants( input: &Arc, predicate: &Arc, @@ -255,18 +374,24 @@ impl FilterExec { if let Some(binary) = conjunction.as_any().downcast_ref::() && binary.op() == &Operator::Eq { - // Filter evaluates to single value for all partitions - if input_eqs.is_expr_constant(binary.left()).is_some() { - let across = input_eqs - .is_expr_constant(binary.right()) - .unwrap_or_default(); + // Check if either side is constant — either already known + // constant from the input equivalence properties, or a literal + // value (which is inherently constant across all partitions). + let left_const = Self::expr_constant_or_literal(binary.left(), input_eqs); + let right_const = + Self::expr_constant_or_literal(binary.right(), input_eqs); + + if let Some(left_across) = left_const { + // LEFT is constant, so RIGHT must also be constant. + // Use RIGHT's known across value if available, otherwise + // propagate LEFT's (e.g. Uniform from a literal). + let across = right_const.unwrap_or(left_across); res_constants .push(ConstExpr::new(Arc::clone(binary.right()), across)); - } else if input_eqs.is_expr_constant(binary.right()).is_some() { - let across = input_eqs - .is_expr_constant(binary.left()) - .unwrap_or_default(); - res_constants.push(ConstExpr::new(Arc::clone(binary.left()), across)); + } else if let Some(right_across) = right_const { + // RIGHT is constant, so LEFT must also be constant. + res_constants + .push(ConstExpr::new(Arc::clone(binary.left()), right_across)); } } } @@ -277,7 +402,7 @@ impl FilterExec { input: &Arc, predicate: &Arc, default_selectivity: u8, - projection: Option<&Vec>, + projection: Option<&[usize]>, ) -> Result { // Combine the equal predicates with the input equivalence properties // to construct the equivalence properties: @@ -316,7 +441,7 @@ impl FilterExec { if let Some(projection) = projection { let schema = eq_properties.schema(); let projection_mapping = ProjectionMapping::from_indices(projection, schema)?; - let out_schema = project_schema(schema, Some(projection))?; + let out_schema = project_schema(schema, Some(&projection))?; output_partitioning = output_partitioning.project(&projection_mapping, &eq_properties); eq_properties = eq_properties.project(&projection_mapping, out_schema); @@ -329,6 +454,17 @@ impl FilterExec { input.boundedness(), )) } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for FilterExec { @@ -383,7 +519,7 @@ impl ExecutionPlan for FilterExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -400,13 +536,12 @@ impl ExecutionPlan for FilterExec { self: Arc, mut children: Vec>, ) -> Result> { - FilterExec::try_new(Arc::clone(&self.predicate), children.swap_remove(0)) - .and_then(|e| { - let selectivity = e.default_selectivity(); - e.with_default_selectivity(selectivity) - }) - .and_then(|e| e.with_projection(self.projection().cloned())) - .map(|e| e.with_fetch(self.fetch).unwrap()) + check_if_same_properties!(self, children); + let new_input = children.swap_remove(0); + FilterExecBuilder::from(&*self) + .with_input(new_input) + .build() + .map(|e| Arc::new(e) as _) } fn execute( @@ -441,15 +576,10 @@ impl ExecutionPlan for FilterExec { /// The output statistics of a filtering operation can be estimated if the /// predicate's selectivity value can be determined for the incoming data. - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { let input_stats = self.input.partition_statistics(partition)?; - let schema = self.schema(); let stats = Self::statistics_helper( - &schema, + &self.input.schema(), input_stats, self.predicate(), self.default_selectivity, @@ -473,15 +603,15 @@ impl ExecutionPlan for FilterExec { if let Some(new_predicate) = update_expr(self.predicate(), projection.expr(), false)? { - return FilterExec::try_new( - new_predicate, - make_with_child(projection, self.input())?, - ) - .and_then(|e| { - let selectivity = self.default_selectivity(); - e.with_default_selectivity(selectivity) - }) - .map(|e| Some(Arc::new(e) as _)); + return FilterExecBuilder::from(self) + .with_input(make_with_child(projection, self.input())?) + .with_predicate(new_predicate) + // The original FilterExec projection referenced columns from its old + // input. After the swap the new input is the ProjectionExec which + // already handles column selection, so clear the projection here. + .apply_projection(None)? + .build() + .map(|e| Some(Arc::new(e) as _)); } } try_embed_projection(projection, self) @@ -493,17 +623,10 @@ impl ExecutionPlan for FilterExec { parent_filters: Vec>, _config: &ConfigOptions, ) -> Result { - if !matches!(phase, FilterPushdownPhase::Pre) { - // For non-pre phase, filters pass through unchanged - let filter_supports = parent_filters - .into_iter() - .map(PushedDownPredicate::supported) - .collect(); - - return Ok(FilterDescription::new().with_child(ChildFilterDescription { - parent_filters: filter_supports, - self_filters: vec![], - })); + if phase != FilterPushdownPhase::Pre { + let child = + ChildFilterDescription::from_child(&parent_filters, self.input())?; + return Ok(FilterDescription::new().with_child(child)); } let child = ChildFilterDescription::from_child(&parent_filters, self.input())? @@ -523,14 +646,30 @@ impl ExecutionPlan for FilterExec { child_pushdown_result: ChildPushdownResult, _config: &ConfigOptions, ) -> Result>> { - if !matches!(phase, FilterPushdownPhase::Pre) { + if phase != FilterPushdownPhase::Pre { return Ok(FilterPushdownPropagation::if_all(child_pushdown_result)); } // We absorb any parent filters that were not handled by our children - let unsupported_parent_filters = - child_pushdown_result.parent_filters.iter().filter_map(|f| { - matches!(f.all(), PushedDown::No).then_some(Arc::clone(&f.filter)) - }); + let mut unsupported_parent_filters: Vec> = + child_pushdown_result + .parent_filters + .iter() + .filter_map(|f| { + matches!(f.all(), PushedDown::No).then_some(Arc::clone(&f.filter)) + }) + .collect(); + + // If this FilterExec has a projection, the unsupported parent filters + // are in the output schema (after projection) coordinates. We need to + // remap them to the input schema coordinates before combining with self filters. + if self.projection.is_some() { + let input_schema = self.input().schema(); + unsupported_parent_filters = unsupported_parent_filters + .into_iter() + .map(|expr| reassign_expr_columns(expr, &input_schema)) + .collect::>>()?; + } + let unsupported_self_filters = child_pushdown_result .self_filters .first() @@ -552,7 +691,7 @@ impl ExecutionPlan for FilterExec { let new_predicate = conjunction(unhandled_filters); let updated_node = if new_predicate.eq(&lit(true)) { // FilterExec is no longer needed, but we may need to leave a projection in place - match self.projection() { + match self.projection().as_ref() { Some(projection_indices) => { let filter_child_schema = filter_input.schema(); let proj_exprs = projection_indices @@ -578,19 +717,19 @@ impl ExecutionPlan for FilterExec { // The new predicate is the same as our current predicate None } else { - // Create a new FilterExec with the new predicate + // Create a new FilterExec with the new predicate, preserving the projection let new = FilterExec { predicate: Arc::clone(&new_predicate), input: Arc::clone(&filter_input), metrics: self.metrics.clone(), default_selectivity: self.default_selectivity, - cache: Self::compute_properties( + cache: Arc::new(Self::compute_properties( &filter_input, &new_predicate, self.default_selectivity, - self.projection.as_ref(), - )?, - projection: None, + self.projection.as_deref(), + )?), + projection: self.projection.clone(), batch_size: self.batch_size, fetch: self.fetch, }; @@ -603,23 +742,57 @@ impl ExecutionPlan for FilterExec { }) } + fn fetch(&self) -> Option { + self.fetch + } + fn with_fetch(&self, fetch: Option) -> Option> { Some(Arc::new(Self { predicate: Arc::clone(&self.predicate), input: Arc::clone(&self.input), metrics: self.metrics.clone(), default_selectivity: self.default_selectivity, - cache: self.cache.clone(), + cache: Arc::clone(&self.cache), projection: self.projection.clone(), batch_size: self.batch_size, fetch, })) } + + fn with_preserve_order( + &self, + preserve_order: bool, + ) -> Option> { + self.input + .with_preserve_order(preserve_order) + .and_then(|new_input| { + Arc::new(self.clone()) + .with_new_children(vec![new_input]) + .ok() + }) + } } impl EmbeddedProjection for FilterExec { fn with_projection(&self, projection: Option>) -> Result { - self.with_projection(projection) + FilterExecBuilder::from(self) + .apply_projection(projection)? + .build() + } +} + +/// Converts an interval bound to a [`Precision`] value. NULL bounds (which +/// represent "unbounded" in the interval type) map to [`Precision::Absent`]. +fn interval_bound_to_precision( + bound: ScalarValue, + is_exact: bool, +) -> Precision { + if bound.is_null() { + Precision::Absent + } else if is_exact { + Precision::Exact(bound) + } else { + Precision::Inexact(bound) } } @@ -628,6 +801,7 @@ impl EmbeddedProjection for FilterExec { /// is adjusted by using the next/previous value for its data type to convert /// it into a closed bound. fn collect_new_statistics( + schema: &SchemaRef, input_column_stats: &[ColumnStatistics], analysis_boundaries: Vec, ) -> Vec { @@ -644,22 +818,25 @@ fn collect_new_statistics( }, )| { let Some(interval) = interval else { - // If the interval is `None`, we can say that there are no rows: + // If the interval is `None`, we can say that there are no rows. + // Use a typed null to preserve the column's data type, so that + // downstream interval analysis can still intersect intervals + // of the same type. + let typed_null = ScalarValue::try_from(schema.field(idx).data_type()) + .unwrap_or(ScalarValue::Null); return ColumnStatistics { null_count: Precision::Exact(0), - max_value: Precision::Exact(ScalarValue::Null), - min_value: Precision::Exact(ScalarValue::Null), - sum_value: Precision::Exact(ScalarValue::Null), + max_value: Precision::Exact(typed_null.clone()), + min_value: Precision::Exact(typed_null.clone()), + sum_value: Precision::Exact(typed_null), distinct_count: Precision::Exact(0), byte_size: input_column_stats[idx].byte_size, }; }; let (lower, upper) = interval.into_bounds(); - let (min_value, max_value) = if lower.eq(&upper) { - (Precision::Exact(lower), Precision::Exact(upper)) - } else { - (Precision::Inexact(lower), Precision::Inexact(upper)) - }; + let is_exact = !lower.is_null() && !upper.is_null() && lower == upper; + let min_value = interval_bound_to_precision(lower, is_exact); + let max_value = interval_bound_to_precision(upper, is_exact); ColumnStatistics { null_count: input_column_stats[idx].null_count.to_inexact(), max_value, @@ -685,7 +862,7 @@ struct FilterExecStream { /// Runtime metrics recording metrics: FilterExecMetrics, /// The projection indices of the columns in the input schema - projection: Option>, + projection: Option, /// Batch coalescer to combine small batches batch_coalescer: LimitedBatchCoalescer, } @@ -711,23 +888,6 @@ impl FilterExecMetrics { } } -impl FilterExecStream { - fn flush_remaining_batches( - &mut self, - ) -> Poll>> { - // Flush any remaining buffered batch - match self.batch_coalescer.finish() { - Ok(()) => { - Poll::Ready(self.batch_coalescer.next_completed_batch().map(|batch| { - self.metrics.selectivity.add_part(batch.num_rows()); - Ok(batch) - })) - } - Err(e) => Poll::Ready(Some(Err(e))), - } - } -} - pub fn batch_filter( batch: &RecordBatch, predicate: &Arc, @@ -767,18 +927,34 @@ impl Stream for FilterExecStream { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - let poll; let elapsed_compute = self.metrics.baseline_metrics.elapsed_compute().clone(); loop { + // If there is a completed batch ready, return it + if let Some(batch) = self.batch_coalescer.next_completed_batch() { + self.metrics.selectivity.add_part(batch.num_rows()); + let poll = Poll::Ready(Some(Ok(batch))); + return self.metrics.baseline_metrics.record_poll(poll); + } + + if self.batch_coalescer.is_finished() { + // If input is done and no batches are ready, return None to signal end of stream. + return Poll::Ready(None); + } + + // Attempt to pull the next batch from the input stream. match ready!(self.input.poll_next_unpin(cx)) { + None => { + self.batch_coalescer.finish()?; + // continue draining the coalescer + } Some(Ok(batch)) => { let timer = elapsed_compute.timer(); let status = self.predicate.as_ref() .evaluate(&batch) .and_then(|v| v.into_array(batch.num_rows())) .and_then(|array| { - Ok(match self.projection { - Some(ref projection) => { + Ok(match self.projection.as_ref() { + Some(projection) => { let projected_batch = batch.project(projection)?; (array, projected_batch) }, @@ -802,37 +978,22 @@ impl Stream for FilterExecStream { })?; timer.done(); - if let LimitReached = status { - poll = self.flush_remaining_batches(); - break; - } - - if let Some(batch) = self.batch_coalescer.next_completed_batch() { - self.metrics.selectivity.add_part(batch.num_rows()); - poll = Poll::Ready(Some(Ok(batch))); - break; - } - continue; - } - None => { - // Flush any remaining buffered batch - match self.batch_coalescer.finish() { - Ok(()) => { - poll = self.flush_remaining_batches(); + match status { + PushBatchStatus::Continue => { + // Keep pushing more batches } - Err(e) => { - poll = Poll::Ready(Some(Err(e))); + PushBatchStatus::LimitReached => { + // limit was reached, so stop early + self.batch_coalescer.finish()?; + // continue draining the coalescer } } - break; - } - value => { - poll = Poll::Ready(value); - break; } + + // Error case + other => return Poll::Ready(other), } } - self.metrics.baseline_metrics.record_poll(poll) } fn size_hint(&self) -> (usize, Option) { @@ -866,6 +1027,19 @@ fn collect_columns_from_predicate_inner( let predicates = split_conjunction(predicate); predicates.into_iter().for_each(|p| { if let Some(binary) = p.as_any().downcast_ref::() { + // Only extract pairs where at least one side is a Column reference. + // Pairs like `complex_expr = literal` should not create equivalence + // classes — the literal could appear in many unrelated expressions + // (e.g. sort keys), and normalize_expr's deep traversal would + // replace those occurrences with the complex expression, corrupting + // sort orderings. Constant propagation for such pairs is handled + // separately by `extend_constants`. + let has_direct_column_operand = + binary.left().as_any().downcast_ref::().is_some() + || binary.right().as_any().downcast_ref::().is_some(); + if !has_direct_column_operand { + return; + } match binary.op() { Operator::Eq => { eq_predicate_columns.push((binary.left(), binary.right())) @@ -1358,17 +1532,17 @@ mod tests { statistics.column_statistics, vec![ ColumnStatistics { - min_value: Precision::Exact(ScalarValue::Null), - max_value: Precision::Exact(ScalarValue::Null), - sum_value: Precision::Exact(ScalarValue::Null), + min_value: Precision::Exact(ScalarValue::Int32(None)), + max_value: Precision::Exact(ScalarValue::Int32(None)), + sum_value: Precision::Exact(ScalarValue::Int32(None)), distinct_count: Precision::Exact(0), null_count: Precision::Exact(0), byte_size: Precision::Absent, }, ColumnStatistics { - min_value: Precision::Exact(ScalarValue::Null), - max_value: Precision::Exact(ScalarValue::Null), - sum_value: Precision::Exact(ScalarValue::Null), + min_value: Precision::Exact(ScalarValue::Int32(None)), + max_value: Precision::Exact(ScalarValue::Int32(None)), + sum_value: Precision::Exact(ScalarValue::Int32(None)), distinct_count: Precision::Exact(0), null_count: Precision::Exact(0), byte_size: Precision::Absent, @@ -1379,6 +1553,70 @@ mod tests { Ok(()) } + /// Regression test: stacking two FilterExecs where the inner filter + /// proves zero selectivity should not panic with a type mismatch + /// during interval intersection. + /// + /// Previously, when a filter proved no rows could match, the column + /// statistics used untyped `ScalarValue::Null` (data type `Null`). + /// If an outer FilterExec then tried to analyze its own predicate + /// against those statistics, `Interval::intersect` would fail with: + /// "Only intervals with the same data type are intersectable, lhs:Null, rhs:Int32" + #[tokio::test] + async fn test_nested_filter_with_zero_selectivity_inner() -> Result<()> { + // Inner table: a: [1, 100], b: [1, 3] + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(1000), + total_byte_size: Precision::Inexact(4000), + column_statistics: vec![ + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), + ..Default::default() + }, + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(3))), + ..Default::default() + }, + ], + }, + schema, + )); + + // Inner filter: a > 200 (impossible given a max=100 → zero selectivity) + let inner_predicate: Arc = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(200)))), + )); + let inner_filter: Arc = + Arc::new(FilterExec::try_new(inner_predicate, input)?); + + // Outer filter: a = 50 + // Before the fix, this would panic because the inner filter's + // zero-selectivity statistics produced Null-typed intervals for + // column `a`, which couldn't intersect with the Int32 literal. + let outer_predicate: Arc = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(50)))), + )); + let outer_filter: Arc = + Arc::new(FilterExec::try_new(outer_predicate, inner_filter)?); + + // Should succeed without error + let statistics = outer_filter.partition_statistics(None)?; + assert_eq!(statistics.num_rows, Precision::Inexact(0)); + + Ok(()) + } + #[tokio::test] async fn test_filter_statistics_more_inputs() -> Result<()> { let schema = Schema::new(vec![ @@ -1557,13 +1795,14 @@ mod tests { #[test] fn test_equivalence_properties_union_type() -> Result<()> { let union_type = DataType::Union( - UnionFields::new( + UnionFields::try_new( vec![0, 1], vec![ Field::new("f1", DataType::Int32, true), Field::new("f2", DataType::Utf8, true), ], - ), + ) + .unwrap(), UnionMode::Sparse, ); @@ -1586,4 +1825,512 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_builder_with_projection() -> Result<()> { + // Create a schema with multiple columns + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ])); + + let input = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + // Create a filter predicate: a > 10 + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )); + + // Create filter with projection [0, 2] (columns a and c) using builder + let projection = Some(vec![0, 2]); + let filter = FilterExecBuilder::new(predicate, input) + .apply_projection(projection.clone()) + .unwrap() + .build()?; + + // Verify projection is set correctly + assert_eq!(filter.projection(), &Some([0, 2].into())); + + // Verify schema contains only projected columns + let output_schema = filter.schema(); + assert_eq!(output_schema.fields().len(), 2); + assert_eq!(output_schema.field(0).name(), "a"); + assert_eq!(output_schema.field(1).name(), "c"); + + Ok(()) + } + + #[tokio::test] + async fn test_builder_without_projection() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let input = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )); + + // Create filter without projection using builder + let filter = FilterExecBuilder::new(predicate, input).build()?; + + // Verify no projection is set + assert!(filter.projection().is_none()); + + // Verify schema contains all columns + let output_schema = filter.schema(); + assert_eq!(output_schema.fields().len(), 2); + + Ok(()) + } + + #[tokio::test] + async fn test_builder_invalid_projection() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let input = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )); + + // Try to create filter with invalid projection (index out of bounds) using builder + let result = + FilterExecBuilder::new(predicate, input).apply_projection(Some(vec![0, 5])); // 5 is out of bounds + + // Should return an error + assert!(result.is_err()); + + Ok(()) + } + + #[tokio::test] + async fn test_builder_vs_with_projection() -> Result<()> { + // This test verifies that the builder with projection produces the same result + // as try_new().with_projection(), but more efficiently (one compute_properties call) + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + Field::new("d", DataType::Int32, false), + ]); + + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(1000), + total_byte_size: Precision::Inexact(4000), + column_statistics: vec![ + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), + ..Default::default() + }, + ColumnStatistics { + ..Default::default() + }, + ColumnStatistics { + ..Default::default() + }, + ColumnStatistics { + ..Default::default() + }, + ], + }, + schema, + )); + let input: Arc = input; + + let predicate: Arc = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Lt, + Arc::new(Literal::new(ScalarValue::Int32(Some(50)))), + )); + + let projection = Some(vec![0, 2]); + + // Method 1: Builder with projection (one call to compute_properties) + let filter1 = FilterExecBuilder::new(Arc::clone(&predicate), Arc::clone(&input)) + .apply_projection(projection.clone()) + .unwrap() + .build()?; + + // Method 2: Also using builder for comparison (deprecated try_new().with_projection() removed) + let filter2 = FilterExecBuilder::new(predicate, input) + .apply_projection(projection) + .unwrap() + .build()?; + + // Both methods should produce equivalent results + assert_eq!(filter1.schema(), filter2.schema()); + assert_eq!(filter1.projection(), filter2.projection()); + + // Verify statistics are the same + let stats1 = filter1.partition_statistics(None)?; + let stats2 = filter2.partition_statistics(None)?; + assert_eq!(stats1.num_rows, stats2.num_rows); + assert_eq!(stats1.total_byte_size, stats2.total_byte_size); + + Ok(()) + } + + #[tokio::test] + async fn test_builder_statistics_with_projection() -> Result<()> { + // Test that statistics are correctly computed when using builder with projection + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ]); + + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(1000), + total_byte_size: Precision::Inexact(12000), + column_statistics: vec![ + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), + ..Default::default() + }, + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(10))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(200))), + ..Default::default() + }, + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(5))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(50))), + ..Default::default() + }, + ], + }, + schema, + )); + + // Filter: a < 50, Project: [0, 2] + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Lt, + Arc::new(Literal::new(ScalarValue::Int32(Some(50)))), + )); + + let filter = FilterExecBuilder::new(predicate, input) + .apply_projection(Some(vec![0, 2])) + .unwrap() + .build()?; + + let statistics = filter.partition_statistics(None)?; + + // Verify statistics reflect both filtering and projection + assert!(matches!(statistics.num_rows, Precision::Inexact(_))); + + // Schema should only have 2 columns after projection + assert_eq!(filter.schema().fields().len(), 2); + + Ok(()) + } + + #[test] + fn test_builder_predicate_validation() -> Result<()> { + // Test that builder validates predicate type correctly + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let input = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + // Create a predicate that doesn't return boolean (returns Int32) + let invalid_predicate = Arc::new(Column::new("a", 0)); + + // Should fail because predicate doesn't return boolean + let result = FilterExecBuilder::new(invalid_predicate, input) + .apply_projection(Some(vec![0])) + .unwrap() + .build(); + + assert!(result.is_err()); + + Ok(()) + } + + #[tokio::test] + async fn test_builder_projection_composition() -> Result<()> { + // Test that calling apply_projection multiple times composes projections + // If initial projection is [0, 2, 3] and we call apply_projection([0, 2]), + // the result should be [0, 3] (indices 0 and 2 of [0, 2, 3]) + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + Field::new("d", DataType::Int32, false), + ])); + + let input = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + // Create a filter predicate: a > 10 + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )); + + // First projection: [0, 2, 3] -> select columns a, c, d + // Second projection: [0, 2] -> select indices 0 and 2 of [0, 2, 3] -> [0, 3] + // Final result: columns a and d + let filter = FilterExecBuilder::new(predicate, input) + .apply_projection(Some(vec![0, 2, 3]))? + .apply_projection(Some(vec![0, 2]))? + .build()?; + + // Verify composed projection is [0, 3] + assert_eq!(filter.projection(), &Some([0, 3].into())); + + // Verify schema contains only columns a and d + let output_schema = filter.schema(); + assert_eq!(output_schema.fields().len(), 2); + assert_eq!(output_schema.field(0).name(), "a"); + assert_eq!(output_schema.field(1).name(), "d"); + + Ok(()) + } + + #[tokio::test] + async fn test_builder_projection_composition_none_clears() -> Result<()> { + // Test that passing None clears the projection + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let input = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )); + + // Set a projection then clear it with None + let filter = FilterExecBuilder::new(predicate, input) + .apply_projection(Some(vec![0]))? + .apply_projection(None)? + .build()?; + + // Projection should be cleared + assert_eq!(filter.projection(), &None); + + // Schema should have all columns + let output_schema = filter.schema(); + assert_eq!(output_schema.fields().len(), 2); + + Ok(()) + } + + #[test] + fn test_filter_with_projection_remaps_post_phase_parent_filters() -> Result<()> { + // Test that FilterExec with a projection must remap parent dynamic + // filter column indices from its output schema to the input schema + // before passing them to the child. + let input_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ])); + let input = Arc::new(EmptyExec::new(Arc::clone(&input_schema))); + + // FilterExec: a > 0, projection=[c@2] + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(0)))), + )); + let filter = FilterExecBuilder::new(predicate, input) + .apply_projection(Some(vec![2]))? + .build()?; + + // Output schema should be [c:Float64] + let output_schema = filter.schema(); + assert_eq!(output_schema.fields().len(), 1); + assert_eq!(output_schema.field(0).name(), "c"); + + // Simulate a parent dynamic filter referencing output column c@0 + let parent_filter: Arc = Arc::new(Column::new("c", 0)); + + let config = ConfigOptions::new(); + let desc = filter.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![parent_filter], + &config, + )?; + + // The filter pushed to the child must reference c@2 (input schema), + // not c@0 (output schema). + let parent_filters = desc.parent_filters(); + assert_eq!(parent_filters.len(), 1); // one child + assert_eq!(parent_filters[0].len(), 1); // one filter + let remapped = &parent_filters[0][0].predicate; + let display = format!("{remapped}"); + assert_eq!( + display, "c@2", + "Post-phase parent filter column index must be remapped \ + from output schema (c@0) to input schema (c@2)" + ); + + Ok(()) + } + + /// Regression test for https://github.com/apache/datafusion/issues/20194 + /// + /// `collect_columns_from_predicate_inner` should only extract equality + /// pairs where at least one side is a Column. Pairs like + /// `complex_expr = literal` must not create equivalence classes because + /// `normalize_expr`'s deep traversal would replace the literal inside + /// unrelated expressions (e.g. sort keys) with the complex expression. + #[test] + fn test_collect_columns_skips_non_column_pairs() -> Result<()> { + let schema = test::aggr_test_schema(); + + // Simulate: nvl(c2, 0) = 0 → (c2 IS DISTINCT FROM 0) = 0 + // Neither side is a Column, so this should NOT be extracted. + let complex_expr: Arc = binary( + col("c2", &schema)?, + Operator::IsDistinctFrom, + lit(0u32), + &schema, + )?; + let predicate: Arc = + binary(complex_expr, Operator::Eq, lit(0u32), &schema)?; + + let (equal_pairs, _) = collect_columns_from_predicate_inner(&predicate); + assert_eq!( + 0, + equal_pairs.len(), + "Should not extract equality pairs where neither side is a Column" + ); + + // But col = literal should still be extracted + let predicate: Arc = + binary(col("c2", &schema)?, Operator::Eq, lit(0u32), &schema)?; + let (equal_pairs, _) = collect_columns_from_predicate_inner(&predicate); + assert_eq!( + 1, + equal_pairs.len(), + "Should extract equality pairs where one side is a Column" + ); + + Ok(()) + } + + /// Columns with Absent min/max statistics should remain Absent after + /// FilterExec. + #[tokio::test] + async fn test_filter_statistics_absent_columns_stay_absent() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(1000), + total_byte_size: Precision::Absent, + column_statistics: vec![ + ColumnStatistics::default(), + ColumnStatistics::default(), + ], + }, + schema.clone(), + )); + + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(42)))), + )); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, input)?); + + let statistics = filter.partition_statistics(None)?; + let col_b_stats = &statistics.column_statistics[1]; + assert_eq!(col_b_stats.min_value, Precision::Absent); + assert_eq!(col_b_stats.max_value, Precision::Absent); + + Ok(()) + } + + /// Regression test: ProjectionExec on top of a FilterExec that already has + /// an explicit projection must not panic when `try_swapping_with_projection` + /// attempts to swap the two nodes. + /// + /// Before the fix, `FilterExecBuilder::from(self)` copied the old projection + /// (e.g. `[0, 1, 2]`) from the FilterExec. After `.with_input` replaced the + /// input with the narrower ProjectionExec (2 columns), `.build()` tried to + /// validate the stale `[0, 1, 2]` projection against the 2-column schema and + /// panicked with "project index 2 out of bounds, max field 2". + #[test] + fn test_filter_with_projection_swap_does_not_panic() -> Result<()> { + use crate::projection::ProjectionExpr; + use datafusion_physical_expr::expressions::col; + + // Schema: [ts: Int64, tokens: Int64, svc: Utf8] + let schema = Arc::new(Schema::new(vec![ + Field::new("ts", DataType::Int64, false), + Field::new("tokens", DataType::Int64, false), + Field::new("svc", DataType::Utf8, false), + ])); + let input = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + // FilterExec: ts > 0, projection=[ts@0, tokens@1, svc@2] (all 3 cols) + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("ts", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int64(Some(0)))), + )); + let filter = Arc::new( + FilterExecBuilder::new(predicate, input) + .apply_projection(Some(vec![0, 1, 2]))? + .build()?, + ); + + // ProjectionExec: narrows to [ts, tokens] (drops svc) + let proj_exprs = vec![ + ProjectionExpr { + expr: col("ts", &filter.schema())?, + alias: "ts".to_string(), + }, + ProjectionExpr { + expr: col("tokens", &filter.schema())?, + alias: "tokens".to_string(), + }, + ]; + let projection = Arc::new(ProjectionExec::try_new( + proj_exprs, + Arc::clone(&filter) as _, + )?); + + // This must not panic + let result = filter.try_swapping_with_projection(&projection)?; + assert!(result.is_some(), "swap should succeed"); + + let new_plan = result.unwrap(); + // Output schema must still be [ts, tokens] + let out_schema = new_plan.schema(); + assert_eq!(out_schema.fields().len(), 2); + assert_eq!(out_schema.field(0).name(), "ts"); + assert_eq!(out_schema.field(1).name(), "tokens"); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/filter_pushdown.rs b/datafusion/physical-plan/src/filter_pushdown.rs index 1274e954eaeb..7e82b9e8239e 100644 --- a/datafusion/physical-plan/src/filter_pushdown.rs +++ b/datafusion/physical-plan/src/filter_pushdown.rs @@ -37,10 +37,13 @@ use std::collections::HashSet; use std::sync::Arc; -use datafusion_common::Result; -use datafusion_physical_expr::utils::{collect_columns, reassign_expr_columns}; +use arrow_schema::SchemaRef; +use datafusion_common::{ + Result, + tree_node::{Transformed, TreeNode}, +}; +use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use itertools::Itertools; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum FilterPushdownPhase { @@ -217,13 +220,13 @@ pub struct ChildPushdownResult { /// Returned from [`ExecutionPlan::handle_child_pushdown_result`] to communicate /// to the optimizer: /// -/// 1. What to do with any parent filters that were could not be pushed down into the children. +/// 1. What to do with any parent filters that could not be pushed down into the children. /// 2. If the node needs to be replaced in the execution plan with a new node or not. /// /// [`ExecutionPlan::handle_child_pushdown_result`]: crate::ExecutionPlan::handle_child_pushdown_result #[derive(Debug, Clone)] pub struct FilterPushdownPropagation { - /// What filters were pushed into the parent node. + /// Which parent filters were pushed down into this node's children. pub filters: Vec, /// The updated node, if it was updated during pushdown pub updated_node: Option, @@ -306,6 +309,83 @@ pub struct ChildFilterDescription { pub(crate) self_filters: Vec>, } +/// Validates and remaps filter column references to a target schema in one step. +/// +/// When pushing filters from a parent to a child node, we need to: +/// 1. Verify that all columns referenced by the filter exist in the target +/// 2. Remap column indices to match the target schema +/// +/// `allowed_indices` controls which column indices (in the parent schema) are +/// considered valid. For single-input nodes this defaults to +/// `0..child_schema.len()` (all columns are reachable). For join nodes it is +/// restricted to the subset of output columns that map to the target child, +/// which is critical when different sides have same-named columns. +pub(crate) struct FilterRemapper { + /// The target schema to remap column indices into. + child_schema: SchemaRef, + /// Only columns at these indices (in the *parent* schema) are considered + /// valid. For non-join nodes this defaults to `0..child_schema.len()`. + allowed_indices: HashSet, +} + +impl FilterRemapper { + /// Create a remapper that accepts any column whose index falls within + /// `0..child_schema.len()` and whose name exists in the target schema. + pub(crate) fn new(child_schema: SchemaRef) -> Self { + let allowed_indices = (0..child_schema.fields().len()).collect(); + Self { + child_schema, + allowed_indices, + } + } + + /// Create a remapper that only accepts columns at the given indices. + /// This is used by join nodes to restrict pushdown to one side of the + /// join when both sides have same-named columns. + fn with_allowed_indices( + child_schema: SchemaRef, + allowed_indices: HashSet, + ) -> Self { + Self { + child_schema, + allowed_indices, + } + } + + /// Try to remap a filter's column references to the target schema. + /// + /// Validates and remaps in a single tree traversal: for each column, + /// checks that its index is in the allowed set and that + /// its name exists in the target schema, then remaps the index. + /// Returns `Some(remapped)` if all columns are valid, or `None` if any + /// column fails validation. + pub(crate) fn try_remap( + &self, + filter: &Arc, + ) -> Result>> { + let mut all_valid = true; + let transformed = Arc::clone(filter).transform_down(|expr| { + if let Some(col) = expr.as_any().downcast_ref::() { + if self.allowed_indices.contains(&col.index()) + && let Ok(new_index) = self.child_schema.index_of(col.name()) + { + Ok(Transformed::yes(Arc::new(Column::new( + col.name(), + new_index, + )))) + } else { + all_valid = false; + Ok(Transformed::complete(expr)) + } + } else { + Ok(Transformed::no(expr)) + } + })?; + + Ok(all_valid.then_some(transformed.data)) + } +} + impl ChildFilterDescription { /// Build a child filter description by analyzing which parent filters can be pushed to a specific child. /// @@ -318,36 +398,41 @@ impl ChildFilterDescription { parent_filters: &[Arc], child: &Arc, ) -> Result { - let child_schema = child.schema(); + let remapper = FilterRemapper::new(child.schema()); + Self::remap_filters(parent_filters, &remapper) + } - // Get column names from child schema for quick lookup - let child_column_names: HashSet<&str> = child_schema - .fields() - .iter() - .map(|f| f.name().as_str()) - .collect(); + /// Like [`Self::from_child`], but restricts which parent-level columns are + /// considered reachable through this child. + /// + /// `allowed_indices` is the set of column indices (in the *parent* + /// schema) that map to this child's side of a join. A filter is only + /// eligible for pushdown when **every** column index it references + /// appears in `allowed_indices`. + /// + /// This prevents incorrect pushdown when different join sides have + /// columns with the same name: matching on index ensures a filter + /// referencing the right side's `k@2` is not pushed to the left side + /// which also has a column named `k` but at a different index. + pub fn from_child_with_allowed_indices( + parent_filters: &[Arc], + allowed_indices: HashSet, + child: &Arc, + ) -> Result { + let remapper = + FilterRemapper::with_allowed_indices(child.schema(), allowed_indices); + Self::remap_filters(parent_filters, &remapper) + } - // Analyze each parent filter + fn remap_filters( + parent_filters: &[Arc], + remapper: &FilterRemapper, + ) -> Result { let mut child_parent_filters = Vec::with_capacity(parent_filters.len()); - for filter in parent_filters { - // Check which columns the filter references - let referenced_columns = collect_columns(filter); - - // Check if all referenced columns exist in the child schema - let all_columns_exist = referenced_columns - .iter() - .all(|col| child_column_names.contains(col.name())); - - if all_columns_exist { - // All columns exist in child - we can push down - // Need to reassign column indices to match child schema - let reassigned_filter = - reassign_expr_columns(Arc::clone(filter), &child_schema)?; - child_parent_filters - .push(PushedDownPredicate::supported(reassigned_filter)); + if let Some(remapped) = remapper.try_remap(filter)? { + child_parent_filters.push(PushedDownPredicate::supported(remapped)); } else { - // Some columns don't exist in child - cannot push down child_parent_filters .push(PushedDownPredicate::unsupported(Arc::clone(filter))); } @@ -359,6 +444,17 @@ impl ChildFilterDescription { }) } + /// Mark all parent filters as unsupported for this child. + pub fn all_unsupported(parent_filters: &[Arc]) -> Self { + Self { + parent_filters: parent_filters + .iter() + .map(|f| PushedDownPredicate::unsupported(Arc::clone(f))) + .collect(), + self_filters: vec![], + } + } + /// Add a self filter (from the current node) to be pushed down to this child. pub fn with_self_filter(mut self, filter: Arc) -> Self { self.self_filters.push(filter); @@ -434,15 +530,9 @@ impl FilterDescription { children: &[&Arc], ) -> Self { let mut desc = Self::new(); - let child_filters = parent_filters - .iter() - .map(|f| PushedDownPredicate::unsupported(Arc::clone(f))) - .collect_vec(); for _ in 0..children.len() { - desc = desc.with_child(ChildFilterDescription { - parent_filters: child_filters.clone(), - self_filters: vec![], - }); + desc = + desc.with_child(ChildFilterDescription::all_unsupported(parent_filters)); } desc } diff --git a/datafusion/physical-plan/src/joins/array_map.rs b/datafusion/physical-plan/src/joins/array_map.rs new file mode 100644 index 000000000000..ad40d6776df4 --- /dev/null +++ b/datafusion/physical-plan/src/joins/array_map.rs @@ -0,0 +1,547 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::DataType; +use num_traits::AsPrimitive; +use std::mem::size_of; + +use crate::joins::MapOffset; +use crate::joins::chain::traverse_chain; +use arrow::array::{Array, ArrayRef, AsArray, BooleanArray}; +use arrow::buffer::BooleanBuffer; +use arrow::datatypes::ArrowNumericType; +use datafusion_common::{Result, ScalarValue, internal_err}; + +/// A macro to downcast only supported integer types (up to 64-bit) and invoke a generic function. +/// +/// Usage: `downcast_supported_integer!(data_type => (Method, arg1, arg2, ...))` +/// +/// The `Method` must be an associated method of [`ArrayMap`] that is generic over +/// `` and allow `T::Native: AsPrimitive`. +macro_rules! downcast_supported_integer { + ($DATA_TYPE:expr => ($METHOD:ident $(, $ARGS:expr)*)) => { + match $DATA_TYPE { + arrow::datatypes::DataType::Int8 => ArrayMap::$METHOD::($($ARGS),*), + arrow::datatypes::DataType::Int16 => ArrayMap::$METHOD::($($ARGS),*), + arrow::datatypes::DataType::Int32 => ArrayMap::$METHOD::($($ARGS),*), + arrow::datatypes::DataType::Int64 => ArrayMap::$METHOD::($($ARGS),*), + arrow::datatypes::DataType::UInt8 => ArrayMap::$METHOD::($($ARGS),*), + arrow::datatypes::DataType::UInt16 => ArrayMap::$METHOD::($($ARGS),*), + arrow::datatypes::DataType::UInt32 => ArrayMap::$METHOD::($($ARGS),*), + arrow::datatypes::DataType::UInt64 => ArrayMap::$METHOD::($($ARGS),*), + _ => { + return internal_err!( + "Unsupported type for ArrayMap: {:?}", + $DATA_TYPE + ); + } + } + }; +} + +/// A dense map for single-column integer join keys within a limited range. +/// +/// Maps join keys to build-side indices using direct array indexing: +/// `data[val - min_val_in_build_side] -> val_idx_in_build_side + 1`. +/// +/// NULL values are ignored on both the build side and the probe side. +/// +/// # Handling Negative Numbers with `wrapping_sub` +/// +/// This implementation supports signed integer ranges (e.g., `[-5, 5]`) efficiently by +/// treating them as `u64` (Two's Complement) and relying on the bitwise properties of +/// wrapping arithmetic (`wrapping_sub`). +/// +/// In Two's Complement representation, `a_signed - b_signed` produces the same bit pattern +/// as `a_unsigned.wrapping_sub(b_unsigned)` (modulo 2^N). This allows us to perform +/// range calculations and zero-based index mapping uniformly for both signed and unsigned +/// types without branching. +/// +/// ## Examples +/// +/// Consider an `Int64` range `[-5, 5]`. +/// * `min_val (-5)` casts to `u64`: `...11111011` (`u64::MAX - 4`) +/// * `max_val (5)` casts to `u64`: `...00000101` (`5`) +/// +/// **1. Range Calculation** +/// +/// ```text +/// In modular arithmetic, this is equivalent to: +/// (5 - (2^64 - 5)) mod 2^64 +/// = (5 - 2^64 + 5) mod 2^64 +/// = (10 - 2^64) mod 2^64 +/// = 10 +/// +/// ``` +/// The resulting `range` (10) correctly represents the size of the interval `[-5, 5]`. +/// +/// **2. Index Lookup (in `get_matched_indices`)** +/// +/// For a probe value of `0` (which is stored as `0u64`): +/// ```text +/// In modular arithmetic, this is equivalent to: +/// (0 - (2^64 - 5)) mod 2^64 +/// = (-2^64 + 5) mod 2^64 +/// = 5 +/// ``` +/// This correctly maps `-5` to index `0`, `0` to index `5`, etc. +#[derive(Debug)] +pub struct ArrayMap { + // data[probSideVal-offset] -> valIdxInBuildSide + 1; 0 for absent + data: Vec, + // min val in buildSide + offset: u64, + // next[buildSideIdx] -> next matching valIdxInBuildSide + 1; 0 for end of chain. + // If next is empty, it means there are no duplicate keys (no conflicts). + // It uses the same chain-based conflict resolution as [`JoinHashMapType`]. + next: Vec, + num_of_distinct_key: usize, +} + +impl ArrayMap { + pub fn is_supported_type(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + ) + } + + pub(crate) fn key_to_u64(v: &ScalarValue) -> Option { + match v { + ScalarValue::Int8(Some(v)) => Some(*v as u64), + ScalarValue::Int16(Some(v)) => Some(*v as u64), + ScalarValue::Int32(Some(v)) => Some(*v as u64), + ScalarValue::Int64(Some(v)) => Some(*v as u64), + ScalarValue::UInt8(Some(v)) => Some(*v as u64), + ScalarValue::UInt16(Some(v)) => Some(*v as u64), + ScalarValue::UInt32(Some(v)) => Some(*v as u64), + ScalarValue::UInt64(Some(v)) => Some(*v), + _ => None, + } + } + + /// Estimates the maximum memory usage for an `ArrayMap` with the given parameters. + /// + pub fn estimate_memory_size(min_val: u64, max_val: u64, num_rows: usize) -> usize { + let range = Self::calculate_range(min_val, max_val); + if range >= usize::MAX as u64 { + return usize::MAX; + } + let size = (range + 1) as usize; + size.saturating_mul(size_of::()) + .saturating_add(num_rows.saturating_mul(size_of::())) + } + + pub fn calculate_range(min_val: u64, max_val: u64) -> u64 { + max_val.wrapping_sub(min_val) + } + + /// Creates a new [`ArrayMap`] from the given array of join keys. + /// + /// Note: This function processes only the non-null values in the input `array`, + /// ignoring any rows where the key is `NULL`. + /// + pub(crate) fn try_new(array: &ArrayRef, min_val: u64, max_val: u64) -> Result { + let range = max_val.wrapping_sub(min_val); + if range >= usize::MAX as u64 { + return internal_err!("ArrayMap key range is too large to be allocated."); + } + let size = (range + 1) as usize; + + let mut data: Vec = vec![0; size]; + let mut next: Vec = vec![]; + let mut num_of_distinct_key = 0; + + downcast_supported_integer!( + array.data_type() => ( + fill_data, + array, + min_val, + &mut data, + &mut next, + &mut num_of_distinct_key + ) + )?; + + Ok(Self { + data, + offset: min_val, + next, + num_of_distinct_key, + }) + } + + fn fill_data( + array: &ArrayRef, + offset_val: u64, + data: &mut [u32], + next: &mut Vec, + num_of_distinct_key: &mut usize, + ) -> Result<()> + where + T::Native: AsPrimitive, + { + let arr = array.as_primitive::(); + // Iterate in reverse to maintain FIFO order when there are duplicate keys. + for (i, val) in arr.iter().enumerate().rev() { + if let Some(val) = val { + let key: u64 = val.as_(); + let idx = key.wrapping_sub(offset_val) as usize; + if idx >= data.len() { + return internal_err!("failed build Array idx >= data.len()"); + } + + if data[idx] != 0 { + if next.is_empty() { + *next = vec![0; array.len()] + } + next[i] = data[idx] + } else { + *num_of_distinct_key += 1; + } + data[idx] = (i) as u32 + 1; + } + } + Ok(()) + } + + pub fn num_of_distinct_key(&self) -> usize { + self.num_of_distinct_key + } + + /// Returns the memory usage of this [`ArrayMap`] in bytes. + pub fn size(&self) -> usize { + self.data.capacity() * size_of::() + self.next.capacity() * size_of::() + } + + pub fn get_matched_indices_with_limit_offset( + &self, + prob_side_keys: &[ArrayRef], + limit: usize, + current_offset: MapOffset, + probe_indices: &mut Vec, + build_indices: &mut Vec, + ) -> Result> { + if prob_side_keys.len() != 1 { + return internal_err!( + "ArrayMap expects 1 join key, but got {}", + prob_side_keys.len() + ); + } + let array = &prob_side_keys[0]; + + downcast_supported_integer!( + array.data_type() => ( + lookup_and_get_indices, + self, + array, + limit, + current_offset, + probe_indices, + build_indices + ) + ) + } + + fn lookup_and_get_indices( + &self, + array: &ArrayRef, + limit: usize, + current_offset: MapOffset, + probe_indices: &mut Vec, + build_indices: &mut Vec, + ) -> Result> + where + T::Native: Copy + AsPrimitive, + { + probe_indices.clear(); + build_indices.clear(); + + let arr = array.as_primitive::(); + + let have_null = arr.null_count() > 0; + + if self.next.is_empty() { + for prob_idx in current_offset.0..arr.len() { + if build_indices.len() == limit { + return Ok(Some((prob_idx, None))); + } + + // short circuit + if have_null && arr.is_null(prob_idx) { + continue; + } + // SAFETY: prob_idx is guaranteed to be within bounds by the loop range. + let prob_val: u64 = unsafe { arr.value_unchecked(prob_idx) }.as_(); + let idx_in_build_side = prob_val.wrapping_sub(self.offset) as usize; + + if idx_in_build_side >= self.data.len() + || self.data[idx_in_build_side] == 0 + { + continue; + } + build_indices.push((self.data[idx_in_build_side] - 1) as u64); + probe_indices.push(prob_idx as u32); + } + Ok(None) + } else { + let mut remaining_output = limit; + let to_skip = match current_offset { + // None `initial_next_idx` indicates that `initial_idx` processing hasn't been started + (idx, None) => idx, + // Zero `initial_next_idx` indicates that `initial_idx` has been processed during + // previous iteration, and it should be skipped + (idx, Some(0)) => idx + 1, + // Otherwise, process remaining `initial_idx` matches by traversing `next_chain`, + // to start with the next index + (idx, Some(next_idx)) => { + let is_last = idx == arr.len() - 1; + if let Some(next_offset) = traverse_chain( + &self.next, + idx, + next_idx as u32, + &mut remaining_output, + probe_indices, + build_indices, + is_last, + ) { + return Ok(Some(next_offset)); + } + idx + 1 + } + }; + + for prob_side_idx in to_skip..arr.len() { + if remaining_output == 0 { + return Ok(Some((prob_side_idx, None))); + } + + if arr.is_null(prob_side_idx) { + continue; + } + + let is_last = prob_side_idx == arr.len() - 1; + + // SAFETY: prob_idx is guaranteed to be within bounds by the loop range. + let prob_val: u64 = unsafe { arr.value_unchecked(prob_side_idx) }.as_(); + let idx_in_build_side = prob_val.wrapping_sub(self.offset) as usize; + if idx_in_build_side >= self.data.len() + || self.data[idx_in_build_side] == 0 + { + continue; + } + + let build_idx = self.data[idx_in_build_side]; + + if let Some(offset) = traverse_chain( + &self.next, + prob_side_idx, + build_idx, + &mut remaining_output, + probe_indices, + build_indices, + is_last, + ) { + return Ok(Some(offset)); + } + } + Ok(None) + } + } + + pub fn contain_keys(&self, probe_side_keys: &[ArrayRef]) -> Result { + if probe_side_keys.len() != 1 { + return internal_err!( + "ArrayMap join expects 1 join key, but got {}", + probe_side_keys.len() + ); + } + let array = &probe_side_keys[0]; + + downcast_supported_integer!( + array.data_type() => ( + contain_hashes_helper, + self, + array + ) + ) + } + + fn contain_hashes_helper( + &self, + array: &ArrayRef, + ) -> Result + where + T::Native: AsPrimitive, + { + let arr = array.as_primitive::(); + let buffer = BooleanBuffer::collect_bool(arr.len(), |i| { + if arr.is_null(i) { + return false; + } + // SAFETY: i is within bounds [0, arr.len()) + let key: u64 = unsafe { arr.value_unchecked(i) }.as_(); + let idx = key.wrapping_sub(self.offset) as usize; + idx < self.data.len() && self.data[idx] != 0 + }); + Ok(BooleanArray::new(buffer, None)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Int32Array; + use arrow::array::Int64Array; + use std::sync::Arc; + + #[test] + fn test_array_map_limit_offset_duplicate_elements() -> Result<()> { + let build: ArrayRef = Arc::new(Int32Array::from(vec![1, 1, 2])); + let map = ArrayMap::try_new(&build, 1, 2)?; + let probe = [Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef]; + + let mut prob_idx = Vec::new(); + let mut build_idx = Vec::new(); + let mut next = Some((0, None)); + let mut results = vec![]; + + while let Some(o) = next { + next = map.get_matched_indices_with_limit_offset( + &probe, + 1, + o, + &mut prob_idx, + &mut build_idx, + )?; + results.push((prob_idx.clone(), build_idx.clone(), next)); + } + + let expected = vec![ + (vec![0], vec![0], Some((0, Some(2)))), + (vec![0], vec![1], Some((0, Some(0)))), + (vec![1], vec![2], None), + ]; + assert_eq!(results, expected); + Ok(()) + } + + #[test] + fn test_array_map_with_limit_and_misses() -> Result<()> { + let build: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); + let map = ArrayMap::try_new(&build, 1, 2)?; + let probe = [Arc::new(Int32Array::from(vec![10, 1, 2])) as ArrayRef]; + + let (mut p_idx, mut b_idx) = (vec![], vec![]); + // Skip 10, find 1, next is 2 + let next = map.get_matched_indices_with_limit_offset( + &probe, + 1, + (0, None), + &mut p_idx, + &mut b_idx, + )?; + assert_eq!(p_idx, vec![1]); + assert_eq!(b_idx, vec![0]); + assert_eq!(next, Some((2, None))); + + // Find 2, end + let next = map.get_matched_indices_with_limit_offset( + &probe, + 1, + next.unwrap(), + &mut p_idx, + &mut b_idx, + )?; + assert_eq!(p_idx, vec![2]); + assert_eq!(b_idx, vec![1]); + assert!(next.is_none()); + Ok(()) + } + + #[test] + fn test_array_map_with_build_duplicates_and_misses() -> Result<()> { + let build_array: ArrayRef = Arc::new(Int32Array::from(vec![1, 1])); + let array_map = ArrayMap::try_new(&build_array, 1, 1)?; + // prob: 10(m), 1(h1, h2), 20(m), 1(h1, h2) + let probe_array: ArrayRef = Arc::new(Int32Array::from(vec![10, 1, 20, 1])); + let prob_side_keys = [probe_array]; + + let mut prob_indices = Vec::new(); + let mut build_indices = Vec::new(); + + // batch_size=3, should get 2 matches from first '1' and 1 match from second '1' + let result_offset = array_map.get_matched_indices_with_limit_offset( + &prob_side_keys, + 3, + (0, None), + &mut prob_indices, + &mut build_indices, + )?; + + assert_eq!(prob_indices, vec![1, 1, 3]); + assert_eq!(build_indices, vec![0, 1, 0]); + assert_eq!(result_offset, Some((3, Some(2)))); + Ok(()) + } + + #[test] + fn test_array_map_i64_with_negative_and_positive_numbers() -> Result<()> { + // Build array with a mix of negative and positive i64 values, no duplicates + let build_array: ArrayRef = Arc::new(Int64Array::from(vec![-5, 0, 5, -2, 3, 10])); + let min_val = -5_i128; + let max_val = 10_i128; + + let array_map = ArrayMap::try_new(&build_array, min_val as u64, max_val as u64)?; + + // Probe array + let probe_array: ArrayRef = Arc::new(Int64Array::from(vec![0, -5, 10, -1])); + let prob_side_keys = [Arc::clone(&probe_array)]; + + let mut prob_indices = Vec::new(); + let mut build_indices = Vec::new(); + + // Call once to get all matches + let result_offset = array_map.get_matched_indices_with_limit_offset( + &prob_side_keys, + 10, // A batch size larger than number of probes + (0, None), + &mut prob_indices, + &mut build_indices, + )?; + + // Expected matches, in probe-side order: + // Probe 0 (value 0) -> Build 1 (value 0) + // Probe 1 (value -5) -> Build 0 (value -5) + // Probe 2 (value 10) -> Build 5 (value 10) + let expected_prob_indices = vec![0, 1, 2]; + let expected_build_indices = vec![1, 0, 5]; + + assert_eq!(prob_indices, expected_prob_indices); + assert_eq!(build_indices, expected_build_indices); + assert!(result_offset.is_none()); + + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/joins/chain.rs b/datafusion/physical-plan/src/joins/chain.rs new file mode 100644 index 000000000000..846b7505d647 --- /dev/null +++ b/datafusion/physical-plan/src/joins/chain.rs @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fmt::Debug; +use std::ops::Sub; + +use arrow::datatypes::ArrowNativeType; + +use crate::joins::MapOffset; + +/// Traverses the chain of matching indices, collecting results up to the remaining limit. +/// Returns `Some(offset)` if the limit was reached and there are more results to process, +/// or `None` if the chain was fully traversed. +#[inline(always)] +pub(crate) fn traverse_chain( + next_chain: &[T], + prob_idx: usize, + start_chain_idx: T, + remaining: &mut usize, + input_indices: &mut Vec, + match_indices: &mut Vec, + is_last_input: bool, +) -> Option +where + T: Copy + TryFrom + PartialOrd + Into + Sub, + >::Error: Debug, + T: ArrowNativeType, +{ + let zero = T::usize_as(0); + let one = T::usize_as(1); + let mut match_row_idx = start_chain_idx - one; + + loop { + match_indices.push(match_row_idx.into()); + input_indices.push(prob_idx as u32); + *remaining -= 1; + + let next = next_chain[match_row_idx.into() as usize]; + + if *remaining == 0 { + // Limit reached - return offset for next call + return if is_last_input && next == zero { + // Finished processing the last input row + None + } else { + Some((prob_idx, Some(next.into()))) + }; + } + if next == zero { + // End of chain + return None; + } + match_row_idx = next - one; + } +} diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 4f32b6176ec3..342cb7e70a78 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -34,7 +34,7 @@ use crate::projection::{ use crate::{ ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, PlanProperties, RecordBatchStream, - SendableRecordBatchStream, Statistics, handle_state, + SendableRecordBatchStream, Statistics, check_if_same_properties, handle_state, }; use arrow::array::{RecordBatch, RecordBatchOptions}; @@ -94,7 +94,7 @@ pub struct CrossJoinExec { /// Execution plan metrics metrics: ExecutionPlanMetricsSet, /// Properties such as schema, equivalence properties, ordering, partitioning, etc. - cache: PlanProperties, + cache: Arc, } impl CrossJoinExec { @@ -125,7 +125,7 @@ impl CrossJoinExec { schema, left_fut: Default::default(), metrics: ExecutionPlanMetricsSet::default(), - cache, + cache: Arc::new(cache), } } @@ -192,6 +192,23 @@ impl CrossJoinExec { &self.right.schema(), ) } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + let left = children.swap_remove(0); + let right = children.swap_remove(0); + + Self { + left, + right, + metrics: ExecutionPlanMetricsSet::new(), + left_fut: Default::default(), + cache: Arc::clone(&self.cache), + schema: Arc::clone(&self.schema), + } + } } /// Asynchronously collect the result of the left child @@ -206,7 +223,7 @@ async fn load_left_input( let (batches, _metrics, reservation) = stream .try_fold( (Vec::new(), metrics, reservation), - |(mut batches, metrics, mut reservation), batch| async { + |(mut batches, metrics, reservation), batch| async { let batch_size = batch.get_array_memory_size(); // Reserve memory for incoming batch reservation.try_grow(batch_size)?; @@ -256,7 +273,7 @@ impl ExecutionPlan for CrossJoinExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -272,6 +289,7 @@ impl ExecutionPlan for CrossJoinExec { self: Arc, children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); Ok(Arc::new(CrossJoinExec::new( Arc::clone(&children[0]), Arc::clone(&children[1]), @@ -285,7 +303,7 @@ impl ExecutionPlan for CrossJoinExec { schema: Arc::clone(&self.schema), left_fut: Default::default(), // reset the build side! metrics: ExecutionPlanMetricsSet::default(), - cache: self.cache.clone(), + cache: Arc::clone(&self.cache), }; Ok(Arc::new(new_exec)) } @@ -356,10 +374,6 @@ impl ExecutionPlan for CrossJoinExec { } } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { // Get the all partitions statistics of the left let left_stats = self.left.partition_statistics(None)?; diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index 91fc1ee4436e..25b320f98550 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -15,18 +15,24 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashSet; use std::fmt; use std::mem::size_of; -use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::{Arc, OnceLock}; use std::{any::Any, vec}; use crate::ExecutionPlanProperties; -use crate::execution_plan::{EmissionType, boundedness_from_children}; +use crate::execution_plan::{ + EmissionType, boundedness_from_children, has_same_children_properties, + stub_properties, +}; use crate::filter_pushdown::{ - ChildPushdownResult, FilterDescription, FilterPushdownPhase, + ChildFilterDescription, ChildPushdownResult, FilterDescription, FilterPushdownPhase, FilterPushdownPropagation, }; +use crate::joins::Map; +use crate::joins::array_map::ArrayMap; use crate::joins::hash_join::inlist_builder::build_struct_inlist_values; use crate::joins::hash_join::shared_bounds::{ ColumnBounds, PartitionBounds, PushdownStrategy, SharedBuildAccumulator, @@ -40,6 +46,7 @@ use crate::joins::utils::{ swap_join_projection, update_hash, }; use crate::joins::{JoinOn, JoinOnRef, PartitionMode, SharedBitmapBuilder}; +use crate::metrics::{Count, MetricBuilder}; use crate::projection::{ EmbeddedProjection, JoinData, ProjectionExec, try_embed_projection, try_pushdown_through_join, @@ -63,12 +70,12 @@ use arrow::compute::concat_batches; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use arrow::util::bit_util; -use arrow_schema::DataType; +use arrow_schema::{DataType, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::{ - JoinSide, JoinType, NullEquality, Result, assert_or_internal_err, plan_err, - project_schema, + JoinSide, JoinType, NullEquality, Result, assert_or_internal_err, internal_err, + plan_err, project_schema, }; use datafusion_execution::TaskContext; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; @@ -77,7 +84,8 @@ use datafusion_functions_aggregate_common::min_max::{MaxAccumulator, MinAccumula use datafusion_physical_expr::equivalence::{ ProjectionMapping, join_equivalence_properties, }; -use datafusion_physical_expr::expressions::{DynamicFilterPhysicalExpr, lit}; +use datafusion_physical_expr::expressions::{Column, DynamicFilterPhysicalExpr, lit}; +use datafusion_physical_expr::projection::{ProjectionRef, combine_projections}; use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; use ahash::RandomState; @@ -92,11 +100,96 @@ use super::partitioned_hash_eval::SeededRandomState; pub(crate) const HASH_JOIN_SEED: SeededRandomState = SeededRandomState::with_seeds('J' as u64, 'O' as u64, 'I' as u64, 'N' as u64); +const ARRAY_MAP_CREATED_COUNT_METRIC_NAME: &str = "array_map_created_count"; + +#[expect(clippy::too_many_arguments)] +fn try_create_array_map( + bounds: &Option, + schema: &SchemaRef, + batches: &[RecordBatch], + on_left: &[PhysicalExprRef], + reservation: &mut MemoryReservation, + perfect_hash_join_small_build_threshold: usize, + perfect_hash_join_min_key_density: f64, + null_equality: NullEquality, +) -> Result)>> { + if on_left.len() != 1 { + return Ok(None); + } + + if null_equality == NullEquality::NullEqualsNull { + for batch in batches.iter() { + let arrays = evaluate_expressions_to_arrays(on_left, batch)?; + if arrays[0].null_count() > 0 { + return Ok(None); + } + } + } + + let (min_val, max_val) = if let Some(bounds) = bounds { + let (min_val, max_val) = if let Some(cb) = bounds.get_column_bounds(0) { + (cb.min.clone(), cb.max.clone()) + } else { + return Ok(None); + }; + + if min_val.is_null() || max_val.is_null() { + return Ok(None); + } + + if min_val > max_val { + return internal_err!("min_val>max_val"); + } + + if let Some((mi, ma)) = + ArrayMap::key_to_u64(&min_val).zip(ArrayMap::key_to_u64(&max_val)) + { + (mi, ma) + } else { + return Ok(None); + } + } else { + return Ok(None); + }; + + let range = ArrayMap::calculate_range(min_val, max_val); + let num_row: usize = batches.iter().map(|x| x.num_rows()).sum(); + + // TODO: support create ArrayMap + if num_row >= u32::MAX as usize { + return Ok(None); + } + + // When the key range spans the full integer domain (e.g. i64::MIN to i64::MAX), + // range is u64::MAX and `range + 1` below would overflow. + if range == usize::MAX as u64 { + return Ok(None); + } + + let dense_ratio = (num_row as f64) / ((range + 1) as f64); + + if range >= perfect_hash_join_small_build_threshold as u64 + && dense_ratio <= perfect_hash_join_min_key_density + { + return Ok(None); + } + + let mem_size = ArrayMap::estimate_memory_size(min_val, max_val, num_row); + reservation.try_grow(mem_size)?; + + let batch = concat_batches(schema, batches)?; + let left_values = evaluate_expressions_to_arrays(on_left, &batch)?; + + let array_map = ArrayMap::try_new(&left_values[0], min_val, max_val)?; + + Ok(Some((array_map, batch, left_values))) +} + /// HashTable and input data for the left (build side) of a join pub(super) struct JoinLeftData { /// The hash table with indices into `batch` /// Arc is used to allow sharing with SharedBuildAccumulator for hash map pushdown - pub(super) hash_map: Arc, + pub(super) map: Arc, /// The input rows for the build side batch: RecordBatch, /// The build side on expressions values @@ -118,12 +211,17 @@ pub(super) struct JoinLeftData { /// Membership testing strategy for filter pushdown /// Contains either InList values for small build sides or hash table reference for large build sides pub(super) membership: PushdownStrategy, + /// Shared atomic flag indicating if any probe partition saw data (for null-aware anti joins) + /// This is shared across all probe partitions to provide global knowledge + pub(super) probe_side_non_empty: AtomicBool, + /// Shared atomic flag indicating if any probe partition saw NULL in join keys (for null-aware anti joins) + pub(super) probe_side_has_null: AtomicBool, } impl JoinLeftData { - /// return a reference to the hash map - pub(super) fn hash_map(&self) -> &dyn JoinHashMapType { - &*self.hash_map + /// return a reference to the map + pub(super) fn map(&self) -> &Map { + &self.map } /// returns a reference to the build side batch @@ -153,6 +251,277 @@ impl JoinLeftData { } } +/// Helps to build [`HashJoinExec`]. +/// +/// Builder can be created from an existing [`HashJoinExec`] using [`From::from`]. +/// In this case, all its fields are inherited. If a field that affects the node's +/// properties is modified, they will be automatically recomputed during the build. +/// +/// # Adding setters +/// +/// When adding a new setter, it is necessary to ensure that the `preserve_properties` +/// flag is set to false if modifying the field requires a recomputation of the plan's +/// properties. +/// +pub struct HashJoinExecBuilder { + exec: HashJoinExec, + preserve_properties: bool, +} + +impl HashJoinExecBuilder { + /// Make a new [`HashJoinExecBuilder`]. + pub fn new( + left: Arc, + right: Arc, + on: Vec<(PhysicalExprRef, PhysicalExprRef)>, + join_type: JoinType, + ) -> Self { + Self { + exec: HashJoinExec { + left, + right, + on, + filter: None, + join_type, + left_fut: Default::default(), + random_state: HASH_JOIN_SEED, + mode: PartitionMode::Auto, + fetch: None, + metrics: ExecutionPlanMetricsSet::new(), + projection: None, + column_indices: vec![], + null_equality: NullEquality::NullEqualsNothing, + null_aware: false, + dynamic_filter: None, + // Will be computed at when plan will be built. + cache: stub_properties(), + join_schema: Arc::new(Schema::empty()), + }, + // As `exec` is initialized with stub properties, + // they will be properly computed when plan will be built. + preserve_properties: false, + } + } + + /// Set join type. + pub fn with_type(mut self, join_type: JoinType) -> Self { + self.exec.join_type = join_type; + self.preserve_properties = false; + self + } + + /// Set projection from the vector. + pub fn with_projection(self, projection: Option>) -> Self { + self.with_projection_ref(projection.map(Into::into)) + } + + /// Set projection from the shared reference. + pub fn with_projection_ref(mut self, projection: Option) -> Self { + self.exec.projection = projection; + self.preserve_properties = false; + self + } + + /// Set optional filter. + pub fn with_filter(mut self, filter: Option) -> Self { + self.exec.filter = filter; + self + } + + /// Set expressions to join on. + pub fn with_on(mut self, on: Vec<(PhysicalExprRef, PhysicalExprRef)>) -> Self { + self.exec.on = on; + self.preserve_properties = false; + self + } + + /// Set partition mode. + pub fn with_partition_mode(mut self, mode: PartitionMode) -> Self { + self.exec.mode = mode; + self.preserve_properties = false; + self + } + + /// Set null equality property. + pub fn with_null_equality(mut self, null_equality: NullEquality) -> Self { + self.exec.null_equality = null_equality; + self + } + + /// Set null aware property. + pub fn with_null_aware(mut self, null_aware: bool) -> Self { + self.exec.null_aware = null_aware; + self + } + + /// Set fetch property. + pub fn with_fetch(mut self, fetch: Option) -> Self { + self.exec.fetch = fetch; + self + } + + /// Require to recompute plan properties. + pub fn recompute_properties(mut self) -> Self { + self.preserve_properties = false; + self + } + + /// Replace children. + pub fn with_new_children( + mut self, + mut children: Vec>, + ) -> Result { + assert_or_internal_err!( + children.len() == 2, + "wrong number of children passed into `HashJoinExecBuilder`" + ); + self.preserve_properties &= has_same_children_properties(&self.exec, &children)?; + self.exec.right = children.swap_remove(1); + self.exec.left = children.swap_remove(0); + Ok(self) + } + + /// Reset runtime state. + pub fn reset_state(mut self) -> Self { + self.exec.left_fut = Default::default(); + self.exec.dynamic_filter = None; + self.exec.metrics = ExecutionPlanMetricsSet::new(); + self + } + + /// Build result as a dyn execution plan. + pub fn build_exec(self) -> Result> { + self.build().map(|p| Arc::new(p) as _) + } + + /// Build resulting execution plan. + pub fn build(self) -> Result { + let Self { + exec, + preserve_properties, + } = self; + + // Validate null_aware flag + if exec.null_aware { + let join_type = exec.join_type(); + if !matches!(join_type, JoinType::LeftAnti) { + return plan_err!( + "null_aware can only be true for LeftAnti joins, got {join_type}" + ); + } + let on = exec.on(); + if on.len() != 1 { + return plan_err!( + "null_aware anti join only supports single column join key, got {} columns", + on.len() + ); + } + } + + if preserve_properties { + return Ok(exec); + } + + let HashJoinExec { + left, + right, + on, + filter, + join_type, + left_fut, + random_state, + mode, + metrics, + projection, + null_equality, + null_aware, + dynamic_filter, + fetch, + // Recomputed. + join_schema: _, + column_indices: _, + cache: _, + } = exec; + + let left_schema = left.schema(); + let right_schema = right.schema(); + if on.is_empty() { + return plan_err!("On constraints in HashJoinExec should be non-empty"); + } + + check_join_is_valid(&left_schema, &right_schema, &on)?; + let (join_schema, column_indices) = + build_join_schema(&left_schema, &right_schema, &join_type); + + let join_schema = Arc::new(join_schema); + + // Check if the projection is valid. + can_project(&join_schema, projection.as_deref())?; + + let cache = HashJoinExec::compute_properties( + &left, + &right, + &join_schema, + join_type, + &on, + mode, + projection.as_deref(), + )?; + + Ok(HashJoinExec { + left, + right, + on, + filter, + join_type, + join_schema, + left_fut, + random_state, + mode, + metrics, + projection, + column_indices, + null_equality, + null_aware, + cache: Arc::new(cache), + dynamic_filter, + fetch, + }) + } + + fn with_dynamic_filter(mut self, filter: Option) -> Self { + self.exec.dynamic_filter = filter; + self + } +} + +impl From<&HashJoinExec> for HashJoinExecBuilder { + fn from(exec: &HashJoinExec) -> Self { + Self { + exec: HashJoinExec { + left: Arc::clone(exec.left()), + right: Arc::clone(exec.right()), + on: exec.on.clone(), + filter: exec.filter.clone(), + join_type: exec.join_type, + join_schema: Arc::clone(&exec.join_schema), + left_fut: Arc::clone(&exec.left_fut), + random_state: exec.random_state.clone(), + mode: exec.mode, + metrics: exec.metrics.clone(), + projection: exec.projection.clone(), + column_indices: exec.column_indices.clone(), + null_equality: exec.null_equality, + null_aware: exec.null_aware, + cache: Arc::clone(&exec.cache), + dynamic_filter: exec.dynamic_filter.clone(), + fetch: exec.fetch, + }, + preserve_properties: true, + } + } +} + #[expect(rustdoc::private_intra_doc_links)] /// Join execution plan: Evaluates equijoin predicates in parallel on multiple /// partitions using a hash table and an optional filter list to apply post @@ -168,6 +537,36 @@ impl JoinLeftData { /// ` != `) are known as "filter expressions" and are evaluated /// after the equijoin predicates. /// +/// # ArrayMap Optimization +/// +/// For joins with a single integer-based join key, `HashJoinExec` may use an [`ArrayMap`] +/// (also known as a "perfect hash join") instead of a general-purpose hash map. +/// This optimization is used when: +/// 1. There is exactly one join key. +/// 2. The join key is an integer type up to 64 bits wide that can be losslessly converted +/// to `u64` (128-bit integer types such as `i128` and `u128` are not supported). +/// 3. The range of keys is small enough (controlled by `perfect_hash_join_small_build_threshold`) +/// OR the keys are sufficiently dense (controlled by `perfect_hash_join_min_key_density`). +/// 4. build_side.num_rows() < u32::MAX +/// 5. NullEqualsNothing || (NullEqualsNull && build side doesn't contain null) +/// +/// See [`try_create_array_map`] for more details. +/// +/// Note that when using [`PartitionMode::Partitioned`], the build side is split into multiple +/// partitions. This can cause a dense build side to become sparse within each partition, +/// potentially disabling this optimization. +/// +/// For example, consider: +/// ```sql +/// SELECT t1.value, t2.value +/// FROM range(10000) AS t1 +/// JOIN range(10000) AS t2 +/// ON t1.value = t2.value; +/// ``` +/// With 24 partitions, each partition will only receive a subset of the 10,000 rows. +/// The first partition might contain values like `3, 10, 18, 39, 43`, which are sparse +/// relative to the original range, even though the overall data set is dense. +/// /// # "Build Side" vs "Probe Side" /// /// HashJoin takes two inputs, which are referred to as the "build" and the @@ -201,9 +600,9 @@ impl JoinLeftData { /// Resulting hash table stores hashed join-key fields for each row as a key, and /// indices of corresponding rows in concatenated batch. /// -/// Hash join uses LIFO data structure as a hash table, and in order to retain -/// original build-side input order while obtaining data during probe phase, hash -/// table is updated by iterating batch sequence in reverse order -- it allows to +/// When using the standard `JoinHashMap`, hash join uses LIFO data structure as a hash table, +/// and in order to retain original build-side input order while obtaining data during probe phase, +/// hash table is updated by iterating batch sequence in reverse order -- it allows to /// keep rows with smaller indices "on the top" of hash table, and still maintain /// correct indexing for concatenated build-side data batch. /// @@ -343,17 +742,21 @@ pub struct HashJoinExec { /// Execution metrics metrics: ExecutionPlanMetricsSet, /// The projection indices of the columns in the output schema of join - pub projection: Option>, + pub projection: Option, /// Information of index and left / right placement of columns column_indices: Vec, /// The equality null-handling behavior of the join algorithm. pub null_equality: NullEquality, + /// Flag to indicate if this is a null-aware anti join + pub null_aware: bool, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, /// Dynamic filter for pushing down to the probe side /// Set when dynamic filter pushdown is detected in handle_child_pushdown_result. /// HashJoinExec also needs to keep a shared bounds accumulator for coordinating updates. dynamic_filter: Option, + /// Maximum number of rows to return + fetch: Option, } #[derive(Clone)] @@ -394,7 +797,7 @@ impl EmbeddedProjection for HashJoinExec { } impl HashJoinExec { - /// Tries to create a new [HashJoinExec]. + /// Tries to create a new [`HashJoinExec`]. /// /// # Error /// This function errors when it is not possible to join the left and right sides on keys `on`. @@ -408,55 +811,24 @@ impl HashJoinExec { projection: Option>, partition_mode: PartitionMode, null_equality: NullEquality, + null_aware: bool, ) -> Result { - let left_schema = left.schema(); - let right_schema = right.schema(); - if on.is_empty() { - return plan_err!("On constraints in HashJoinExec should be non-empty"); - } - - check_join_is_valid(&left_schema, &right_schema, &on)?; - - let (join_schema, column_indices) = - build_join_schema(&left_schema, &right_schema, join_type); - - let random_state = HASH_JOIN_SEED; - - let join_schema = Arc::new(join_schema); - - // check if the projection is valid - can_project(&join_schema, projection.as_ref())?; - - let cache = Self::compute_properties( - &left, - &right, - &join_schema, - *join_type, - &on, - partition_mode, - projection.as_ref(), - )?; - - // Initialize both dynamic filter and bounds accumulator to None - // They will be set later if dynamic filtering is enabled + HashJoinExecBuilder::new(left, right, on, *join_type) + .with_filter(filter) + .with_projection(projection) + .with_partition_mode(partition_mode) + .with_null_equality(null_equality) + .with_null_aware(null_aware) + .build() + } - Ok(HashJoinExec { - left, - right, - on, - filter, - join_type: *join_type, - join_schema, - left_fut: Default::default(), - random_state, - mode: partition_mode, - metrics: ExecutionPlanMetricsSet::new(), - projection, - column_indices, - null_equality, - cache, - dynamic_filter: None, - }) + /// Create a builder based on the existing [`HashJoinExec`]. + /// + /// Returned builder preserves all existing fields. If a field requiring properties + /// recomputation is modified, this will be done automatically during the node build. + /// + pub fn builder(&self) -> HashJoinExecBuilder { + self.into() } fn create_dynamic_filter(on: &JoinOn) -> Arc { @@ -467,6 +839,28 @@ impl HashJoinExec { Arc::new(DynamicFilterPhysicalExpr::new(right_keys, lit(true))) } + fn allow_join_dynamic_filter_pushdown(&self, config: &ConfigOptions) -> bool { + if self.join_type != JoinType::Inner + || !config.optimizer.enable_join_dynamic_filter_pushdown + { + return false; + } + + // `preserve_file_partitions` can report Hash partitioning for Hive-style + // file groups, but those partitions are not actually hash-distributed. + // Partitioned dynamic filters rely on hash routing, so disable them in + // this mode to avoid incorrect results. Follow-up work: enable dynamic + // filtering for preserve_file_partitioned scans (issue #20195). + // https://github.com/apache/datafusion/issues/20195 + if config.optimizer.preserve_file_partitions > 0 + && self.mode == PartitionMode::Partitioned + { + return false; + } + + true + } + /// left (build) side which gets hashed pub fn left(&self) -> &Arc { &self.left @@ -513,10 +907,8 @@ impl HashJoinExec { /// /// This method is intended for testing only and should not be used in production code. #[doc(hidden)] - pub fn dynamic_filter_for_test(&self) -> Option> { - self.dynamic_filter - .as_ref() - .map(|df| Arc::clone(&df.filter)) + pub fn dynamic_filter_for_test(&self) -> Option<&Arc> { + self.dynamic_filter.as_ref().map(|df| &df.filter) } /// Calculate order preservation flags for this hash join. @@ -547,25 +939,12 @@ impl HashJoinExec { /// Return new instance of [HashJoinExec] with the given projection. pub fn with_projection(&self, projection: Option>) -> Result { + let projection = projection.map(Into::into); // check if the projection is valid - can_project(&self.schema(), projection.as_ref())?; - let projection = match projection { - Some(projection) => match &self.projection { - Some(p) => Some(projection.iter().map(|i| p[*i]).collect()), - None => Some(projection), - }, - None => None, - }; - Self::try_new( - Arc::clone(&self.left), - Arc::clone(&self.right), - self.on.clone(), - self.filter.clone(), - &self.join_type, - projection, - self.mode, - self.null_equality, - ) + can_project(&self.schema(), projection.as_deref())?; + let projection = + combine_projections(projection.as_ref(), self.projection.as_ref())?; + self.builder().with_projection_ref(projection).build() } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -576,7 +955,7 @@ impl HashJoinExec { join_type: JoinType, on: JoinOnRef, mode: PartitionMode, - projection: Option<&Vec>, + projection: Option<&[usize]>, ) -> Result { // Calculate equivalence properties: let mut eq_properties = join_equivalence_properties( @@ -628,7 +1007,7 @@ impl HashJoinExec { if let Some(projection) = projection { // construct a map from the input expressions to the output expression of the Projection let projection_mapping = ProjectionMapping::from_indices(projection, schema)?; - let out_schema = project_schema(schema, Some(projection))?; + let out_schema = project_schema(schema, Some(&projection))?; output_partitioning = output_partitioning.project(&projection_mapping, &eq_properties); eq_properties = eq_properties.project(&projection_mapping, out_schema); @@ -671,24 +1050,25 @@ impl HashJoinExec { ) -> Result> { let left = self.left(); let right = self.right(); - let new_join = HashJoinExec::try_new( - Arc::clone(right), - Arc::clone(left), - self.on() - .iter() - .map(|(l, r)| (Arc::clone(r), Arc::clone(l))) - .collect(), - self.filter().map(JoinFilter::swap), - &self.join_type().swap(), - swap_join_projection( + let new_join = self + .builder() + .with_type(self.join_type.swap()) + .with_new_children(vec![Arc::clone(right), Arc::clone(left)])? + .with_on( + self.on() + .iter() + .map(|(l, r)| (Arc::clone(r), Arc::clone(l))) + .collect(), + ) + .with_filter(self.filter().map(JoinFilter::swap)) + .with_projection(swap_join_projection( left.schema().fields().len(), right.schema().fields().len(), - self.projection.as_ref(), + self.projection.as_deref(), self.join_type(), - ), - partition_mode, - self.null_equality(), - )?; + )) + .with_partition_mode(partition_mode) + .build()?; // In case of anti / semi joins or if there is embedded projection in HashJoinExec, output column order is preserved, no need to add projection again if matches!( self.join_type(), @@ -734,11 +1114,14 @@ impl DisplayAs for HashJoinExec { "".to_string() }; let display_null_equality = - if matches!(self.null_equality(), NullEquality::NullEqualsNull) { + if self.null_equality() == NullEquality::NullEqualsNull { ", NullsEqual: true" } else { "" }; + let display_fetch = self + .fetch + .map_or_else(String::new, |f| format!(", fetch={f}")); let on = self .on .iter() @@ -747,13 +1130,14 @@ impl DisplayAs for HashJoinExec { .join(", "); write!( f, - "HashJoinExec: mode={:?}, join_type={:?}, on=[{}]{}{}{}", + "HashJoinExec: mode={:?}, join_type={:?}, on=[{}]{}{}{}{}", self.mode, self.join_type, on, display_filter, display_projections, display_null_equality, + display_fetch, ) } DisplayFormatType::TreeRender => { @@ -772,7 +1156,7 @@ impl DisplayAs for HashJoinExec { writeln!(f, "on={on}")?; - if matches!(self.null_equality(), NullEquality::NullEqualsNull) { + if self.null_equality() == NullEquality::NullEqualsNull { writeln!(f, "NullsEqual: true")?; } @@ -780,6 +1164,10 @@ impl DisplayAs for HashJoinExec { writeln!(f, "filter={filter}")?; } + if let Some(fetch) = self.fetch { + writeln!(f, "fetch={fetch}")?; + } + Ok(()) } } @@ -795,7 +1183,7 @@ impl ExecutionPlan for HashJoinExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -856,54 +1244,11 @@ impl ExecutionPlan for HashJoinExec { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(HashJoinExec { - left: Arc::clone(&children[0]), - right: Arc::clone(&children[1]), - on: self.on.clone(), - filter: self.filter.clone(), - join_type: self.join_type, - join_schema: Arc::clone(&self.join_schema), - left_fut: Arc::clone(&self.left_fut), - random_state: self.random_state.clone(), - mode: self.mode, - metrics: ExecutionPlanMetricsSet::new(), - projection: self.projection.clone(), - column_indices: self.column_indices.clone(), - null_equality: self.null_equality, - cache: Self::compute_properties( - &children[0], - &children[1], - &self.join_schema, - self.join_type, - &self.on, - self.mode, - self.projection.as_ref(), - )?, - // Keep the dynamic filter, bounds accumulator will be reset - dynamic_filter: self.dynamic_filter.clone(), - })) + self.builder().with_new_children(children)?.build_exec() } fn reset_state(self: Arc) -> Result> { - Ok(Arc::new(HashJoinExec { - left: Arc::clone(&self.left), - right: Arc::clone(&self.right), - on: self.on.clone(), - filter: self.filter.clone(), - join_type: self.join_type, - join_schema: Arc::clone(&self.join_schema), - // Reset the left_fut to allow re-execution - left_fut: Arc::new(OnceAsync::default()), - random_state: self.random_state.clone(), - mode: self.mode, - metrics: ExecutionPlanMetricsSet::new(), - projection: self.projection.clone(), - column_indices: self.column_indices.clone(), - null_equality: self.null_equality, - cache: self.cache.clone(), - // Reset dynamic filter and bounds accumulator to initial state - dynamic_filter: None, - })) + self.builder().reset_state().build_exec() } fn execute( @@ -937,11 +1282,8 @@ impl ExecutionPlan for HashJoinExec { // - A dynamic filter exists // - At least one consumer is holding a reference to it, this avoids expensive filter // computation when disabled or when no consumer will use it. - let enable_dynamic_filter_pushdown = context - .session_config() - .options() - .optimizer - .enable_join_dynamic_filter_pushdown + let enable_dynamic_filter_pushdown = self + .allow_join_dynamic_filter_pushdown(context.session_config().options()) && self .dynamic_filter .as_ref() @@ -949,6 +1291,10 @@ impl ExecutionPlan for HashJoinExec { .unwrap_or(false); let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics); + + let array_map_created_count = MetricBuilder::new(&self.metrics) + .counter(ARRAY_MAP_CREATED_COUNT_METRIC_NAME, partition); + let left_fut = match self.mode { PartitionMode::CollectLeft => self.left_fut.try_once(|| { let left_stream = self.left.execute(0, Arc::clone(&context))?; @@ -965,16 +1311,9 @@ impl ExecutionPlan for HashJoinExec { need_produce_result_in_final(self.join_type), self.right().output_partitioning().partition_count(), enable_dynamic_filter_pushdown, - context - .session_config() - .options() - .optimizer - .hash_join_inlist_pushdown_max_size, - context - .session_config() - .options() - .optimizer - .hash_join_inlist_pushdown_max_distinct_values, + Arc::clone(context.session_config().options()), + self.null_equality, + array_map_created_count, )) })?, PartitionMode::Partitioned => { @@ -993,16 +1332,9 @@ impl ExecutionPlan for HashJoinExec { need_produce_result_in_final(self.join_type), 1, enable_dynamic_filter_pushdown, - context - .session_config() - .options() - .optimizer - .hash_join_inlist_pushdown_max_size, - context - .session_config() - .options() - .optimizer - .hash_join_inlist_pushdown_max_distinct_values, + Arc::clone(context.session_config().options()), + self.null_equality, + array_map_created_count, )) } PartitionMode::Auto => { @@ -1047,7 +1379,7 @@ impl ExecutionPlan for HashJoinExec { let right_stream = self.right.execute(partition, context)?; // update column indices to reflect the projection - let column_indices_after_projection = match &self.projection { + let column_indices_after_projection = match self.projection.as_ref() { Some(projection) => projection .iter() .map(|i| self.column_indices[*i].clone()) @@ -1079,6 +1411,8 @@ impl ExecutionPlan for HashJoinExec { self.right.output_ordering().is_some(), build_accumulator, self.mode, + self.null_aware, + self.fetch, ))) } @@ -1086,10 +1420,6 @@ impl ExecutionPlan for HashJoinExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { if partition.is_some() { return Ok(Statistics::new_unknown(&self.schema())); @@ -1105,7 +1435,9 @@ impl ExecutionPlan for HashJoinExec { &self.join_schema, )?; // Project statistics if there is a projection - Ok(stats.project(self.projection.as_ref())) + let stats = stats.project(self.projection.as_ref()); + // Apply fetch limit to statistics + stats.with_fetch(self.fetch, 0, 1) } /// Tries to push `projection` down through `hash_join`. If possible, performs the @@ -1134,17 +1466,17 @@ impl ExecutionPlan for HashJoinExec { &schema, self.filter(), )? { - Ok(Some(Arc::new(HashJoinExec::try_new( - Arc::new(projected_left_child), - Arc::new(projected_right_child), - join_on, - join_filter, - self.join_type(), + self.builder() + .with_new_children(vec![ + Arc::new(projected_left_child), + Arc::new(projected_right_child), + ])? + .with_on(join_on) + .with_filter(join_filter) // Returned early if projection is not None - None, - *self.partition_mode(), - self.null_equality, - )?))) + .with_projection(None) + .build_exec() + .map(Some) } else { try_embed_projection(projection, self) } @@ -1156,30 +1488,111 @@ impl ExecutionPlan for HashJoinExec { parent_filters: Vec>, config: &ConfigOptions, ) -> Result { - // Other types of joins can support *some* filters, but restrictions are complex and error prone. - // For now we don't support them. - // See the logical optimizer rules for more details: datafusion/optimizer/src/push_down_filter.rs - // See https://github.com/apache/datafusion/issues/16973 for tracking. - if self.join_type != JoinType::Inner { - return Ok(FilterDescription::all_unsupported( - &parent_filters, - &self.children(), - )); - } + // This is the physical-plan equivalent of `push_down_all_join` in + // `datafusion/optimizer/src/push_down_filter.rs`. That function uses `lr_is_preserved` + // to decide which parent predicates can be pushed past a logical join to its children, + // then checks column references to route each predicate to the correct side. + // + // We apply the same two-level logic here: + // 1. `lr_is_preserved` gates whether a side is eligible at all. + // 2. For each filter, we check that all column references belong to the + // target child (using `column_indices` to map output column positions + // to join sides). This is critical for correctness: name-based matching + // alone (as done by `ChildFilterDescription::from_child`) can incorrectly + // push filters when different join sides have columns with the same name + // (e.g. nested mark joins both producing "mark" columns). + let (left_preserved, right_preserved) = lr_is_preserved(self.join_type); + + // Build the set of allowed column indices for each side + let column_indices: Vec = match self.projection.as_ref() { + Some(projection) => projection + .iter() + .map(|i| self.column_indices[*i].clone()) + .collect(), + None => self.column_indices.clone(), + }; - // Get basic filter descriptions for both children - let left_child = crate::filter_pushdown::ChildFilterDescription::from_child( - &parent_filters, - self.left(), - )?; - let mut right_child = crate::filter_pushdown::ChildFilterDescription::from_child( - &parent_filters, - self.right(), - )?; + let (mut left_allowed, mut right_allowed) = (HashSet::new(), HashSet::new()); + column_indices + .iter() + .enumerate() + .for_each(|(output_idx, ci)| { + match ci.side { + JoinSide::Left => left_allowed.insert(output_idx), + JoinSide::Right => right_allowed.insert(output_idx), + // Mark columns - don't allow pushdown to either side + JoinSide::None => false, + }; + }); + + // For semi/anti joins, the non-preserved side's columns are not in the + // output, but filters on join key columns can still be pushed there. + // We find output columns that are join keys on the preserved side and + // add their output indices to the non-preserved side's allowed set. + // The name-based remap in FilterRemapper will then match them to the + // corresponding column in the non-preserved child's schema. + match self.join_type { + JoinType::LeftSemi | JoinType::LeftAnti => { + let left_key_indices: HashSet = self + .on + .iter() + .filter_map(|(left_key, _)| { + left_key + .as_any() + .downcast_ref::() + .map(|c| c.index()) + }) + .collect(); + for (output_idx, ci) in column_indices.iter().enumerate() { + if ci.side == JoinSide::Left && left_key_indices.contains(&ci.index) { + right_allowed.insert(output_idx); + } + } + } + JoinType::RightSemi | JoinType::RightAnti => { + let right_key_indices: HashSet = self + .on + .iter() + .filter_map(|(_, right_key)| { + right_key + .as_any() + .downcast_ref::() + .map(|c| c.index()) + }) + .collect(); + for (output_idx, ci) in column_indices.iter().enumerate() { + if ci.side == JoinSide::Right && right_key_indices.contains(&ci.index) + { + left_allowed.insert(output_idx); + } + } + } + _ => {} + } + + let left_child = if left_preserved { + ChildFilterDescription::from_child_with_allowed_indices( + &parent_filters, + left_allowed, + self.left(), + )? + } else { + ChildFilterDescription::all_unsupported(&parent_filters) + }; + + let mut right_child = if right_preserved { + ChildFilterDescription::from_child_with_allowed_indices( + &parent_filters, + right_allowed, + self.right(), + )? + } else { + ChildFilterDescription::all_unsupported(&parent_filters) + }; // Add dynamic filters in Post phase if enabled - if matches!(phase, FilterPushdownPhase::Post) - && config.optimizer.enable_join_dynamic_filter_pushdown + if phase == FilterPushdownPhase::Post + && self.allow_join_dynamic_filter_pushdown(config) { // Add actual dynamic filter to right side (probe side) let dynamic_filter = Self::create_dynamic_filter(&self.on); @@ -1197,19 +1610,6 @@ impl ExecutionPlan for HashJoinExec { child_pushdown_result: ChildPushdownResult, _config: &ConfigOptions, ) -> Result>> { - // Note: this check shouldn't be necessary because we already marked all parent filters as unsupported for - // non-inner joins in `gather_filters_for_pushdown`. - // However it's a cheap check and serves to inform future devs touching this function that they need to be really - // careful pushing down filters through non-inner joins. - if self.join_type != JoinType::Inner { - // Other types of joins can support *some* filters, but restrictions are complex and error prone. - // For now we don't support them. - // See the logical optimizer rules for more details: datafusion/optimizer/src/push_down_filter.rs - return Ok(FilterPushdownPropagation::all_unsupported( - child_pushdown_result, - )); - } - let mut result = FilterPushdownPropagation::if_any(child_pushdown_result.clone()); assert_eq!(child_pushdown_result.self_filters.len(), 2); // Should always be 2, we have 2 children let right_child_self_filters = &child_pushdown_result.self_filters[1]; // We only push down filters to the right child @@ -1222,31 +1622,58 @@ impl ExecutionPlan for HashJoinExec { Arc::downcast::(predicate) { // We successfully pushed down our self filter - we need to make a new node with the dynamic filter - let new_node = Arc::new(HashJoinExec { - left: Arc::clone(&self.left), - right: Arc::clone(&self.right), - on: self.on.clone(), - filter: self.filter.clone(), - join_type: self.join_type, - join_schema: Arc::clone(&self.join_schema), - left_fut: Arc::clone(&self.left_fut), - random_state: self.random_state.clone(), - mode: self.mode, - metrics: ExecutionPlanMetricsSet::new(), - projection: self.projection.clone(), - column_indices: self.column_indices.clone(), - null_equality: self.null_equality, - cache: self.cache.clone(), - dynamic_filter: Some(HashJoinExecDynamicFilter { + let new_node = self + .builder() + .with_dynamic_filter(Some(HashJoinExecDynamicFilter { filter: dynamic_filter, build_accumulator: OnceLock::new(), - }), - }); - result = result.with_updated_node(new_node as Arc); + })) + .build_exec()?; + result = result.with_updated_node(new_node); } } Ok(result) } + + fn supports_limit_pushdown(&self) -> bool { + // Hash join execution plan does not support pushing limit down through to children + // because the children don't know about the join condition and can't + // determine how many rows to produce + false + } + + fn fetch(&self) -> Option { + self.fetch + } + + fn with_fetch(&self, limit: Option) -> Option> { + self.builder() + .with_fetch(limit) + .build() + .ok() + .map(|exec| Arc::new(exec) as _) + } +} + +/// Determines which sides of a join are "preserved" for filter pushdown. +/// +/// A preserved side means filters on that side's columns can be safely pushed +/// below the join. This mirrors the logic in the logical optimizer's +/// `lr_is_preserved` in `datafusion/optimizer/src/push_down_filter.rs`. +fn lr_is_preserved(join_type: JoinType) -> (bool, bool) { + match join_type { + JoinType::Inner => (true, true), + JoinType::Left => (true, false), + JoinType::Right => (false, true), + JoinType::Full => (false, false), + // Filters in semi/anti joins are either on the preserved side, or on join keys, + // as all output columns come from the preserved side. Join key filters can be + // safely pushed down into the other side. + JoinType::LeftSemi | JoinType::LeftAnti => (true, true), + JoinType::RightSemi | JoinType::RightAnti => (true, true), + JoinType::LeftMark => (true, false), + JoinType::RightMark => (false, true), + } } /// Accumulator for collecting min/max bounds from build-side data during hash join. @@ -1364,6 +1791,19 @@ impl BuildSideState { } } +fn should_collect_min_max_for_perfect_hash( + on_left: &[PhysicalExprRef], + schema: &SchemaRef, +) -> Result { + if on_left.len() != 1 { + return Ok(false); + } + + let expr = &on_left[0]; + let data_type = expr.data_type(schema)?; + Ok(ArrayMap::is_supported_type(&data_type)) +} + /// Collects all batches from the left (build) side stream and creates a hash map for joining. /// /// This function is responsible for: @@ -1402,20 +1842,21 @@ async fn collect_left_input( with_visited_indices_bitmap: bool, probe_threads_count: usize, should_compute_dynamic_filters: bool, - max_inlist_size: usize, - max_inlist_distinct_values: usize, + config: Arc, + null_equality: NullEquality, + array_map_created_count: Count, ) -> Result { let schema = left_stream.schema(); - // This operation performs 2 steps at once: - // 1. creates a [JoinHashMap] of all batches from the stream - // 2. stores the batches in a vector. + let should_collect_min_max_for_phj = + should_collect_min_max_for_perfect_hash(&on_left, &schema)?; + let initial = BuildSideState::try_new( metrics, reservation, on_left.clone(), &schema, - should_compute_dynamic_filters, + should_compute_dynamic_filters || should_collect_min_max_for_phj, )?; let state = left_stream @@ -1452,50 +1893,85 @@ async fn collect_left_input( bounds_accumulators, } = state; - // Estimation of memory size, required for hashtable, prior to allocation. - // Final result can be verified using `RawTable.allocation_info()` - let fixed_size_u32 = size_of::(); - let fixed_size_u64 = size_of::(); - - // Use `u32` indices for the JoinHashMap when num_rows ≤ u32::MAX, otherwise use the - // `u64` indice variant - // Arc is used instead of Box to allow sharing with SharedBuildAccumulator for hash map pushdown - let mut hashmap: Box = if num_rows > u32::MAX as usize { - let estimated_hashtable_size = - estimate_memory_size::<(u64, u64)>(num_rows, fixed_size_u64)?; - reservation.try_grow(estimated_hashtable_size)?; - metrics.build_mem_used.add(estimated_hashtable_size); - Box::new(JoinHashMapU64::with_capacity(num_rows)) - } else { - let estimated_hashtable_size = - estimate_memory_size::<(u32, u64)>(num_rows, fixed_size_u32)?; - reservation.try_grow(estimated_hashtable_size)?; - metrics.build_mem_used.add(estimated_hashtable_size); - Box::new(JoinHashMapU32::with_capacity(num_rows)) + // Compute bounds + let mut bounds = match bounds_accumulators { + Some(accumulators) if num_rows > 0 => { + let bounds = accumulators + .into_iter() + .map(CollectLeftAccumulator::evaluate) + .collect::>>()?; + Some(PartitionBounds::new(bounds)) + } + _ => None, }; - let mut hashes_buffer = Vec::new(); - let mut offset = 0; - - // Updating hashmap starting from the last batch - let batches_iter = batches.iter().rev(); - for batch in batches_iter.clone() { - hashes_buffer.clear(); - hashes_buffer.resize(batch.num_rows(), 0); - update_hash( + let (join_hash_map, batch, left_values) = + if let Some((array_map, batch, left_value)) = try_create_array_map( + &bounds, + &schema, + &batches, &on_left, - batch, - &mut *hashmap, - offset, - &random_state, - &mut hashes_buffer, - 0, - true, - )?; - offset += batch.num_rows(); - } - // Merge all batches into a single batch, so we can directly index into the arrays - let batch = concat_batches(&schema, batches_iter)?; + &mut reservation, + config.execution.perfect_hash_join_small_build_threshold, + config.execution.perfect_hash_join_min_key_density, + null_equality, + )? { + array_map_created_count.add(1); + metrics.build_mem_used.add(array_map.size()); + + (Map::ArrayMap(array_map), batch, left_value) + } else { + // Estimation of memory size, required for hashtable, prior to allocation. + // Final result can be verified using `RawTable.allocation_info()` + let fixed_size_u32 = size_of::(); + let fixed_size_u64 = size_of::(); + + // Use `u32` indices for the JoinHashMap when num_rows ≤ u32::MAX, otherwise use the + // `u64` indice variant + // Arc is used instead of Box to allow sharing with SharedBuildAccumulator for hash map pushdown + let mut hashmap: Box = if num_rows > u32::MAX as usize { + let estimated_hashtable_size = + estimate_memory_size::<(u64, u64)>(num_rows, fixed_size_u64)?; + reservation.try_grow(estimated_hashtable_size)?; + metrics.build_mem_used.add(estimated_hashtable_size); + Box::new(JoinHashMapU64::with_capacity(num_rows)) + } else { + let estimated_hashtable_size = + estimate_memory_size::<(u32, u64)>(num_rows, fixed_size_u32)?; + reservation.try_grow(estimated_hashtable_size)?; + metrics.build_mem_used.add(estimated_hashtable_size); + Box::new(JoinHashMapU32::with_capacity(num_rows)) + }; + + let mut hashes_buffer = Vec::new(); + let mut offset = 0; + + let batches_iter = batches.iter().rev(); + + // Updating hashmap starting from the last batch + for batch in batches_iter.clone() { + hashes_buffer.clear(); + hashes_buffer.resize(batch.num_rows(), 0); + update_hash( + &on_left, + batch, + &mut *hashmap, + offset, + &random_state, + &mut hashes_buffer, + 0, + true, + )?; + offset += batch.num_rows(); + } + + // Merge all batches into a single batch, so we can directly index into the arrays + let batch = concat_batches(&schema, batches_iter.clone())?; + + let left_values = evaluate_expressions_to_arrays(&on_left, &batch)?; + + (Map::HashMap(hashmap), batch, left_values) + }; // Reserve additional memory for visited indices bitmap and create shared builder let visited_indices_bitmap = if with_visited_indices_bitmap { @@ -1510,22 +1986,7 @@ async fn collect_left_input( BooleanBufferBuilder::new(0) }; - let left_values = evaluate_expressions_to_arrays(&on_left, &batch)?; - - // Compute bounds for dynamic filter if enabled - let bounds = match bounds_accumulators { - Some(accumulators) if num_rows > 0 => { - let bounds = accumulators - .into_iter() - .map(CollectLeftAccumulator::evaluate) - .collect::>>()?; - Some(PartitionBounds::new(bounds)) - } - _ => None, - }; - - // Convert Box to Arc for sharing with SharedBuildAccumulator - let hash_map: Arc = hashmap.into(); + let map = Arc::new(join_hash_map); let membership = if num_rows == 0 { PushdownStrategy::Empty @@ -1539,19 +2000,26 @@ async fn collect_left_input( .sum::(); if left_values.is_empty() || left_values[0].is_empty() - || estimated_size > max_inlist_size - || hash_map.len() > max_inlist_distinct_values + || estimated_size > config.optimizer.hash_join_inlist_pushdown_max_size + || map.num_of_distinct_key() + > config + .optimizer + .hash_join_inlist_pushdown_max_distinct_values { - PushdownStrategy::HashTable(Arc::clone(&hash_map)) + PushdownStrategy::Map(Arc::clone(&map)) } else if let Some(in_list_values) = build_struct_inlist_values(&left_values)? { PushdownStrategy::InList(in_list_values) } else { - PushdownStrategy::HashTable(Arc::clone(&hash_map)) + PushdownStrategy::Map(Arc::clone(&map)) } }; + if should_collect_min_max_for_phj && !should_compute_dynamic_filters { + bounds = None; + } + let data = JoinLeftData { - hash_map, + map, batch, values: left_values, visited_indices_bitmap: Mutex::new(visited_indices_bitmap), @@ -1559,6 +2027,8 @@ async fn collect_left_input( _reservation: reservation, bounds, membership, + probe_side_non_empty: AtomicBool::new(false), + probe_side_has_null: AtomicBool::new(false), }; Ok(data) @@ -1567,6 +2037,43 @@ async fn collect_left_input( #[cfg(test)] mod tests { use super::*; + + fn assert_phj_used(metrics: &MetricsSet, use_phj: bool) { + if use_phj { + assert!( + metrics + .sum_by_name(ARRAY_MAP_CREATED_COUNT_METRIC_NAME) + .expect("should have array_map_created_count metrics") + .as_usize() + >= 1 + ); + } else { + assert_eq!( + metrics + .sum_by_name(ARRAY_MAP_CREATED_COUNT_METRIC_NAME) + .map(|v| v.as_usize()) + .unwrap_or(0), + 0 + ) + } + } + + fn build_schema_and_on() -> Result<(SchemaRef, SchemaRef, JoinOn)> { + let left_schema = Arc::new(Schema::new(vec![ + Field::new("a1", DataType::Int32, true), + Field::new("b1", DataType::Int32, true), + ])); + let right_schema = Arc::new(Schema::new(vec![ + Field::new("a2", DataType::Int32, true), + Field::new("b1", DataType::Int32, true), + ])); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left_schema)?) as _, + Arc::new(Column::new_with_schema("b1", &right_schema)?) as _, + )]; + Ok((left_schema, right_schema, on)) + } + use crate::coalesce_partitions::CoalescePartitionsExec; use crate::joins::hash_join::stream::lookup_join_hashmap; use crate::test::{TestMemoryExec, assert_join_metrics}; @@ -1575,7 +2082,9 @@ mod tests { test::exec::MockExec, }; - use arrow::array::{Date32Array, Int32Array, StructArray, UInt32Array, UInt64Array}; + use arrow::array::{ + Date32Array, Int32Array, Int64Array, StructArray, UInt32Array, UInt64Array, + }; use arrow::buffer::NullBuffer; use arrow::datatypes::{DataType, Field}; use arrow_schema::Schema; @@ -1601,10 +2110,37 @@ mod tests { #[template] #[rstest] - fn batch_sizes(#[values(8192, 10, 5, 2, 1)] batch_size: usize) {} + fn hash_join_exec_configs( + #[values(8192, 10, 5, 2, 1)] batch_size: usize, + #[values(true, false)] use_perfect_hash_join_as_possible: bool, + ) { + } - fn prepare_task_ctx(batch_size: usize) -> Arc { - let session_config = SessionConfig::default().with_batch_size(batch_size); + fn prepare_task_ctx( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Arc { + let mut session_config = SessionConfig::default().with_batch_size(batch_size); + + if use_perfect_hash_join_as_possible { + session_config + .options_mut() + .execution + .perfect_hash_join_small_build_threshold = 819200; + session_config + .options_mut() + .execution + .perfect_hash_join_min_key_density = 0.0; + } else { + session_config + .options_mut() + .execution + .perfect_hash_join_small_build_threshold = 0; + session_config + .options_mut() + .execution + .perfect_hash_join_min_key_density = f64::INFINITY; + } Arc::new(TaskContext::default().with_session_config(session_config)) } @@ -1618,6 +2154,26 @@ mod tests { TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() } + /// Build a table with two columns supporting nullable values + fn build_table_two_cols( + a: (&str, &Vec>), + b: (&str, &Vec>), + ) -> Arc { + let schema = Arc::new(Schema::new(vec![ + Field::new(a.0, DataType::Int32, true), + Field::new(b.0, DataType::Int32, true), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + ], + ) + .unwrap(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + fn join( left: Arc, right: Arc, @@ -1634,6 +2190,7 @@ mod tests { None, PartitionMode::CollectLeft, null_equality, + false, ) } @@ -1654,6 +2211,7 @@ mod tests { None, PartitionMode::CollectLeft, null_equality, + false, ) } @@ -1752,6 +2310,7 @@ mod tests { None, partition_mode, null_equality, + false, )?; let columns = columns(&join.schema()); @@ -1772,10 +2331,13 @@ mod tests { Ok((columns, batches, metrics)) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_inner_one(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_inner_one( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -1818,14 +2380,18 @@ mod tests { } assert_join_metrics!(metrics, 3); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn partitioned_join_inner_one(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn partitioned_join_inner_one( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -1866,6 +2432,7 @@ mod tests { } assert_join_metrics!(metrics, 3); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); Ok(()) } @@ -1967,10 +2534,13 @@ mod tests { Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_inner_two(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_inner_two( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 2]), ("b2", &vec![1, 2, 2]), @@ -2044,10 +2614,13 @@ mod tests { } /// Test where the left has 2 parts, the right with 1 part => 1 part - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_inner_one_two_parts_left(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_inner_one_two_parts_left( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let batch1 = build_table_i32( ("a1", &vec![1, 2]), ("b2", &vec![1, 2]), @@ -2189,10 +2762,13 @@ mod tests { } /// Test where the left has 1 part, the right has 2 parts => 2 parts - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_inner_one_two_parts_right(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_inner_one_two_parts_right( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -2293,6 +2869,9 @@ mod tests { "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + Ok(()) } @@ -2306,10 +2885,13 @@ mod tests { TestMemoryExec::try_new_exec(&[vec![batch.clone(), batch]], schema, None).unwrap() } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_left_multi_batch(batch_size: usize) { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_left_multi_batch( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -2326,9 +2908,9 @@ mod tests { )]; let join = join( - left, - right, - on, + Arc::clone(&left), + Arc::clone(&right), + on.clone(), &JoinType::Left, NullEquality::NullEqualsNothing, ) @@ -2337,8 +2919,15 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); - let stream = join.execute(0, task_ctx).unwrap(); - let batches = common::collect(stream).await.unwrap(); + let (_, batches, metrics) = join_collect( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + &JoinType::Left, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; allow_duplicates! { assert_snapshot!(batches_to_sort_string(&batches), @r" @@ -2353,12 +2942,18 @@ mod tests { +----+----+----+----+----+----+ "); } + + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + return Ok(()); } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_full_multi_batch(batch_size: usize) { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_full_multi_batch( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -2389,6 +2984,7 @@ mod tests { let stream = join.execute(0, task_ctx).unwrap(); let batches = common::collect(stream).await.unwrap(); + let metrics = join.metrics().unwrap(); allow_duplicates! { assert_snapshot!(batches_to_sort_string(&batches), @r" @@ -2405,12 +3001,17 @@ mod tests { +----+----+----+----+----+----+ "); } + + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_left_empty_right(batch_size: usize) { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_left_empty_right( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -2437,6 +3038,7 @@ mod tests { let stream = join.execute(0, task_ctx).unwrap(); let batches = common::collect(stream).await.unwrap(); + let metrics = join.metrics().unwrap(); allow_duplicates! { assert_snapshot!(batches_to_sort_string(&batches), @r" @@ -2449,12 +3051,17 @@ mod tests { +----+----+----+----+----+----+ "); } + + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_full_empty_right(batch_size: usize) { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_full_empty_right( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -2481,6 +3088,7 @@ mod tests { let stream = join.execute(0, task_ctx).unwrap(); let batches = common::collect(stream).await.unwrap(); + let metrics = join.metrics().unwrap(); allow_duplicates! { assert_snapshot!(batches_to_sort_string(&batches), @r" @@ -2493,12 +3101,17 @@ mod tests { +----+----+----+----+----+----+ "); } + + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_left_one(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_left_one( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -2539,14 +3152,18 @@ mod tests { } assert_join_metrics!(metrics, 3); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn partitioned_join_left_one(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn partitioned_join_left_one( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -2587,6 +3204,7 @@ mod tests { } assert_join_metrics!(metrics, 3); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); Ok(()) } @@ -2611,10 +3229,13 @@ mod tests { ) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_left_semi(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_left_semi( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); // left_table left semi join right_table on left_table.b1 = right_table.b2 @@ -2650,13 +3271,19 @@ mod tests { "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_left_semi_with_filter(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_left_semi_with_filter( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); @@ -2712,6 +3339,9 @@ mod tests { "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + // left_table left semi join right_table on left_table.b1 = right_table.b2 and right_table.a2 > 10 let filter_expression = Arc::new(BinaryExpr::new( Arc::new(Column::new("x", 0)), @@ -2749,13 +3379,19 @@ mod tests { "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_right_semi(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_right_semi( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); @@ -2792,13 +3428,19 @@ mod tests { "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_right_semi_with_filter(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_right_semi_with_filter( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); @@ -2855,6 +3497,9 @@ mod tests { "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + // left_table right semi join right_table on left_table.b1 = right_table.b2 on left_table.a1!=9 let filter_expression = Arc::new(BinaryExpr::new( Arc::new(Column::new("x", 0)), @@ -2891,13 +3536,19 @@ mod tests { "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_left_anti(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_left_anti( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); // left_table left anti join right_table on left_table.b1 = right_table.b2 @@ -2932,13 +3583,20 @@ mod tests { +----+----+----+ "); } + + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_left_anti_with_filter(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_left_anti_with_filter( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); // left_table left anti join right_table on left_table.b1 = right_table.b2 and right_table.a2!=8 @@ -2995,6 +3653,9 @@ mod tests { "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + // left_table left anti join right_table on left_table.b1 = right_table.b2 and right_table.a2 != 13 let filter_expression = Arc::new(BinaryExpr::new( Arc::new(Column::new("x", 0)), @@ -3038,13 +3699,19 @@ mod tests { "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_right_anti(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_right_anti( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); let on = vec![( @@ -3078,13 +3745,20 @@ mod tests { +----+----+-----+ "); } + + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_right_anti_with_filter(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_right_anti_with_filter( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); // left_table right anti join right_table on left_table.b1 = right_table.b2 and left_table.a1!=13 @@ -3142,6 +3816,9 @@ mod tests { "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + // left_table right anti join right_table on left_table.b1 = right_table.b2 and right_table.b2!=8 let column_indices = vec![ColumnIndex { index: 1, @@ -3188,13 +3865,19 @@ mod tests { "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_right_one(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_right_one( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -3235,14 +3918,18 @@ mod tests { } assert_join_metrics!(metrics, 3); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn partitioned_join_right_one(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn partitioned_join_right_one( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -3283,14 +3970,18 @@ mod tests { } assert_join_metrics!(metrics, 3); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_full_one(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_full_one( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -3333,13 +4024,19 @@ mod tests { "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_left_mark(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_left_mark( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -3380,14 +4077,18 @@ mod tests { } assert_join_metrics!(metrics, 3); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn partitioned_join_left_mark(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn partitioned_join_left_mark( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -3428,14 +4129,18 @@ mod tests { } assert_join_metrics!(metrics, 3); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_right_mark(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_right_mark( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -3475,14 +4180,18 @@ mod tests { assert_batches_sorted_eq!(expected, &batches); assert_join_metrics!(metrics, 3); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn partitioned_join_right_mark(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn partitioned_join_right_mark( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -3523,6 +4232,7 @@ mod tests { assert_batches_sorted_eq!(expected, &batches); assert_join_metrics!(metrics, 4); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); Ok(()) } @@ -3729,10 +4439,13 @@ mod tests { ) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_inner_with_filter(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_inner_with_filter( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a", &vec![0, 1, 2, 2]), ("b", &vec![4, 5, 7, 8]), @@ -3775,13 +4488,19 @@ mod tests { "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_left_with_filter(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_left_with_filter( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a", &vec![0, 1, 2, 2]), ("b", &vec![4, 5, 7, 8]), @@ -3827,13 +4546,19 @@ mod tests { "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_right_with_filter(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_right_with_filter( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a", &vec![0, 1, 2, 2]), ("b", &vec![4, 5, 7, 8]), @@ -3878,13 +4603,19 @@ mod tests { "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_full_with_filter(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_full_with_filter( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a", &vec![0, 1, 2, 2]), ("b", &vec![4, 5, 7, 8]), @@ -3931,6 +4662,9 @@ mod tests { ]; assert_batches_sorted_eq!(expected, &batches); + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + // THIS MIGRATION HALTED DUE TO ISSUE #15312 //allow_duplicates! { // assert_snapshot!(batches_to_sort_string(&batches), @r#" @@ -4280,7 +5014,7 @@ mod tests { // validation of partial join results output for different batch_size setting for join_type in join_types { for batch_size in (1..21).rev() { - let task_ctx = prepare_task_ctx(batch_size); + let task_ctx = prepare_task_ctx(batch_size, true); let join = join( Arc::clone(&left), @@ -4460,6 +5194,7 @@ mod tests { None, PartitionMode::Partitioned, NullEquality::NullEqualsNothing, + false, )?; let stream = join.execute(1, task_ctx)?; @@ -4635,11 +5370,6 @@ mod tests { let dynamic_filter = HashJoinExec::create_dynamic_filter(&on); let dynamic_filter_clone = Arc::clone(&dynamic_filter); - // Simulate a consumer by creating a transformed copy (what happens during filter pushdown) - let _consumer = Arc::clone(&dynamic_filter) - .with_new_children(vec![]) - .unwrap(); - // Create HashJoinExec with the dynamic filter let mut join = HashJoinExec::try_new( left, @@ -4650,6 +5380,7 @@ mod tests { None, PartitionMode::CollectLeft, NullEquality::NullEqualsNothing, + false, )?; join.dynamic_filter = Some(HashJoinExecDynamicFilter { filter: dynamic_filter, @@ -4688,11 +5419,6 @@ mod tests { let dynamic_filter = HashJoinExec::create_dynamic_filter(&on); let dynamic_filter_clone = Arc::clone(&dynamic_filter); - // Simulate a consumer by creating a transformed copy (what happens during filter pushdown) - let _consumer = Arc::clone(&dynamic_filter) - .with_new_children(vec![]) - .unwrap(); - // Create HashJoinExec with the dynamic filter let mut join = HashJoinExec::try_new( left, @@ -4703,6 +5429,7 @@ mod tests { None, PartitionMode::CollectLeft, NullEquality::NullEqualsNothing, + false, )?; join.dynamic_filter = Some(HashJoinExecDynamicFilter { filter: dynamic_filter, @@ -4719,4 +5446,496 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_perfect_hash_join_with_negative_numbers() -> Result<()> { + let task_ctx = prepare_task_ctx(8192, true); + let (left_schema, right_schema, on) = build_schema_and_on()?; + + let left_batch = RecordBatch::try_new( + Arc::clone(&left_schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + Arc::new(Int32Array::from(vec![-1, 0, 1])) as ArrayRef, + ], + )?; + let left = TestMemoryExec::try_new_exec(&[vec![left_batch]], left_schema, None)?; + + let right_batch = RecordBatch::try_new( + Arc::clone(&right_schema), + vec![ + Arc::new(Int32Array::from(vec![10, 20, 30, 40])) as ArrayRef, + Arc::new(Int32Array::from(vec![1, -1, 0, 2])) as ArrayRef, + ], + )?; + let right = + TestMemoryExec::try_new_exec(&[vec![right_batch]], right_schema, None)?; + + let (columns, batches, metrics) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; + + assert_eq!(columns, vec!["a1", "b1", "a2", "b1"]); + + assert_batches_sorted_eq!( + [ + "+----+----+----+----+", + "| a1 | b1 | a2 | b1 |", + "+----+----+----+----+", + "| 1 | -1 | 20 | -1 |", + "| 2 | 0 | 30 | 0 |", + "| 3 | 1 | 10 | 1 |", + "+----+----+----+----+", + ], + &batches + ); + + assert_phj_used(&metrics, true); + + Ok(()) + } + + #[tokio::test] + async fn test_perfect_hash_join_overflow_full_int64_range() -> Result<()> { + let task_ctx = prepare_task_ctx(8192, true); + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int64Array::from(vec![i64::MIN, i64::MAX]))], + )?; + let left = TestMemoryExec::try_new_exec( + &[vec![batch.clone()]], + Arc::clone(&schema), + None, + )?; + let right = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None)?; + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("a", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a", &right.schema())?) as _, + )]; + let (_columns, batches, _metrics) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 2); + Ok(()) + } + + #[apply(hash_join_exec_configs)] + #[tokio::test] + async fn test_phj_null_equals_null_build_no_nulls_probe_has_nulls( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); + let (left_schema, right_schema, on) = build_schema_and_on()?; + + let left_batch = RecordBatch::try_new( + Arc::clone(&left_schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef, + Arc::new(Int32Array::from(vec![10, 20])) as ArrayRef, + ], + )?; + let left = TestMemoryExec::try_new_exec(&[vec![left_batch]], left_schema, None)?; + + let right_batch = RecordBatch::try_new( + Arc::clone(&right_schema), + vec![ + Arc::new(Int32Array::from(vec![3, 4])) as ArrayRef, + Arc::new(Int32Array::from(vec![Some(10), None])) as ArrayRef, + ], + )?; + let right = + TestMemoryExec::try_new_exec(&[vec![right_batch]], right_schema, None)?; + + let (columns, batches, metrics) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNull, + task_ctx, + ) + .await?; + + assert_eq!(columns, vec!["a1", "b1", "a2", "b1"]); + assert_batches_sorted_eq!( + [ + "+----+----+----+----+", + "| a1 | b1 | a2 | b1 |", + "+----+----+----+----+", + "| 1 | 10 | 3 | 10 |", + "+----+----+----+----+", + ], + &batches + ); + + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + + Ok(()) + } + + #[apply(hash_join_exec_configs)] + #[tokio::test] + async fn test_phj_null_equals_nothing_build_probe_all_have_nulls( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); + let (left_schema, right_schema, on) = build_schema_and_on()?; + + let left_batch = RecordBatch::try_new( + Arc::clone(&left_schema), + vec![ + Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef, + Arc::new(Int32Array::from(vec![Some(10), None])) as ArrayRef, + ], + )?; + let left = TestMemoryExec::try_new_exec(&[vec![left_batch]], left_schema, None)?; + + let right_batch = RecordBatch::try_new( + Arc::clone(&right_schema), + vec![ + Arc::new(Int32Array::from(vec![Some(3), Some(4)])) as ArrayRef, + Arc::new(Int32Array::from(vec![Some(10), None])) as ArrayRef, + ], + )?; + let right = + TestMemoryExec::try_new_exec(&[vec![right_batch]], right_schema, None)?; + + let (columns, batches, metrics) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; + + assert_eq!(columns, vec!["a1", "b1", "a2", "b1"]); + assert_batches_sorted_eq!( + [ + "+----+----+----+----+", + "| a1 | b1 | a2 | b1 |", + "+----+----+----+----+", + "| 1 | 10 | 3 | 10 |", + "+----+----+----+----+", + ], + &batches + ); + + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + + Ok(()) + } + + #[tokio::test] + async fn test_phj_null_equals_null_build_have_nulls() -> Result<()> { + let task_ctx = prepare_task_ctx(8192, true); + let (left_schema, right_schema, on) = build_schema_and_on()?; + + let left_batch = RecordBatch::try_new( + Arc::clone(&left_schema), + vec![ + Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)])) as ArrayRef, + Arc::new(Int32Array::from(vec![Some(10), Some(20), None])) as ArrayRef, + ], + )?; + let left = TestMemoryExec::try_new_exec(&[vec![left_batch]], left_schema, None)?; + + let right_batch = RecordBatch::try_new( + Arc::clone(&right_schema), + vec![ + Arc::new(Int32Array::from(vec![Some(3), Some(4)])) as ArrayRef, + Arc::new(Int32Array::from(vec![Some(10), Some(30)])) as ArrayRef, + ], + )?; + let right = + TestMemoryExec::try_new_exec(&[vec![right_batch]], right_schema, None)?; + + let (columns, batches, metrics) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNull, + task_ctx, + ) + .await?; + + assert_eq!(columns, vec!["a1", "b1", "a2", "b1"]); + assert_batches_sorted_eq!( + [ + "+----+----+----+----+", + "| a1 | b1 | a2 | b1 |", + "+----+----+----+----+", + "| 1 | 10 | 3 | 10 |", + "+----+----+----+----+", + ], + &batches + ); + + assert_phj_used(&metrics, false); + + Ok(()) + } + + /// Test null-aware anti join when probe side (right) contains NULL + /// Expected: no rows should be output (NULL in subquery means all results are unknown) + #[apply(hash_join_exec_configs)] + #[tokio::test] + async fn test_null_aware_anti_join_probe_null(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, false); + + // Build left table (rows to potentially output) + let left = build_table_two_cols( + ("c1", &vec![Some(1), Some(2), Some(3), Some(4)]), + ("dummy", &vec![Some(10), Some(20), Some(30), Some(40)]), + ); + + // Build right table (subquery with NULL) + let right = build_table_two_cols( + ("c2", &vec![Some(1), Some(2), Some(3), None]), + ("dummy", &vec![Some(100), Some(200), Some(300), Some(400)]), + ); + + let on = vec![( + Arc::new(Column::new_with_schema("c1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("c2", &right.schema())?) as _, + )]; + + // Create null-aware anti join + let join = HashJoinExec::try_new( + left, + right, + on, + None, + &JoinType::LeftAnti, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + true, // null_aware = true + )?; + + let stream = join.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + + // Expected: empty result (probe side has NULL, so no rows should be output) + allow_duplicates! { + assert_snapshot!(batches_to_sort_string(&batches), @r" + ++ + ++ + "); + } + Ok(()) + } + + /// Test null-aware anti join when build side (left) contains NULL keys + /// Expected: rows with NULL keys should not be output + #[apply(hash_join_exec_configs)] + #[tokio::test] + async fn test_null_aware_anti_join_build_null(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, false); + + // Build left table with NULL key (this row should not be output) + let left = build_table_two_cols( + ("c1", &vec![Some(1), Some(4), None]), + ("dummy", &vec![Some(10), Some(40), Some(0)]), + ); + + // Build right table (no NULL, so probe-side check passes) + let right = build_table_two_cols( + ("c2", &vec![Some(1), Some(2), Some(3)]), + ("dummy", &vec![Some(100), Some(200), Some(300)]), + ); + + let on = vec![( + Arc::new(Column::new_with_schema("c1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("c2", &right.schema())?) as _, + )]; + + // Create null-aware anti join + let join = HashJoinExec::try_new( + left, + right, + on, + None, + &JoinType::LeftAnti, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + true, // null_aware = true + )?; + + let stream = join.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + + // Expected: only c1=4 (not c1=1 which matches, not c1=NULL) + allow_duplicates! { + assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+-------+ + | c1 | dummy | + +----+-------+ + | 4 | 40 | + +----+-------+ + "); + } + Ok(()) + } + + /// Test null-aware anti join with no NULLs (should work like regular anti join) + #[apply(hash_join_exec_configs)] + #[tokio::test] + async fn test_null_aware_anti_join_no_nulls(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, false); + + // Build left table (no NULLs) + let left = build_table_two_cols( + ("c1", &vec![Some(1), Some(2), Some(4), Some(5)]), + ("dummy", &vec![Some(10), Some(20), Some(40), Some(50)]), + ); + + // Build right table (no NULLs) + let right = build_table_two_cols( + ("c2", &vec![Some(1), Some(2), Some(3)]), + ("dummy", &vec![Some(100), Some(200), Some(300)]), + ); + + let on = vec![( + Arc::new(Column::new_with_schema("c1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("c2", &right.schema())?) as _, + )]; + + // Create null-aware anti join + let join = HashJoinExec::try_new( + left, + right, + on, + None, + &JoinType::LeftAnti, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + true, // null_aware = true + )?; + + let stream = join.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + + // Expected: c1=4 and c1=5 (they don't match anything in right) + allow_duplicates! { + assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+-------+ + | c1 | dummy | + +----+-------+ + | 4 | 40 | + | 5 | 50 | + +----+-------+ + "); + } + Ok(()) + } + + /// Test that null_aware validation rejects non-LeftAnti join types + #[tokio::test] + async fn test_null_aware_validation_wrong_join_type() { + let left = + build_table_two_cols(("c1", &vec![Some(1)]), ("dummy", &vec![Some(10)])); + let right = + build_table_two_cols(("c2", &vec![Some(1)]), ("dummy", &vec![Some(100)])); + + let on = vec![( + Arc::new(Column::new_with_schema("c1", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("c2", &right.schema()).unwrap()) as _, + )]; + + // Try to create null-aware Inner join (should fail) + let result = HashJoinExec::try_new( + left, + right, + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + true, // null_aware = true (invalid for Inner join) + ); + + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("null_aware can only be true for LeftAnti joins") + ); + } + + /// Test that null_aware validation rejects multi-column joins + #[tokio::test] + async fn test_null_aware_validation_multi_column() { + let left = build_table(("a", &vec![1]), ("b", &vec![2]), ("c", &vec![3])); + let right = build_table(("x", &vec![1]), ("y", &vec![2]), ("z", &vec![3])); + + // Try multi-column join + let on = vec![ + ( + Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("x", &right.schema()).unwrap()) as _, + ), + ( + Arc::new(Column::new_with_schema("b", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("y", &right.schema()).unwrap()) as _, + ), + ]; + + // Try to create null-aware anti join with 2 columns (should fail) + let result = HashJoinExec::try_new( + left, + right, + on, + None, + &JoinType::LeftAnti, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + true, // null_aware = true (invalid for multi-column) + ); + + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("null_aware anti join only supports single column join key") + ); + } + + #[test] + fn test_lr_is_preserved() { + assert_eq!(lr_is_preserved(JoinType::Inner), (true, true)); + assert_eq!(lr_is_preserved(JoinType::Left), (true, false)); + assert_eq!(lr_is_preserved(JoinType::Right), (false, true)); + assert_eq!(lr_is_preserved(JoinType::Full), (false, false)); + assert_eq!(lr_is_preserved(JoinType::LeftSemi), (true, true)); + assert_eq!(lr_is_preserved(JoinType::LeftAnti), (true, true)); + assert_eq!(lr_is_preserved(JoinType::LeftMark), (true, false)); + assert_eq!(lr_is_preserved(JoinType::RightSemi), (true, true)); + assert_eq!(lr_is_preserved(JoinType::RightAnti), (true, true)); + assert_eq!(lr_is_preserved(JoinType::RightMark), (false, true)); + } } diff --git a/datafusion/physical-plan/src/joins/hash_join/inlist_builder.rs b/datafusion/physical-plan/src/joins/hash_join/inlist_builder.rs index 7dccc5b0ba7c..0ca338265ecc 100644 --- a/datafusion/physical-plan/src/joins/hash_join/inlist_builder.rs +++ b/datafusion/physical-plan/src/joins/hash_join/inlist_builder.rs @@ -21,7 +21,6 @@ use std::sync::Arc; use arrow::array::{ArrayRef, StructArray}; use arrow::datatypes::{Field, FieldRef, Fields}; -use arrow::downcast_dictionary_array; use arrow_schema::DataType; use datafusion_common::Result; @@ -33,18 +32,6 @@ pub(super) fn build_struct_fields(data_types: &[DataType]) -> Result { .collect() } -/// Flattens dictionary-encoded arrays to their underlying value arrays. -/// Non-dictionary arrays are returned as-is. -fn flatten_dictionary_array(array: &ArrayRef) -> ArrayRef { - downcast_dictionary_array! { - array => { - // Recursively flatten in case of nested dictionaries - flatten_dictionary_array(array.values()) - } - _ => Arc::clone(array) - } -} - /// Builds InList values from join key column arrays. /// /// If `join_key_arrays` is: @@ -64,20 +51,14 @@ fn flatten_dictionary_array(array: &ArrayRef) -> ArrayRef { pub(super) fn build_struct_inlist_values( join_key_arrays: &[ArrayRef], ) -> Result> { - // Flatten any dictionary-encoded arrays - let flattened_arrays: Vec = join_key_arrays - .iter() - .map(flatten_dictionary_array) - .collect(); - // Build the source array/struct - let source_array: ArrayRef = if flattened_arrays.len() == 1 { + let source_array: ArrayRef = if join_key_arrays.len() == 1 { // Single column: use directly - Arc::clone(&flattened_arrays[0]) + Arc::clone(&join_key_arrays[0]) } else { // Multi-column: build StructArray once from all columns let fields = build_struct_fields( - &flattened_arrays + &join_key_arrays .iter() .map(|arr| arr.data_type().clone()) .collect::>(), @@ -87,7 +68,7 @@ pub(super) fn build_struct_inlist_values( let arrays_with_fields: Vec<(FieldRef, ArrayRef)> = fields .iter() .cloned() - .zip(flattened_arrays.iter().cloned()) + .zip(join_key_arrays.iter().cloned()) .collect(); Arc::new(StructArray::from(arrays_with_fields)) @@ -99,7 +80,9 @@ pub(super) fn build_struct_inlist_values( #[cfg(test)] mod tests { use super::*; - use arrow::array::{Int32Array, StringArray}; + use arrow::array::{ + DictionaryArray, Int8Array, Int32Array, StringArray, StringDictionaryBuilder, + }; use arrow_schema::DataType; use std::sync::Arc; @@ -130,4 +113,48 @@ mod tests { ) ); } + + #[test] + fn test_build_multi_column_inlist_with_dictionary() { + let mut builder = StringDictionaryBuilder::::new(); + builder.append_value("foo"); + builder.append_value("foo"); + builder.append_value("foo"); + let dict_array = Arc::new(builder.finish()) as ArrayRef; + + let int_array = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; + + let result = build_struct_inlist_values(&[dict_array, int_array]) + .unwrap() + .unwrap(); + + assert_eq!(result.len(), 3); + assert_eq!( + *result.data_type(), + DataType::Struct( + build_struct_fields(&[ + DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Utf8) + ), + DataType::Int32 + ]) + .unwrap() + ) + ); + } + + #[test] + fn test_build_single_column_dictionary_inlist() { + let keys = Int8Array::from(vec![0i8, 0, 0]); + let values = Arc::new(StringArray::from(vec!["foo"])); + let dict_array = Arc::new(DictionaryArray::new(keys, values)) as ArrayRef; + + let result = build_struct_inlist_values(std::slice::from_ref(&dict_array)) + .unwrap() + .unwrap(); + + assert_eq!(result.len(), 3); + assert_eq!(result.data_type(), dict_array.data_type()); + } } diff --git a/datafusion/physical-plan/src/joins/hash_join/mod.rs b/datafusion/physical-plan/src/joins/hash_join/mod.rs index 8592e1d96853..b915802ea401 100644 --- a/datafusion/physical-plan/src/joins/hash_join/mod.rs +++ b/datafusion/physical-plan/src/joins/hash_join/mod.rs @@ -17,7 +17,7 @@ //! [`HashJoinExec`] Partitioned Hash Join Operator -pub use exec::HashJoinExec; +pub use exec::{HashJoinExec, HashJoinExecBuilder}; pub use partitioned_hash_eval::{HashExpr, HashTableLookupExpr, SeededRandomState}; mod exec; diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned_hash_eval.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned_hash_eval.rs index 4c437e813139..e3d432643cfb 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned_hash_eval.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned_hash_eval.rs @@ -21,18 +21,18 @@ use std::{any::Any, fmt::Display, hash::Hash, sync::Arc}; use ahash::RandomState; use arrow::{ - array::{BooleanArray, UInt64Array}, - buffer::MutableBuffer, + array::{ArrayRef, UInt64Array}, datatypes::{DataType, Schema}, - util::bit_util, + record_batch::RecordBatch, }; -use datafusion_common::{Result, internal_datafusion_err, internal_err}; +use datafusion_common::Result; +use datafusion_common::hash_utils::{create_hashes, with_hashes}; use datafusion_expr::ColumnarValue; use datafusion_physical_expr_common::physical_expr::{ DynHash, PhysicalExpr, PhysicalExprRef, }; -use crate::{hash_utils::create_hashes, joins::utils::JoinHashMapType}; +use crate::joins::Map; /// RandomState wrapper that preserves the seeds used to create it. /// @@ -181,18 +181,11 @@ impl PhysicalExpr for HashExpr { Ok(false) } - fn evaluate( - &self, - batch: &arrow::record_batch::RecordBatch, - ) -> Result { + fn evaluate(&self, batch: &RecordBatch) -> Result { let num_rows = batch.num_rows(); // Evaluate columns - let keys_values = self - .on_columns - .iter() - .map(|c| c.evaluate(batch)?.into_array(num_rows)) - .collect::>>()?; + let keys_values = evaluate_columns(&self.on_columns, batch)?; // Compute hashes let mut hashes_buffer = vec![0; num_rows]; @@ -212,15 +205,17 @@ impl PhysicalExpr for HashExpr { } } -/// Physical expression that checks if hash values exist in a hash table +/// Physical expression that checks join keys in a [`Map`] (hash table or array map). /// -/// Takes a UInt64Array of hash values and checks membership in a hash table. -/// Returns a BooleanArray indicating which hashes exist. +/// Returns a [`BooleanArray`](arrow::array::BooleanArray) indicating if join keys (from `on_columns`) exist in the map. +// TODO: rename to MapLookupExpr pub struct HashTableLookupExpr { - /// Expression that computes hash values (should be a HashExpr) - hash_expr: PhysicalExprRef, - /// Hash table to check against - hash_map: Arc, + /// Columns in the ON clause used to compute the join key for lookups + on_columns: Vec, + /// Random state for hashing (with seeds preserved for serialization) + random_state: SeededRandomState, + /// Map to check against (hash table or array map) + map: Arc, /// Description for display description: String, } @@ -229,21 +224,23 @@ impl HashTableLookupExpr { /// Create a new HashTableLookupExpr /// /// # Arguments - /// * `hash_expr` - Expression that computes hash values - /// * `hash_map` - Hash table to check membership + /// * `on_columns` - Columns in the ON clause used to compute the join key + /// * `random_state` - SeededRandomState for hashing + /// * `map` - Map to check membership (hash table or array map) /// * `description` - Description for debugging - /// /// # Note /// This is public for internal testing purposes only and is not /// guaranteed to be stable across versions. pub fn new( - hash_expr: PhysicalExprRef, - hash_map: Arc, + on_columns: Vec, + random_state: SeededRandomState, + map: Arc, description: String, ) -> Self { Self { - hash_expr, - hash_map, + on_columns, + random_state, + map, description, } } @@ -251,14 +248,22 @@ impl HashTableLookupExpr { impl std::fmt::Debug for HashTableLookupExpr { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}({:?})", self.description, self.hash_expr) + let cols = self + .on_columns + .iter() + .map(|e| e.to_string()) + .collect::>() + .join(", "); + let (s1, s2, s3, s4) = self.random_state.seeds(); + write!(f, "{}({cols}, [{s1},{s2},{s3},{s4}])", self.description) } } impl Hash for HashTableLookupExpr { fn hash(&self, state: &mut H) { - self.hash_expr.dyn_hash(state); + self.on_columns.dyn_hash(state); self.description.hash(state); + self.random_state.seeds().hash(state); // Note that we compare hash_map by pointer equality. // Actually comparing the contents of the hash maps would be expensive. // The way these hash maps are used in actuality is that HashJoinExec creates @@ -266,7 +271,7 @@ impl Hash for HashTableLookupExpr { // hash maps to have the same content in practice. // Theoretically this is a public API and users could create identical hash maps, // but that seems unlikely and not worth paying the cost of deep comparison all the time. - Arc::as_ptr(&self.hash_map).hash(state); + Arc::as_ptr(&self.map).hash(state); } } @@ -279,9 +284,10 @@ impl PartialEq for HashTableLookupExpr { // hash maps to have the same content in practice. // Theoretically this is a public API and users could create identical hash maps, // but that seems unlikely and not worth paying the cost of deep comparison all the time. - self.hash_expr.as_ref() == other.hash_expr.as_ref() + self.on_columns == other.on_columns && self.description == other.description - && Arc::ptr_eq(&self.hash_map, &other.hash_map) + && self.random_state.seeds() == other.random_state.seeds() + && Arc::ptr_eq(&self.map, &other.map) } } @@ -299,22 +305,17 @@ impl PhysicalExpr for HashTableLookupExpr { } fn children(&self) -> Vec<&Arc> { - vec![&self.hash_expr] + self.on_columns.iter().collect() } fn with_new_children( self: Arc, children: Vec>, ) -> Result> { - if children.len() != 1 { - return internal_err!( - "HashTableLookupExpr expects exactly 1 child, got {}", - children.len() - ); - } Ok(Arc::new(HashTableLookupExpr::new( - Arc::clone(&children[0]), - Arc::clone(&self.hash_map), + children, + self.random_state.clone(), + Arc::clone(&self.map), self.description.clone(), ))) } @@ -327,36 +328,22 @@ impl PhysicalExpr for HashTableLookupExpr { Ok(false) } - fn evaluate( - &self, - batch: &arrow::record_batch::RecordBatch, - ) -> Result { - let num_rows = batch.num_rows(); - - // Evaluate hash expression to get hash values - let hash_array = self.hash_expr.evaluate(batch)?.into_array(num_rows)?; - let hash_array = hash_array.as_any().downcast_ref::().ok_or( - internal_datafusion_err!( - "HashTableLookupExpr expects UInt64Array from hash expression" - ), - )?; - - // Check each hash against the hash table - let mut buf = MutableBuffer::from_len_zeroed(bit_util::ceil(num_rows, 8)); - for (idx, hash_value) in hash_array.values().iter().enumerate() { - // Use get_matched_indices to check - if it returns any indices, the hash exists - let (matched_indices, _) = self - .hash_map - .get_matched_indices(Box::new(std::iter::once((idx, hash_value))), None); - - if !matched_indices.is_empty() { - bit_util::set_bit(buf.as_slice_mut(), idx); + fn evaluate(&self, batch: &RecordBatch) -> Result { + // Evaluate columns + let join_keys = evaluate_columns(&self.on_columns, batch)?; + + match self.map.as_ref() { + Map::HashMap(map) => { + with_hashes(&join_keys, self.random_state.random_state(), |hashes| { + let array = map.contain_hashes(hashes); + Ok(ColumnarValue::Array(Arc::new(array))) + }) + } + Map::ArrayMap(map) => { + let array = map.contain_keys(&join_keys)?; + Ok(ColumnarValue::Array(Arc::new(array))) } } - - Ok(ColumnarValue::Array(Arc::new( - BooleanArray::new_from_packed(buf, 0, num_rows), - ))) } fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -364,6 +351,17 @@ impl PhysicalExpr for HashTableLookupExpr { } } +fn evaluate_columns( + columns: &[PhysicalExprRef], + batch: &RecordBatch, +) -> Result> { + let num_rows = batch.num_rows(); + columns + .iter() + .map(|c| c.evaluate(batch)?.into_array(num_rows)) + .collect() +} + #[cfg(test)] mod tests { use super::*; @@ -482,22 +480,19 @@ mod tests { #[test] fn test_hash_table_lookup_expr_eq_same() { let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0)); - let hash_expr: PhysicalExprRef = Arc::new(HashExpr::new( - vec![Arc::clone(&col_a)], - SeededRandomState::with_seeds(1, 2, 3, 4), - "inner_hash".to_string(), - )); - let hash_map: Arc = - Arc::new(JoinHashMapU32::with_capacity(10)); + let hash_map = + Arc::new(Map::HashMap(Box::new(JoinHashMapU32::with_capacity(10)))); let expr1 = HashTableLookupExpr::new( - Arc::clone(&hash_expr), + vec![Arc::clone(&col_a)], + SeededRandomState::with_seeds(1, 2, 3, 4), Arc::clone(&hash_map), "lookup".to_string(), ); let expr2 = HashTableLookupExpr::new( - Arc::clone(&hash_expr), + vec![Arc::clone(&col_a)], + SeededRandomState::with_seeds(1, 2, 3, 4), Arc::clone(&hash_map), "lookup".to_string(), ); @@ -506,33 +501,23 @@ mod tests { } #[test] - fn test_hash_table_lookup_expr_eq_different_hash_expr() { + fn test_hash_table_lookup_expr_eq_different_columns() { let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0)); let col_b: PhysicalExprRef = Arc::new(Column::new("b", 1)); - let hash_expr1: PhysicalExprRef = Arc::new(HashExpr::new( - vec![Arc::clone(&col_a)], - SeededRandomState::with_seeds(1, 2, 3, 4), - "inner_hash".to_string(), - )); - - let hash_expr2: PhysicalExprRef = Arc::new(HashExpr::new( - vec![Arc::clone(&col_b)], - SeededRandomState::with_seeds(1, 2, 3, 4), - "inner_hash".to_string(), - )); - - let hash_map: Arc = - Arc::new(JoinHashMapU32::with_capacity(10)); + let hash_map = + Arc::new(Map::HashMap(Box::new(JoinHashMapU32::with_capacity(10)))); let expr1 = HashTableLookupExpr::new( - Arc::clone(&hash_expr1), + vec![Arc::clone(&col_a)], + SeededRandomState::with_seeds(1, 2, 3, 4), Arc::clone(&hash_map), "lookup".to_string(), ); let expr2 = HashTableLookupExpr::new( - Arc::clone(&hash_expr2), + vec![Arc::clone(&col_b)], + SeededRandomState::with_seeds(1, 2, 3, 4), Arc::clone(&hash_map), "lookup".to_string(), ); @@ -543,22 +528,19 @@ mod tests { #[test] fn test_hash_table_lookup_expr_eq_different_description() { let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0)); - let hash_expr: PhysicalExprRef = Arc::new(HashExpr::new( - vec![Arc::clone(&col_a)], - SeededRandomState::with_seeds(1, 2, 3, 4), - "inner_hash".to_string(), - )); - let hash_map: Arc = - Arc::new(JoinHashMapU32::with_capacity(10)); + let hash_map = + Arc::new(Map::HashMap(Box::new(JoinHashMapU32::with_capacity(10)))); let expr1 = HashTableLookupExpr::new( - Arc::clone(&hash_expr), + vec![Arc::clone(&col_a)], + SeededRandomState::with_seeds(1, 2, 3, 4), Arc::clone(&hash_map), "lookup_one".to_string(), ); let expr2 = HashTableLookupExpr::new( - Arc::clone(&hash_expr), + vec![Arc::clone(&col_a)], + SeededRandomState::with_seeds(1, 2, 3, 4), Arc::clone(&hash_map), "lookup_two".to_string(), ); @@ -569,26 +551,22 @@ mod tests { #[test] fn test_hash_table_lookup_expr_eq_different_hash_map() { let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0)); - let hash_expr: PhysicalExprRef = Arc::new(HashExpr::new( - vec![Arc::clone(&col_a)], - SeededRandomState::with_seeds(1, 2, 3, 4), - "inner_hash".to_string(), - )); // Two different Arc pointers (even with same content) should not be equal - let hash_map1: Arc = - Arc::new(JoinHashMapU32::with_capacity(10)); - let hash_map2: Arc = - Arc::new(JoinHashMapU32::with_capacity(10)); - + let hash_map1 = + Arc::new(Map::HashMap(Box::new(JoinHashMapU32::with_capacity(10)))); + let hash_map2 = + Arc::new(Map::HashMap(Box::new(JoinHashMapU32::with_capacity(10)))); let expr1 = HashTableLookupExpr::new( - Arc::clone(&hash_expr), + vec![Arc::clone(&col_a)], + SeededRandomState::with_seeds(1, 2, 3, 4), hash_map1, "lookup".to_string(), ); let expr2 = HashTableLookupExpr::new( - Arc::clone(&hash_expr), + vec![Arc::clone(&col_a)], + SeededRandomState::with_seeds(1, 2, 3, 4), hash_map2, "lookup".to_string(), ); @@ -600,22 +578,19 @@ mod tests { #[test] fn test_hash_table_lookup_expr_hash_consistency() { let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0)); - let hash_expr: PhysicalExprRef = Arc::new(HashExpr::new( - vec![Arc::clone(&col_a)], - SeededRandomState::with_seeds(1, 2, 3, 4), - "inner_hash".to_string(), - )); - let hash_map: Arc = - Arc::new(JoinHashMapU32::with_capacity(10)); + let hash_map = + Arc::new(Map::HashMap(Box::new(JoinHashMapU32::with_capacity(10)))); let expr1 = HashTableLookupExpr::new( - Arc::clone(&hash_expr), + vec![Arc::clone(&col_a)], + SeededRandomState::with_seeds(1, 2, 3, 4), Arc::clone(&hash_map), "lookup".to_string(), ); let expr2 = HashTableLookupExpr::new( - Arc::clone(&hash_expr), + vec![Arc::clone(&col_a)], + SeededRandomState::with_seeds(1, 2, 3, 4), Arc::clone(&hash_map), "lookup".to_string(), ); diff --git a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs index 7d34ce9acbd5..f32dc7fa8026 100644 --- a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs +++ b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs @@ -23,13 +23,13 @@ use std::sync::Arc; use crate::ExecutionPlan; use crate::ExecutionPlanProperties; +use crate::joins::Map; use crate::joins::PartitionMode; use crate::joins::hash_join::exec::HASH_JOIN_SEED; use crate::joins::hash_join::inlist_builder::build_struct_fields; use crate::joins::hash_join::partitioned_hash_eval::{ HashExpr, HashTableLookupExpr, SeededRandomState, }; -use crate::joins::utils::JoinHashMapType; use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; @@ -49,9 +49,9 @@ use tokio::sync::Barrier; #[derive(Debug, Clone, PartialEq)] pub(crate) struct ColumnBounds { /// The minimum value observed for this column - min: ScalarValue, + pub(crate) min: ScalarValue, /// The maximum value observed for this column - max: ScalarValue, + pub(crate) max: ScalarValue, } impl ColumnBounds { @@ -128,19 +128,12 @@ fn create_membership_predicate( )?))) } // Use hash table lookup for large build sides - PushdownStrategy::HashTable(hash_map) => { - let lookup_hash_expr = Arc::new(HashExpr::new( - on_right.to_vec(), - random_state.clone(), - "hash_join".to_string(), - )) as Arc; - - Ok(Some(Arc::new(HashTableLookupExpr::new( - lookup_hash_expr, - hash_map, - "hash_lookup".to_string(), - )) as Arc)) - } + PushdownStrategy::Map(hash_map) => Ok(Some(Arc::new(HashTableLookupExpr::new( + on_right.to_vec(), + random_state.clone(), + hash_map, + "hash_lookup".to_string(), + )) as Arc)), // Empty partition - should not create a filter for this PushdownStrategy::Empty => Ok(None), } @@ -240,8 +233,8 @@ pub(crate) struct SharedBuildAccumulator { pub(crate) enum PushdownStrategy { /// Use InList for small build sides (< 128MB) InList(ArrayRef), - /// Use hash table lookup for large build sides - HashTable(Arc), + /// Use map lookup for large build sides + Map(Arc), /// There was no data in this partition, do not build a dynamic filter for it Empty, } diff --git a/datafusion/physical-plan/src/joins/hash_join/stream.rs b/datafusion/physical-plan/src/joins/hash_join/stream.rs index e6735675125b..b31982ea3b7b 100644 --- a/datafusion/physical-plan/src/joins/hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/hash_join/stream.rs @@ -21,8 +21,12 @@ //! [`super::HashJoinExec`]. See comments in [`HashJoinStream`] for more details. use std::sync::Arc; +use std::sync::atomic::Ordering; use std::task::Poll; +use crate::coalesce::{LimitedBatchCoalescer, PushBatchStatus}; +use crate::joins::Map; +use crate::joins::MapOffset; use crate::joins::PartitionMode; use crate::joins::hash_join::exec::JoinLeftData; use crate::joins::hash_join::shared_bounds::{ @@ -34,7 +38,6 @@ use crate::joins::utils::{ use crate::{ RecordBatchStream, SendableRecordBatchStream, handle_state, hash_utils::create_hashes, - joins::join_hash_map::JoinHashMapOffset, joins::utils::{ BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinHashMapType, StatefulStreamResult, adjust_indices_by_join_type, apply_join_filter_to_indices, @@ -44,7 +47,6 @@ use crate::{ }; use arrow::array::{Array, ArrayRef, UInt32Array, UInt64Array}; -use arrow::compute::BatchCoalescer; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::{ @@ -154,13 +156,13 @@ pub(super) struct ProcessProbeBatchState { /// Probe-side on expressions values values: Vec, /// Starting offset for JoinHashMap lookups - offset: JoinHashMapOffset, + offset: MapOffset, /// Max joined probe-side index from current batch joined_probe_idx: Option, } impl ProcessProbeBatchState { - fn advance(&mut self, offset: JoinHashMapOffset, joined_probe_idx: Option) { + fn advance(&mut self, offset: MapOffset, joined_probe_idx: Option) { self.offset = offset; if joined_probe_idx.is_some() { self.joined_probe_idx = joined_probe_idx; @@ -219,10 +221,11 @@ pub(super) struct HashJoinStream { build_waiter: Option>, /// Partitioning mode to use mode: PartitionMode, - /// Output buffer for coalescing small batches into larger ones. - /// Uses `BatchCoalescer` from arrow to efficiently combine batches. - /// When batches are already close to target size, they bypass coalescing. - output_buffer: Box, + /// Output buffer for coalescing small batches into larger ones with optional fetch limit. + /// Uses `LimitedBatchCoalescer` to efficiently combine batches and absorb limit with 'fetch' + output_buffer: LimitedBatchCoalescer, + /// Whether this is a null-aware anti join + null_aware: bool, } impl RecordBatchStream for HashJoinStream { @@ -287,10 +290,10 @@ pub(super) fn lookup_join_hashmap( null_equality: NullEquality, hashes_buffer: &[u64], limit: usize, - offset: JoinHashMapOffset, + offset: MapOffset, probe_indices_buffer: &mut Vec, build_indices_buffer: &mut Vec, -) -> Result<(UInt64Array, UInt32Array, Option)> { +) -> Result<(UInt64Array, UInt32Array, Option)> { let next_offset = build_hashmap.get_matched_indices_with_limit_offset( hashes_buffer, limit, @@ -370,14 +373,12 @@ impl HashJoinStream { right_side_ordered: bool, build_accumulator: Option>, mode: PartitionMode, + null_aware: bool, + fetch: Option, ) -> Self { - // Create output buffer with coalescing. - // Use biggest_coalesce_batch_size to bypass coalescing for batches - // that are already close to target size (within 50%). - let output_buffer = Box::new( - BatchCoalescer::new(Arc::clone(&schema), batch_size) - .with_biggest_coalesce_batch_size(Some(batch_size / 2)), - ); + // Create output buffer with coalescing and optional fetch limit. + let output_buffer = + LimitedBatchCoalescer::new(Arc::clone(&schema), batch_size, fetch); Self { partition, @@ -401,6 +402,7 @@ impl HashJoinStream { build_waiter: None, mode, output_buffer, + null_aware, } } @@ -419,6 +421,11 @@ impl HashJoinStream { .record_poll(Poll::Ready(Some(Ok(batch)))); } + // Check if the coalescer has finished (limit reached and flushed) + if self.output_buffer.is_finished() { + return Poll::Ready(None); + } + return match self.state { HashJoinStreamState::WaitBuildSide => { handle_state!(ready!(self.collect_build_side(cx))) @@ -437,7 +444,7 @@ impl HashJoinStream { } HashJoinStreamState::Completed if !self.output_buffer.is_empty() => { // Flush any remaining buffered data - self.output_buffer.finish_buffered_batch()?; + self.output_buffer.finish()?; // Continue loop to emit the flushed batch continue; } @@ -483,6 +490,10 @@ impl HashJoinStream { )?; build_timer.done(); + // Note: For null-aware anti join, we need to check the probe side (right) for NULLs, + // not the build side (left). The probe-side NULL check happens during process_probe_batch. + // The probe_side_has_null flag will be set there if any probe batch contains NULL. + // Handle dynamic filter build-side information accumulation // // Dynamic filter coordination between partitions: @@ -552,9 +563,15 @@ impl HashJoinStream { // Precalculate hash values for fetched batch let keys_values = evaluate_expressions_to_arrays(&self.on_right, &batch)?; - self.hashes_buffer.clear(); - self.hashes_buffer.resize(batch.num_rows(), 0); - create_hashes(&keys_values, &self.random_state, &mut self.hashes_buffer)?; + if let Map::HashMap(_) = self.build_side.try_as_ready()?.left_data.map() { + self.hashes_buffer.clear(); + self.hashes_buffer.resize(batch.num_rows(), 0); + create_hashes( + &keys_values, + &self.random_state, + &mut self.hashes_buffer, + )?; + } self.join_metrics.input_batches.add(1); self.join_metrics.input_rows.add(batch.num_rows()); @@ -588,8 +605,48 @@ impl HashJoinStream { let timer = self.join_metrics.join_time.timer(); + // Null-aware anti join semantics: + // For LeftAnti: output LEFT (build) rows where LEFT.key NOT IN RIGHT.key + // 1. If RIGHT (probe) contains NULL in any batch, no LEFT rows should be output + // 2. LEFT rows with NULL keys should not be output (handled in final stage) + if self.null_aware { + // Mark that we've seen a probe batch with actual rows (probe side is non-empty) + // Only set this if batch has rows - empty batches don't count + // Use shared atomic state so all partitions can see this global information + if state.batch.num_rows() > 0 { + build_side + .left_data + .probe_side_non_empty + .store(true, Ordering::Relaxed); + } + + // Check if probe side (RIGHT) contains NULL + // Since null_aware validation ensures single column join, we only check the first column + let probe_key_column = &state.values[0]; + if probe_key_column.null_count() > 0 { + // Found NULL in probe side - set shared flag to prevent any output + build_side + .left_data + .probe_side_has_null + .store(true, Ordering::Relaxed); + } + + // If probe side has NULL (detected in this or any other partition), return empty result + if build_side + .left_data + .probe_side_has_null + .load(Ordering::Relaxed) + { + timer.done(); + self.state = HashJoinStreamState::FetchProbeBatch; + return Ok(StatefulStreamResult::Continue); + } + } + // if the left side is empty, we can skip the (potentially expensive) join operation - if build_side.left_data.hash_map.is_empty() && self.filter.is_none() { + let is_empty = build_side.left_data.map().is_empty(); + + if is_empty && self.filter.is_none() { let result = build_batch_empty_build_side( &self.schema, build_side.left_data.batch(), @@ -605,17 +662,34 @@ impl HashJoinStream { } // get the matched by join keys indices - let (left_indices, right_indices, next_offset) = lookup_join_hashmap( - build_side.left_data.hash_map(), - build_side.left_data.values(), - &state.values, - self.null_equality, - &self.hashes_buffer, - self.batch_size, - state.offset, - &mut self.probe_indices_buffer, - &mut self.build_indices_buffer, - )?; + let (left_indices, right_indices, next_offset) = match build_side.left_data.map() + { + Map::HashMap(map) => lookup_join_hashmap( + map.as_ref(), + build_side.left_data.values(), + &state.values, + self.null_equality, + &self.hashes_buffer, + self.batch_size, + state.offset, + &mut self.probe_indices_buffer, + &mut self.build_indices_buffer, + )?, + Map::ArrayMap(array_map) => { + let next_offset = array_map.get_matched_indices_with_limit_offset( + &state.values, + self.batch_size, + state.offset, + &mut self.probe_indices_buffer, + &mut self.build_indices_buffer, + )?; + ( + UInt64Array::from(self.build_indices_buffer.clone()), + UInt32Array::from(self.probe_indices_buffer.clone()), + next_offset, + ) + } + }; let distinct_right_indices_count = count_distinct_sorted_indices(&right_indices); @@ -639,6 +713,7 @@ impl HashJoinStream { filter, JoinSide::Left, None, + self.join_type, )? } else { (left_indices, right_indices) @@ -707,12 +782,20 @@ impl HashJoinStream { &right_indices, &self.column_indices, join_side, + self.join_type, )?; - self.output_buffer.push_batch(batch)?; + let push_status = self.output_buffer.push_batch(batch)?; timer.done(); + // If limit reached, finish and move to Completed state + if push_status == PushBatchStatus::LimitReached { + self.output_buffer.finish()?; + self.state = HashJoinStreamState::Completed; + return Ok(StatefulStreamResult::Continue); + } + if next_offset.is_none() { self.state = HashJoinStreamState::FetchProbeBatch; } else { @@ -740,18 +823,66 @@ impl HashJoinStream { } let build_side = self.build_side.try_as_ready()?; + + // For null-aware anti join, if probe side had NULL, no rows should be output + // Check shared atomic state to get global knowledge across all partitions + if self.null_aware + && build_side + .left_data + .probe_side_has_null + .load(Ordering::Relaxed) + { + timer.done(); + self.state = HashJoinStreamState::Completed; + return Ok(StatefulStreamResult::Continue); + } if !build_side.left_data.report_probe_completed() { self.state = HashJoinStreamState::Completed; return Ok(StatefulStreamResult::Continue); } // use the global left bitmap to produce the left indices and right indices - let (left_side, right_side) = get_final_indices_from_shared_bitmap( + let (mut left_side, mut right_side) = get_final_indices_from_shared_bitmap( build_side.left_data.visited_indices_bitmap(), self.join_type, true, ); + // For null-aware anti join, filter out LEFT rows with NULL in join keys + // BUT only if the probe side (RIGHT) was non-empty. If probe side is empty, + // NULL NOT IN (empty) = TRUE, so NULL rows should be returned. + // Use shared atomic state to get global knowledge across all partitions + if self.null_aware + && self.join_type == JoinType::LeftAnti + && build_side + .left_data + .probe_side_non_empty + .load(Ordering::Relaxed) + { + // Since null_aware validation ensures single column join, we only check the first column + let build_key_column = &build_side.left_data.values()[0]; + + // Filter out indices where the key is NULL + let filtered_indices: Vec = left_side + .iter() + .filter_map(|idx| { + let idx_usize = idx.unwrap() as usize; + if build_key_column.is_null(idx_usize) { + None // Skip rows with NULL keys + } else { + Some(idx.unwrap()) + } + }) + .collect(); + + left_side = UInt64Array::from(filtered_indices); + + // Update right_side to match the new length + let mut builder = arrow::array::UInt32Builder::with_capacity(left_side.len()); + builder.append_nulls(left_side.len()); + right_side = builder.finish(); + } + self.join_metrics.input_batches.add(1); self.join_metrics.input_rows.add(left_side.len()); @@ -770,8 +901,14 @@ impl HashJoinStream { &right_side, &self.column_indices, JoinSide::Left, + self.join_type, )?; - self.output_buffer.push_batch(batch)?; + let push_status = self.output_buffer.push_batch(batch)?; + + // If limit reached, finish the coalescer + if push_status == PushBatchStatus::LimitReached { + self.output_buffer.finish()?; + } } Ok(StatefulStreamResult::Continue) diff --git a/datafusion/physical-plan/src/joins/join_hash_map.rs b/datafusion/physical-plan/src/joins/join_hash_map.rs index b0ed6dcc7c25..8f0fb66b64fb 100644 --- a/datafusion/physical-plan/src/joins/join_hash_map.rs +++ b/datafusion/physical-plan/src/joins/join_hash_map.rs @@ -22,6 +22,8 @@ use std::fmt::{self, Debug}; use std::ops::Sub; +use arrow::array::BooleanArray; +use arrow::buffer::BooleanBuffer; use arrow::datatypes::ArrowNativeType; use hashbrown::HashTable; use hashbrown::hash_table::Entry::{Occupied, Vacant}; @@ -119,10 +121,13 @@ pub trait JoinHashMapType: Send + Sync { &self, hash_values: &[u64], limit: usize, - offset: JoinHashMapOffset, + offset: MapOffset, input_indices: &mut Vec, match_indices: &mut Vec, - ) -> Option; + ) -> Option; + + /// Returns a BooleanArray indicating which of the provided hashes exist in the map. + fn contain_hashes(&self, hash_values: &[u64]) -> BooleanArray; /// Returns `true` if the join hash map contains no entries. fn is_empty(&self) -> bool; @@ -181,10 +186,10 @@ impl JoinHashMapType for JoinHashMapU32 { &self, hash_values: &[u64], limit: usize, - offset: JoinHashMapOffset, + offset: MapOffset, input_indices: &mut Vec, match_indices: &mut Vec, - ) -> Option { + ) -> Option { get_matched_indices_with_limit_offset::( &self.map, &self.next, @@ -196,6 +201,10 @@ impl JoinHashMapType for JoinHashMapU32 { ) } + fn contain_hashes(&self, hash_values: &[u64]) -> BooleanArray { + contain_hashes(&self.map, hash_values) + } + fn is_empty(&self) -> bool { self.map.is_empty() } @@ -255,10 +264,10 @@ impl JoinHashMapType for JoinHashMapU64 { &self, hash_values: &[u64], limit: usize, - offset: JoinHashMapOffset, + offset: MapOffset, input_indices: &mut Vec, match_indices: &mut Vec, - ) -> Option { + ) -> Option { get_matched_indices_with_limit_offset::( &self.map, &self.next, @@ -270,6 +279,10 @@ impl JoinHashMapType for JoinHashMapU64 { ) } + fn contain_hashes(&self, hash_values: &[u64]) -> BooleanArray { + contain_hashes(&self.map, hash_values) + } + fn is_empty(&self) -> bool { self.map.is_empty() } @@ -279,54 +292,8 @@ impl JoinHashMapType for JoinHashMapU64 { } } -// Type of offsets for obtaining indices from JoinHashMap. -pub(crate) type JoinHashMapOffset = (usize, Option); - -/// Traverses the chain of matching indices, collecting results up to the remaining limit. -/// Returns `Some(offset)` if the limit was reached and there are more results to process, -/// or `None` if the chain was fully traversed. -#[inline(always)] -fn traverse_chain( - next_chain: &[T], - input_idx: usize, - start_chain_idx: T, - remaining: &mut usize, - input_indices: &mut Vec, - match_indices: &mut Vec, - is_last_input: bool, -) -> Option -where - T: Copy + TryFrom + PartialOrd + Into + Sub, - >::Error: Debug, - T: ArrowNativeType, -{ - let zero = T::usize_as(0); - let one = T::usize_as(1); - let mut match_row_idx = start_chain_idx - one; - - loop { - match_indices.push(match_row_idx.into()); - input_indices.push(input_idx as u32); - *remaining -= 1; - - let next = next_chain[match_row_idx.into() as usize]; - - if *remaining == 0 { - // Limit reached - return offset for next call - return if is_last_input && next == zero { - // Finished processing the last input row - None - } else { - Some((input_idx, Some(next.into()))) - }; - } - if next == zero { - // End of chain - return None; - } - match_row_idx = next - one; - } -} +use crate::joins::MapOffset; +use crate::joins::chain::traverse_chain; pub fn update_from_iter<'a, T>( map: &mut HashTable<(u64, T)>, @@ -414,10 +381,10 @@ pub fn get_matched_indices_with_limit_offset( next_chain: &[T], hash_values: &[u64], limit: usize, - offset: JoinHashMapOffset, + offset: MapOffset, input_indices: &mut Vec, match_indices: &mut Vec, -) -> Option +) -> Option where T: Copy + TryFrom + PartialOrd + Into + Sub, >::Error: Debug, @@ -496,3 +463,35 @@ where } None } + +pub fn contain_hashes(map: &HashTable<(u64, T)>, hash_values: &[u64]) -> BooleanArray { + let buffer = BooleanBuffer::collect_bool(hash_values.len(), |i| { + let hash = hash_values[i]; + map.find(hash, |(h, _)| hash == *h).is_some() + }); + BooleanArray::new(buffer, None) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_contain_hashes() { + let mut hash_map = JoinHashMapU32::with_capacity(10); + hash_map.update_from_iter(Box::new([10u64, 20u64, 30u64].iter().enumerate()), 0); + + let probe_hashes = vec![10, 11, 20, 21, 30, 31]; + let array = hash_map.contain_hashes(&probe_hashes); + + assert_eq!(array.len(), probe_hashes.len()); + + for (i, &hash) in probe_hashes.iter().enumerate() { + if matches!(hash, 10 | 20 | 30) { + assert!(array.value(i), "Hash {hash} should exist in the map"); + } else { + assert!(!array.value(i), "Hash {hash} should NOT exist in the map"); + } + } + } +} diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index 3ff61ecf1dac..2cdfa1e6ac02 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -20,13 +20,16 @@ use arrow::array::BooleanBufferBuilder; pub use cross_join::CrossJoinExec; use datafusion_physical_expr::PhysicalExprRef; -pub use hash_join::{HashExpr, HashJoinExec, HashTableLookupExpr, SeededRandomState}; -pub use nested_loop_join::NestedLoopJoinExec; +pub use hash_join::{ + HashExpr, HashJoinExec, HashJoinExecBuilder, HashTableLookupExpr, SeededRandomState, +}; +pub use nested_loop_join::{NestedLoopJoinExec, NestedLoopJoinExecBuilder}; use parking_lot::Mutex; // Note: SortMergeJoin is not used in plans yet pub use piecewise_merge_join::PiecewiseMergeJoinExec; pub use sort_merge_join::SortMergeJoinExec; pub use symmetric_hash_join::SymmetricHashJoinExec; +pub mod chain; mod cross_join; mod hash_join; mod nested_loop_join; @@ -36,6 +39,7 @@ mod stream_join_utils; mod symmetric_hash_join; pub mod utils; +mod array_map; mod join_filter; /// Hash map implementations for join operations. /// @@ -43,6 +47,31 @@ mod join_filter; /// and is not guaranteed to be stable across versions. pub mod join_hash_map; +use array_map::ArrayMap; +use utils::JoinHashMapType; + +pub enum Map { + HashMap(Box), + ArrayMap(ArrayMap), +} + +impl Map { + /// Returns the number of elements in the map. + pub fn num_of_distinct_key(&self) -> usize { + match self { + Map::HashMap(map) => map.len(), + Map::ArrayMap(array_map) => array_map.num_of_distinct_key(), + } + } + + /// Returns `true` if the map contains no elements. + pub fn is_empty(&self) -> bool { + self.num_of_distinct_key() == 0 + } +} + +pub(crate) type MapOffset = (usize, Option); + #[cfg(test)] pub mod test_utils; diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 44637321a7e3..4fb7dabf673d 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -46,6 +46,7 @@ use crate::projection::{ use crate::{ DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, PlanProperties, RecordBatchStream, SendableRecordBatchStream, + check_if_same_properties, }; use arrow::array::{ @@ -71,6 +72,7 @@ use datafusion_physical_expr::equivalence::{ ProjectionMapping, join_equivalence_properties, }; +use datafusion_physical_expr::projection::{ProjectionRef, combine_projections}; use futures::{Stream, StreamExt, TryStreamExt}; use log::debug; use parking_lot::Mutex; @@ -192,50 +194,120 @@ pub struct NestedLoopJoinExec { /// Information of index and left / right placement of columns column_indices: Vec, /// Projection to apply to the output of the join - projection: Option>, + projection: Option, /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, } -impl NestedLoopJoinExec { - /// Try to create a new [`NestedLoopJoinExec`] - pub fn try_new( +/// Helps to build [`NestedLoopJoinExec`]. +pub struct NestedLoopJoinExecBuilder { + left: Arc, + right: Arc, + join_type: JoinType, + filter: Option, + projection: Option, +} + +impl NestedLoopJoinExecBuilder { + /// Make a new [`NestedLoopJoinExecBuilder`]. + pub fn new( left: Arc, right: Arc, - filter: Option, - join_type: &JoinType, - projection: Option>, - ) -> Result { + join_type: JoinType, + ) -> Self { + Self { + left, + right, + join_type, + filter: None, + projection: None, + } + } + + /// Set projection from the vector. + pub fn with_projection(self, projection: Option>) -> Self { + self.with_projection_ref(projection.map(Into::into)) + } + + /// Set projection from the shared reference. + pub fn with_projection_ref(mut self, projection: Option) -> Self { + self.projection = projection; + self + } + + /// Set optional filter. + pub fn with_filter(mut self, filter: Option) -> Self { + self.filter = filter; + self + } + + /// Build resulting execution plan. + pub fn build(self) -> Result { + let Self { + left, + right, + join_type, + filter, + projection, + } = self; + let left_schema = left.schema(); let right_schema = right.schema(); check_join_is_valid(&left_schema, &right_schema, &[])?; let (join_schema, column_indices) = - build_join_schema(&left_schema, &right_schema, join_type); + build_join_schema(&left_schema, &right_schema, &join_type); let join_schema = Arc::new(join_schema); - let cache = Self::compute_properties( + let cache = NestedLoopJoinExec::compute_properties( &left, &right, &join_schema, - *join_type, - projection.as_ref(), + join_type, + projection.as_deref(), )?; - Ok(NestedLoopJoinExec { left, right, filter, - join_type: *join_type, + join_type, join_schema, build_side_data: Default::default(), column_indices, projection, metrics: Default::default(), - cache, + cache: Arc::new(cache), }) } +} + +impl From<&NestedLoopJoinExec> for NestedLoopJoinExecBuilder { + fn from(exec: &NestedLoopJoinExec) -> Self { + Self { + left: Arc::clone(exec.left()), + right: Arc::clone(exec.right()), + join_type: exec.join_type, + filter: exec.filter.clone(), + projection: exec.projection.clone(), + } + } +} + +impl NestedLoopJoinExec { + /// Try to create a new [`NestedLoopJoinExec`] + pub fn try_new( + left: Arc, + right: Arc, + filter: Option, + join_type: &JoinType, + projection: Option>, + ) -> Result { + NestedLoopJoinExecBuilder::new(left, right, *join_type) + .with_projection(projection) + .with_filter(filter) + .build() + } /// left side pub fn left(&self) -> &Arc { @@ -257,8 +329,8 @@ impl NestedLoopJoinExec { &self.join_type } - pub fn projection(&self) -> Option<&Vec> { - self.projection.as_ref() + pub fn projection(&self) -> &Option { + &self.projection } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -267,7 +339,7 @@ impl NestedLoopJoinExec { right: &Arc, schema: &SchemaRef, join_type: JoinType, - projection: Option<&Vec>, + projection: Option<&[usize]>, ) -> Result { // Calculate equivalence properties: let mut eq_properties = join_equivalence_properties( @@ -310,7 +382,7 @@ impl NestedLoopJoinExec { if let Some(projection) = projection { // construct a map from the input expressions to the output expression of the Projection let projection_mapping = ProjectionMapping::from_indices(projection, schema)?; - let out_schema = project_schema(schema, Some(projection))?; + let out_schema = project_schema(schema, Some(&projection))?; output_partitioning = output_partitioning.project(&projection_mapping, &eq_properties); eq_properties = eq_properties.project(&projection_mapping, out_schema); @@ -334,22 +406,14 @@ impl NestedLoopJoinExec { } pub fn with_projection(&self, projection: Option>) -> Result { + let projection = projection.map(Into::into); // check if the projection is valid - can_project(&self.schema(), projection.as_ref())?; - let projection = match projection { - Some(projection) => match &self.projection { - Some(p) => Some(projection.iter().map(|i| p[*i]).collect()), - None => Some(projection), - }, - None => None, - }; - Self::try_new( - Arc::clone(&self.left), - Arc::clone(&self.right), - self.filter.clone(), - &self.join_type, - projection, - ) + can_project(&self.schema(), projection.as_deref())?; + let projection = + combine_projections(projection.as_ref(), self.projection.as_ref())?; + NestedLoopJoinExecBuilder::from(self) + .with_projection_ref(projection) + .build() } /// Returns a new `ExecutionPlan` that runs NestedLoopsJoins with the left @@ -371,7 +435,7 @@ impl NestedLoopJoinExec { swap_join_projection( left.schema().fields().len(), right.schema().fields().len(), - self.projection.as_ref(), + self.projection.as_deref(), self.join_type(), ), )?; @@ -399,6 +463,27 @@ impl NestedLoopJoinExec { Ok(plan) } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + let left = children.swap_remove(0); + let right = children.swap_remove(0); + + Self { + left, + right, + metrics: ExecutionPlanMetricsSet::new(), + build_side_data: Default::default(), + cache: Arc::clone(&self.cache), + filter: self.filter.clone(), + join_type: self.join_type, + join_schema: Arc::clone(&self.join_schema), + column_indices: self.column_indices.clone(), + projection: self.projection.clone(), + } + } } impl DisplayAs for NestedLoopJoinExec { @@ -453,7 +538,7 @@ impl ExecutionPlan for NestedLoopJoinExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -476,13 +561,17 @@ impl ExecutionPlan for NestedLoopJoinExec { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(NestedLoopJoinExec::try_new( - Arc::clone(&children[0]), - Arc::clone(&children[1]), - self.filter.clone(), - &self.join_type, - self.projection.clone(), - )?)) + check_if_same_properties!(self, children); + Ok(Arc::new( + NestedLoopJoinExecBuilder::new( + Arc::clone(&children[0]), + Arc::clone(&children[1]), + self.join_type, + ) + .with_filter(self.filter.clone()) + .with_projection_ref(self.projection.clone()) + .build()?, + )) } fn execute( @@ -521,7 +610,7 @@ impl ExecutionPlan for NestedLoopJoinExec { let probe_side_data = self.right.execute(partition, context)?; // update column indices to reflect the projection - let column_indices_after_projection = match &self.projection { + let column_indices_after_projection = match self.projection.as_ref() { Some(projection) => projection .iter() .map(|i| self.column_indices[*i].clone()) @@ -545,10 +634,6 @@ impl ExecutionPlan for NestedLoopJoinExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { // NestedLoopJoinExec is designed for joins without equijoin keys in the // ON clause (e.g., `t1 JOIN t2 ON (t1.v1 + t2.v1) % 2 = 0`). Any join @@ -682,10 +767,10 @@ async fn collect_left_input( let schema = stream.schema(); // Load all batches and count the rows - let (batches, metrics, mut reservation) = stream + let (batches, metrics, reservation) = stream .try_fold( (Vec::new(), join_metrics, reservation), - |(mut batches, metrics, mut reservation), batch| async { + |(mut batches, metrics, reservation), batch| async { let batch_size = batch.get_array_memory_size(); // Reserve memory for incoming batch reservation.try_grow(batch_size)?; @@ -1949,9 +2034,10 @@ fn build_row_join_batch( // Broadcast the single build-side row to match the filtered // probe-side batch length let original_left_array = build_side_batch.column(column_index.index); - // Avoid using `ScalarValue::to_array_of_size()` for `List(Utf8View)` to avoid - // deep copies for buffers inside `Utf8View` array. See below for details. - // https://github.com/apache/datafusion/issues/18159 + + // Use `arrow::compute::take` directly for `List(Utf8View)` rather + // than going through `ScalarValue::to_array_of_size()`, which + // avoids some intermediate allocations. // // In other cases, `to_array_of_size()` is faster. match original_left_array.data_type() { diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs index 04daa3698d92..bb32a222de96 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs @@ -490,7 +490,7 @@ fn resolve_classic_join( // If we find a match we append all indices and move to the next stream row index match operator { Operator::Gt | Operator::Lt => { - if matches!(compare, Ordering::Less) { + if compare == Ordering::Less { batch_process_state.found = true; let count = buffered_len - buffer_idx; diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs index 508be2e3984f..abb6e34aa295 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs @@ -51,7 +51,9 @@ use crate::joins::piecewise_merge_join::utils::{ }; use crate::joins::utils::asymmetric_join_output_partitioning; use crate::metrics::MetricsSet; -use crate::{DisplayAs, DisplayFormatType, ExecutionPlanProperties}; +use crate::{ + DisplayAs, DisplayFormatType, ExecutionPlanProperties, check_if_same_properties, +}; use crate::{ ExecutionPlan, PlanProperties, joins::{ @@ -86,7 +88,7 @@ use crate::{ /// Both sides are sorted so that we can iterate from index 0 to the end on each side. This ordering ensures /// that when we find the first matching pair of rows, we can emit the current stream row joined with all remaining /// probe rows from the match position onward, without rescanning earlier probe rows. -/// +/// /// For `<` and `<=` operators, both inputs are sorted in **descending** order, while for `>` and `>=` operators /// they are sorted in **ascending** order. This choice ensures that the pointer on the buffered side can advance /// monotonically as we stream new batches from the stream side. @@ -129,34 +131,34 @@ use crate::{ /// /// Processing Row 1: /// -/// Sorted Buffered Side Sorted Streamed Side -/// ┌──────────────────┐ ┌──────────────────┐ -/// 1 │ 100 │ 1 │ 100 │ -/// ├──────────────────┤ ├──────────────────┤ -/// 2 │ 200 │ ─┐ 2 │ 200 │ -/// ├──────────────────┤ │ For row 1 on streamed side with ├──────────────────┤ -/// 3 │ 200 │ │ value 100, we emit rows 2 - 5. 3 │ 500 │ +/// Sorted Buffered Side Sorted Streamed Side +/// ┌──────────────────┐ ┌──────────────────┐ +/// 1 │ 100 │ 1 │ 100 │ +/// ├──────────────────┤ ├──────────────────┤ +/// 2 │ 200 │ ─┐ 2 │ 200 │ +/// ├──────────────────┤ │ For row 1 on streamed side with ├──────────────────┤ +/// 3 │ 200 │ │ value 100, we emit rows 2 - 5. 3 │ 500 │ /// ├──────────────────┤ │ as matches when the operator is └──────────────────┘ /// 4 │ 300 │ │ `Operator::Lt` (<) Emitting all /// ├──────────────────┤ │ rows after the first match (row /// 5 │ 400 │ ─┘ 2 buffered side; 100 < 200) -/// └──────────────────┘ +/// └──────────────────┘ /// /// Processing Row 2: /// By sorting the streamed side we know /// -/// Sorted Buffered Side Sorted Streamed Side -/// ┌──────────────────┐ ┌──────────────────┐ -/// 1 │ 100 │ 1 │ 100 │ -/// ├──────────────────┤ ├──────────────────┤ -/// 2 │ 200 │ <- Start here when probing for the 2 │ 200 │ -/// ├──────────────────┤ streamed side row 2. ├──────────────────┤ -/// 3 │ 200 │ 3 │ 500 │ +/// Sorted Buffered Side Sorted Streamed Side +/// ┌──────────────────┐ ┌──────────────────┐ +/// 1 │ 100 │ 1 │ 100 │ +/// ├──────────────────┤ ├──────────────────┤ +/// 2 │ 200 │ <- Start here when probing for the 2 │ 200 │ +/// ├──────────────────┤ streamed side row 2. ├──────────────────┤ +/// 3 │ 200 │ 3 │ 500 │ /// ├──────────────────┤ └──────────────────┘ -/// 4 │ 300 │ -/// ├──────────────────┤ +/// 4 │ 300 │ +/// ├──────────────────┤ /// 5 │ 400 │ -/// └──────────────────┘ +/// └──────────────────┘ /// ``` /// /// ## Existence Joins (Semi, Anti, Mark) @@ -202,10 +204,10 @@ use crate::{ /// 1 │ 100 │ 1 │ 500 │ /// ├──────────────────┤ ├──────────────────┤ /// 2 │ 200 │ 2 │ 200 │ -/// ├──────────────────┤ ├──────────────────┤ +/// ├──────────────────┤ ├──────────────────┤ /// 3 │ 200 │ 3 │ 300 │ /// ├──────────────────┤ └──────────────────┘ -/// 4 │ 300 │ ─┐ +/// 4 │ 300 │ ─┐ /// ├──────────────────┤ | We emit matches for row 4 - 5 /// 5 │ 400 │ ─┘ on the buffered side. /// └──────────────────┘ @@ -236,11 +238,11 @@ use crate::{ /// /// # Mark Join: /// Sorts the probe side, then computes the min/max range of the probe keys and scans the buffered side only -/// within that range. +/// within that range. /// Complexity: `O(|S| + scan(R[range]))`. /// /// ## Nested Loop Join -/// Compares every row from `S` with every row from `R`. +/// Compares every row from `S` with every row from `R`. /// Complexity: `O(|S| * |R|)`. /// /// ## Nested Loop Join @@ -273,13 +275,12 @@ pub struct PiecewiseMergeJoinExec { left_child_plan_required_order: LexOrdering, /// The right sort order, descending for `<`, `<=` operations + ascending for `>`, `>=` operations /// Unsorted for mark joins - #[expect(dead_code)] right_batch_required_orders: LexOrdering, /// This determines the sort order of all join columns used in sorting the stream and buffered execution plans. sort_options: SortOptions, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, /// Number of partitions to process num_partitions: usize, } @@ -373,7 +374,7 @@ impl PiecewiseMergeJoinExec { left_child_plan_required_order, right_batch_required_orders, sort_options, - cache, + cache: Arc::new(cache), num_partitions, }) } @@ -466,6 +467,31 @@ impl PiecewiseMergeJoinExec { pub fn swap_inputs(&self) -> Result> { todo!() } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + let buffered = children.swap_remove(0); + let streamed = children.swap_remove(0); + Self { + buffered, + streamed, + on: self.on.clone(), + operator: self.operator, + join_type: self.join_type, + schema: Arc::clone(&self.schema), + left_child_plan_required_order: self.left_child_plan_required_order.clone(), + right_batch_required_orders: self.right_batch_required_orders.clone(), + sort_options: self.sort_options, + cache: Arc::clone(&self.cache), + num_partitions: self.num_partitions, + + // Re-set state. + metrics: ExecutionPlanMetricsSet::new(), + buffered_fut: Default::default(), + } + } } impl ExecutionPlan for PiecewiseMergeJoinExec { @@ -477,7 +503,7 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -511,6 +537,7 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { self: Arc, children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); match &children[..] { [left, right] => Ok(Arc::new(PiecewiseMergeJoinExec::try_new( Arc::clone(left), @@ -527,6 +554,13 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { } } + fn reset_state(self: Arc) -> Result> { + Ok(Arc::new(self.with_new_children_and_same_properties(vec![ + Arc::clone(&self.buffered), + Arc::clone(&self.streamed), + ]))) + } + fn execute( &self, partition: usize, @@ -620,7 +654,7 @@ async fn build_buffered_data( // Combine batches and record number of rows let initial = (Vec::new(), 0, metrics, reservation); - let (batches, num_rows, metrics, mut reservation) = buffered + let (batches, num_rows, metrics, reservation) = buffered .try_fold(initial, |mut acc, batch| async { let batch_size = get_record_batch_memory_size(&batch); acc.3.try_grow(batch_size)?; diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs b/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs index ae7a5fa764bc..b34e811f9192 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs @@ -39,7 +39,7 @@ use crate::projection::{ }; use crate::{ DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, - PlanProperties, SendableRecordBatchStream, Statistics, + PlanProperties, SendableRecordBatchStream, Statistics, check_if_same_properties, }; use arrow::compute::SortOptions; @@ -127,7 +127,7 @@ pub struct SortMergeJoinExec { /// Defines the null equality for the join. pub null_equality: NullEquality, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, } impl SortMergeJoinExec { @@ -198,7 +198,7 @@ impl SortMergeJoinExec { right_sort_exprs, sort_options, null_equality, - cache, + cache: Arc::new(cache), }) } @@ -340,6 +340,20 @@ impl SortMergeJoinExec { reorder_output_after_swap(Arc::new(new_join), &left.schema(), &right.schema()) } } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + let left = children.swap_remove(0); + let right = children.swap_remove(0); + Self { + left, + right, + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for SortMergeJoinExec { @@ -353,7 +367,7 @@ impl DisplayAs for SortMergeJoinExec { .collect::>() .join(", "); let display_null_equality = - if matches!(self.null_equality(), NullEquality::NullEqualsNull) { + if self.null_equality() == NullEquality::NullEqualsNull { ", NullsEqual: true" } else { "" @@ -386,7 +400,7 @@ impl DisplayAs for SortMergeJoinExec { } writeln!(f, "on={on}")?; - if matches!(self.null_equality(), NullEquality::NullEqualsNull) { + if self.null_equality() == NullEquality::NullEqualsNull { writeln!(f, "NullsEqual: true")?; } @@ -405,7 +419,7 @@ impl ExecutionPlan for SortMergeJoinExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -440,6 +454,7 @@ impl ExecutionPlan for SortMergeJoinExec { self: Arc, children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); match &children[..] { [left, right] => Ok(Arc::new(SortMergeJoinExec::try_new( Arc::clone(left), @@ -519,10 +534,6 @@ impl ExecutionPlan for SortMergeJoinExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { // SortMergeJoinExec uses symmetric hash partitioning where both left and right // inputs are hash-partitioned on the join keys. This means partition `i` of the diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/filter.rs b/datafusion/physical-plan/src/joins/sort_merge_join/filter.rs new file mode 100644 index 000000000000..d598442b653e --- /dev/null +++ b/datafusion/physical-plan/src/joins/sort_merge_join/filter.rs @@ -0,0 +1,595 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Filter handling for Sort-Merge Join +//! +//! This module encapsulates the complexity of join filter evaluation, including: +//! - Immediate filtering for INNER joins +//! - Deferred filtering for outer/semi/anti/mark joins +//! - Metadata tracking for grouping output rows by input row +//! - Correcting filter masks to handle multiple matches per input row + +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayBuilder, ArrayRef, BooleanArray, BooleanBuilder, RecordBatch, + UInt64Array, UInt64Builder, +}; +use arrow::compute::{self, concat_batches, filter_record_batch}; +use arrow::datatypes::SchemaRef; +use datafusion_common::{JoinSide, JoinType, Result}; + +use crate::joins::utils::JoinFilter; + +/// Metadata for tracking filter results during deferred filtering +/// +/// When a join filter is present and we need to ensure each input row produces +/// at least one output (outer joins) or exactly one output (semi joins), we can't +/// filter immediately. Instead, we accumulate all joined rows with metadata, +/// then post-process to determine which rows to output. +#[derive(Debug)] +pub struct FilterMetadata { + /// Did each output row pass the join filter? + /// Used to detect if an input row found ANY match + pub filter_mask: BooleanBuilder, + + /// Which input row (within batch) produced each output row? + /// Used for grouping output rows by input row + pub row_indices: UInt64Builder, + + /// Which input batch did each output row come from? + /// Used to disambiguate row_indices across multiple batches + pub batch_ids: Vec, +} + +impl FilterMetadata { + /// Create new empty filter metadata + pub fn new() -> Self { + Self { + filter_mask: BooleanBuilder::new(), + row_indices: UInt64Builder::new(), + batch_ids: vec![], + } + } + + /// Returns (row_indices, filter_mask, batch_ids_ref) and clears builders + pub fn finish_metadata(&mut self) -> (UInt64Array, BooleanArray, &[usize]) { + let row_indices = self.row_indices.finish(); + let filter_mask = self.filter_mask.finish(); + (row_indices, filter_mask, &self.batch_ids) + } + + /// Add metadata for null-joined rows (no filter applied) + pub fn append_nulls(&mut self, num_rows: usize) { + self.filter_mask.append_nulls(num_rows); + self.row_indices.append_nulls(num_rows); + self.batch_ids.resize( + self.batch_ids.len() + num_rows, + 0, // batch_id = 0 for null-joined rows + ); + } + + /// Add metadata for filtered rows + pub fn append_filter_metadata( + &mut self, + row_indices: &UInt64Array, + filter_mask: &BooleanArray, + batch_id: usize, + ) { + debug_assert_eq!( + row_indices.len(), + filter_mask.len(), + "row_indices and filter_mask must have same length" + ); + + for i in 0..row_indices.len() { + if filter_mask.is_null(i) { + self.filter_mask.append_null(); + } else if filter_mask.value(i) { + self.filter_mask.append_value(true); + } else { + self.filter_mask.append_value(false); + } + + if row_indices.is_null(i) { + self.row_indices.append_null(); + } else { + self.row_indices.append_value(row_indices.value(i)); + } + + self.batch_ids.push(batch_id); + } + } + + /// Verify that metadata arrays are aligned (same length) + pub fn debug_assert_metadata_aligned(&self) { + if self.filter_mask.len() > 0 { + debug_assert_eq!( + self.filter_mask.len(), + self.row_indices.len(), + "filter_mask and row_indices must have same length when metadata is used" + ); + debug_assert_eq!( + self.filter_mask.len(), + self.batch_ids.len(), + "filter_mask and batch_ids must have same length when metadata is used" + ); + } else { + debug_assert_eq!( + self.filter_mask.len(), + 0, + "filter_mask should be empty when batches is empty" + ); + } + } +} + +impl Default for FilterMetadata { + fn default() -> Self { + Self::new() + } +} + +/// Determines if a join type needs deferred filtering +/// +/// Deferred filtering is required when: +/// - A filter exists AND +/// - The join type requires ensuring each input row produces at least one output +/// (or exactly one for semi joins) +pub fn needs_deferred_filtering( + filter: &Option, + join_type: JoinType, +) -> bool { + filter.is_some() + && matches!( + join_type, + JoinType::Left + | JoinType::LeftSemi + | JoinType::LeftMark + | JoinType::Right + | JoinType::RightSemi + | JoinType::RightMark + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::Full + ) +} + +/// Gets the arrays which join filters are applied on +/// +/// Extracts the columns needed for filter evaluation from left and right batch columns +pub fn get_filter_columns( + join_filter: &Option, + left_columns: &[ArrayRef], + right_columns: &[ArrayRef], +) -> Vec { + let mut filter_columns = vec![]; + + if let Some(f) = join_filter { + let left_columns: Vec = f + .column_indices() + .iter() + .filter(|col_index| col_index.side == JoinSide::Left) + .map(|i| Arc::clone(&left_columns[i.index])) + .collect(); + let right_columns: Vec = f + .column_indices() + .iter() + .filter(|col_index| col_index.side == JoinSide::Right) + .map(|i| Arc::clone(&right_columns[i.index])) + .collect(); + + filter_columns.extend(left_columns); + filter_columns.extend(right_columns); + } + + filter_columns +} + +/// Determines if current index is the last occurrence of a row +/// +/// Used during filter mask correction to detect row boundaries when grouping +/// output rows by input row. +fn last_index_for_row( + row_index: usize, + indices: &UInt64Array, + batch_ids: &[usize], + indices_len: usize, +) -> bool { + debug_assert_eq!( + indices.len(), + indices_len, + "indices.len() should match indices_len parameter" + ); + debug_assert_eq!( + batch_ids.len(), + indices_len, + "batch_ids.len() should match indices_len" + ); + debug_assert!( + row_index < indices_len, + "row_index {row_index} should be < indices_len {indices_len}", + ); + + // If this is the last index overall, it's definitely the last for this row + if row_index == indices_len - 1 { + return true; + } + + // Check if next row has different (batch_id, index) pair + let current_batch_id = batch_ids[row_index]; + let next_batch_id = batch_ids[row_index + 1]; + + if current_batch_id != next_batch_id { + return true; + } + + // Same batch_id, check if row index is different + // Both current and next should be non-null (already joined rows) + if indices.is_null(row_index) || indices.is_null(row_index + 1) { + return true; + } + + indices.value(row_index) != indices.value(row_index + 1) +} + +/// Corrects the filter mask for joins with deferred filtering +/// +/// When an input row joins with multiple buffered rows, we get multiple output rows. +/// This function groups them by input row and applies join-type-specific logic: +/// +/// - **Outer joins**: Keep first matching row, convert rest to nulls, add null-joined for unmatched +/// - **Semi joins**: Keep first matching row, discard rest +/// - **Anti joins**: Keep row only if NO matches passed filter +/// - **Mark joins**: Like semi but first match only +/// +/// # Arguments +/// * `join_type` - The type of join being performed +/// * `row_indices` - Which input row produced each output row +/// * `batch_ids` - Which batch each output row came from +/// * `filter_mask` - Whether each output row passed the filter +/// * `expected_size` - Total number of input rows (for adding unmatched) +/// +/// # Returns +/// Corrected mask indicating which rows to include in final output: +/// - `true`: Include this row +/// - `false`: Convert to null-joined row (outer joins) or include as unmatched (anti joins) +/// - `null`: Discard this row +pub fn get_corrected_filter_mask( + join_type: JoinType, + row_indices: &UInt64Array, + batch_ids: &[usize], + filter_mask: &BooleanArray, + expected_size: usize, +) -> Option { + let row_indices_length = row_indices.len(); + let mut corrected_mask: BooleanBuilder = + BooleanBuilder::with_capacity(row_indices_length); + let mut seen_true = false; + + match join_type { + JoinType::Left | JoinType::Right => { + // For outer joins: Keep first matching row per input row, + // convert rest to nulls, add null-joined rows for unmatched + for i in 0..row_indices_length { + let last_index = + last_index_for_row(i, row_indices, batch_ids, row_indices_length); + if filter_mask.value(i) { + seen_true = true; + corrected_mask.append_value(true); + } else if seen_true || !filter_mask.value(i) && !last_index { + corrected_mask.append_null(); // to be ignored and not set to output + } else { + corrected_mask.append_value(false); // to be converted to null joined row + } + + if last_index { + seen_true = false; + } + } + + // Generate null joined rows for records which have no matching join key + corrected_mask.append_n(expected_size - corrected_mask.len(), false); + Some(corrected_mask.finish()) + } + JoinType::LeftMark | JoinType::RightMark => { + // For mark joins: Like outer but only keep first match, mark with boolean + for i in 0..row_indices_length { + let last_index = + last_index_for_row(i, row_indices, batch_ids, row_indices_length); + if filter_mask.value(i) && !seen_true { + seen_true = true; + corrected_mask.append_value(true); + } else if seen_true || !filter_mask.value(i) && !last_index { + corrected_mask.append_null(); // to be ignored and not set to output + } else { + corrected_mask.append_value(false); // to be converted to null joined row + } + + if last_index { + seen_true = false; + } + } + + // Generate null joined rows for records which have no matching join key + corrected_mask.append_n(expected_size - corrected_mask.len(), false); + Some(corrected_mask.finish()) + } + JoinType::LeftSemi | JoinType::RightSemi => { + // For semi joins: Keep only first matching row per input row, discard rest + for i in 0..row_indices_length { + let last_index = + last_index_for_row(i, row_indices, batch_ids, row_indices_length); + if filter_mask.value(i) && !seen_true { + seen_true = true; + corrected_mask.append_value(true); + } else { + corrected_mask.append_null(); // to be ignored and not set to output + } + + if last_index { + seen_true = false; + } + } + + Some(corrected_mask.finish()) + } + JoinType::LeftAnti | JoinType::RightAnti => { + // For anti joins: Keep row only if NO matches passed the filter + for i in 0..row_indices_length { + let last_index = + last_index_for_row(i, row_indices, batch_ids, row_indices_length); + + if filter_mask.value(i) { + seen_true = true; + } + + if last_index { + if !seen_true { + corrected_mask.append_value(true); + } else { + corrected_mask.append_null(); + } + + seen_true = false; + } else { + corrected_mask.append_null(); + } + } + // Generate null joined rows for records which have no matching join key, + // for LeftAnti non-matched considered as true + corrected_mask.append_n(expected_size - corrected_mask.len(), true); + Some(corrected_mask.finish()) + } + JoinType::Full => { + // For full joins: Similar to outer but handle both sides + for i in 0..row_indices_length { + let last_index = + last_index_for_row(i, row_indices, batch_ids, row_indices_length); + + if filter_mask.is_null(i) { + // null joined + corrected_mask.append_value(true); + } else if filter_mask.value(i) { + seen_true = true; + corrected_mask.append_value(true); + } else if seen_true || !filter_mask.value(i) && !last_index { + corrected_mask.append_null(); // to be ignored and not set to output + } else { + corrected_mask.append_value(false); // to be converted to null joined row + } + + if last_index { + seen_true = false; + } + } + // Generate null joined rows for records which have no matching join key + corrected_mask.append_n(expected_size - corrected_mask.len(), false); + Some(corrected_mask.finish()) + } + JoinType::Inner => { + // Inner joins don't need deferred filtering + None + } + } +} + +/// Applies corrected filter mask to record batch based on join type +/// +/// Different join types require different handling of filtered results: +/// - Outer joins: Add null-joined rows for false mask values +/// - Semi/Anti joins: May need projection to remove right columns +/// - Full joins: Add null-joined rows for both sides +pub fn filter_record_batch_by_join_type( + record_batch: &RecordBatch, + corrected_mask: &BooleanArray, + join_type: JoinType, + schema: &SchemaRef, + streamed_schema: &SchemaRef, + buffered_schema: &SchemaRef, +) -> Result { + let filtered_record_batch = filter_record_batch(record_batch, corrected_mask)?; + + match join_type { + JoinType::Left | JoinType::LeftMark => { + // For left joins, add null-joined rows where mask is false + let null_mask = compute::not(corrected_mask)?; + let null_joined_batch = filter_record_batch(record_batch, &null_mask)?; + + if null_joined_batch.num_rows() == 0 { + return Ok(filtered_record_batch); + } + + // Create null columns for right side + let null_joined_streamed_batch = create_null_joined_batch( + &null_joined_batch, + buffered_schema, + JoinSide::Left, + join_type, + schema, + )?; + + Ok(concat_batches( + schema, + &[filtered_record_batch, null_joined_streamed_batch], + )?) + } + JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::RightSemi + | JoinType::RightAnti => { + // For semi/anti joins, project to only include the outer side columns + // Both Left and Right semi/anti use streamed_schema.len() because: + // - For Left: columns are [left, right], so we take first streamed_schema.len() + // - For Right: columns are [right, left], and streamed side is right, so we take first streamed_schema.len() + let output_column_indices: Vec = + (0..streamed_schema.fields().len()).collect(); + Ok(filtered_record_batch.project(&output_column_indices)?) + } + JoinType::Right | JoinType::RightMark => { + // For right joins, add null-joined rows where mask is false + let null_mask = compute::not(corrected_mask)?; + let null_joined_batch = filter_record_batch(record_batch, &null_mask)?; + + if null_joined_batch.num_rows() == 0 { + return Ok(filtered_record_batch); + } + + // Create null columns for left side (buffered side for RIGHT join) + let null_joined_buffered_batch = create_null_joined_batch( + &null_joined_batch, + buffered_schema, // Pass buffered (left) schema to create nulls for it + JoinSide::Right, + join_type, + schema, + )?; + + Ok(concat_batches( + schema, + &[filtered_record_batch, null_joined_buffered_batch], + )?) + } + JoinType::Full => { + // For full joins, add null-joined rows for both sides + let joined_filter_not_matched_mask = compute::not(corrected_mask)?; + let joined_filter_not_matched_batch = + filter_record_batch(record_batch, &joined_filter_not_matched_mask)?; + + if joined_filter_not_matched_batch.num_rows() == 0 { + return Ok(filtered_record_batch); + } + + // Create null-joined batches for both sides + let left_null_joined_batch = create_null_joined_batch( + &joined_filter_not_matched_batch, + buffered_schema, + JoinSide::Left, + join_type, + schema, + )?; + + Ok(concat_batches( + schema, + &[filtered_record_batch, left_null_joined_batch], + )?) + } + JoinType::Inner => Ok(filtered_record_batch), + } +} + +/// Creates a batch with null columns for the non-joined side +/// +/// Note: The input `batch` is assumed to be a fully-joined batch that already contains +/// columns from both sides. We need to extract the data side columns and replace the +/// null side columns with actual nulls. +fn create_null_joined_batch( + batch: &RecordBatch, + null_schema: &SchemaRef, + join_side: JoinSide, + join_type: JoinType, + output_schema: &SchemaRef, +) -> Result { + let num_rows = batch.num_rows(); + + // The input batch is a fully-joined batch [left_cols..., right_cols...] + // We need to extract the appropriate side and replace the other with nulls (or mark column) + let columns = match (join_side, join_type) { + (JoinSide::Left, JoinType::LeftMark) => { + // For LEFT mark: output is [left_cols..., mark_col] + // Batch is [left_cols..., right_cols...], extract left from beginning + // Number of left columns = output columns - 1 (mark column) + let left_col_count = output_schema.fields().len() - 1; + let mut result: Vec = batch.columns()[..left_col_count].to_vec(); + result.push(Arc::new(BooleanArray::from(vec![false; num_rows])) as ArrayRef); + result + } + (JoinSide::Right, JoinType::RightMark) => { + // For RIGHT mark: output is [right_cols..., mark_col] + // For RIGHT joins, batch is [right_cols..., left_cols...] (right comes first!) + // Extract right columns from the beginning + let right_col_count = output_schema.fields().len() - 1; // -1 for mark column + let mut result: Vec = batch.columns()[..right_col_count].to_vec(); + result.push(Arc::new(BooleanArray::from(vec![false; num_rows])) as ArrayRef); + result + } + (JoinSide::Left, _) => { + // For LEFT join: output is [left_cols..., right_cols...] + // Extract left columns, then add null right columns + let null_columns: Vec = null_schema + .fields() + .iter() + .map(|field| arrow::array::new_null_array(field.data_type(), num_rows)) + .collect(); + let left_col_count = output_schema.fields().len() - null_columns.len(); + let mut result: Vec = batch.columns()[..left_col_count].to_vec(); + result.extend(null_columns); + result + } + (JoinSide::Right, _) => { + // For RIGHT join: batch is [left_cols..., right_cols...] (same as schema) + // We want: [null_left..., actual_right...] + // Extract left columns from beginning, replace with nulls, keep right columns + let null_columns: Vec = null_schema + .fields() + .iter() + .map(|field| arrow::array::new_null_array(field.data_type(), num_rows)) + .collect(); + let left_col_count = null_columns.len(); + let mut result = null_columns; + // Extract right columns starting after left columns + result.extend_from_slice(&batch.columns()[left_col_count..]); + result + } + (JoinSide::None, _) => { + // This should not happen in normal join operations + unreachable!( + "JoinSide::None should not be used in null-joined batch creation" + ) + } + }; + + // Create the batch - don't validate nullability since outer joins can have + // null values in columns that were originally non-nullable + use arrow::array::RecordBatchOptions; + let mut options = RecordBatchOptions::new(); + options = options.with_row_count(Some(num_rows)); + Ok(RecordBatch::try_new_with_options( + Arc::clone(output_schema), + columns, + &options, + )?) +} diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/mod.rs b/datafusion/physical-plan/src/joins/sort_merge_join/mod.rs index 82f18e741409..06290ec4d090 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/mod.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/mod.rs @@ -20,6 +20,7 @@ pub use exec::SortMergeJoinExec; mod exec; +mod filter; mod metrics; mod stream; diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs index b36992caf4b4..4dcbe1f64799 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs @@ -33,6 +33,10 @@ use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering::Relaxed; use std::task::{Context, Poll}; +use crate::joins::sort_merge_join::filter::{ + FilterMetadata, filter_record_batch_by_join_type, get_corrected_filter_mask, + get_filter_columns, needs_deferred_filtering, +}; use crate::joins::sort_merge_join::metrics::SortMergeJoinMetrics; use crate::joins::utils::{JoinFilter, compare_join_arrays}; use crate::metrics::RecordOutput; @@ -42,15 +46,13 @@ use crate::{PhysicalExpr, RecordBatchStream, SendableRecordBatchStream}; use arrow::array::{types::UInt64Type, *}; use arrow::compute::{ self, BatchCoalescer, SortOptions, concat_batches, filter_record_batch, is_not_null, - take, + take, take_arrays, }; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; -use arrow::error::ArrowError; use arrow::ipc::reader::StreamReader; use datafusion_common::config::SpillCompression; use datafusion_common::{ - DataFusionError, HashSet, JoinSide, JoinType, NullEquality, Result, exec_err, - internal_err, not_impl_err, + HashSet, JoinType, NullEquality, Result, exec_err, internal_err, not_impl_err, }; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::MemoryReservation; @@ -68,6 +70,8 @@ pub(super) enum SortMergeJoinState { Polling, /// Joining polled data and making output JoinOutput, + /// Emit ready data if have any and then go back to [`Self::Init`] state + EmitReadyThenInit, /// No more output Exhausted, } @@ -124,6 +128,8 @@ pub(super) struct StreamedBatch { pub join_arrays: Vec, /// Chunks of indices from buffered side (may be nulls) joined to streamed pub output_indices: Vec, + /// Total number of output rows across all chunks in `output_indices` + pub num_output_rows: usize, /// Index of currently scanned batch from buffered data pub buffered_batch_idx: Option, /// Indices that found a match for the given join filter @@ -140,6 +146,7 @@ impl StreamedBatch { idx: 0, join_arrays, output_indices: vec![], + num_output_rows: 0, buffered_batch_idx: None, join_filter_matched_idxs: HashSet::new(), } @@ -151,6 +158,7 @@ impl StreamedBatch { idx: 0, join_arrays: vec![], output_indices: vec![], + num_output_rows: 0, buffered_batch_idx: None, join_filter_matched_idxs: HashSet::new(), } @@ -158,10 +166,7 @@ impl StreamedBatch { /// Number of unfrozen output pairs in this streamed batch fn num_output_rows(&self) -> usize { - self.output_indices - .iter() - .map(|chunk| chunk.streamed_indices.len()) - .sum() + self.num_output_rows } /// Appends new pair consisting of current streamed index and `buffered_idx` @@ -171,7 +176,6 @@ impl StreamedBatch { buffered_batch_idx: Option, buffered_idx: Option, batch_size: usize, - num_unfrozen_pairs: usize, ) { // If no current chunk exists or current chunk is not for current buffered batch, // create a new chunk @@ -179,12 +183,13 @@ impl StreamedBatch { { // Compute capacity only when creating a new chunk (infrequent operation). // The capacity is the remaining space to reach batch_size. - // This should always be >= 1 since we only call this when num_unfrozen_pairs < batch_size. + // This should always be >= 1 since we only call this when num_output_rows < batch_size. debug_assert!( - batch_size > num_unfrozen_pairs, - "batch_size ({batch_size}) must be > num_unfrozen_pairs ({num_unfrozen_pairs})" + batch_size > self.num_output_rows, + "batch_size ({batch_size}) must be > num_output_rows ({})", + self.num_output_rows ); - let capacity = batch_size - num_unfrozen_pairs; + let capacity = batch_size - self.num_output_rows; self.output_indices.push(StreamedJoinedChunk { buffered_batch_idx, streamed_indices: UInt64Builder::with_capacity(capacity), @@ -201,6 +206,7 @@ impl StreamedBatch { } else { current_chunk.buffered_indices.append_null(); } + self.num_output_rows += 1; } } @@ -370,12 +376,8 @@ pub(super) struct SortMergeJoinStream { pub(super) struct JoinedRecordBatches { /// Joined batches. Each batch is already joined columns from left and right sources pub(super) joined_batches: BatchCoalescer, - /// Did each output row pass the join filter? (detect if input row found any match) - pub(super) filter_mask: BooleanBuilder, - /// Which input row (within batch) produced each output row? (for grouping by input row) - pub(super) row_indices: UInt64Builder, - /// Which input batch did each output row come from? (disambiguate row_indices) - pub(super) batch_ids: Vec, + /// Filter metadata for deferred filtering + pub(super) filter_metadata: FilterMetadata, } impl JoinedRecordBatches { @@ -398,61 +400,28 @@ impl JoinedRecordBatches { } } - /// Finishes and returns the metadata arrays, clearing the builders - /// - /// Returns (row_indices, filter_mask, batch_ids_ref) - /// Note: batch_ids is returned as a reference since it's still needed in the struct - fn finish_metadata(&mut self) -> (UInt64Array, BooleanArray, &[usize]) { - let row_indices = self.row_indices.finish(); - let filter_mask = self.filter_mask.finish(); - (row_indices, filter_mask, &self.batch_ids) - } - /// Clears batches without touching metadata (for early return when no filtering needed) fn clear_batches(&mut self, schema: &SchemaRef, batch_size: usize) { self.joined_batches = BatchCoalescer::new(Arc::clone(schema), batch_size) .with_biggest_coalesce_batch_size(Option::from(batch_size / 2)); } - /// Asserts that internal metadata arrays are consistent with each other - /// Only checks if metadata is actually being used (i.e., not all empty) - #[inline] - fn debug_assert_metadata_aligned(&self) { - // Metadata arrays should be aligned IF they're being used - // (For non-filtered joins, they may all be empty) - if self.filter_mask.len() > 0 - || self.row_indices.len() > 0 - || !self.batch_ids.is_empty() - { - debug_assert_eq!( - self.filter_mask.len(), - self.row_indices.len(), - "filter_mask and row_indices must have same length when metadata is used" - ); - debug_assert_eq!( - self.filter_mask.len(), - self.batch_ids.len(), - "filter_mask and batch_ids must have same length when metadata is used" - ); - } - } - /// Asserts that if batches is empty, metadata is also empty #[inline] fn debug_assert_empty_consistency(&self) { if self.joined_batches.is_empty() { debug_assert_eq!( - self.filter_mask.len(), + self.filter_metadata.filter_mask.len(), 0, "filter_mask should be empty when batches is empty" ); debug_assert_eq!( - self.row_indices.len(), + self.filter_metadata.row_indices.len(), 0, "row_indices should be empty when batches is empty" ); debug_assert_eq!( - self.batch_ids.len(), + self.filter_metadata.batch_ids.len(), 0, "batch_ids should be empty when batches is empty" ); @@ -467,20 +436,15 @@ impl JoinedRecordBatches { /// Maintains invariant: N rows → N metadata entries (nulls) fn push_batch_with_null_metadata(&mut self, batch: RecordBatch, join_type: JoinType) { debug_assert!( - matches!(join_type, JoinType::Full), + join_type == JoinType::Full, "push_batch_with_null_metadata should only be called for Full joins" ); let num_rows = batch.num_rows(); - self.filter_mask.append_nulls(num_rows); - self.row_indices.append_nulls(num_rows); - self.batch_ids.resize( - self.batch_ids.len() + num_rows, - 0, // batch_id = 0 for null-joined rows - ); + self.filter_metadata.append_nulls(num_rows); - self.debug_assert_metadata_aligned(); + self.filter_metadata.debug_assert_metadata_aligned(); self.joined_batches .push_batch(batch) .expect("Failed to push batch to BatchCoalescer"); @@ -525,13 +489,13 @@ impl JoinedRecordBatches { "row_indices and filter_mask must have same length" ); - // For Full joins, we keep the pre_mask (with nulls), for others we keep the cleaned mask - self.filter_mask.extend(filter_mask); - self.row_indices.extend(row_indices); - self.batch_ids - .resize(self.batch_ids.len() + row_indices.len(), streamed_batch_id); + self.filter_metadata.append_filter_metadata( + row_indices, + filter_mask, + streamed_batch_id, + ); - self.debug_assert_metadata_aligned(); + self.filter_metadata.debug_assert_metadata_aligned(); self.joined_batches .push_batch(batch) .expect("Failed to push batch to BatchCoalescer"); @@ -551,9 +515,7 @@ impl JoinedRecordBatches { fn clear(&mut self, schema: &SchemaRef, batch_size: usize) { self.joined_batches = BatchCoalescer::new(Arc::clone(schema), batch_size) .with_biggest_coalesce_batch_size(Option::from(batch_size / 2)); - self.batch_ids.clear(); - self.filter_mask = BooleanBuilder::new(); - self.row_indices = UInt64Builder::new(); + self.filter_metadata = FilterMetadata::new(); self.debug_assert_empty_consistency(); } } @@ -563,199 +525,6 @@ impl RecordBatchStream for SortMergeJoinStream { } } -/// True if next index refers to either: -/// - another batch id -/// - another row index within same batch id -/// - end of row indices -#[inline(always)] -fn last_index_for_row( - row_index: usize, - indices: &UInt64Array, - batch_ids: &[usize], - indices_len: usize, -) -> bool { - debug_assert_eq!( - indices.len(), - indices_len, - "indices.len() should match indices_len parameter" - ); - debug_assert_eq!( - batch_ids.len(), - indices_len, - "batch_ids.len() should match indices_len" - ); - debug_assert!( - row_index < indices_len, - "row_index {row_index} should be < indices_len {indices_len}", - ); - - row_index == indices_len - 1 - || batch_ids[row_index] != batch_ids[row_index + 1] - || indices.value(row_index) != indices.value(row_index + 1) -} - -// Returns a corrected boolean bitmask for the given join type -// Values in the corrected bitmask can be: true, false, null -// `true` - the row found its match and sent to the output -// `null` - the row ignored, no output -// `false` - the row sent as NULL joined row -pub(super) fn get_corrected_filter_mask( - join_type: JoinType, - row_indices: &UInt64Array, - batch_ids: &[usize], - filter_mask: &BooleanArray, - expected_size: usize, -) -> Option { - let row_indices_length = row_indices.len(); - let mut corrected_mask: BooleanBuilder = - BooleanBuilder::with_capacity(row_indices_length); - let mut seen_true = false; - - match join_type { - JoinType::Left | JoinType::Right => { - for i in 0..row_indices_length { - let last_index = - last_index_for_row(i, row_indices, batch_ids, row_indices_length); - if filter_mask.value(i) { - seen_true = true; - corrected_mask.append_value(true); - } else if seen_true || !filter_mask.value(i) && !last_index { - corrected_mask.append_null(); // to be ignored and not set to output - } else { - corrected_mask.append_value(false); // to be converted to null joined row - } - - if last_index { - seen_true = false; - } - } - - // Generate null joined rows for records which have no matching join key - corrected_mask.append_n(expected_size - corrected_mask.len(), false); - Some(corrected_mask.finish()) - } - JoinType::LeftMark | JoinType::RightMark => { - for i in 0..row_indices_length { - let last_index = - last_index_for_row(i, row_indices, batch_ids, row_indices_length); - if filter_mask.value(i) && !seen_true { - seen_true = true; - corrected_mask.append_value(true); - } else if seen_true || !filter_mask.value(i) && !last_index { - corrected_mask.append_null(); // to be ignored and not set to output - } else { - corrected_mask.append_value(false); // to be converted to null joined row - } - - if last_index { - seen_true = false; - } - } - - // Generate null joined rows for records which have no matching join key - corrected_mask.append_n(expected_size - corrected_mask.len(), false); - Some(corrected_mask.finish()) - } - JoinType::LeftSemi | JoinType::RightSemi => { - for i in 0..row_indices_length { - let last_index = - last_index_for_row(i, row_indices, batch_ids, row_indices_length); - if filter_mask.value(i) && !seen_true { - seen_true = true; - corrected_mask.append_value(true); - } else { - corrected_mask.append_null(); // to be ignored and not set to output - } - - if last_index { - seen_true = false; - } - } - - Some(corrected_mask.finish()) - } - JoinType::LeftAnti | JoinType::RightAnti => { - for i in 0..row_indices_length { - let last_index = - last_index_for_row(i, row_indices, batch_ids, row_indices_length); - - if filter_mask.value(i) { - seen_true = true; - } - - if last_index { - if !seen_true { - corrected_mask.append_value(true); - } else { - corrected_mask.append_null(); - } - - seen_true = false; - } else { - corrected_mask.append_null(); - } - } - // Generate null joined rows for records which have no matching join key, - // for LeftAnti non-matched considered as true - corrected_mask.append_n(expected_size - corrected_mask.len(), true); - Some(corrected_mask.finish()) - } - JoinType::Full => { - let mut mask: Vec> = vec![Some(true); row_indices_length]; - let mut last_true_idx = 0; - let mut first_row_idx = 0; - let mut seen_false = false; - - for i in 0..row_indices_length { - let last_index = - last_index_for_row(i, row_indices, batch_ids, row_indices_length); - let val = filter_mask.value(i); - let is_null = filter_mask.is_null(i); - - if val { - // memoize the first seen matched row - if !seen_true { - last_true_idx = i; - } - seen_true = true; - } - - if is_null || val { - mask[i] = Some(true); - } else if !is_null && !val && (seen_true || seen_false) { - mask[i] = None; - } else { - mask[i] = Some(false); - } - - if !is_null && !val { - seen_false = true; - } - - if last_index { - // If the left row seen as true its needed to output it once - // To do that we mark all other matches for same row as null to avoid the output - if seen_true { - #[expect(clippy::needless_range_loop)] - for j in first_row_idx..last_true_idx { - mask[j] = None; - } - } - - seen_true = false; - seen_false = false; - last_true_idx = 0; - first_row_idx = i + 1; - } - } - - Some(BooleanArray::from(mask)) - } - // Only outer joins needs to keep track of processed rows and apply corrected filter mask - _ => None, - } -} - impl Stream for SortMergeJoinStream { type Item = Result; @@ -778,7 +547,10 @@ impl Stream for SortMergeJoinStream { match self.current_ordering { Ordering::Less | Ordering::Equal => { if !streamed_exhausted { - if self.needs_deferred_filtering() { + if needs_deferred_filtering( + &self.filter, + self.join_type, + ) { match self.process_filtered_batches()? { Poll::Ready(Some(batch)) => { return Poll::Ready(Some(Ok(batch))); @@ -830,22 +602,56 @@ impl Stream for SortMergeJoinStream { self.current_ordering = self.compare_streamed_buffered()?; self.state = SortMergeJoinState::JoinOutput; } + SortMergeJoinState::EmitReadyThenInit => { + // If have data to emit, emit it and if no more, change to next + + // Verify metadata alignment before checking if we have batches to output + self.joined_record_batches + .filter_metadata + .debug_assert_metadata_aligned(); + + // For filtered joins, skip output and let Init state handle it + if needs_deferred_filtering(&self.filter, self.join_type) { + self.state = SortMergeJoinState::Init; + continue; + } + + // For non-filtered joins, only output if we have a completed batch + // (opportunistic output when target batch size is reached) + if self + .joined_record_batches + .joined_batches + .has_completed_batch() + { + let record_batch = self + .joined_record_batches + .joined_batches + .next_completed_batch() + .expect("has_completed_batch was true"); + (&record_batch) + .record_output(&self.join_metrics.baseline_metrics()); + return Poll::Ready(Some(Ok(record_batch))); + } + self.state = SortMergeJoinState::Init; + } SortMergeJoinState::JoinOutput => { self.join_partial()?; if self.num_unfrozen_pairs() < self.batch_size { if self.buffered_data.scanning_finished() { self.buffered_data.scanning_reset(); - self.state = SortMergeJoinState::Init; + self.state = SortMergeJoinState::EmitReadyThenInit; } } else { self.freeze_all()?; // Verify metadata alignment before checking if we have batches to output - self.joined_record_batches.debug_assert_metadata_aligned(); + self.joined_record_batches + .filter_metadata + .debug_assert_metadata_aligned(); // For filtered joins, skip output and let Init state handle it - if self.needs_deferred_filtering() { + if needs_deferred_filtering(&self.filter, self.join_type) { continue; } @@ -872,10 +678,12 @@ impl Stream for SortMergeJoinStream { self.freeze_all()?; // Verify metadata alignment before final output - self.joined_record_batches.debug_assert_metadata_aligned(); + self.joined_record_batches + .filter_metadata + .debug_assert_metadata_aligned(); // For filtered joins, must concat and filter ALL data at once - if self.needs_deferred_filtering() + if needs_deferred_filtering(&self.filter, self.join_type) && !self.joined_record_batches.joined_batches.is_empty() { let record_batch = self.filter_joined_batch()?; @@ -975,9 +783,7 @@ impl SortMergeJoinStream { joined_record_batches: JoinedRecordBatches { joined_batches: BatchCoalescer::new(Arc::clone(&schema), batch_size) .with_biggest_coalesce_batch_size(Option::from(batch_size / 2)), - filter_mask: BooleanBuilder::new(), - row_indices: UInt64Builder::new(), - batch_ids: vec![], + filter_metadata: FilterMetadata::new(), }, output: BatchCoalescer::new(schema, batch_size) .with_biggest_coalesce_batch_size(Option::from(batch_size / 2)), @@ -996,26 +802,6 @@ impl SortMergeJoinStream { self.streamed_batch.num_output_rows() } - /// Returns true if this join needs deferred filtering - /// - /// Deferred filtering is needed when a filter exists and the join type requires - /// ensuring each input row produces at least one output row (or exactly one for semi). - fn needs_deferred_filtering(&self) -> bool { - self.filter.is_some() - && matches!( - self.join_type, - JoinType::Left - | JoinType::LeftSemi - | JoinType::LeftMark - | JoinType::Right - | JoinType::RightSemi - | JoinType::RightMark - | JoinType::LeftAnti - | JoinType::RightAnti - | JoinType::Full - ) - } - /// Process accumulated batches for filtered joins /// /// Freezes unfrozen pairs, applies deferred filtering, and outputs if ready. @@ -1023,7 +809,9 @@ impl SortMergeJoinStream { fn process_filtered_batches(&mut self) -> Poll>> { self.freeze_all()?; - self.joined_record_batches.debug_assert_metadata_aligned(); + self.joined_record_batches + .filter_metadata + .debug_assert_metadata_aligned(); if !self.joined_record_batches.joined_batches.is_empty() { let out_filtered_batch = self.filter_joined_batch()?; @@ -1329,7 +1117,7 @@ impl SortMergeJoinStream { } } Ordering::Greater => { - if matches!(self.join_type, JoinType::Full) { + if self.join_type == JoinType::Full { join_buffered = !self.buffered_joined; }; } @@ -1348,13 +1136,10 @@ impl SortMergeJoinStream { let scanning_idx = self.buffered_data.scanning_idx(); if join_streamed { // Join streamed row and buffered row - // Pass batch_size and num_unfrozen_pairs to compute capacity only when - // creating a new chunk (when buffered_batch_idx changes), not on every iteration. self.streamed_batch.append_output_pair( Some(self.buffered_data.scanning_batch_idx), Some(scanning_idx), self.batch_size, - self.num_unfrozen_pairs(), ); } else { // Join nulls and buffered row for FULL join @@ -1380,13 +1165,10 @@ impl SortMergeJoinStream { // For Mark join we store a dummy id to indicate the row has a match let scanning_idx = mark_row_as_match.then_some(0); - // Pass batch_size=1 and num_unfrozen_pairs=0 to get capacity of 1, - // since we only append a single null-joined pair here (not in a loop). self.streamed_batch.append_output_pair( scanning_batch_idx, scanning_idx, - 1, - 0, + self.batch_size, ); self.buffered_data.scanning_finish(); self.streamed_joined = true; @@ -1399,7 +1181,9 @@ impl SortMergeJoinStream { self.freeze_streamed()?; // After freezing, metadata should be aligned - self.joined_record_batches.debug_assert_metadata_aligned(); + self.joined_record_batches + .filter_metadata + .debug_assert_metadata_aligned(); Ok(()) } @@ -1414,7 +1198,9 @@ impl SortMergeJoinStream { self.freeze_buffered(1)?; // After freezing, metadata should be aligned - self.joined_record_batches.debug_assert_metadata_aligned(); + self.joined_record_batches + .filter_metadata + .debug_assert_metadata_aligned(); Ok(()) } @@ -1425,7 +1211,7 @@ impl SortMergeJoinStream { // Applicable only in case of Full join. // fn freeze_buffered(&mut self, batch_count: usize) -> Result<()> { - if !matches!(self.join_type, JoinType::Full) { + if self.join_type != JoinType::Full { return Ok(()); } for buffered_batch in self.buffered_data.batches.range_mut(..batch_count) { @@ -1450,7 +1236,7 @@ impl SortMergeJoinStream { &mut self, buffered_batch: &mut BufferedBatch, ) -> Result<()> { - if !matches!(self.join_type, JoinType::Full) { + if self.join_type != JoinType::Full { return Ok(()); } @@ -1490,13 +1276,19 @@ impl SortMergeJoinStream { continue; } - let mut left_columns = self - .streamed_batch - .batch - .columns() - .iter() - .map(|column| take(column, &left_indices, None)) - .collect::, ArrowError>>()?; + let mut left_columns = if let Some(range) = is_contiguous_range(&left_indices) + { + // When indices form a contiguous range (common for the streamed + // side which advances sequentially), use zero-copy slice instead + // of the O(n) take kernel. + self.streamed_batch + .batch + .slice(range.start, range.len()) + .columns() + .to_vec() + } else { + take_arrays(self.streamed_batch.batch.columns(), &left_indices, None)? + }; // The row indices of joined buffered batch let right_indices: UInt64Array = chunk.buffered_indices.finish(); @@ -1529,35 +1321,37 @@ impl SortMergeJoinStream { // Prepare the columns we apply join filter on later. // Only for joined rows between streamed and buffered. - let filter_columns = if chunk.buffered_batch_idx.is_some() { - if !matches!(self.join_type, JoinType::Right) { + let filter_columns = if let Some(buffered_batch_idx) = + chunk.buffered_batch_idx + { + if self.join_type != JoinType::Right { if matches!( self.join_type, JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark ) { let right_cols = fetch_right_columns_by_idxs( &self.buffered_data, - chunk.buffered_batch_idx.unwrap(), + buffered_batch_idx, &right_indices, )?; - get_filter_column(&self.filter, &left_columns, &right_cols) + get_filter_columns(&self.filter, &left_columns, &right_cols) } else if matches!( self.join_type, JoinType::RightAnti | JoinType::RightSemi | JoinType::RightMark ) { let right_cols = fetch_right_columns_by_idxs( &self.buffered_data, - chunk.buffered_batch_idx.unwrap(), + buffered_batch_idx, &right_indices, )?; - get_filter_column(&self.filter, &right_cols, &left_columns) + get_filter_columns(&self.filter, &right_cols, &left_columns) } else { - get_filter_column(&self.filter, &left_columns, &right_columns) + get_filter_columns(&self.filter, &left_columns, &right_columns) } } else { - get_filter_column(&self.filter, &right_columns, &left_columns) + get_filter_columns(&self.filter, &right_columns, &left_columns) } } else { // This chunk is totally for null joined rows (outer join), we don't need to apply join filter. @@ -1565,7 +1359,7 @@ impl SortMergeJoinStream { vec![] }; - let columns = if !matches!(self.join_type, JoinType::Right) { + let columns = if self.join_type != JoinType::Right { left_columns.extend(right_columns); left_columns } else { @@ -1618,7 +1412,7 @@ impl SortMergeJoinStream { if needs_deferred_filtering { // Outer/semi/anti/mark joins: push unfiltered batch with metadata for deferred filtering - let mask_to_use = if !matches!(self.join_type, JoinType::Full) { + let mask_to_use = if self.join_type != JoinType::Full { &mask } else { pre_mask @@ -1642,7 +1436,7 @@ impl SortMergeJoinStream { // all joined rows are failed on the join filter. // I.e., if all rows joined from a streamed row are failed with the join filter, // we need to join it with nulls as buffered side. - if matches!(self.join_type, JoinType::Full) { + if self.join_type == JoinType::Full { let buffered_batch = &mut self.buffered_data.batches [chunk.buffered_batch_idx.unwrap()]; @@ -1673,17 +1467,20 @@ impl SortMergeJoinStream { } self.streamed_batch.output_indices.clear(); + self.streamed_batch.num_output_rows = 0; Ok(()) } fn filter_joined_batch(&mut self) -> Result { // Metadata should be aligned before processing - self.joined_record_batches.debug_assert_metadata_aligned(); + self.joined_record_batches + .filter_metadata + .debug_assert_metadata_aligned(); let record_batch = self.joined_record_batches.concat_batches(&self.schema)?; let (mut out_indices, mut out_mask, mut batch_ids) = - self.joined_record_batches.finish_metadata(); + self.joined_record_batches.filter_metadata.finish_metadata(); let default_batch_ids = vec![0; record_batch.num_rows()]; // If only nulls come in and indices sizes doesn't match with expected record batch count @@ -1754,139 +1551,14 @@ impl SortMergeJoinStream { record_batch: &RecordBatch, corrected_mask: &BooleanArray, ) -> Result { - // Corrected mask should have length matching or exceeding record_batch rows - // (for outer joins it may be longer to include null-joined rows) - debug_assert!( - corrected_mask.len() >= record_batch.num_rows(), - "corrected_mask length ({}) should be >= record_batch rows ({})", - corrected_mask.len(), - record_batch.num_rows() - ); - - let mut filtered_record_batch = - filter_record_batch(record_batch, corrected_mask)?; - let left_columns_length = self.streamed_schema.fields.len(); - let right_columns_length = self.buffered_schema.fields.len(); - - if matches!( - self.join_type, - JoinType::Left | JoinType::LeftMark | JoinType::Right | JoinType::RightMark - ) { - let null_mask = compute::not(corrected_mask)?; - let null_joined_batch = filter_record_batch(record_batch, &null_mask)?; - - let mut right_columns = create_unmatched_columns( - self.join_type, - &self.buffered_schema, - null_joined_batch.num_rows(), - ); - - let columns = match self.join_type { - JoinType::Right => { - // The first columns are the right columns. - let left_columns = null_joined_batch - .columns() - .iter() - .skip(right_columns_length) - .cloned() - .collect::>(); - - right_columns.extend(left_columns); - right_columns - } - JoinType::Left | JoinType::LeftMark | JoinType::RightMark => { - // The first columns are the left columns. - let mut left_columns = null_joined_batch - .columns() - .iter() - .take(left_columns_length) - .cloned() - .collect::>(); - - left_columns.extend(right_columns); - left_columns - } - _ => exec_err!("Did not expect join type {}", self.join_type)?, - }; - - // Push the streamed/buffered batch joined nulls to the output - let null_joined_streamed_batch = - RecordBatch::try_new(Arc::clone(&self.schema), columns)?; - - filtered_record_batch = concat_batches( - &self.schema, - &[filtered_record_batch, null_joined_streamed_batch], - )?; - } else if matches!( + let filtered_record_batch = filter_record_batch_by_join_type( + record_batch, + corrected_mask, self.join_type, - JoinType::LeftSemi - | JoinType::LeftAnti - | JoinType::RightAnti - | JoinType::RightSemi - ) { - let output_column_indices = (0..left_columns_length).collect::>(); - filtered_record_batch = - filtered_record_batch.project(&output_column_indices)?; - } else if matches!(self.join_type, JoinType::Full) - && corrected_mask.false_count() > 0 - { - // Find rows which joined by key but Filter predicate evaluated as false - let joined_filter_not_matched_mask = compute::not(corrected_mask)?; - let joined_filter_not_matched_batch = - filter_record_batch(record_batch, &joined_filter_not_matched_mask)?; - - // Add left unmatched rows adding the right side as nulls - let right_null_columns = self - .buffered_schema - .fields() - .iter() - .map(|f| { - new_null_array( - f.data_type(), - joined_filter_not_matched_batch.num_rows(), - ) - }) - .collect::>(); - - let mut result_joined = joined_filter_not_matched_batch - .columns() - .iter() - .take(left_columns_length) - .cloned() - .collect::>(); - - result_joined.extend(right_null_columns); - - let left_null_joined_batch = - RecordBatch::try_new(Arc::clone(&self.schema), result_joined)?; - - // Add right unmatched rows adding the left side as nulls - let mut result_joined = self - .streamed_schema - .fields() - .iter() - .map(|f| { - new_null_array( - f.data_type(), - joined_filter_not_matched_batch.num_rows(), - ) - }) - .collect::>(); - - let right_data = joined_filter_not_matched_batch - .columns() - .iter() - .skip(left_columns_length) - .cloned() - .collect::>(); - - result_joined.extend(right_data); - - filtered_record_batch = concat_batches( - &self.schema, - &[filtered_record_batch, left_null_joined_batch], - )?; - } + &self.schema, + &self.streamed_schema, + &self.buffered_schema, + )?; self.joined_record_batches .clear(&self.schema, self.batch_size); @@ -1911,36 +1583,6 @@ fn create_unmatched_columns( } } -/// Gets the arrays which join filters are applied on. -fn get_filter_column( - join_filter: &Option, - streamed_columns: &[ArrayRef], - buffered_columns: &[ArrayRef], -) -> Vec { - let mut filter_columns = vec![]; - - if let Some(f) = join_filter { - let left_columns = f - .column_indices() - .iter() - .filter(|col_index| col_index.side == JoinSide::Left) - .map(|i| Arc::clone(&streamed_columns[i.index])) - .collect::>(); - - let right_columns = f - .column_indices() - .iter() - .filter(|col_index| col_index.side == JoinSide::Right) - .map(|i| Arc::clone(&buffered_columns[i.index])) - .collect::>(); - - filter_columns.extend(left_columns); - filter_columns.extend(right_columns); - } - - filter_columns -} - fn produce_buffered_null_batch( schema: &SchemaRef, streamed_schema: &SchemaRef, @@ -1970,6 +1612,30 @@ fn produce_buffered_null_batch( )?)) } +/// Checks if a `UInt64Array` contains a contiguous ascending range (e.g. \[3,4,5,6\]). +/// Returns `Some(start..start+len)` if so, `None` otherwise. +/// This allows replacing an O(n) `take` with an O(1) `slice`. +#[inline] +fn is_contiguous_range(indices: &UInt64Array) -> Option> { + if indices.is_empty() || indices.null_count() > 0 { + return None; + } + let values = indices.values(); + let start = values[0]; + let len = values.len() as u64; + // Quick rejection: if last element doesn't match expected, not contiguous + if values[values.len() - 1] != start + len - 1 { + return None; + } + // Verify every element is sequential (handles duplicates and gaps) + for i in 1..values.len() { + if values[i] != start + i as u64 { + return None; + } + } + Some(start as usize..(start + len) as usize) +} + /// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]` by specific column indices #[inline(always)] fn fetch_right_columns_by_idxs( @@ -1990,12 +1656,16 @@ fn fetch_right_columns_from_batch_by_idxs( ) -> Result> { match &buffered_batch.batch { // In memory batch - BufferedBatchState::InMemory(batch) => Ok(batch - .columns() - .iter() - .map(|column| take(column, &buffered_indices, None)) - .collect::, ArrowError>>() - .map_err(Into::::into)?), + // In memory batch + BufferedBatchState::InMemory(batch) => { + // When indices form a contiguous range (common in SMJ since the + // buffered side is scanned sequentially), use zero-copy slice. + if let Some(range) = is_contiguous_range(buffered_indices) { + Ok(batch.slice(range.start, range.len()).columns().to_vec()) + } else { + Ok(take_arrays(batch.columns(), buffered_indices, None)?) + } + } // If the batch was spilled to disk, less likely BufferedBatchState::Spilled(spill_file) => { let mut buffered_cols: Vec = diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs index d0bcc79636f7..b16ad59abc5b 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs @@ -24,42 +24,44 @@ //! //! Add relevant tests under the specified sections. -use std::sync::Arc; - +use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn}; +use crate::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec}; +use crate::test::TestMemoryExec; +use crate::test::exec::BarrierExec; +use crate::test::{build_table_i32, build_table_i32_two_cols}; +use crate::{ExecutionPlan, common}; +use crate::{ + expressions::Column, joins::sort_merge_join::filter::get_corrected_filter_mask, + joins::sort_merge_join::stream::JoinedRecordBatches, +}; use arrow::array::{ BinaryArray, BooleanArray, Date32Array, Date64Array, FixedSizeBinaryArray, Int32Array, RecordBatch, UInt64Array, - builder::{BooleanBuilder, UInt64Builder}, }; use arrow::compute::{BatchCoalescer, SortOptions, filter_record_batch}; use arrow::datatypes::{DataType, Field, Schema}; - +use arrow_ord::sort::SortColumn; +use arrow_schema::SchemaRef; use datafusion_common::JoinType::*; use datafusion_common::{ - JoinSide, + JoinSide, internal_err, test_util::{batches_to_sort_string, batches_to_string}, }; use datafusion_common::{ JoinType, NullEquality, Result, assert_batches_eq, assert_contains, }; -use datafusion_execution::TaskContext; +use datafusion_common_runtime::JoinSet; use datafusion_execution::config::SessionConfig; use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode}; use datafusion_execution::runtime_env::RuntimeEnvBuilder; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::BinaryExpr; +use futures::StreamExt; use insta::{allow_duplicates, assert_snapshot}; - -use crate::{ - expressions::Column, - joins::sort_merge_join::stream::{JoinedRecordBatches, get_corrected_filter_mask}, -}; - -use crate::joins::SortMergeJoinExec; -use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn}; -use crate::test::TestMemoryExec; -use crate::test::{build_table_i32, build_table_i32_two_cols}; -use crate::{ExecutionPlan, common}; +use itertools::Itertools; +use std::sync::Arc; +use std::task::Poll; fn build_table( a: (&str, &Vec), @@ -2375,9 +2377,7 @@ fn build_joined_record_batches() -> Result { let mut batches = JoinedRecordBatches { joined_batches: BatchCoalescer::new(Arc::clone(&schema), 8192), - filter_mask: BooleanBuilder::new(), - row_indices: UInt64Builder::new(), - batch_ids: vec![], + filter_metadata: crate::joins::sort_merge_join::filter::FilterMetadata::new(), }; // Insert already prejoined non-filtered rows @@ -2432,44 +2432,73 @@ fn build_joined_record_batches() -> Result { )?)?; let streamed_indices = vec![0, 0]; - batches.batch_ids.extend(vec![0; streamed_indices.len()]); batches + .filter_metadata + .batch_ids + .extend(vec![0; streamed_indices.len()]); + batches + .filter_metadata .row_indices .extend(&UInt64Array::from(streamed_indices)); let streamed_indices = vec![1]; - batches.batch_ids.extend(vec![0; streamed_indices.len()]); batches + .filter_metadata + .batch_ids + .extend(vec![0; streamed_indices.len()]); + batches + .filter_metadata .row_indices .extend(&UInt64Array::from(streamed_indices)); let streamed_indices = vec![0, 0]; - batches.batch_ids.extend(vec![1; streamed_indices.len()]); batches + .filter_metadata + .batch_ids + .extend(vec![1; streamed_indices.len()]); + batches + .filter_metadata .row_indices .extend(&UInt64Array::from(streamed_indices)); let streamed_indices = vec![0]; - batches.batch_ids.extend(vec![2; streamed_indices.len()]); batches + .filter_metadata + .batch_ids + .extend(vec![2; streamed_indices.len()]); + batches + .filter_metadata .row_indices .extend(&UInt64Array::from(streamed_indices)); let streamed_indices = vec![0, 0]; - batches.batch_ids.extend(vec![3; streamed_indices.len()]); batches + .filter_metadata + .batch_ids + .extend(vec![3; streamed_indices.len()]); + batches + .filter_metadata .row_indices .extend(&UInt64Array::from(streamed_indices)); batches + .filter_metadata .filter_mask .extend(&BooleanArray::from(vec![true, false])); - batches.filter_mask.extend(&BooleanArray::from(vec![true])); batches + .filter_metadata + .filter_mask + .extend(&BooleanArray::from(vec![true])); + batches + .filter_metadata .filter_mask .extend(&BooleanArray::from(vec![false, true])); - batches.filter_mask.extend(&BooleanArray::from(vec![false])); batches + .filter_metadata + .filter_mask + .extend(&BooleanArray::from(vec![false])); + batches + .filter_metadata .filter_mask .extend(&BooleanArray::from(vec![false, false])); @@ -2482,8 +2511,8 @@ async fn test_left_outer_join_filtered_mask() -> Result<()> { let schema = joined_batches.joined_batches.schema(); let output = joined_batches.concat_batches(&schema)?; - let out_mask = joined_batches.filter_mask.finish(); - let out_indices = joined_batches.row_indices.finish(); + let out_mask = joined_batches.filter_metadata.filter_mask.finish(); + let out_indices = joined_batches.filter_metadata.row_indices.finish(); assert_eq!( get_corrected_filter_mask( @@ -2620,7 +2649,7 @@ async fn test_left_outer_join_filtered_mask() -> Result<()> { let corrected_mask = get_corrected_filter_mask( Left, &out_indices, - &joined_batches.batch_ids, + &joined_batches.filter_metadata.batch_ids, &out_mask, output.num_rows(), ) @@ -2689,8 +2718,8 @@ async fn test_semi_join_filtered_mask() -> Result<()> { let schema = joined_batches.joined_batches.schema(); let output = joined_batches.concat_batches(&schema)?; - let out_mask = joined_batches.filter_mask.finish(); - let out_indices = joined_batches.row_indices.finish(); + let out_mask = joined_batches.filter_metadata.filter_mask.finish(); + let out_indices = joined_batches.filter_metadata.row_indices.finish(); assert_eq!( get_corrected_filter_mask( @@ -2791,7 +2820,7 @@ async fn test_semi_join_filtered_mask() -> Result<()> { let corrected_mask = get_corrected_filter_mask( join_type, &out_indices, - &joined_batches.batch_ids, + &joined_batches.filter_metadata.batch_ids, &out_mask, output.num_rows(), ) @@ -2864,8 +2893,8 @@ async fn test_anti_join_filtered_mask() -> Result<()> { let schema = joined_batches.joined_batches.schema(); let output = joined_batches.concat_batches(&schema)?; - let out_mask = joined_batches.filter_mask.finish(); - let out_indices = joined_batches.row_indices.finish(); + let out_mask = joined_batches.filter_metadata.filter_mask.finish(); + let out_indices = joined_batches.filter_metadata.row_indices.finish(); assert_eq!( get_corrected_filter_mask( @@ -2966,7 +2995,7 @@ async fn test_anti_join_filtered_mask() -> Result<()> { let corrected_mask = get_corrected_filter_mask( join_type, &out_indices, - &joined_batches.batch_ids, + &joined_batches.filter_metadata.batch_ids, &out_mask, output.num_rows(), ) @@ -3077,7 +3106,7 @@ fn test_partition_statistics() -> Result<()> { ); // Verify that aggregate statistics have a meaningful num_rows (not Absent) assert!( - !matches!(stats.num_rows, Precision::Absent), + stats.num_rows != Precision::Absent, "Aggregate stats should have meaningful num_rows for {join_type:?}, got {:?}", stats.num_rows ); @@ -3095,7 +3124,7 @@ fn test_partition_statistics() -> Result<()> { ); // When children return unknown stats, the join's partition stats will be Absent assert!( - matches!(partition_stats.num_rows, Precision::Absent), + partition_stats.num_rows == Precision::Absent, "Partition stats should have Absent num_rows when children return unknown for {join_type:?}, got {:?}", partition_stats.num_rows ); @@ -3104,6 +3133,420 @@ fn test_partition_statistics() -> Result<()> { Ok(()) } +fn build_batches( + a: (&str, &[Vec]), + b: (&str, &[Vec]), + c: (&str, &[Vec]), +) -> (Vec, SchemaRef) { + assert_eq!(a.1.len(), b.1.len()); + let mut batches = vec![]; + + let schema = Arc::new(Schema::new(vec![ + Field::new(a.0, DataType::Boolean, false), + Field::new(b.0, DataType::Int32, false), + Field::new(c.0, DataType::Int32, false), + ])); + + for i in 0..a.1.len() { + batches.push( + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(BooleanArray::from(a.1[i].clone())), + Arc::new(Int32Array::from(b.1[i].clone())), + Arc::new(Int32Array::from(c.1[i].clone())), + ], + ) + .unwrap(), + ); + } + let schema = batches[0].schema(); + (batches, schema) +} + +fn build_batched_finish_barrier_table( + a: (&str, &[Vec]), + b: (&str, &[Vec]), + c: (&str, &[Vec]), +) -> (Arc, Arc) { + let (batches, schema) = build_batches(a, b, c); + + let memory_exec = TestMemoryExec::try_new_exec( + std::slice::from_ref(&batches), + Arc::clone(&schema), + None, + ) + .unwrap(); + + let barrier_exec = Arc::new( + BarrierExec::new(vec![batches], schema) + .with_log(false) + .without_start_barrier() + .with_finish_barrier(), + ); + + (barrier_exec, memory_exec) +} + +/// Concat and sort batches by all the columns to make sure we can compare them with different join +fn prepare_record_batches_for_cmp(output: Vec) -> RecordBatch { + let output_batch = arrow::compute::concat_batches(output[0].schema_ref(), &output) + .expect("failed to concat batches"); + + // Sort on all columns to make sure we have a deterministic order for the assertion + let sort_columns = output_batch + .columns() + .iter() + .map(|c| SortColumn { + values: Arc::clone(c), + options: None, + }) + .collect::>(); + + let sorted_columns = + arrow::compute::lexsort(&sort_columns, None).expect("failed to sort"); + + RecordBatch::try_new(output_batch.schema(), sorted_columns) + .expect("failed to create batch") +} + +#[expect(clippy::too_many_arguments)] +async fn join_get_stream_and_get_expected( + left: Arc, + right: Arc, + oracle_left: Arc, + oracle_right: Arc, + on: JoinOn, + join_type: JoinType, + filter: Option, + batch_size: usize, +) -> Result<(SendableRecordBatchStream, RecordBatch)> { + let sort_options = vec![SortOptions::default(); on.len()]; + let null_equality = NullEquality::NullEqualsNothing; + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(SessionConfig::default().with_batch_size(batch_size)), + ); + + let expected_output = { + let oracle = HashJoinExec::try_new( + oracle_left, + oracle_right, + on.clone(), + filter.clone(), + &join_type, + None, + PartitionMode::Partitioned, + null_equality, + false, + )?; + + let stream = oracle.execute(0, Arc::clone(&task_ctx))?; + + let batches = common::collect(stream).await?; + + prepare_record_batches_for_cmp(batches) + }; + + let join = SortMergeJoinExec::try_new( + left, + right, + on, + filter, + join_type, + sort_options, + null_equality, + )?; + + let stream = join.execute(0, task_ctx)?; + + Ok((stream, expected_output)) +} + +fn generate_data_for_emit_early_test( + batch_size: usize, + number_of_batches: usize, + join_type: JoinType, +) -> ( + Arc, + Arc, + Arc, + Arc, +) { + let number_of_rows_per_batch = number_of_batches * batch_size; + // Prepare data + let left_a1 = (0..number_of_rows_per_batch as i32) + .chunks(batch_size) + .into_iter() + .map(|chunk| chunk.collect::>()) + .collect::>(); + let left_b1 = (0..1000000) + .filter(|item| { + match join_type { + LeftAnti | RightAnti => { + let remainder = item % (batch_size as i32); + + // Make sure to have one that match and one that don't + remainder == 0 || remainder == 1 + } + // Have at least 1 that is not matching + _ => item % batch_size as i32 != 0, + } + }) + .take(number_of_rows_per_batch) + .chunks(batch_size) + .into_iter() + .map(|chunk| chunk.collect::>()) + .collect::>(); + + let left_bool_col1 = left_a1 + .clone() + .into_iter() + .map(|b| { + b.into_iter() + // Mostly true but have some false that not overlap with the right column + .map(|a| a % (batch_size as i32) != (batch_size as i32) - 2) + .collect::>() + }) + .collect::>(); + + let (left, left_memory) = build_batched_finish_barrier_table( + ("bool_col1", left_bool_col1.as_slice()), + ("b1", left_b1.as_slice()), + ("a1", left_a1.as_slice()), + ); + + let right_a2 = (0..number_of_rows_per_batch as i32) + .map(|item| item * 11) + .chunks(batch_size) + .into_iter() + .map(|chunk| chunk.collect::>()) + .collect::>(); + let right_b1 = (0..1000000) + .filter(|item| { + match join_type { + LeftAnti | RightAnti => { + let remainder = item % (batch_size as i32); + + // Make sure to have one that match and one that don't + remainder == 1 || remainder == 2 + } + // Have at least 1 that is not matching + _ => item % batch_size as i32 != 1, + } + }) + .take(number_of_rows_per_batch) + .chunks(batch_size) + .into_iter() + .map(|chunk| chunk.collect::>()) + .collect::>(); + let right_bool_col2 = right_a2 + .clone() + .into_iter() + .map(|b| { + b.into_iter() + // Mostly true but have some false that not overlap with the left column + .map(|a| a % (batch_size as i32) != (batch_size as i32) - 1) + .collect::>() + }) + .collect::>(); + + let (right, right_memory) = build_batched_finish_barrier_table( + ("bool_col2", right_bool_col2.as_slice()), + ("b1", right_b1.as_slice()), + ("a2", right_a2.as_slice()), + ); + + (left, right, left_memory, right_memory) +} + +#[tokio::test] +async fn test_should_emit_early_when_have_enough_data_to_emit() -> Result<()> { + for with_filtering in [false, true] { + let join_types = vec![ + Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, RightMark, + ]; + const BATCH_SIZE: usize = 10; + for join_type in join_types { + for output_batch_size in [ + BATCH_SIZE / 3, + BATCH_SIZE / 2, + BATCH_SIZE, + BATCH_SIZE * 2, + BATCH_SIZE * 3, + ] { + // Make sure the number of batches is enough for all join type to emit some output + let number_of_batches = if output_batch_size <= BATCH_SIZE { + 100 + } else { + // Have enough batches + (output_batch_size * 100) / BATCH_SIZE + }; + + let (left, right, left_memory, right_memory) = + generate_data_for_emit_early_test( + BATCH_SIZE, + number_of_batches, + join_type, + ); + + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let join_filter = if with_filtering { + let filter = JoinFilter::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("bool_col1", 0)), + Operator::And, + Arc::new(Column::new("bool_col2", 1)), + )), + vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ], + Arc::new(Schema::new(vec![ + Field::new("bool_col1", DataType::Boolean, true), + Field::new("bool_col2", DataType::Boolean, true), + ])), + ); + Some(filter) + } else { + None + }; + + // select * + // from t1 + // right join t2 on t1.b1 = t2.b1 and t1.bool_col1 AND t2.bool_col2 + let (mut output_stream, expected) = join_get_stream_and_get_expected( + Arc::clone(&left) as Arc, + Arc::clone(&right) as Arc, + left_memory as Arc, + right_memory as Arc, + on, + join_type, + join_filter, + output_batch_size, + ) + .await?; + + let (output_batched, output_batches_after_finish) = + consume_stream_until_finish_barrier_reached(left, right, &mut output_stream).await.unwrap_or_else(|e| panic!("Failed to consume stream for join type: '{join_type}' and with filtering '{with_filtering}': {e:?}")); + + // It should emit more than that, but we are being generous + // and to make sure the test pass for all + const MINIMUM_OUTPUT_BATCHES: usize = 5; + assert!( + MINIMUM_OUTPUT_BATCHES <= number_of_batches / 5, + "Make sure that the minimum output batches is realistic" + ); + // Test to make sure that we are not waiting for input to be fully consumed to emit some output + assert!( + output_batched.len() >= MINIMUM_OUTPUT_BATCHES, + "[Sort Merge Join {join_type}] Stream must have at least emit {} batches, but only got {} batches", + MINIMUM_OUTPUT_BATCHES, + output_batched.len() + ); + + // Just sanity test to make sure we are still producing valid output + { + let output = [output_batched, output_batches_after_finish].concat(); + let actual_prepared = prepare_record_batches_for_cmp(output); + + assert_eq!(actual_prepared.columns(), expected.columns()); + } + } + } + } + Ok(()) +} + +/// Polls the stream until both barriers are reached, +/// collecting the emitted batches along the way. +/// +/// If the stream is pending for too long (5s) without emitting any batches, +/// it panics to avoid hanging the test indefinitely. +/// +/// Note: The left and right BarrierExec might be the input of the output stream +async fn consume_stream_until_finish_barrier_reached( + left: Arc, + right: Arc, + output_stream: &mut SendableRecordBatchStream, +) -> Result<(Vec, Vec)> { + let mut switch_to_finish_barrier = false; + let mut output_batched = vec![]; + let mut after_finish_barrier_reached = vec![]; + let mut background_task = JoinSet::new(); + + let mut start_time_since_last_ready = datafusion_common::instant::Instant::now(); + loop { + let next_item = output_stream.next(); + + // Manual polling + let poll_output = futures::poll!(next_item); + + // Wake up the stream to make sure it makes progress + tokio::task::yield_now().await; + + match poll_output { + Poll::Ready(Some(Ok(batch))) => { + if batch.num_rows() == 0 { + return internal_err!("join stream should not emit empty batch"); + } + if switch_to_finish_barrier { + after_finish_barrier_reached.push(batch); + } else { + output_batched.push(batch); + } + start_time_since_last_ready = datafusion_common::instant::Instant::now(); + } + Poll::Ready(Some(Err(e))) => return Err(e), + Poll::Ready(None) if !switch_to_finish_barrier => { + unreachable!("Stream should not end before manually finishing it") + } + Poll::Ready(None) => { + break; + } + Poll::Pending => { + if right.is_finish_barrier_reached() + && left.is_finish_barrier_reached() + && !switch_to_finish_barrier + { + switch_to_finish_barrier = true; + + let right = Arc::clone(&right); + background_task.spawn(async move { + right.wait_finish().await; + }); + let left = Arc::clone(&left); + background_task.spawn(async move { + left.wait_finish().await; + }); + } + + // Make sure the test doesn't run forever + if start_time_since_last_ready.elapsed() + > std::time::Duration::from_secs(5) + { + return internal_err!( + "Stream should have emitted data by now, but it's still pending. Output batches so far: {}", + output_batched.len() + ); + } + } + } + } + + Ok((output_batched, after_finish_barrier_reached)) +} + /// Returns the column names on the schema fn columns(schema: &Schema) -> Vec { schema.fields().iter().map(|f| f.name().clone()).collect() diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index 22cc82a22db5..beed07f562db 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -22,8 +22,9 @@ use std::collections::{HashMap, VecDeque}; use std::mem::size_of; use std::sync::Arc; +use crate::joins::MapOffset; use crate::joins::join_hash_map::{ - JoinHashMapOffset, get_matched_indices, get_matched_indices_with_limit_offset, + contain_hashes, get_matched_indices, get_matched_indices_with_limit_offset, update_from_iter, }; use crate::joins::utils::{JoinFilter, JoinHashMapType}; @@ -31,7 +32,8 @@ use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder}; use crate::{ExecutionPlan, metrics}; use arrow::array::{ - ArrowPrimitiveType, BooleanBufferBuilder, NativeAdapter, PrimitiveArray, RecordBatch, + ArrowPrimitiveType, BooleanArray, BooleanBufferBuilder, NativeAdapter, + PrimitiveArray, RecordBatch, }; use arrow::compute::concat_batches; use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef}; @@ -77,10 +79,10 @@ impl JoinHashMapType for PruningJoinHashMap { &self, hash_values: &[u64], limit: usize, - offset: JoinHashMapOffset, + offset: MapOffset, input_indices: &mut Vec, match_indices: &mut Vec, - ) -> Option { + ) -> Option { // Flatten the deque let next: Vec = self.next.iter().copied().collect(); get_matched_indices_with_limit_offset::( @@ -94,6 +96,10 @@ impl JoinHashMapType for PruningJoinHashMap { ) } + fn contain_hashes(&self, hash_values: &[u64]) -> BooleanArray { + contain_hashes(&self.map, hash_values) + } + fn is_empty(&self) -> bool { self.map.is_empty() } diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 1f6bc703a030..7407b05ea569 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -32,6 +32,7 @@ use std::sync::Arc; use std::task::{Context, Poll}; use std::vec; +use crate::check_if_same_properties; use crate::common::SharedMemoryReservation; use crate::execution_plan::{boundedness_from_children, emission_type_from_children}; use crate::joins::stream_join_utils::{ @@ -52,7 +53,7 @@ use crate::projection::{ }; use crate::{ DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, - PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, + PlanProperties, RecordBatchStream, SendableRecordBatchStream, joins::StreamJoinPartitionMode, metrics::{ExecutionPlanMetricsSet, MetricsSet}, }; @@ -197,7 +198,7 @@ pub struct SymmetricHashJoinExec { /// Partition Mode mode: StreamJoinPartitionMode, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, } impl SymmetricHashJoinExec { @@ -253,7 +254,7 @@ impl SymmetricHashJoinExec { left_sort_exprs, right_sort_exprs, mode, - cache, + cache: Arc::new(cache), }) } @@ -360,6 +361,20 @@ impl SymmetricHashJoinExec { } Ok(false) } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + let left = children.swap_remove(0); + let right = children.swap_remove(0); + Self { + left, + right, + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for SymmetricHashJoinExec { @@ -411,7 +426,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -453,6 +468,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { self: Arc, children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); Ok(Arc::new(SymmetricHashJoinExec::try_new( Arc::clone(&children[0]), Arc::clone(&children[1]), @@ -470,11 +486,6 @@ impl ExecutionPlan for SymmetricHashJoinExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - // TODO stats: it is not possible in general to know the output size of joins - Ok(Statistics::new_unknown(&self.schema())) - } - fn execute( &self, partition: usize, @@ -930,6 +941,7 @@ pub(crate) fn build_side_determined_results( &probe_indices, column_indices, build_hash_joiner.build_side, + join_type, ) .map(|batch| (batch.num_rows() > 0).then_some(batch)) } else { @@ -993,6 +1005,7 @@ pub(crate) fn join_with_probe_batch( filter, build_hash_joiner.build_side, None, + join_type, )? } else { (build_indices, probe_indices) @@ -1031,6 +1044,7 @@ pub(crate) fn join_with_probe_batch( &probe_indices, column_indices, build_hash_joiner.build_side, + join_type, ) .map(|batch| (batch.num_rows() > 0).then_some(batch)) } diff --git a/datafusion/physical-plan/src/joins/test_utils.rs b/datafusion/physical-plan/src/joins/test_utils.rs index 27284bf546bc..0455fb2a1eb6 100644 --- a/datafusion/physical-plan/src/joins/test_utils.rs +++ b/datafusion/physical-plan/src/joins/test_utils.rs @@ -152,6 +152,7 @@ pub async fn partitioned_hash_join_with_filter( None, PartitionMode::Partitioned, null_equality, + false, // null_aware )?); let mut batches = vec![]; diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index a9243fe04e28..cf4bf2cd163f 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -739,7 +739,7 @@ fn max_distinct_count( { let range_dc = range_dc as usize; // Note that the `unwrap` calls in the below statement are safe. - return if matches!(result, Precision::Absent) + return if result == Precision::Absent || &range_dc < result.get_value().unwrap() { if stats.min_value.is_exact().unwrap() @@ -910,6 +910,7 @@ pub(crate) fn get_final_indices_from_bit_map( (left_indices, right_indices) } +#[expect(clippy::too_many_arguments)] pub(crate) fn apply_join_filter_to_indices( build_input_buffer: &RecordBatch, probe_batch: &RecordBatch, @@ -918,6 +919,7 @@ pub(crate) fn apply_join_filter_to_indices( filter: &JoinFilter, build_side: JoinSide, max_intermediate_size: Option, + join_type: JoinType, ) -> Result<(UInt64Array, UInt32Array)> { if build_indices.is_empty() && probe_indices.is_empty() { return Ok((build_indices, probe_indices)); @@ -938,6 +940,7 @@ pub(crate) fn apply_join_filter_to_indices( &probe_indices.slice(i, len), filter.column_indices(), build_side, + join_type, )?; let filter_result = filter .expression() @@ -959,6 +962,7 @@ pub(crate) fn apply_join_filter_to_indices( &probe_indices, filter.column_indices(), build_side, + join_type, )?; filter @@ -977,8 +981,20 @@ pub(crate) fn apply_join_filter_to_indices( )) } +/// Creates a [RecordBatch] with zero columns but the given row count. +/// Used when a join has an empty projection (e.g. `SELECT count(1) ...`). +fn new_empty_schema_batch(schema: &Schema, row_count: usize) -> Result { + let options = RecordBatchOptions::new().with_row_count(Some(row_count)); + Ok(RecordBatch::try_new_with_options( + Arc::new(schema.clone()), + vec![], + &options, + )?) +} + /// Returns a new [RecordBatch] by combining the `left` and `right` according to `indices`. /// The resulting batch has [Schema] `schema`. +#[expect(clippy::too_many_arguments)] pub(crate) fn build_batch_from_indices( schema: &Schema, build_input_buffer: &RecordBatch, @@ -987,17 +1003,17 @@ pub(crate) fn build_batch_from_indices( probe_indices: &UInt32Array, column_indices: &[ColumnIndex], build_side: JoinSide, + join_type: JoinType, ) -> Result { if schema.fields().is_empty() { - let options = RecordBatchOptions::new() - .with_match_field_names(true) - .with_row_count(Some(build_indices.len())); - - return Ok(RecordBatch::try_new_with_options( - Arc::new(schema.clone()), - vec![], - &options, - )?); + // For RightAnti and RightSemi joins, after `adjust_indices_by_join_type` + // the build_indices were untouched so only probe_indices hold the actual + // row count. + let row_count = match join_type { + JoinType::RightAnti | JoinType::RightSemi => probe_indices.len(), + _ => build_indices.len(), + }; + return new_empty_schema_batch(schema, row_count); } // build the columns of the new [RecordBatch]: @@ -1057,6 +1073,9 @@ pub(crate) fn build_batch_empty_build_side( // the remaining joins will return data for the right columns and null for the left ones JoinType::Right | JoinType::Full | JoinType::RightAnti | JoinType::RightMark => { let num_rows = probe_batch.num_rows(); + if schema.fields().is_empty() { + return new_empty_schema_batch(schema, num_rows); + } let mut columns: Vec> = Vec::with_capacity(schema.fields().len()); @@ -1674,7 +1693,7 @@ fn swap_reverting_projection( pub fn swap_join_projection( left_schema_len: usize, right_schema_len: usize, - projection: Option<&Vec>, + projection: Option<&[usize]>, join_type: &JoinType, ) -> Option> { match join_type { @@ -1685,7 +1704,7 @@ pub fn swap_join_projection( | JoinType::RightAnti | JoinType::RightSemi | JoinType::LeftMark - | JoinType::RightMark => projection.cloned(), + | JoinType::RightMark => projection.map(|p| p.to_vec()), _ => projection.map(|p| { p.iter() .map(|i| { @@ -2889,4 +2908,35 @@ mod tests { Ok(()) } + + #[test] + fn test_build_batch_empty_build_side_empty_schema() -> Result<()> { + // When the output schema has no fields (empty projection pushed into + // the join), build_batch_empty_build_side should return a RecordBatch + // with the correct row count but no columns. + let empty_schema = Schema::empty(); + + let build_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])), + vec![Arc::new(arrow::array::Int32Array::from(vec![1, 2, 3]))], + )?; + + let probe_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("b", DataType::Int32, true)])), + vec![Arc::new(arrow::array::Int32Array::from(vec![4, 5, 6, 7]))], + )?; + + let result = build_batch_empty_build_side( + &empty_schema, + &build_batch, + &probe_batch, + &[], // no column indices with empty projection + JoinType::Right, + )?; + + assert_eq!(result.num_rows(), 4); + assert_eq!(result.num_columns(), 0); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index ec8e154caec9..6467d7a2e389 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -24,8 +24,6 @@ // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -// https://github.com/apache/datafusion/issues/18881 -#![deny(clippy::allow_attributes)] //! Traits for physical query plan, supporting parallel execution for partitioned relations. //! @@ -65,9 +63,11 @@ mod visitor; pub mod aggregates; pub mod analyze; pub mod async_func; +pub mod buffer; pub mod coalesce; pub mod coalesce_batches; pub mod coalesce_partitions; +pub mod column_rewriter; pub mod common; pub mod coop; pub mod display; diff --git a/datafusion/physical-plan/src/limit.rs b/datafusion/physical-plan/src/limit.rs index 05d688282147..a78e5c067ff1 100644 --- a/datafusion/physical-plan/src/limit.rs +++ b/datafusion/physical-plan/src/limit.rs @@ -28,13 +28,17 @@ use super::{ SendableRecordBatchStream, Statistics, }; use crate::execution_plan::{Boundedness, CardinalityEffect}; -use crate::{DisplayFormatType, Distribution, ExecutionPlan, Partitioning}; +use crate::{ + DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + check_if_same_properties, +}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::{Result, assert_eq_or_internal_err, internal_err}; use datafusion_execution::TaskContext; +use datafusion_physical_expr::LexOrdering; use futures::stream::{Stream, StreamExt}; use log::trace; @@ -50,7 +54,10 @@ pub struct GlobalLimitExec { fetch: Option, /// Execution metrics metrics: ExecutionPlanMetricsSet, - cache: PlanProperties, + /// Does the limit have to preserve the order of its input, and if so what is it? + /// Some optimizations may reorder the input if no particular sort is required + required_ordering: Option, + cache: Arc, } impl GlobalLimitExec { @@ -62,7 +69,8 @@ impl GlobalLimitExec { skip, fetch, metrics: ExecutionPlanMetricsSet::new(), - cache, + required_ordering: None, + cache: Arc::new(cache), } } @@ -91,6 +99,27 @@ impl GlobalLimitExec { Boundedness::Bounded, ) } + + /// Get the required ordering from limit + pub fn required_ordering(&self) -> &Option { + &self.required_ordering + } + + /// Set the required ordering for limit + pub fn set_required_ordering(&mut self, required_ordering: Option) { + self.required_ordering = required_ordering; + } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for GlobalLimitExec { @@ -129,7 +158,7 @@ impl ExecutionPlan for GlobalLimitExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -151,10 +180,11 @@ impl ExecutionPlan for GlobalLimitExec { fn with_new_children( self: Arc, - children: Vec>, + mut children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); Ok(Arc::new(GlobalLimitExec::new( - Arc::clone(&children[0]), + children.swap_remove(0), self.skip, self.fetch, ))) @@ -194,10 +224,6 @@ impl ExecutionPlan for GlobalLimitExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { self.input .partition_statistics(partition)? @@ -214,7 +240,7 @@ impl ExecutionPlan for GlobalLimitExec { } /// LocalLimitExec applies a limit to a single partition -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct LocalLimitExec { /// Input execution plan input: Arc, @@ -222,7 +248,10 @@ pub struct LocalLimitExec { fetch: usize, /// Execution metrics metrics: ExecutionPlanMetricsSet, - cache: PlanProperties, + /// If the child plan is a sort node, after the sort node is removed during + /// physical optimization, we should add the required ordering to the limit node + required_ordering: Option, + cache: Arc, } impl LocalLimitExec { @@ -233,7 +262,8 @@ impl LocalLimitExec { input, fetch, metrics: ExecutionPlanMetricsSet::new(), - cache, + required_ordering: None, + cache: Arc::new(cache), } } @@ -257,6 +287,27 @@ impl LocalLimitExec { Boundedness::Bounded, ) } + + /// Get the required ordering from limit + pub fn required_ordering(&self) -> &Option { + &self.required_ordering + } + + /// Set the required ordering for limit + pub fn set_required_ordering(&mut self, required_ordering: Option) { + self.required_ordering = required_ordering; + } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for LocalLimitExec { @@ -286,7 +337,7 @@ impl ExecutionPlan for LocalLimitExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -306,6 +357,7 @@ impl ExecutionPlan for LocalLimitExec { self: Arc, children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); match children.len() { 1 => Ok(Arc::new(LocalLimitExec::new( Arc::clone(&children[0]), @@ -340,10 +392,6 @@ impl ExecutionPlan for LocalLimitExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { self.input .partition_statistics(partition)? diff --git a/datafusion/physical-plan/src/memory.rs b/datafusion/physical-plan/src/memory.rs index 4a406ca648d5..90fd3f24cf1b 100644 --- a/datafusion/physical-plan/src/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -27,7 +27,7 @@ use crate::execution_plan::{Boundedness, EmissionType, SchedulingType}; use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use crate::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, - RecordBatchStream, SendableRecordBatchStream, Statistics, + RecordBatchStream, SendableRecordBatchStream, }; use arrow::array::RecordBatch; @@ -161,7 +161,7 @@ pub struct LazyMemoryExec { /// Functions to generate batches for each partition batch_generators: Vec>>, /// Plan properties cache storing equivalence properties, partitioning, and execution mode - cache: PlanProperties, + cache: Arc, /// Execution metrics metrics: ExecutionPlanMetricsSet, } @@ -200,7 +200,8 @@ impl LazyMemoryExec { EmissionType::Incremental, boundedness, ) - .with_scheduling_type(SchedulingType::Cooperative); + .with_scheduling_type(SchedulingType::Cooperative) + .into(); Ok(Self { schema, @@ -215,9 +216,9 @@ impl LazyMemoryExec { match projection.as_ref() { Some(columns) => { let projected = Arc::new(self.schema.project(columns).unwrap()); - self.cache = self.cache.with_eq_properties(EquivalenceProperties::new( - Arc::clone(&projected), - )); + Arc::make_mut(&mut self.cache).set_eq_properties( + EquivalenceProperties::new(Arc::clone(&projected)), + ); self.schema = projected; self.projection = projection; self @@ -236,12 +237,12 @@ impl LazyMemoryExec { partition_count, generator_count ); - self.cache.partitioning = partitioning; + Arc::make_mut(&mut self.cache).partitioning = partitioning; Ok(()) } pub fn add_ordering(&mut self, ordering: impl IntoIterator) { - self.cache + Arc::make_mut(&mut self.cache) .eq_properties .add_orderings(std::iter::once(ordering)); } @@ -306,7 +307,7 @@ impl ExecutionPlan for LazyMemoryExec { Arc::clone(&self.schema) } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -352,10 +353,6 @@ impl ExecutionPlan for LazyMemoryExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - Ok(Statistics::new_unknown(&self.schema)) - } - fn reset_state(self: Arc) -> Result> { let generators = self .generators() @@ -365,7 +362,7 @@ impl ExecutionPlan for LazyMemoryExec { Ok(Arc::new(LazyMemoryExec { schema: Arc::clone(&self.schema), batch_generators: generators, - cache: self.cache.clone(), + cache: Arc::clone(&self.cache), metrics: ExecutionPlanMetricsSet::new(), projection: self.projection.clone(), })) diff --git a/datafusion/physical-plan/src/placeholder_row.rs b/datafusion/physical-plan/src/placeholder_row.rs index 4d00b73cff39..5dbd7b303254 100644 --- a/datafusion/physical-plan/src/placeholder_row.rs +++ b/datafusion/physical-plan/src/placeholder_row.rs @@ -43,7 +43,7 @@ pub struct PlaceholderRowExec { schema: SchemaRef, /// Number of partitions partitions: usize, - cache: PlanProperties, + cache: Arc, } impl PlaceholderRowExec { @@ -54,7 +54,7 @@ impl PlaceholderRowExec { PlaceholderRowExec { schema, partitions, - cache, + cache: Arc::new(cache), } } @@ -63,7 +63,7 @@ impl PlaceholderRowExec { self.partitions = partitions; // Update output partitioning when updating partitions: let output_partitioning = Self::output_partitioning_helper(self.partitions); - self.cache = self.cache.with_partitioning(output_partitioning); + Arc::make_mut(&mut self.cache).partitioning = output_partitioning; self } @@ -132,7 +132,7 @@ impl ExecutionPlan for PlaceholderRowExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -169,10 +169,6 @@ impl ExecutionPlan for PlaceholderRowExec { Ok(Box::pin(cooperative(ms))) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { let batches = self .data() diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index e8608f17a1b2..db3a71fc70ae 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -20,19 +20,20 @@ //! of a projection on table `t1` where the expressions `a`, `b`, and `a+b` are the //! projection expressions. `SELECT` without `FROM` will only evaluate expressions. -use super::expressions::{Column, Literal}; +use super::expressions::Column; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::{ DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream, SendableRecordBatchStream, SortOrderPushdownResult, Statistics, }; +use crate::column_rewriter::PhysicalColumnRewriter; use crate::execution_plan::CardinalityEffect; use crate::filter_pushdown::{ - ChildPushdownResult, FilterDescription, FilterPushdownPhase, - FilterPushdownPropagation, + ChildFilterDescription, ChildPushdownResult, FilterDescription, FilterPushdownPhase, + FilterPushdownPropagation, FilterRemapper, PushedDownPredicate, }; use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn, JoinOnRef}; -use crate::{DisplayFormatType, ExecutionPlan, PhysicalExpr}; +use crate::{DisplayFormatType, ExecutionPlan, PhysicalExpr, check_if_same_properties}; use std::any::Any; use std::collections::HashMap; use std::pin::Pin; @@ -45,8 +46,9 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; -use datafusion_common::{JoinSide, Result, internal_err}; +use datafusion_common::{DataFusionError, JoinSide, Result, internal_err}; use datafusion_execution::TaskContext; +use datafusion_expr::ExpressionPlacement; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::projection::Projector; use datafusion_physical_expr::utils::collect_columns; @@ -77,7 +79,7 @@ pub struct ProjectionExec { /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, } impl ProjectionExec { @@ -136,13 +138,19 @@ impl ProjectionExec { E: Into, { let input_schema = input.schema(); - // convert argument to Vec - let expr_vec = expr.into_iter().map(Into::into).collect::>(); - let projection = ProjectionExprs::new(expr_vec); + let expr_arc = expr.into_iter().map(Into::into).collect::>(); + let projection = ProjectionExprs::from_expressions(expr_arc); let projector = projection.make_projector(&input_schema)?; + Self::try_from_projector(projector, input) + } + fn try_from_projector( + projector: Projector, + input: Arc, + ) -> Result { // Construct a map from the input expressions to the output expression of the Projection - let projection_mapping = projection.projection_mapping(&input_schema)?; + let projection_mapping = + projector.projection().projection_mapping(&input.schema())?; let cache = Self::compute_properties( &input, &projection_mapping, @@ -152,7 +160,7 @@ impl ProjectionExec { projector, input, metrics: ExecutionPlanMetricsSet::new(), - cache, + cache: Arc::new(cache), }) } @@ -192,6 +200,40 @@ impl ProjectionExec { input.boundedness(), )) } + + /// Collect reverse alias mapping from projection expressions. + /// The result hash map is a map from aliased Column in parent to original expr. + fn collect_reverse_alias( + &self, + ) -> Result>> { + let mut alias_map = datafusion_common::HashMap::new(); + for projection in self.projection_expr().iter() { + let (aliased_index, _output_field) = self + .projector + .output_schema() + .column_with_name(&projection.alias) + .ok_or_else(|| { + DataFusionError::Internal(format!( + "Expr {} with alias {} not found in output schema", + projection.expr, projection.alias + )) + })?; + let aliased_col = Column::new(&projection.alias, aliased_index); + alias_map.insert(aliased_col, Arc::clone(&projection.expr)); + } + Ok(alias_map) + } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for ProjectionExec { @@ -245,7 +287,7 @@ impl ExecutionPlan for ProjectionExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -261,10 +303,13 @@ impl ExecutionPlan for ProjectionExec { .as_ref() .iter() .all(|proj_expr| { - proj_expr.expr.as_any().is::() - || proj_expr.expr.as_any().is::() + !matches!( + proj_expr.expr.placement(), + ExpressionPlacement::KeepInPlace + ) }); - // If expressions are all either column_expr or Literal, then all computations in this projection are reorder or rename, + // If expressions are all either column_expr or Literal (or other cheap expressions), + // then all computations in this projection are reorder or rename, // and projection would not benefit from the repartition, benefits_from_input_partitioning will return false. vec![!all_simple_exprs] } @@ -277,8 +322,9 @@ impl ExecutionPlan for ProjectionExec { self: Arc, mut children: Vec>, ) -> Result> { - ProjectionExec::try_new( - self.projector.projection().clone(), + check_if_same_properties!(self, children); + ProjectionExec::try_from_projector( + self.projector.clone(), children.swap_remove(0), ) .map(|p| Arc::new(p) as _) @@ -308,10 +354,6 @@ impl ExecutionPlan for ProjectionExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { let input_stats = self.input.partition_statistics(partition)?; let output_schema = self.schema(); @@ -347,10 +389,28 @@ impl ExecutionPlan for ProjectionExec { parent_filters: Vec>, _config: &ConfigOptions, ) -> Result { - // TODO: In future, we can try to handle inverting aliases here. - // For the time being, we pass through untransformed filters, so filters on aliases are not handled. - // https://github.com/apache/datafusion/issues/17246 - FilterDescription::from_children(parent_filters, &self.children()) + // expand alias column to original expr in parent filters + let invert_alias_map = self.collect_reverse_alias()?; + let output_schema = self.schema(); + let remapper = FilterRemapper::new(output_schema); + let mut child_parent_filters = Vec::with_capacity(parent_filters.len()); + + for filter in parent_filters { + // Check that column exists in child, then reassign column indices to match child schema + if let Some(reassigned) = remapper.try_remap(&filter)? { + // rewrite filter expression using invert alias map + let mut rewriter = PhysicalColumnRewriter::new(&invert_alias_map); + let rewritten = reassigned.rewrite(&mut rewriter)?.data; + child_parent_filters.push(PushedDownPredicate::supported(rewritten)); + } else { + child_parent_filters.push(PushedDownPredicate::unsupported(filter)); + } + } + + Ok(FilterDescription::new().with_child(ChildFilterDescription { + parent_filters: child_parent_filters, + self_filters: vec![], + })) } fn handle_child_pushdown_result( @@ -427,6 +487,19 @@ impl ExecutionPlan for ProjectionExec { } } } + + fn with_preserve_order( + &self, + preserve_order: bool, + ) -> Option> { + self.input + .with_preserve_order(preserve_order) + .and_then(|new_input| { + Arc::new(self.clone()) + .with_new_children(vec![new_input]) + .ok() + }) + } } impl ProjectionStream { @@ -485,6 +558,15 @@ impl RecordBatchStream for ProjectionStream { } } +/// Trait for execution plans that can embed a projection, avoiding a separate +/// [`ProjectionExec`] wrapper. +/// +/// # Empty projections +/// +/// `Some(vec![])` is a valid projection that produces zero output columns while +/// preserving the correct row count. Implementors must ensure that runtime batch +/// construction still returns batches with the right number of rows even when no +/// columns are selected (e.g. for `SELECT count(1) … JOIN …`). pub trait EmbeddedProjection: ExecutionPlan + Sized { fn with_projection(&self, projection: Option>) -> Result; } @@ -495,6 +577,15 @@ pub fn try_embed_projection( projection: &ProjectionExec, execution_plan: &Exec, ) -> Result>> { + // If the projection has no expressions at all (e.g., ProjectionExec: expr=[]), + // embed an empty projection into the execution plan so it outputs zero columns. + // This avoids allocating throwaway null arrays for build-side columns + // when no output columns are actually needed (e.g., count(1) over a right join). + if projection.expr().is_empty() { + let new_execution_plan = Arc::new(execution_plan.with_projection(Some(vec![]))?); + return Ok(Some(new_execution_plan)); + } + // Collect all column indices from the given projection expressions. let projection_index = collect_column_indices(projection.expr()); @@ -945,11 +1036,15 @@ fn try_unifying_projections( .unwrap(); }); // Merging these projections is not beneficial, e.g - // If an expression is not trivial and it is referred more than 1, unifies projections will be + // If an expression is not trivial (KeepInPlace) and it is referred more than 1, unifies projections will be // beneficial as caching mechanism for non-trivial computations. // See discussion in: https://github.com/apache/datafusion/issues/8296 if column_ref_map.iter().any(|(column, count)| { - *count > 1 && !is_expr_trivial(&Arc::clone(&child.expr()[column.index()].expr)) + *count > 1 + && !child.expr()[column.index()] + .expr + .placement() + .should_push_to_leaves() }) { return Ok(None); } @@ -1059,13 +1154,6 @@ fn new_columns_for_join_on( (new_columns.len() == hash_join_on.len()).then_some(new_columns) } -/// Checks if the given expression is trivial. -/// An expression is considered trivial if it is either a `Column` or a `Literal`. -fn is_expr_trivial(expr: &Arc) -> bool { - expr.as_any().downcast_ref::().is_some() - || expr.as_any().downcast_ref::().is_some() -} - #[cfg(test)] mod tests { use super::*; @@ -1073,6 +1161,7 @@ mod tests { use crate::common::collect; + use crate::filter_pushdown::PushedDown; use crate::test; use crate::test::exec::StatisticsExec; @@ -1081,7 +1170,9 @@ mod tests { use datafusion_common::stats::{ColumnStatistics, Precision, Statistics}; use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal, col}; + use datafusion_physical_expr::expressions::{ + BinaryExpr, Column, DynamicFilterPhysicalExpr, Literal, binary, col, lit, + }; #[test] fn test_collect_column_indices() -> Result<()> { @@ -1270,4 +1361,431 @@ mod tests { ); assert!(stats.total_byte_size.is_exact().unwrap_or(false)); } + + #[test] + fn test_filter_pushdown_with_alias() -> Result<()> { + let input_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics::new_unknown(&input_schema), + input_schema.clone(), + )); + + // project "a" as "b" + let projection = ProjectionExec::try_new( + vec![ProjectionExpr { + expr: Arc::new(Column::new("a", 0)), + alias: "b".to_string(), + }], + input, + )?; + + // filter "b > 5" + let filter = Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + + let description = projection.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![filter], + &ConfigOptions::default(), + )?; + + // Should be converted to "a > 5" + // "a" is index 0 in input + let expected_filter = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + + assert_eq!(description.self_filters(), vec![vec![]]); + let pushed_filters = &description.parent_filters()[0]; + assert_eq!( + format!("{}", pushed_filters[0].predicate), + format!("{}", expected_filter) + ); + // Verify the predicate was actually pushed down + assert!(matches!(pushed_filters[0].discriminant, PushedDown::Yes)); + + Ok(()) + } + + #[test] + fn test_filter_pushdown_with_multiple_aliases() -> Result<()> { + let input_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let input = Arc::new(StatisticsExec::new( + Statistics { + column_statistics: vec![Default::default(); input_schema.fields().len()], + ..Default::default() + }, + input_schema.clone(), + )); + + // project "a" as "x", "b" as "y" + let projection = ProjectionExec::try_new( + vec![ + ProjectionExpr { + expr: Arc::new(Column::new("a", 0)), + alias: "x".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("b", 1)), + alias: "y".to_string(), + }, + ], + input, + )?; + + // filter "x > 5" + let filter1 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("x", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + + // filter "y < 10" + let filter2 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("y", 1)), + Operator::Lt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )) as Arc; + + let description = projection.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![filter1, filter2], + &ConfigOptions::default(), + )?; + + // Should be converted to "a > 5" and "b < 10" + let expected_filter1 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + + let expected_filter2 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Lt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )) as Arc; + + let pushed_filters = &description.parent_filters()[0]; + assert_eq!(pushed_filters.len(), 2); + // Note: The order of filters is preserved + assert_eq!( + format!("{}", pushed_filters[0].predicate), + format!("{}", expected_filter1) + ); + assert_eq!( + format!("{}", pushed_filters[1].predicate), + format!("{}", expected_filter2) + ); + // Verify the predicates were actually pushed down + assert!(matches!(pushed_filters[0].discriminant, PushedDown::Yes)); + assert!(matches!(pushed_filters[1].discriminant, PushedDown::Yes)); + + Ok(()) + } + + #[test] + fn test_filter_pushdown_with_swapped_aliases() -> Result<()> { + let input_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let input = Arc::new(StatisticsExec::new( + Statistics { + column_statistics: vec![Default::default(); input_schema.fields().len()], + ..Default::default() + }, + input_schema.clone(), + )); + + // project "a" as "b", "b" as "a" + let projection = ProjectionExec::try_new( + vec![ + ProjectionExpr { + expr: Arc::new(Column::new("a", 0)), + alias: "b".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("b", 1)), + alias: "a".to_string(), + }, + ], + input, + )?; + + // filter "b > 5" (output column 0, which is "a" in input) + let filter1 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + + // filter "a < 10" (output column 1, which is "b" in input) + let filter2 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 1)), + Operator::Lt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )) as Arc; + + let description = projection.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![filter1, filter2], + &ConfigOptions::default(), + )?; + + let pushed_filters = &description.parent_filters()[0]; + assert_eq!(pushed_filters.len(), 2); + + // "b" (output index 0) -> "a" (input index 0) + let expected_filter1 = "a@0 > 5"; + // "a" (output index 1) -> "b" (input index 1) + let expected_filter2 = "b@1 < 10"; + + assert_eq!(format!("{}", pushed_filters[0].predicate), expected_filter1); + assert_eq!(format!("{}", pushed_filters[1].predicate), expected_filter2); + // Verify the predicates were actually pushed down + assert!(matches!(pushed_filters[0].discriminant, PushedDown::Yes)); + assert!(matches!(pushed_filters[1].discriminant, PushedDown::Yes)); + + Ok(()) + } + + #[test] + fn test_filter_pushdown_with_mixed_columns() -> Result<()> { + let input_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let input = Arc::new(StatisticsExec::new( + Statistics { + column_statistics: vec![Default::default(); input_schema.fields().len()], + ..Default::default() + }, + input_schema.clone(), + )); + + // project "a" as "x", "b" as "b" (pass through) + let projection = ProjectionExec::try_new( + vec![ + ProjectionExpr { + expr: Arc::new(Column::new("a", 0)), + alias: "x".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("b", 1)), + alias: "b".to_string(), + }, + ], + input, + )?; + + // filter "x > 5" + let filter1 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("x", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + + // filter "b < 10" (using output index 1 which corresponds to 'b') + let filter2 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Lt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )) as Arc; + + let description = projection.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![filter1, filter2], + &ConfigOptions::default(), + )?; + + let pushed_filters = &description.parent_filters()[0]; + assert_eq!(pushed_filters.len(), 2); + // "x" -> "a" (index 0) + let expected_filter1 = "a@0 > 5"; + // "b" -> "b" (index 1) + let expected_filter2 = "b@1 < 10"; + + assert_eq!(format!("{}", pushed_filters[0].predicate), expected_filter1); + assert_eq!(format!("{}", pushed_filters[1].predicate), expected_filter2); + // Verify the predicates were actually pushed down + assert!(matches!(pushed_filters[0].discriminant, PushedDown::Yes)); + assert!(matches!(pushed_filters[1].discriminant, PushedDown::Yes)); + + Ok(()) + } + + #[test] + fn test_filter_pushdown_with_complex_expression() -> Result<()> { + let input_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics { + column_statistics: vec![Default::default(); input_schema.fields().len()], + ..Default::default() + }, + input_schema.clone(), + )); + + // project "a + 1" as "z" + let projection = ProjectionExec::try_new( + vec![ProjectionExpr { + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + )), + alias: "z".to_string(), + }], + input, + )?; + + // filter "z > 10" + let filter = Arc::new(BinaryExpr::new( + Arc::new(Column::new("z", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )) as Arc; + + let description = projection.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![filter], + &ConfigOptions::default(), + )?; + + // expand to `a + 1 > 10` + let pushed_filters = &description.parent_filters()[0]; + assert!(matches!(pushed_filters[0].discriminant, PushedDown::Yes)); + assert_eq!(format!("{}", pushed_filters[0].predicate), "a@0 + 1 > 10"); + + Ok(()) + } + + #[test] + fn test_filter_pushdown_with_unknown_column() -> Result<()> { + let input_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics { + column_statistics: vec![Default::default(); input_schema.fields().len()], + ..Default::default() + }, + input_schema.clone(), + )); + + // project "a" as "a" + let projection = ProjectionExec::try_new( + vec![ProjectionExpr { + expr: Arc::new(Column::new("a", 0)), + alias: "a".to_string(), + }], + input, + )?; + + // filter "unknown_col > 5" - using a column name that doesn't exist in projection output + // Column constructor: name, index. Index 1 doesn't exist. + let filter = Arc::new(BinaryExpr::new( + Arc::new(Column::new("unknown_col", 1)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + + let description = projection.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![filter], + &ConfigOptions::default(), + )?; + + let pushed_filters = &description.parent_filters()[0]; + assert!(matches!(pushed_filters[0].discriminant, PushedDown::No)); + // The column shouldn't be found in the alias map, so it remains unchanged with its index + assert_eq!( + format!("{}", pushed_filters[0].predicate), + "unknown_col@1 > 5" + ); + + Ok(()) + } + + /// Basic test for `DynamicFilterPhysicalExpr` can correctly update its child expression + /// i.e. starting with lit(true) and after update it becomes `a > 5` + /// with projection [b - 1 as a], the pushed down filter should be `b - 1 > 5` + #[test] + fn test_basic_dyn_filter_projection_pushdown_update_child() -> Result<()> { + let input_schema = + Arc::new(Schema::new(vec![Field::new("b", DataType::Int32, false)])); + + let input = Arc::new(StatisticsExec::new( + Statistics { + column_statistics: vec![Default::default(); input_schema.fields().len()], + ..Default::default() + }, + input_schema.as_ref().clone(), + )); + + // project "b" - 1 as "a" + let projection = ProjectionExec::try_new( + vec![ProjectionExpr { + expr: binary( + Arc::new(Column::new("b", 0)), + Operator::Minus, + lit(1), + &input_schema, + ) + .unwrap(), + alias: "a".to_string(), + }], + input, + )?; + + // simulate projection's parent create a dynamic filter on "a" + let projected_schema = projection.schema(); + let col_a = col("a", &projected_schema)?; + let dynamic_filter = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::clone(&col_a)], + lit(true), + )); + // Initial state should be lit(true) + let current = dynamic_filter.current()?; + assert_eq!(format!("{current}"), "true"); + + let dyn_phy_expr: Arc = Arc::clone(&dynamic_filter) as _; + + let description = projection.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![dyn_phy_expr], + &ConfigOptions::default(), + )?; + + let pushed_filters = &description.parent_filters()[0][0]; + + // Check currently pushed_filters is lit(true) + assert_eq!( + format!("{}", pushed_filters.predicate), + "DynamicFilter [ empty ]" + ); + + // Update to a > 5 (after projection, b is now called a) + let new_expr = + Arc::new(BinaryExpr::new(Arc::clone(&col_a), Operator::Gt, lit(5i32))); + dynamic_filter.update(new_expr)?; + + // Now it should be a > 5 + let current = dynamic_filter.current()?; + assert_eq!(format!("{current}"), "a@0 > 5"); + + // Check currently pushed_filters is b - 1 > 5 (because b - 1 is projected as a) + assert_eq!( + format!("{}", pushed_filters.predicate), + "DynamicFilter [ b@0 - 1 > 5 ]" + ); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 683dbb4e4976..995aa4822a40 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -24,13 +24,13 @@ use std::task::{Context, Poll}; use super::work_table::{ReservedBatches, WorkTable}; use crate::aggregates::group_values::{GroupValues, new_group_values}; use crate::aggregates::order::GroupOrdering; -use crate::execution_plan::{Boundedness, EmissionType}; +use crate::execution_plan::{Boundedness, EmissionType, reset_plan_states}; use crate::metrics::{ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput, }; use crate::{ DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream, - SendableRecordBatchStream, Statistics, + SendableRecordBatchStream, }; use arrow::array::{BooleanArray, BooleanBuilder}; use arrow::compute::filter_record_batch; @@ -74,7 +74,7 @@ pub struct RecursiveQueryExec { /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, } impl RecursiveQueryExec { @@ -97,7 +97,7 @@ impl RecursiveQueryExec { is_distinct, work_table, metrics: ExecutionPlanMetricsSet::new(), - cache, + cache: Arc::new(cache), }) } @@ -143,7 +143,7 @@ impl ExecutionPlan for RecursiveQueryExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -208,10 +208,6 @@ impl ExecutionPlan for RecursiveQueryExec { fn metrics(&self) -> Option { Some(self.metrics.clone_inner()) } - - fn statistics(&self) -> Result { - Ok(Statistics::new_unknown(&self.schema())) - } } impl DisplayAs for RecursiveQueryExec { @@ -387,20 +383,6 @@ fn assign_work_table( .data() } -/// Some plans will change their internal states after execution, making them unable to be executed again. -/// This function uses [`ExecutionPlan::reset_state`] to reset any internal state within the plan. -/// -/// An example is `CrossJoinExec`, which loads the left table into memory and stores it in the plan. -/// However, if the data of the left table is derived from the work table, it will become outdated -/// as the work table changes. When the next iteration executes this plan again, we must clear the left table. -fn reset_plan_states(plan: Arc) -> Result> { - plan.transform_up(|plan| { - let new_plan = Arc::clone(&plan).reset_state()?; - Ok(Transformed::yes(new_plan)) - }) - .data() -} - impl Stream for RecursiveQueryStream { type Item = Result; diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 1efdaaabc7d6..da4329e2cc2a 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -39,7 +39,10 @@ use crate::sorts::streaming_merge::StreamingMergeBuilder; use crate::spill::spill_manager::SpillManager; use crate::spill::spill_pool::{self, SpillPoolWriter}; use crate::stream::RecordBatchStreamAdapter; -use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics}; +use crate::{ + DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics, + check_if_same_properties, +}; use arrow::array::{PrimitiveArray, RecordBatch, RecordBatchOptions}; use arrow::compute::take_arrays; @@ -48,7 +51,8 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::stats::Precision; use datafusion_common::utils::transpose; use datafusion_common::{ - ColumnStatistics, DataFusionError, HashMap, assert_or_internal_err, internal_err, + ColumnStatistics, DataFusionError, HashMap, assert_or_internal_err, + internal_datafusion_err, internal_err, }; use datafusion_common::{Result, not_impl_err}; use datafusion_common_runtime::SpawnedTask; @@ -421,6 +425,7 @@ enum BatchPartitionerState { exprs: Vec>, num_partitions: usize, hash_buffer: Vec, + indices: Vec>, }, RoundRobin { num_partitions: usize, @@ -434,33 +439,91 @@ pub const REPARTITION_RANDOM_STATE: SeededRandomState = SeededRandomState::with_seeds(0, 0, 0, 0); impl BatchPartitioner { - /// Create a new [`BatchPartitioner`] with the provided [`Partitioning`] + /// Create a new [`BatchPartitioner`] for hash-based repartitioning. + /// + /// # Parameters + /// - `exprs`: Expressions used to compute the hash for each input row. + /// - `num_partitions`: Total number of output partitions. + /// - `timer`: Metric used to record time spent during repartitioning. + /// + /// # Notes + /// This constructor cannot fail and performs no validation. + pub fn new_hash_partitioner( + exprs: Vec>, + num_partitions: usize, + timer: metrics::Time, + ) -> Self { + Self { + state: BatchPartitionerState::Hash { + exprs, + num_partitions, + hash_buffer: vec![], + indices: vec![vec![]; num_partitions], + }, + timer, + } + } + + /// Create a new [`BatchPartitioner`] for round-robin repartitioning. + /// + /// # Parameters + /// - `num_partitions`: Total number of output partitions. + /// - `timer`: Metric used to record time spent during repartitioning. + /// - `input_partition`: Index of the current input partition. + /// - `num_input_partitions`: Total number of input partitions. + /// + /// # Notes + /// The starting output partition is derived from the input partition + /// to avoid skew when multiple input partitions are used. + pub fn new_round_robin_partitioner( + num_partitions: usize, + timer: metrics::Time, + input_partition: usize, + num_input_partitions: usize, + ) -> Self { + Self { + state: BatchPartitionerState::RoundRobin { + num_partitions, + next_idx: (input_partition * num_partitions) / num_input_partitions, + }, + timer, + } + } + /// Create a new [`BatchPartitioner`] based on the provided [`Partitioning`] scheme. + /// + /// This is a convenience constructor that delegates to the specialized + /// hash or round-robin constructors depending on the partitioning variant. /// - /// The time spent repartitioning will be recorded to `timer` + /// # Parameters + /// - `partitioning`: Partitioning scheme to apply (hash or round-robin). + /// - `timer`: Metric used to record time spent during repartitioning. + /// - `input_partition`: Index of the current input partition. + /// - `num_input_partitions`: Total number of input partitions. + /// + /// # Errors + /// Returns an error if the provided partitioning scheme is not supported. pub fn try_new( partitioning: Partitioning, timer: metrics::Time, input_partition: usize, num_input_partitions: usize, ) -> Result { - let state = match partitioning { + match partitioning { + Partitioning::Hash(exprs, num_partitions) => { + Ok(Self::new_hash_partitioner(exprs, num_partitions, timer)) + } Partitioning::RoundRobinBatch(num_partitions) => { - BatchPartitionerState::RoundRobin { + Ok(Self::new_round_robin_partitioner( num_partitions, - // Distribute starting index evenly based on input partition, number of input partitions and number of partitions - // to avoid they all start at partition 0 and heavily skew on the lower partitions - next_idx: ((input_partition * num_partitions) / num_input_partitions), - } + timer, + input_partition, + num_input_partitions, + )) } - Partitioning::Hash(exprs, num_partitions) => BatchPartitionerState::Hash { - exprs, - num_partitions, - hash_buffer: vec![], - }, - other => return not_impl_err!("Unsupported repartitioning scheme {other:?}"), - }; - - Ok(Self { state, timer }) + other => { + not_impl_err!("Unsupported repartitioning scheme {other:?}") + } + } } /// Partition the provided [`RecordBatch`] into one or more partitioned [`RecordBatch`] @@ -505,6 +568,7 @@ impl BatchPartitioner { exprs, num_partitions: partitions, hash_buffer, + indices, } => { // Tracking time required for distributing indexes across output partitions let timer = self.timer.timer(); @@ -521,9 +585,7 @@ impl BatchPartitioner { hash_buffer, )?; - let mut indices: Vec<_> = (0..*partitions) - .map(|_| Vec::with_capacity(batch.num_rows())) - .collect(); + indices.iter_mut().for_each(|v| v.clear()); for (index, hash) in hash_buffer.iter().enumerate() { indices[(*hash % *partitions as u64) as usize].push(index as u32); @@ -534,22 +596,23 @@ impl BatchPartitioner { // Borrowing partitioner timer to prevent moving `self` to closure let partitioner_timer = &self.timer; - let it = indices - .into_iter() - .enumerate() - .filter_map(|(partition, indices)| { - let indices: PrimitiveArray = indices.into(); - (!indices.is_empty()).then_some((partition, indices)) - }) - .map(move |(partition, indices)| { + + let mut partitioned_batches = vec![]; + for (partition, p_indices) in indices.iter_mut().enumerate() { + if !p_indices.is_empty() { + let taken_indices = std::mem::take(p_indices); + let indices_array: PrimitiveArray = + taken_indices.into(); + // Tracking time required for repartitioned batches construction let _timer = partitioner_timer.timer(); // Produce batches based on indices - let columns = take_arrays(batch.columns(), &indices, None)?; + let columns = + take_arrays(batch.columns(), &indices_array, None)?; let mut options = RecordBatchOptions::new(); - options = options.with_row_count(Some(indices.len())); + options = options.with_row_count(Some(indices_array.len())); let batch = RecordBatch::try_new_with_options( batch.schema(), columns, @@ -557,10 +620,22 @@ impl BatchPartitioner { ) .unwrap(); - Ok((partition, batch)) - }); + partitioned_batches.push(Ok((partition, batch))); + + // Return the taken vec + let (_, buffer, _) = indices_array.into_parts(); + let mut vec = + buffer.into_inner().into_vec::().map_err(|e| { + internal_datafusion_err!( + "Could not convert buffer to vec: {e:?}" + ) + })?; + vec.clear(); + *p_indices = vec; + } + } - Box::new(it) + Box::new(partitioned_batches.into_iter()) } }; @@ -674,6 +749,10 @@ impl BatchPartitioner { /// system Paper](https://dl.acm.org/doi/pdf/10.1145/93605.98720) /// which uses the term "Exchange" for the concept of repartitioning /// data across threads. +/// +/// For more background, please also see the [Optimizing Repartitions in DataFusion] blog. +/// +/// [Optimizing Repartitions in DataFusion]: https://datafusion.apache.org/blog/2025/12/15/avoid-consecutive-repartitions #[derive(Debug, Clone)] pub struct RepartitionExec { /// Input execution plan @@ -687,7 +766,7 @@ pub struct RepartitionExec { /// `SortPreservingRepartitionExec`, false means `RepartitionExec`. preserve_order: bool, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, } #[derive(Debug, Clone)] @@ -756,6 +835,18 @@ impl RepartitionExec { pub fn name(&self) -> &str { "RepartitionExec" } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics: ExecutionPlanMetricsSet::new(), + state: Default::default(), + ..Self::clone(self) + } + } } impl DisplayAs for RepartitionExec { @@ -815,7 +906,7 @@ impl ExecutionPlan for RepartitionExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -827,6 +918,7 @@ impl ExecutionPlan for RepartitionExec { self: Arc, mut children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); let mut repartition = RepartitionExec::try_new( children.swap_remove(0), self.partitioning().clone(), @@ -994,10 +1086,6 @@ impl ExecutionPlan for RepartitionExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.input.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { if let Some(partition) = partition { let partition_count = self.partitioning().partition_count(); @@ -1128,7 +1216,7 @@ impl ExecutionPlan for RepartitionExec { _config: &ConfigOptions, ) -> Result>> { use Partitioning::*; - let mut new_properties = self.cache.clone(); + let mut new_properties = PlanProperties::clone(&self.cache); new_properties.partitioning = match new_properties.partitioning { RoundRobinBatch(_) => RoundRobinBatch(target_partitions), Hash(hash, _) => Hash(hash, target_partitions), @@ -1139,7 +1227,7 @@ impl ExecutionPlan for RepartitionExec { state: Arc::clone(&self.state), metrics: self.metrics.clone(), preserve_order: self.preserve_order, - cache: new_properties, + cache: new_properties.into(), }))) } } @@ -1159,7 +1247,7 @@ impl RepartitionExec { state: Default::default(), metrics: ExecutionPlanMetricsSet::new(), preserve_order, - cache, + cache: Arc::new(cache), }) } @@ -1220,7 +1308,7 @@ impl RepartitionExec { // to maintain order self.input.output_partitioning().partition_count() > 1; let eq_properties = Self::eq_properties_helper(&self.input, self.preserve_order); - self.cache = self.cache.with_eq_properties(eq_properties); + Arc::make_mut(&mut self.cache).set_eq_properties(eq_properties); self } @@ -1245,12 +1333,26 @@ impl RepartitionExec { input_partition: usize, num_input_partitions: usize, ) -> Result<()> { - let mut partitioner = BatchPartitioner::try_new( - partitioning, - metrics.repartition_time.clone(), - input_partition, - num_input_partitions, - )?; + let mut partitioner = match &partitioning { + Partitioning::Hash(exprs, num_partitions) => { + BatchPartitioner::new_hash_partitioner( + exprs.clone(), + *num_partitions, + metrics.repartition_time.clone(), + ) + } + Partitioning::RoundRobinBatch(num_partitions) => { + BatchPartitioner::new_round_robin_partitioner( + *num_partitions, + metrics.repartition_time.clone(), + input_partition, + num_input_partitions, + ) + } + other => { + return not_impl_err!("Unsupported repartitioning scheme {other:?}"); + } + }; // While there are still outputs to send to, keep pulling inputs let mut batches_until_yield = partitioner.num_partitions(); @@ -2397,7 +2499,7 @@ mod tests { /// Create vector batches fn create_vec_batches(n: usize) -> Vec { let batch = create_batch(); - (0..n).map(|_| batch.clone()).collect() + std::iter::repeat_n(batch, n).collect() } /// Create batch diff --git a/datafusion/physical-plan/src/sorts/mod.rs b/datafusion/physical-plan/src/sorts/mod.rs index 9c72e34fe343..a73872a175b9 100644 --- a/datafusion/physical-plan/src/sorts/mod.rs +++ b/datafusion/physical-plan/src/sorts/mod.rs @@ -26,3 +26,5 @@ pub mod sort; pub mod sort_preserving_merge; mod stream; pub mod streaming_merge; + +pub(crate) use stream::IncrementalSortIterator; diff --git a/datafusion/physical-plan/src/sorts/partial_sort.rs b/datafusion/physical-plan/src/sorts/partial_sort.rs index 73ba889c9e40..0dbb75f2ef47 100644 --- a/datafusion/physical-plan/src/sorts/partial_sort.rs +++ b/datafusion/physical-plan/src/sorts/partial_sort.rs @@ -62,6 +62,7 @@ use crate::sorts::sort::sort_batch; use crate::{ DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, Partitioning, PlanProperties, SendableRecordBatchStream, Statistics, + check_if_same_properties, }; use arrow::compute::concat_batches; @@ -93,7 +94,7 @@ pub struct PartialSortExec { /// Fetch highest/lowest n results fetch: Option, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, } impl PartialSortExec { @@ -114,7 +115,7 @@ impl PartialSortExec { metrics_set: ExecutionPlanMetricsSet::new(), preserve_partitioning, fetch: None, - cache, + cache: Arc::new(cache), } } @@ -132,12 +133,8 @@ impl PartialSortExec { /// input partitions producing a single, sorted partition. pub fn with_preserve_partitioning(mut self, preserve_partitioning: bool) -> Self { self.preserve_partitioning = preserve_partitioning; - self.cache = self - .cache - .with_partitioning(Self::output_partitioning_helper( - &self.input, - self.preserve_partitioning, - )); + Arc::make_mut(&mut self.cache).partitioning = + Self::output_partitioning_helper(&self.input, self.preserve_partitioning); self } @@ -207,6 +204,17 @@ impl PartialSortExec { input.boundedness(), )) } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics_set: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for PartialSortExec { @@ -255,7 +263,7 @@ impl ExecutionPlan for PartialSortExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -283,6 +291,7 @@ impl ExecutionPlan for PartialSortExec { self: Arc, children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); let new_partial_sort = PartialSortExec::new( self.expr.clone(), Arc::clone(&children[0]), @@ -329,10 +338,6 @@ impl ExecutionPlan for PartialSortExec { Some(self.metrics_set.clone_inner()) } - fn statistics(&self) -> Result { - self.input.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { self.input.partition_statistics(partition) } diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 3e8fdf1f3ed7..ae881dcd4b79 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -27,7 +27,9 @@ use std::sync::Arc; use parking_lot::RwLock; use crate::common::spawn_buffered; -use crate::execution_plan::{Boundedness, CardinalityEffect, EmissionType}; +use crate::execution_plan::{ + Boundedness, CardinalityEffect, EmissionType, has_same_children_properties, +}; use crate::expressions::PhysicalSortExpr; use crate::filter_pushdown::{ ChildFilterDescription, FilterDescription, FilterPushdownPhase, @@ -37,6 +39,7 @@ use crate::metrics::{ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, SpillMetrics, }; use crate::projection::{ProjectionExec, make_with_child, update_ordering}; +use crate::sorts::IncrementalSortIterator; use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; use crate::spill::get_record_batch_memory_size; use crate::spill::in_progress_spill_file::InProgressSpillFile; @@ -709,7 +712,7 @@ impl ExternalSorter { &self, batch: RecordBatch, metrics: &BaselineMetrics, - mut reservation: MemoryReservation, + reservation: MemoryReservation, ) -> Result { assert_eq!( get_reserved_bytes_for_record_batch(&batch)?, @@ -726,39 +729,28 @@ impl ExternalSorter { // Sort the batch immediately and get all output batches let sorted_batches = sort_batch_chunked(&batch, &expressions, batch_size)?; - drop(batch); - // Free the old reservation and grow it to match the actual sorted output size - reservation.free(); + // Resize the reservation to match the actual sorted output size. + // Using try_resize avoids a release-then-reacquire cycle, which + // matters for MemoryPool implementations where grow/shrink have + // non-trivial cost (e.g. JNI calls in Comet). + let total_sorted_size: usize = sorted_batches + .iter() + .map(get_record_batch_memory_size) + .sum(); + reservation + .try_resize(total_sorted_size) + .map_err(Self::err_with_oom_context)?; - Result::<_, DataFusionError>::Ok((schema, sorted_batches, reservation)) - }) - .then({ - move |batches| async move { - match batches { - Ok((schema, sorted_batches, mut reservation)) => { - // Calculate the total size of sorted batches - let total_sorted_size: usize = sorted_batches - .iter() - .map(get_record_batch_memory_size) - .sum(); - reservation - .try_grow(total_sorted_size) - .map_err(Self::err_with_oom_context)?; - - // Wrap in ReservationStream to hold the reservation - Ok(Box::pin(ReservationStream::new( - Arc::clone(&schema), - Box::pin(RecordBatchStreamAdapter::new( - schema, - futures::stream::iter(sorted_batches.into_iter().map(Ok)), - )), - reservation, - )) as SendableRecordBatchStream) - } - Err(e) => Err(e), - } - } + // Wrap in ReservationStream to hold the reservation + Result::<_, DataFusionError>::Ok(Box::pin(ReservationStream::new( + Arc::clone(&schema), + Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + futures::stream::iter(sorted_batches.into_iter().map(Ok)), + )), + reservation, + )) as SendableRecordBatchStream) }) .try_flatten() .map(move |batch| match batch { @@ -819,7 +811,8 @@ impl ExternalSorter { match e { DataFusionError::ResourcesExhausted(_) => e.context( "Not enough memory to continue external sort. \ - Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes" + Consider increasing the memory limit config: 'datafusion.runtime.memory_limit', \ + or decreasing the config: 'datafusion.execution.sort_spill_reservation_bytes'." ), // This is not an OOM error, so just return it as is. _ => e, @@ -850,11 +843,13 @@ pub(crate) fn get_reserved_bytes_for_record_batch_size( /// Estimate how much memory is needed to sort a `RecordBatch`. /// This will just call `get_reserved_bytes_for_record_batch_size` with the /// memory size of the record batch and its sliced size. -pub(super) fn get_reserved_bytes_for_record_batch(batch: &RecordBatch) -> Result { - Ok(get_reserved_bytes_for_record_batch_size( - get_record_batch_memory_size(batch), - batch.get_sliced_size()?, - )) +pub(crate) fn get_reserved_bytes_for_record_batch(batch: &RecordBatch) -> Result { + batch.get_sliced_size().map(|sliced_size| { + get_reserved_bytes_for_record_batch_size( + get_record_batch_memory_size(batch), + sliced_size, + ) + }) } impl Debug for ExternalSorter { @@ -897,38 +892,7 @@ pub fn sort_batch_chunked( expressions: &LexOrdering, batch_size: usize, ) -> Result> { - let sort_columns = expressions - .iter() - .map(|expr| expr.evaluate_to_sort_column(batch)) - .collect::>>()?; - - let indices = lexsort_to_indices(&sort_columns, None)?; - - // Split indices into chunks of batch_size - let num_rows = indices.len(); - let num_chunks = num_rows.div_ceil(batch_size); - - let result_batches = (0..num_chunks) - .map(|chunk_idx| { - let start = chunk_idx * batch_size; - let end = (start + batch_size).min(num_rows); - let chunk_len = end - start; - - // Create a slice of indices for this chunk - let chunk_indices = indices.slice(start, chunk_len); - - // Take the columns using this chunk of indices - let columns = take_arrays(batch.columns(), &chunk_indices, None)?; - - let options = RecordBatchOptions::new().with_row_count(Some(chunk_len)); - let chunk_batch = - RecordBatch::try_new_with_options(batch.schema(), columns, &options)?; - - Ok(chunk_batch) - }) - .collect::>>()?; - - Ok(result_batches) + IncrementalSortIterator::new(batch.clone(), expressions.clone(), batch_size).collect() } /// Sort execution plan. @@ -951,7 +915,7 @@ pub struct SortExec { /// Normalized common sort prefix between the input and the sort expressions (only used with fetch) common_sort_prefix: Vec, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, /// Filter matching the state of the sort for dynamic filter pushdown. /// If `fetch` is `Some`, this will also be set and a TopK operator may be used. /// If `fetch` is `None`, this will be `None`. @@ -973,7 +937,7 @@ impl SortExec { preserve_partitioning, fetch: None, common_sort_prefix: sort_prefix, - cache, + cache: Arc::new(cache), filter: None, } } @@ -992,12 +956,8 @@ impl SortExec { /// input partitions producing a single, sorted partition. pub fn with_preserve_partitioning(mut self, preserve_partitioning: bool) -> Self { self.preserve_partitioning = preserve_partitioning; - self.cache = self - .cache - .with_partitioning(Self::output_partitioning_helper( - &self.input, - self.preserve_partitioning, - )); + Arc::make_mut(&mut self.cache).partitioning = + Self::output_partitioning_helper(&self.input, self.preserve_partitioning); self } @@ -1021,7 +981,7 @@ impl SortExec { preserve_partitioning: self.preserve_partitioning, common_sort_prefix: self.common_sort_prefix.clone(), fetch: self.fetch, - cache: self.cache.clone(), + cache: Arc::clone(&self.cache), filter: self.filter.clone(), } } @@ -1034,12 +994,12 @@ impl SortExec { /// operation since rows that are not going to be included /// can be dropped. pub fn with_fetch(&self, fetch: Option) -> Self { - let mut cache = self.cache.clone(); + let mut cache = PlanProperties::clone(&self.cache); // If the SortExec can emit incrementally (that means the sort requirements // and properties of the input match), the SortExec can generate its result // without scanning the entire input when a fetch value exists. let is_pipeline_friendly = matches!( - self.cache.emission_type, + cache.emission_type, EmissionType::Incremental | EmissionType::Both ); if fetch.is_some() && is_pipeline_friendly { @@ -1051,7 +1011,7 @@ impl SortExec { }); let mut new_sort = self.cloned(); new_sort.fetch = fetch; - new_sort.cache = cache; + new_sort.cache = cache.into(); new_sort.filter = filter; new_sort } @@ -1206,7 +1166,7 @@ impl ExecutionPlan for SortExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -1235,14 +1195,17 @@ impl ExecutionPlan for SortExec { let mut new_sort = self.cloned(); assert_eq!(children.len(), 1, "SortExec should have exactly one child"); new_sort.input = Arc::clone(&children[0]); - // Recompute the properties based on the new input since they may have changed - let (cache, sort_prefix) = Self::compute_properties( - &new_sort.input, - new_sort.expr.clone(), - new_sort.preserve_partitioning, - )?; - new_sort.cache = cache; - new_sort.common_sort_prefix = sort_prefix; + + if !has_same_children_properties(self.as_ref(), &children)? { + // Recompute the properties based on the new input since they may have changed + let (cache, sort_prefix) = Self::compute_properties( + &new_sort.input, + new_sort.expr.clone(), + new_sort.preserve_partitioning, + )?; + new_sort.cache = Arc::new(cache); + new_sort.common_sort_prefix = sort_prefix; + } Ok(Arc::new(new_sort)) } @@ -1352,10 +1315,6 @@ impl ExecutionPlan for SortExec { Some(self.metrics_set.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { if !self.preserve_partitioning() { return self @@ -1414,12 +1373,23 @@ impl ExecutionPlan for SortExec { parent_filters: Vec>, config: &datafusion_common::config::ConfigOptions, ) -> Result { - if !matches!(phase, FilterPushdownPhase::Post) { + if phase != FilterPushdownPhase::Post { + if self.fetch.is_some() { + return Ok(FilterDescription::all_unsupported( + &parent_filters, + &self.children(), + )); + } return FilterDescription::from_children(parent_filters, &self.children()); } - let mut child = - ChildFilterDescription::from_child(&parent_filters, self.input())?; + // In Post phase: block parent filters when fetch is set, + // but still push the TopK dynamic filter (self-filter). + let mut child = if self.fetch.is_some() { + ChildFilterDescription::all_unsupported(&parent_filters) + } else { + ChildFilterDescription::from_child(&parent_filters, self.input())? + }; if let Some(filter) = &self.filter && config.optimizer.enable_topk_dynamic_filter_pushdown @@ -1440,8 +1410,10 @@ mod tests { use super::*; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::collect; + use crate::empty::EmptyExec; use crate::execution_plan::Boundedness; use crate::expressions::col; + use crate::filter_pushdown::{FilterPushdownPhase, PushedDown}; use crate::test; use crate::test::TestMemoryExec; use crate::test::exec::{BlockingExec, assert_strong_count_converges_to_zero}; @@ -1451,6 +1423,7 @@ mod tests { use arrow::compute::SortOptions; use arrow::datatypes::*; use datafusion_common::cast::as_primitive_array; + use datafusion_common::config::ConfigOptions; use datafusion_common::test_util::batches_to_string; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_execution::RecordBatchStream; @@ -1466,7 +1439,7 @@ mod tests { pub struct SortedUnboundedExec { schema: Schema, batch_size: u64, - cache: PlanProperties, + cache: Arc, } impl DisplayAs for SortedUnboundedExec { @@ -1506,7 +1479,7 @@ mod tests { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -1736,6 +1709,21 @@ mod tests { "Assertion failed: expected a ResourcesExhausted error, but got: {err:?}" ); + // Verify external sorter error message when resource is exhausted + let config_vector = vec![ + "datafusion.runtime.memory_limit", + "datafusion.execution.sort_spill_reservation_bytes", + ]; + let error_message = err.message().to_string(); + for config in config_vector.into_iter() { + assert!( + error_message.as_str().contains(config), + "Config: '{}' should be contained in error message: {}.", + config, + error_message.as_str() + ); + } + Ok(()) } @@ -1756,7 +1744,7 @@ mod tests { // The input has 200 partitions, each partition has a batch containing 100 rows. // Each row has a single Utf8 column, the Utf8 string values are roughly 42 bytes. - // The total size of the input is roughly 8.4 KB. + // The total size of the input is roughly 820 KB. let input = test::scan_partitioned_utf8(200); let schema = input.schema(); @@ -2259,7 +2247,9 @@ mod tests { let source = SortedUnboundedExec { schema: schema.clone(), batch_size: 2, - cache: SortedUnboundedExec::compute_properties(Arc::new(schema.clone())), + cache: Arc::new(SortedUnboundedExec::compute_properties(Arc::new( + schema.clone(), + ))), }; let mut plan = SortExec::new( [PhysicalSortExpr::new_default(Arc::new(Column::new( @@ -2715,4 +2705,68 @@ mod tests { Ok(()) } + + fn make_sort_exec_with_fetch(fetch: Option) -> SortExec { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let input = Arc::new(EmptyExec::new(schema)); + SortExec::new( + [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(), + input, + ) + .with_fetch(fetch) + } + + #[test] + fn test_sort_with_fetch_blocks_filter_pushdown() -> Result<()> { + let sort = make_sort_exec_with_fetch(Some(10)); + let desc = sort.gather_filters_for_pushdown( + FilterPushdownPhase::Pre, + vec![Arc::new(Column::new("a", 0))], + &ConfigOptions::new(), + )?; + // Sort with fetch (TopK) must not allow filters to be pushed below it. + assert!(matches!( + desc.parent_filters()[0][0].discriminant, + PushedDown::No + )); + Ok(()) + } + + #[test] + fn test_sort_without_fetch_allows_filter_pushdown() -> Result<()> { + let sort = make_sort_exec_with_fetch(None); + let desc = sort.gather_filters_for_pushdown( + FilterPushdownPhase::Pre, + vec![Arc::new(Column::new("a", 0))], + &ConfigOptions::new(), + )?; + // Plain sort (no fetch) is filter-commutative. + assert!(matches!( + desc.parent_filters()[0][0].discriminant, + PushedDown::Yes + )); + Ok(()) + } + + #[test] + fn test_sort_with_fetch_allows_topk_self_filter_in_post_phase() -> Result<()> { + let sort = make_sort_exec_with_fetch(Some(10)); + assert!(sort.filter.is_some(), "TopK filter should be created"); + + let mut config = ConfigOptions::new(); + config.optimizer.enable_topk_dynamic_filter_pushdown = true; + let desc = sort.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![Arc::new(Column::new("a", 0))], + &config, + )?; + // Parent filters are still blocked in the Post phase. + assert!(matches!( + desc.parent_filters()[0][0].discriminant, + PushedDown::No + )); + // But the TopK self-filter should be pushed down. + assert_eq!(desc.self_filters()[0].len(), 1); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index 4b26f8409950..763b72a66048 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -28,6 +28,7 @@ use crate::sorts::streaming_merge::StreamingMergeBuilder; use crate::{ DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, Partitioning, PlanProperties, SendableRecordBatchStream, Statistics, + check_if_same_properties, }; use datafusion_common::{Result, assert_eq_or_internal_err, internal_err}; @@ -93,7 +94,7 @@ pub struct SortPreservingMergeExec { /// Optional number of rows to fetch. Stops producing rows after this fetch fetch: Option, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, /// Use round-robin selection of tied winners of loser tree /// /// See [`Self::with_round_robin_repartition`] for more information. @@ -109,7 +110,7 @@ impl SortPreservingMergeExec { expr, metrics: ExecutionPlanMetricsSet::new(), fetch: None, - cache, + cache: Arc::new(cache), enable_round_robin_repartition: true, } } @@ -180,6 +181,17 @@ impl SortPreservingMergeExec { .with_evaluation_type(drive) .with_scheduling_type(scheduling) } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for SortPreservingMergeExec { @@ -225,7 +237,7 @@ impl ExecutionPlan for SortPreservingMergeExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -240,11 +252,24 @@ impl ExecutionPlan for SortPreservingMergeExec { expr: self.expr.clone(), metrics: self.metrics.clone(), fetch: limit, - cache: self.cache.clone(), + cache: Arc::clone(&self.cache), enable_round_robin_repartition: true, })) } + fn with_preserve_order( + &self, + preserve_order: bool, + ) -> Option> { + self.input + .with_preserve_order(preserve_order) + .and_then(|new_input| { + Arc::new(self.clone()) + .with_new_children(vec![new_input]) + .ok() + }) + } + fn required_input_distribution(&self) -> Vec { vec![Distribution::UnspecifiedDistribution] } @@ -267,10 +292,11 @@ impl ExecutionPlan for SortPreservingMergeExec { fn with_new_children( self: Arc, - children: Vec>, + mut children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); Ok(Arc::new( - SortPreservingMergeExec::new(self.expr.clone(), Arc::clone(&children[0])) + SortPreservingMergeExec::new(self.expr.clone(), children.swap_remove(0)) .with_fetch(self.fetch), )) } @@ -359,10 +385,6 @@ impl ExecutionPlan for SortPreservingMergeExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.input.partition_statistics(None) - } - fn partition_statistics(&self, _partition: Option) -> Result { self.input.partition_statistics(None) } @@ -408,7 +430,6 @@ mod tests { use std::time::Duration; use super::*; - use crate::coalesce_batches::CoalesceBatchesExec; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::execution_plan::{Boundedness, EmissionType}; use crate::expressions::col; @@ -444,11 +465,14 @@ mod tests { // The number in the function is highly related to the memory limit we are testing // any change of the constant should be aware of - fn generate_task_ctx_for_round_robin_tie_breaker() -> Result> { + fn generate_task_ctx_for_round_robin_tie_breaker( + target_batch_size: usize, + ) -> Result> { let runtime = RuntimeEnvBuilder::new() .with_memory_limit(20_000_000, 1.0) .build_arc()?; - let config = SessionConfig::new(); + let mut config = SessionConfig::new(); + config.options_mut().execution.batch_size = target_batch_size; let task_ctx = TaskContext::default() .with_runtime(runtime) .with_session_config(config); @@ -459,16 +483,14 @@ mod tests { fn generate_spm_for_round_robin_tie_breaker( enable_round_robin_repartition: bool, ) -> Result> { - let target_batch_size = 12500; let row_size = 12500; let a: ArrayRef = Arc::new(Int32Array::from(vec![1; row_size])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"); row_size])); let c: ArrayRef = Arc::new(Int64Array::from_iter(vec![0; row_size])); let rb = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)])?; - - let rbs = (0..1024).map(|_| rb.clone()).collect::>(); - let schema = rb.schema(); + + let rbs = std::iter::repeat_n(rb, 1024).collect::>(); let sort = [ PhysicalSortExpr { expr: col("b", &schema)?, @@ -485,9 +507,7 @@ mod tests { TestMemoryExec::try_new_exec(&[rbs], schema, None)?, Partitioning::RoundRobinBatch(2), )?; - let coalesce_batches_exec = - CoalesceBatchesExec::new(Arc::new(repartition_exec), target_batch_size); - let spm = SortPreservingMergeExec::new(sort, Arc::new(coalesce_batches_exec)) + let spm = SortPreservingMergeExec::new(sort, Arc::new(repartition_exec)) .with_round_robin_repartition(enable_round_robin_repartition); Ok(Arc::new(spm)) } @@ -499,7 +519,8 @@ mod tests { /// based on whether the tie breaker is enabled or disabled. #[tokio::test(flavor = "multi_thread")] async fn test_round_robin_tie_breaker_success() -> Result<()> { - let task_ctx = generate_task_ctx_for_round_robin_tie_breaker()?; + let target_batch_size = 12500; + let task_ctx = generate_task_ctx_for_round_robin_tie_breaker(target_batch_size)?; let spm = generate_spm_for_round_robin_tie_breaker(true)?; let _collected = collect(spm, task_ctx).await?; Ok(()) @@ -512,7 +533,7 @@ mod tests { /// based on whether the tie breaker is enabled or disabled. #[tokio::test(flavor = "multi_thread")] async fn test_round_robin_tie_breaker_fail() -> Result<()> { - let task_ctx = generate_task_ctx_for_round_robin_tie_breaker()?; + let task_ctx = generate_task_ctx_for_round_robin_tie_breaker(8192)?; let spm = generate_spm_for_round_robin_tie_breaker(false)?; let _err = collect(spm, task_ctx).await.unwrap_err(); Ok(()) @@ -1350,7 +1371,7 @@ mod tests { #[derive(Debug, Clone)] struct CongestedExec { schema: Schema, - cache: PlanProperties, + cache: Arc, congestion: Arc, } @@ -1386,7 +1407,7 @@ mod tests { fn as_any(&self) -> &dyn Any { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } fn children(&self) -> Vec<&Arc> { @@ -1479,7 +1500,7 @@ mod tests { }; let source = CongestedExec { schema: schema.clone(), - cache: properties, + cache: Arc::new(properties), congestion: Arc::new(Congestion::new(partition_count)), }; let spm = SortPreservingMergeExec::new( diff --git a/datafusion/physical-plan/src/sorts/stream.rs b/datafusion/physical-plan/src/sorts/stream.rs index a510f44e4f4d..ff7f259dd134 100644 --- a/datafusion/physical-plan/src/sorts/stream.rs +++ b/datafusion/physical-plan/src/sorts/stream.rs @@ -18,16 +18,20 @@ use crate::SendableRecordBatchStream; use crate::sorts::cursor::{ArrayValues, CursorArray, RowValues}; use crate::{PhysicalExpr, PhysicalSortExpr}; -use arrow::array::Array; +use arrow::array::{Array, UInt32Array}; +use arrow::compute::take_record_batch; use arrow::datatypes::Schema; use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, Rows, SortField}; +use arrow_ord::sort::lexsort_to_indices; use datafusion_common::{Result, internal_datafusion_err}; use datafusion_execution::memory_pool::MemoryReservation; use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays; use futures::stream::{Fuse, StreamExt}; +use std::iter::FusedIterator; use std::marker::PhantomData; +use std::mem; use std::sync::Arc; use std::task::{Context, Poll, ready}; @@ -103,7 +107,7 @@ impl ReusableRows { self.inner[stream_idx][1] = Some(Arc::clone(rows)); // swap the current with the previous one, so that the next poll can reuse the Rows from the previous poll let [a, b] = &mut self.inner[stream_idx]; - std::mem::swap(a, b); + mem::swap(a, b); } } @@ -180,7 +184,7 @@ impl RowCursorStream { self.rows.save(stream_idx, &rows); // track the memory in the newly created Rows. - let mut rows_reservation = self.reservation.new_empty(); + let rows_reservation = self.reservation.new_empty(); rows_reservation.try_grow(rows.size())?; Ok(RowValues::new(rows, rows_reservation)) } @@ -246,7 +250,7 @@ impl FieldCursorStream { let array = value.into_array(batch.num_rows())?; let size_in_mem = array.get_buffer_memory_size(); let array = array.as_any().downcast_ref::().expect("field values"); - let mut array_reservation = self.reservation.new_empty(); + let array_reservation = self.reservation.new_empty(); array_reservation.try_grow(size_in_mem)?; Ok(ArrayValues::new( self.sort.options, @@ -276,3 +280,159 @@ impl PartitionedStream for FieldCursorStream { })) } } + +/// A lazy, memory-efficient sort iterator used as a fallback during aggregate +/// spill when there is not enough memory for an eager sort (which requires ~2x +/// peak memory to hold both the unsorted and sorted copies simultaneously). +/// +/// On the first call to `next()`, a sorted index array (`UInt32Array`) is +/// computed via `lexsort_to_indices`. Subsequent calls yield chunks of +/// `batch_size` rows by `take`-ing from the original batch using slices of +/// this index array. Each `take` copies data for the chunk (not zero-copy), +/// but only one chunk is live at a time since the caller consumes it before +/// requesting the next. Once all rows have been yielded, the original batch +/// and index array are dropped to free memory. +/// +/// The caller must reserve `sizeof(batch) + sizeof(one chunk)` for this iterator, +/// and free the reservation once the iterator is depleted. +pub(crate) struct IncrementalSortIterator { + batch: RecordBatch, + expressions: LexOrdering, + batch_size: usize, + indices: Option, + cursor: usize, +} + +impl IncrementalSortIterator { + pub(crate) fn new( + batch: RecordBatch, + expressions: LexOrdering, + batch_size: usize, + ) -> Self { + Self { + batch, + expressions, + batch_size, + cursor: 0, + indices: None, + } + } +} + +impl Iterator for IncrementalSortIterator { + type Item = Result; + + fn next(&mut self) -> Option { + if self.cursor >= self.batch.num_rows() { + return None; + } + + match self.indices.as_ref() { + None => { + let sort_columns = match self + .expressions + .iter() + .map(|expr| expr.evaluate_to_sort_column(&self.batch)) + .collect::>>() + { + Ok(cols) => cols, + Err(e) => return Some(Err(e)), + }; + + let indices = match lexsort_to_indices(&sort_columns, None) { + Ok(indices) => indices, + Err(e) => return Some(Err(e.into())), + }; + self.indices = Some(indices); + + // Call again, this time it will hit the Some(indices) branch and return the first batch + self.next() + } + Some(indices) => { + let batch_size = self.batch_size.min(self.batch.num_rows() - self.cursor); + + // Perform the take to produce the next batch + let new_batch_indices = indices.slice(self.cursor, batch_size); + let new_batch = match take_record_batch(&self.batch, &new_batch_indices) { + Ok(batch) => batch, + Err(e) => return Some(Err(e.into())), + }; + + self.cursor += batch_size; + + // If this is the last batch, we can release the memory + if self.cursor >= self.batch.num_rows() { + let schema = self.batch.schema(); + let _ = mem::replace(&mut self.batch, RecordBatch::new_empty(schema)); + self.indices = None; + } + + // Return the new batch + Some(Ok(new_batch)) + } + } + } + + fn size_hint(&self) -> (usize, Option) { + let num_rows = self.batch.num_rows(); + let batch_size = self.batch_size; + let num_batches = num_rows.div_ceil(batch_size); + (num_batches, Some(num_batches)) + } +} + +impl FusedIterator for IncrementalSortIterator {} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{AsArray, Int32Array}; + use arrow::datatypes::{DataType, Field, Int32Type}; + use datafusion_common::DataFusionError; + use datafusion_physical_expr::expressions::col; + + /// Verifies that `take_record_batch` in `IncrementalSortIterator` actually + /// copies the data into a new allocation rather than returning a zero-copy + /// slice of the original batch. If the output arrays were slices, their + /// underlying buffer length would match the original array's length; a true + /// copy will have a buffer sized to fit only the chunk. + #[test] + fn incremental_sort_iterator_copies_data() -> Result<()> { + let original_len = 10; + let batch_size = 3; + + // Build a batch with a single Int32 column of descending values + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let col_a: Int32Array = Int32Array::from(vec![0; original_len]); + let batch = RecordBatch::try_new(schema, vec![Arc::new(col_a)])?; + + // Sort ascending on column "a" + let expressions = LexOrdering::new(vec![PhysicalSortExpr::new_default(col( + "a", + &batch.schema(), + )?)]) + .unwrap(); + + let mut total_rows = 0; + IncrementalSortIterator::new(batch.clone(), expressions, batch_size).try_for_each( + |result| { + let chunk = result?; + total_rows += chunk.num_rows(); + + // Every output column must be a fresh allocation whose length + // equals the chunk size, NOT the original array length. + chunk.columns().iter().zip(batch.columns()).for_each(|(arr, original_arr)| { + let (_, scalar_buf, _) = arr.as_primitive::().clone().into_parts(); + let (_, original_scalar_buf, _) = original_arr.as_primitive::().clone().into_parts(); + + assert_ne!(scalar_buf.inner().data_ptr(), original_scalar_buf.inner().data_ptr(), "Expected a copy of the data for each chunk, but got a slice that shares the same buffer as the original array"); + }); + + Result::<_, DataFusionError>::Ok(()) + }, + )?; + + assert_eq!(total_rows, original_len); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/spill/in_progress_spill_file.rs b/datafusion/physical-plan/src/spill/in_progress_spill_file.rs index d2acf4993b85..9084ea449d6b 100644 --- a/datafusion/physical-plan/src/spill/in_progress_spill_file.rs +++ b/datafusion/physical-plan/src/spill/in_progress_spill_file.rs @@ -62,8 +62,12 @@ impl InProgressSpillFile { )); } if self.writer.is_none() { - let schema = batch.schema(); - if let Some(ref in_progress_file) = self.in_progress_file { + // Use the SpillManager's declared schema rather than the batch's schema. + // Individual batches may have different schemas (e.g., different nullability) + // when they come from different branches of a UnionExec. The SpillManager's + // schema represents the canonical schema that all batches should conform to. + let schema = self.spill_writer.schema(); + if let Some(in_progress_file) = &mut self.in_progress_file { self.writer = Some(IPCStreamWriter::new( in_progress_file.path(), schema.as_ref(), @@ -72,18 +76,38 @@ impl InProgressSpillFile { // Update metrics self.spill_writer.metrics.spill_file_count.add(1); + + // Update initial size (schema/header) + in_progress_file.update_disk_usage()?; + let initial_size = in_progress_file.current_disk_usage(); + self.spill_writer + .metrics + .spilled_bytes + .add(initial_size as usize); } } if let Some(writer) = &mut self.writer { let (spilled_rows, _) = writer.write(batch)?; if let Some(in_progress_file) = &mut self.in_progress_file { + let pre_size = in_progress_file.current_disk_usage(); in_progress_file.update_disk_usage()?; + let post_size = in_progress_file.current_disk_usage(); + + self.spill_writer.metrics.spilled_rows.add(spilled_rows); + self.spill_writer + .metrics + .spilled_bytes + .add((post_size - pre_size) as usize); } else { unreachable!() // Already checked inside current function } + } + Ok(()) + } - // Update metrics - self.spill_writer.metrics.spilled_rows.add(spilled_rows); + pub fn flush(&mut self) -> Result<()> { + if let Some(writer) = &mut self.writer { + writer.flush()?; } Ok(()) } @@ -106,11 +130,89 @@ impl InProgressSpillFile { // Since spill files are append-only, add the file size to spilled_bytes if let Some(in_progress_file) = &mut self.in_progress_file { // Since writer.finish() writes continuation marker and message length at the end + let pre_size = in_progress_file.current_disk_usage(); in_progress_file.update_disk_usage()?; - let size = in_progress_file.current_disk_usage(); - self.spill_writer.metrics.spilled_bytes.add(size as usize); + let post_size = in_progress_file.current_disk_usage(); + self.spill_writer + .metrics + .spilled_bytes + .add((post_size - pre_size) as usize); } Ok(self.in_progress_file.take()) } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Int64Array; + use arrow_schema::{DataType, Field, Schema}; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; + use datafusion_physical_expr_common::metrics::{ + ExecutionPlanMetricsSet, SpillMetrics, + }; + use futures::TryStreamExt; + + #[tokio::test] + async fn test_spill_file_uses_spill_manager_schema() -> Result<()> { + let nullable_schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::Int64, false), + Field::new("val", DataType::Int64, true), + ])); + let non_nullable_schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::Int64, false), + Field::new("val", DataType::Int64, false), + ])); + + let runtime = Arc::new(RuntimeEnvBuilder::new().build()?); + let metrics_set = ExecutionPlanMetricsSet::new(); + let spill_metrics = SpillMetrics::new(&metrics_set, 0); + let spill_manager = Arc::new(SpillManager::new( + runtime, + spill_metrics, + Arc::clone(&nullable_schema), + )); + + let mut in_progress = spill_manager.create_in_progress_file("test")?; + + // First batch: non-nullable val (simulates literal-0 UNION branch) + let non_nullable_batch = RecordBatch::try_new( + Arc::clone(&non_nullable_schema), + vec![ + Arc::new(Int64Array::from(vec![1, 2, 3])), + Arc::new(Int64Array::from(vec![0, 0, 0])), + ], + )?; + in_progress.append_batch(&non_nullable_batch)?; + + // Second batch: nullable val with NULLs (simulates table UNION branch) + let nullable_batch = RecordBatch::try_new( + Arc::clone(&nullable_schema), + vec![ + Arc::new(Int64Array::from(vec![4, 5, 6])), + Arc::new(Int64Array::from(vec![Some(10), None, Some(30)])), + ], + )?; + in_progress.append_batch(&nullable_batch)?; + + let spill_file = in_progress.finish()?.unwrap(); + + let stream = spill_manager.read_spill_as_stream(spill_file, None)?; + + // Stream schema should be nullable + assert_eq!(stream.schema(), nullable_schema); + + let batches = stream.try_collect::>().await?; + assert_eq!(batches.len(), 2); + + // Both batches must have the SpillManager's nullable schema + assert_eq!( + batches[0], + non_nullable_batch.with_schema(Arc::clone(&nullable_schema))? + ); + assert_eq!(batches[1], nullable_batch); + + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/spill/mod.rs b/datafusion/physical-plan/src/spill/mod.rs index 78dea99ac820..f6ce546a4223 100644 --- a/datafusion/physical-plan/src/spill/mod.rs +++ b/datafusion/physical-plan/src/spill/mod.rs @@ -49,7 +49,7 @@ use datafusion_common_runtime::SpawnedTask; use datafusion_execution::RecordBatchStream; use datafusion_execution::disk_manager::RefCountedTempFile; use futures::{FutureExt as _, Stream}; -use log::warn; +use log::debug; /// Stream that reads spill files from disk where each batch is read in a spawned blocking task /// It will read one batch at a time and will not do any buffering, to buffer data use [`crate::common::spawn_buffered`] @@ -154,7 +154,7 @@ impl SpillReaderStream { > max_record_batch_memory + SPILL_BATCH_MEMORY_MARGIN { - warn!( + debug!( "Record batch memory usage ({actual_size} bytes) exceeds the expected limit ({max_record_batch_memory} bytes) \n\ by more than the allowed tolerance ({SPILL_BATCH_MEMORY_MARGIN} bytes).\n\ This likely indicates a bug in memory accounting during spilling.\n\ @@ -310,6 +310,11 @@ impl IPCStreamWriter { Ok((delta_num_rows, delta_num_bytes)) } + pub fn flush(&mut self) -> Result<()> { + self.writer.flush()?; + Ok(()) + } + /// Finish the writer pub fn finish(&mut self) -> Result<()> { self.writer.finish().map_err(Into::into) @@ -472,11 +477,12 @@ mod tests { let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); let spill_manager = SpillManager::new(env, metrics, Arc::clone(&schema)); + let row_batches: Vec = + (0..batch1.num_rows()).map(|i| batch1.slice(i, 1)).collect(); let (spill_file, max_batch_mem) = spill_manager - .spill_record_batch_by_size_and_return_max_batch_memory( - &batch1, + .spill_record_batch_iter_and_return_max_batch_memory( + row_batches.iter().map(Ok), "Test Spill", - 1, )? .unwrap(); assert!(spill_file.path().exists()); @@ -685,13 +691,13 @@ mod tests { Arc::new(StringArray::from(vec!["d", "e", "f"])), ], )?; - // After appending each batch, spilled_rows should increase, while spill_file_count and - // spilled_bytes remain the same (spilled_bytes is updated only after finish() is called) + // After appending each batch, spilled_rows and spilled_bytes should increase incrementally, + // while spill_file_count remains 1 (since we're writing to the same file) in_progress_file.append_batch(&batch1)?; - verify_metrics(&in_progress_file, 1, 0, 3)?; + verify_metrics(&in_progress_file, 1, 440, 3)?; in_progress_file.append_batch(&batch2)?; - verify_metrics(&in_progress_file, 1, 0, 6)?; + verify_metrics(&in_progress_file, 1, 704, 6)?; let completed_file = in_progress_file.finish()?; assert!(completed_file.is_some()); @@ -726,7 +732,7 @@ mod tests { let completed_file = spill_manager.spill_record_batch_and_finish(&[], "Test")?; assert!(completed_file.is_none()); - // Test write empty batch with interface `spill_record_batch_by_size_and_return_max_batch_memory()` + // Test write empty batch with interface `spill_record_batch_iter_and_return_max_batch_memory()` let empty_batch = RecordBatch::try_new( Arc::clone(&schema), vec![ @@ -735,10 +741,9 @@ mod tests { ], )?; let completed_file = spill_manager - .spill_record_batch_by_size_and_return_max_batch_memory( - &empty_batch, + .spill_record_batch_iter_and_return_max_batch_memory( + std::iter::once(Ok(&empty_batch)), "Test", - 1, )?; assert!(completed_file.is_none()); @@ -799,4 +804,70 @@ mod tests { assert_eq!(alignment, 8); Ok(()) } + #[tokio::test] + async fn test_real_time_spill_metrics() -> Result<()> { + let env = Arc::new(RuntimeEnv::default()); + let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + ])); + + let spill_manager = Arc::new(SpillManager::new( + Arc::clone(&env), + metrics.clone(), + Arc::clone(&schema), + )); + let mut in_progress_file = spill_manager.create_in_progress_file("Test")?; + + let batch1 = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ], + )?; + + // Before any batch, metrics should be 0 + assert_eq!(metrics.spilled_bytes.value(), 0); + assert_eq!(metrics.spill_file_count.value(), 0); + + // Append first batch + in_progress_file.append_batch(&batch1)?; + + // Metrics should be updated immediately (at least schema and first batch) + let bytes_after_batch1 = metrics.spilled_bytes.value(); + assert_eq!(bytes_after_batch1, 440); + assert_eq!(metrics.spill_file_count.value(), 1); + + // Check global progress + let progress = env.spilling_progress(); + assert_eq!(progress.current_bytes, bytes_after_batch1 as u64); + assert_eq!(progress.active_files_count, 1); + + // Append another batch + in_progress_file.append_batch(&batch1)?; + let bytes_after_batch2 = metrics.spilled_bytes.value(); + assert!(bytes_after_batch2 > bytes_after_batch1); + + // Check global progress again + let progress = env.spilling_progress(); + assert_eq!(progress.current_bytes, bytes_after_batch2 as u64); + + // Finish the file + let spilled_file = in_progress_file.finish()?; + let final_bytes = metrics.spilled_bytes.value(); + assert!(final_bytes > bytes_after_batch2); + + // Even after finish, file is still "active" until dropped + let progress = env.spilling_progress(); + assert!(progress.current_bytes > 0); + assert_eq!(progress.active_files_count, 1); + + drop(spilled_file); + assert_eq!(env.spilling_progress().active_files_count, 0); + assert_eq!(env.spilling_progress().current_bytes, 0); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/spill/spill_manager.rs b/datafusion/physical-plan/src/spill/spill_manager.rs index 89b027620677..07ba6d3989bc 100644 --- a/datafusion/physical-plan/src/spill/spill_manager.rs +++ b/datafusion/physical-plan/src/spill/spill_manager.rs @@ -17,19 +17,20 @@ //! Define the `SpillManager` struct, which is responsible for reading and writing `RecordBatch`es to raw files based on the provided configurations. +use super::{SpillReaderStream, in_progress_spill_file::InProgressSpillFile}; +use crate::coop::cooperative; +use crate::{common::spawn_buffered, metrics::SpillMetrics}; use arrow::array::StringViewArray; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use datafusion_common::{Result, config::SpillCompression}; +use datafusion_common::utils::memory::get_record_batch_memory_size; +use datafusion_common::{DataFusionError, Result, config::SpillCompression}; use datafusion_execution::SendableRecordBatchStream; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::runtime_env::RuntimeEnv; +use std::borrow::Borrow; use std::sync::Arc; -use super::{SpillReaderStream, in_progress_spill_file::InProgressSpillFile}; -use crate::coop::cooperative; -use crate::{common::spawn_buffered, metrics::SpillMetrics}; - /// The `SpillManager` is responsible for the following tasks: /// - Reading and writing `RecordBatch`es to raw files based on the provided configurations. /// - Updating the associated metrics. @@ -109,39 +110,29 @@ impl SpillManager { in_progress_file.finish() } - /// Refer to the documentation for [`Self::spill_record_batch_and_finish`]. This method - /// additionally spills the `RecordBatch` into smaller batches, divided by `row_limit`. - /// - /// # Errors - /// - Returns an error if spilling would exceed the disk usage limit configured - /// by `max_temp_directory_size` in `DiskManager` - pub(crate) fn spill_record_batch_by_size_and_return_max_batch_memory( + /// Spill an iterator of `RecordBatch`es to disk and return the spill file and the size of the largest batch in memory + /// Note that this expects the caller to provide *non-sliced* batches, so the memory calculation of each batch is accurate. + pub(crate) fn spill_record_batch_iter_and_return_max_batch_memory( &self, - batch: &RecordBatch, + mut iter: impl Iterator>>, request_description: &str, - row_limit: usize, ) -> Result> { - let total_rows = batch.num_rows(); - let mut batches = Vec::new(); - let mut offset = 0; - - // It's ok to calculate all slices first, because slicing is zero-copy. - while offset < total_rows { - let length = std::cmp::min(total_rows - offset, row_limit); - let sliced_batch = batch.slice(offset, length); - batches.push(sliced_batch); - offset += length; - } - let mut in_progress_file = self.create_in_progress_file(request_description)?; let mut max_record_batch_size = 0; - for batch in batches { - in_progress_file.append_batch(&batch)?; + iter.try_for_each(|batch| { + let batch = batch?; + let borrowed = batch.borrow(); + if borrowed.num_rows() == 0 { + return Ok(()); + } + in_progress_file.append_batch(borrowed)?; - max_record_batch_size = max_record_batch_size.max(batch.get_sliced_size()?); - } + max_record_batch_size = + max_record_batch_size.max(get_record_batch_memory_size(borrowed)); + Result::<_, DataFusionError>::Ok(()) + })?; let file = in_progress_file.finish()?; @@ -188,6 +179,19 @@ impl SpillManager { Ok(spawn_buffered(stream, self.batch_read_buffer_capacity)) } + + /// Same as `read_spill_as_stream`, but without buffering. + pub fn read_spill_as_stream_unbuffered( + &self, + spill_file_path: RefCountedTempFile, + max_record_batch_memory: Option, + ) -> Result { + Ok(Box::pin(cooperative(SpillReaderStream::new( + Arc::clone(&self.schema), + spill_file_path, + max_record_batch_memory, + )))) + } } pub(crate) trait GetSlicedSize { diff --git a/datafusion/physical-plan/src/spill/spill_pool.rs b/datafusion/physical-plan/src/spill/spill_pool.rs index e3b547b5731f..2777b753bb37 100644 --- a/datafusion/physical-plan/src/spill/spill_pool.rs +++ b/datafusion/physical-plan/src/spill/spill_pool.rs @@ -61,6 +61,10 @@ struct SpillPoolShared { /// Writer's reference to the current file (shared by all cloned writers). /// Has its own lock to allow I/O without blocking queue access. current_write_file: Option>>, + /// Number of active writer clones. Only when this reaches zero should + /// `writer_dropped` be set to true. This prevents premature EOF signaling + /// when one writer clone is dropped while others are still active. + active_writer_count: usize, } impl SpillPoolShared { @@ -72,6 +76,7 @@ impl SpillPoolShared { waker: None, writer_dropped: false, current_write_file: None, + active_writer_count: 1, } } @@ -97,7 +102,6 @@ impl SpillPoolShared { /// The writer automatically manages file rotation based on the `max_file_size_bytes` /// configured in [`channel`]. When the last writer clone is dropped, it finalizes the /// current file so readers can access all written data. -#[derive(Clone)] pub struct SpillPoolWriter { /// Maximum size in bytes before rotating to a new file. /// Typically set from configuration `datafusion.execution.max_spill_file_size_bytes`. @@ -106,6 +110,18 @@ pub struct SpillPoolWriter { shared: Arc>, } +impl Clone for SpillPoolWriter { + fn clone(&self) -> Self { + // Increment the active writer count so that `writer_dropped` is only + // set to true when the *last* clone is dropped. + self.shared.lock().active_writer_count += 1; + Self { + max_file_size_bytes: self.max_file_size_bytes, + shared: Arc::clone(&self.shared), + } + } +} + impl SpillPoolWriter { /// Spills a batch to the pool, rotating files when necessary. /// @@ -194,6 +210,8 @@ impl SpillPoolWriter { // Append the batch if let Some(ref mut writer) = file_shared.writer { writer.append_batch(batch)?; + // make sure we flush the writer for readers + writer.flush()?; file_shared.batches_written += 1; file_shared.estimated_size += batch_size; } @@ -231,6 +249,15 @@ impl Drop for SpillPoolWriter { fn drop(&mut self) { let mut shared = self.shared.lock(); + shared.active_writer_count -= 1; + let is_last_writer = shared.active_writer_count == 0; + + if !is_last_writer { + // Other writer clones are still active; do not finalize or + // signal EOF to readers. + return; + } + // Finalize the current file when the last writer is dropped if let Some(current_file) = shared.current_write_file.take() { // Release shared lock before locking file @@ -535,7 +562,11 @@ impl Stream for SpillFile { // Step 2: Lazy-create reader stream if needed if self.reader.is_none() && should_read { if let Some(file) = file { - match self.spill_manager.read_spill_as_stream(file, None) { + // we want this unbuffered because files are actively being written to + match self + .spill_manager + .read_spill_as_stream_unbuffered(file, None) + { Ok(stream) => { self.reader = Some(SpillFileReader { stream, @@ -879,8 +910,8 @@ mod tests { ); assert_eq!( metrics.spilled_bytes.value(), - 0, - "Spilled bytes should be 0 before file finalization" + 320, + "Spilled bytes should reflect data written (header + 1 batch)" ); assert_eq!( metrics.spilled_rows.value(), @@ -1300,11 +1331,11 @@ mod tests { writer.push_batch(&batch)?; } - // Check metrics before drop - spilled_bytes should be 0 since file isn't finalized yet + // Check metrics before drop - spilled_bytes already reflects written data let spilled_bytes_before = metrics.spilled_bytes.value(); assert_eq!( - spilled_bytes_before, 0, - "Spilled bytes should be 0 before writer is dropped" + spilled_bytes_before, 1088, + "Spilled bytes should reflect data written (header + 5 batches)" ); // Explicitly drop the writer - this should finalize the current file @@ -1337,6 +1368,81 @@ mod tests { Ok(()) } + /// Verifies that the reader stays alive as long as any writer clone exists. + /// + /// `SpillPoolWriter` is `Clone`, and in non-preserve-order repartitioning + /// mode multiple input partition tasks share clones of the same writer. + /// The reader must not see EOF until **all** clones have been dropped, + /// even if the queue is temporarily empty between writes from different + /// clones. + /// + /// The test sequence is: + /// + /// 1. writer1 writes a batch, then is dropped. + /// 2. The reader consumes that batch (queue is now empty). + /// 3. writer2 (still alive) writes a batch. + /// 4. The reader must see that batch. + /// 5. EOF is only signalled after writer2 is also dropped. + #[tokio::test] + async fn test_clone_drop_does_not_signal_eof_prematurely() -> Result<()> { + let (writer1, mut reader) = create_spill_channel(1024 * 1024); + let writer2 = writer1.clone(); + + // Synchronization: tell writer2 when it may proceed. + let (proceed_tx, proceed_rx) = tokio::sync::oneshot::channel::<()>(); + + // Spawn writer2 — it waits for the signal before writing. + let writer2_handle = SpawnedTask::spawn(async move { + proceed_rx.await.unwrap(); + writer2.push_batch(&create_test_batch(10, 10)).unwrap(); + // writer2 is dropped here (last clone → true EOF) + }); + + // Writer1 writes one batch, then drops. + writer1.push_batch(&create_test_batch(0, 10))?; + drop(writer1); + + // Read writer1's batch. + let batch1 = reader.next().await.unwrap()?; + assert_eq!(batch1.num_rows(), 10); + let col = batch1 + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col.value(0), 0); + + // Signal writer2 to write its batch. It will execute when the + // current task yields (i.e. when reader.next() returns Pending). + proceed_tx.send(()).unwrap(); + + // The reader should wait (Pending) for writer2's data, not EOF. + let batch2 = + tokio::time::timeout(std::time::Duration::from_secs(5), reader.next()) + .await + .expect("Reader timed out — should not hang"); + + assert!( + batch2.is_some(), + "Reader must not return EOF while a writer clone is still alive" + ); + let batch2 = batch2.unwrap()?; + assert_eq!(batch2.num_rows(), 10); + let col = batch2 + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col.value(0), 10); + + writer2_handle.await.unwrap(); + + // All writers dropped — reader should see real EOF now. + assert!(reader.next().await.is_none()); + + Ok(()) + } + #[tokio::test] async fn test_disk_usage_decreases_as_files_consumed() -> Result<()> { use datafusion_execution::runtime_env::RuntimeEnvBuilder; diff --git a/datafusion/physical-plan/src/stream.rs b/datafusion/physical-plan/src/stream.rs index 80c2233d05db..4b7e707fcced 100644 --- a/datafusion/physical-plan/src/stream.rs +++ b/datafusion/physical-plan/src/stream.rs @@ -1005,7 +1005,7 @@ mod test { .build_arc() .unwrap(); - let mut reservation = MemoryConsumer::new("test").register(&runtime.memory_pool); + let reservation = MemoryConsumer::new("test").register(&runtime.memory_pool); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); @@ -1071,7 +1071,7 @@ mod test { .build_arc() .unwrap(); - let mut reservation = MemoryConsumer::new("test").register(&runtime.memory_pool); + let reservation = MemoryConsumer::new("test").register(&runtime.memory_pool); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); diff --git a/datafusion/physical-plan/src/streaming.rs b/datafusion/physical-plan/src/streaming.rs index c8b8d95718cb..153548237411 100644 --- a/datafusion/physical-plan/src/streaming.rs +++ b/datafusion/physical-plan/src/streaming.rs @@ -67,7 +67,7 @@ pub struct StreamingTableExec { projected_output_ordering: Vec, infinite: bool, limit: Option, - cache: PlanProperties, + cache: Arc, metrics: ExecutionPlanMetricsSet, } @@ -111,7 +111,7 @@ impl StreamingTableExec { projected_output_ordering, infinite, limit, - cache, + cache: Arc::new(cache), metrics: ExecutionPlanMetricsSet::new(), }) } @@ -236,7 +236,7 @@ impl ExecutionPlan for StreamingTableExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -335,7 +335,7 @@ impl ExecutionPlan for StreamingTableExec { projected_output_ordering: self.projected_output_ordering.clone(), infinite: self.infinite, limit, - cache: self.cache.clone(), + cache: Arc::clone(&self.cache), metrics: self.metrics.clone(), })) } diff --git a/datafusion/physical-plan/src/test.rs b/datafusion/physical-plan/src/test.rs index c94b5a413139..0e7b900eb6fc 100644 --- a/datafusion/physical-plan/src/test.rs +++ b/datafusion/physical-plan/src/test.rs @@ -75,7 +75,7 @@ pub struct TestMemoryExec { /// The maximum number of records to read from this plan. If `None`, /// all records after filtering are returned. fetch: Option, - cache: PlanProperties, + cache: Arc, } impl DisplayAs for TestMemoryExec { @@ -134,7 +134,7 @@ impl ExecutionPlan for TestMemoryExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -146,7 +146,7 @@ impl ExecutionPlan for TestMemoryExec { self: Arc, _: Vec>, ) -> Result> { - unimplemented!() + Ok(self) } fn repartitioned( @@ -169,10 +169,6 @@ impl ExecutionPlan for TestMemoryExec { unimplemented!() } - fn statistics(&self) -> Result { - self.statistics_inner() - } - fn partition_statistics(&self, partition: Option) -> Result { if partition.is_some() { Ok(Statistics::new_unknown(&self.schema)) @@ -239,7 +235,7 @@ impl TestMemoryExec { Ok(Self { partitions: partitions.to_vec(), schema, - cache: PlanProperties::new( + cache: Arc::new(PlanProperties::new( EquivalenceProperties::new_with_orderings( Arc::clone(&projected_schema), Vec::::new(), @@ -247,7 +243,7 @@ impl TestMemoryExec { Partitioning::UnknownPartitioning(partitions.len()), EmissionType::Incremental, Boundedness::Bounded, - ), + )), projected_schema, projection, sort_information: vec![], @@ -265,7 +261,7 @@ impl TestMemoryExec { ) -> Result> { let mut source = Self::try_new(partitions, schema, projection)?; let cache = source.compute_properties(); - source.cache = cache; + source.cache = Arc::new(cache); Ok(Arc::new(source)) } @@ -273,7 +269,7 @@ impl TestMemoryExec { pub fn update_cache(source: &Arc) -> TestMemoryExec { let cache = source.compute_properties(); let mut source = (**source).clone(); - source.cache = cache; + source.cache = Arc::new(cache); source } @@ -342,7 +338,7 @@ impl TestMemoryExec { } self.sort_information = sort_information; - self.cache = self.compute_properties(); + self.cache = Arc::new(self.compute_properties()); Ok(self) } diff --git a/datafusion/physical-plan/src/test/exec.rs b/datafusion/physical-plan/src/test/exec.rs index 4507cccba05a..d628fb819f85 100644 --- a/datafusion/physical-plan/src/test/exec.rs +++ b/datafusion/physical-plan/src/test/exec.rs @@ -17,13 +17,6 @@ //! Simple iterator over batches for use in testing -use std::{ - any::Any, - pin::Pin, - sync::{Arc, Weak}, - task::{Context, Poll}, -}; - use crate::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, common, @@ -33,6 +26,13 @@ use crate::{ execution_plan::EmissionType, stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}, }; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::{ + any::Any, + pin::Pin, + sync::{Arc, Weak}, + task::{Context, Poll}, +}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; @@ -125,7 +125,7 @@ pub struct MockExec { /// if true (the default), sends data using a separate task to ensure the /// batches are not available without this stream yielding first use_task: bool, - cache: PlanProperties, + cache: Arc, } impl MockExec { @@ -142,7 +142,7 @@ impl MockExec { data, schema, use_task: true, - cache, + cache: Arc::new(cache), } } @@ -192,7 +192,7 @@ impl ExecutionPlan for MockExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -254,10 +254,6 @@ impl ExecutionPlan for MockExec { } // Panics if one of the batches is an error - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { if partition.is_some() { return Ok(Statistics::new_unknown(&self.schema)); @@ -298,29 +294,91 @@ pub struct BarrierExec { schema: SchemaRef, /// all streams wait on this barrier to produce - barrier: Arc, - cache: PlanProperties, + start_data_barrier: Option>, + + /// the stream wait for this to return Poll::Ready(None) + finish_barrier: Option>, + + cache: Arc, + + log: bool, } impl BarrierExec { /// Create a new exec with some number of partitions. pub fn new(data: Vec>, schema: SchemaRef) -> Self { // wait for all streams and the input - let barrier = Arc::new(Barrier::new(data.len() + 1)); + let barrier = Some(Arc::new(Barrier::new(data.len() + 1))); let cache = Self::compute_properties(Arc::clone(&schema), &data); Self { data, schema, - barrier, - cache, + start_data_barrier: barrier, + cache: Arc::new(cache), + finish_barrier: None, + log: true, } } + pub fn with_log(mut self, log: bool) -> Self { + self.log = log; + self + } + + pub fn without_start_barrier(mut self) -> Self { + self.start_data_barrier = None; + self + } + + pub fn with_finish_barrier(mut self) -> Self { + let barrier = Arc::new(( + // wait for all streams and the input + Barrier::new(self.data.len() + 1), + AtomicUsize::new(0), + )); + + self.finish_barrier = Some(barrier); + self + } + /// wait until all the input streams and this function is ready pub async fn wait(&self) { - println!("BarrierExec::wait waiting on barrier"); - self.barrier.wait().await; - println!("BarrierExec::wait done waiting"); + let barrier = &self + .start_data_barrier + .as_ref() + .expect("Must only be called when having a start barrier"); + if self.log { + println!("BarrierExec::wait waiting on barrier"); + } + barrier.wait().await; + if self.log { + println!("BarrierExec::wait done waiting"); + } + } + + pub async fn wait_finish(&self) { + let (barrier, _) = &self + .finish_barrier + .as_deref() + .expect("Must only be called when having a finish barrier"); + + if self.log { + println!("BarrierExec::wait_finish waiting on barrier"); + } + barrier.wait().await; + if self.log { + println!("BarrierExec::wait_finish done waiting"); + } + } + + /// Return true if the finish barrier has been reached in all partitions + pub fn is_finish_barrier_reached(&self) -> bool { + let (_, reached_finish) = self + .finish_barrier + .as_deref() + .expect("Must only be called when having finish barrier"); + + reached_finish.load(Ordering::Relaxed) == self.data.len() } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -364,7 +422,7 @@ impl ExecutionPlan for BarrierExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -391,17 +449,32 @@ impl ExecutionPlan for BarrierExec { // task simply sends data in order after barrier is reached let data = self.data[partition].clone(); - let b = Arc::clone(&self.barrier); + let start_barrier = self.start_data_barrier.as_ref().map(Arc::clone); + let finish_barrier = self.finish_barrier.as_ref().map(Arc::clone); + let log = self.log; let tx = builder.tx(); builder.spawn(async move { - println!("Partition {partition} waiting on barrier"); - b.wait().await; + if let Some(barrier) = start_barrier { + if log { + println!("Partition {partition} waiting on barrier"); + } + barrier.wait().await; + } for batch in data { - println!("Partition {partition} sending batch"); + if log { + println!("Partition {partition} sending batch"); + } if let Err(e) = tx.send(Ok(batch)).await { println!("ERROR batch via barrier stream stream: {e}"); } } + if let Some((barrier, reached_finish)) = finish_barrier.as_deref() { + if log { + println!("Partition {partition} waiting on finish barrier"); + } + reached_finish.fetch_add(1, Ordering::Relaxed); + barrier.wait().await; + } Ok(()) }); @@ -410,10 +483,6 @@ impl ExecutionPlan for BarrierExec { Ok(builder.build()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { if partition.is_some() { return Ok(Statistics::new_unknown(&self.schema)); @@ -429,7 +498,7 @@ impl ExecutionPlan for BarrierExec { /// A mock execution plan that errors on a call to execute #[derive(Debug)] pub struct ErrorExec { - cache: PlanProperties, + cache: Arc, } impl Default for ErrorExec { @@ -446,7 +515,9 @@ impl ErrorExec { true, )])); let cache = Self::compute_properties(schema); - Self { cache } + Self { + cache: Arc::new(cache), + } } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -487,7 +558,7 @@ impl ExecutionPlan for ErrorExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -517,7 +588,7 @@ impl ExecutionPlan for ErrorExec { pub struct StatisticsExec { stats: Statistics, schema: Arc, - cache: PlanProperties, + cache: Arc, } impl StatisticsExec { pub fn new(stats: Statistics, schema: Schema) -> Self { @@ -530,7 +601,7 @@ impl StatisticsExec { Self { stats, schema: Arc::new(schema), - cache, + cache: Arc::new(cache), } } @@ -577,7 +648,7 @@ impl ExecutionPlan for StatisticsExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -600,10 +671,6 @@ impl ExecutionPlan for StatisticsExec { unimplemented!("This plan only serves for testing statistics") } - fn statistics(&self) -> Result { - Ok(self.stats.clone()) - } - fn partition_statistics(&self, partition: Option) -> Result { Ok(if partition.is_some() { Statistics::new_unknown(&self.schema) @@ -623,7 +690,7 @@ pub struct BlockingExec { /// Ref-counting helper to check if the plan and the produced stream are still in memory. refs: Arc<()>, - cache: PlanProperties, + cache: Arc, } impl BlockingExec { @@ -633,7 +700,7 @@ impl BlockingExec { Self { schema, refs: Default::default(), - cache, + cache: Arc::new(cache), } } @@ -684,7 +751,7 @@ impl ExecutionPlan for BlockingExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -766,7 +833,7 @@ pub struct PanicExec { /// Number of output partitions. Each partition will produce this /// many empty output record batches prior to panicking batches_until_panics: Vec, - cache: PlanProperties, + cache: Arc, } impl PanicExec { @@ -778,7 +845,7 @@ impl PanicExec { Self { schema, batches_until_panics, - cache, + cache: Arc::new(cache), } } @@ -830,7 +897,7 @@ impl ExecutionPlan for PanicExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index ebac497f4fbc..e0b91f25161c 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -131,6 +131,9 @@ pub struct TopK { pub(crate) finished: bool, } +/// For more background, please also see the [Dynamic Filters: Passing Information Between Operators During Execution for 25x Faster Queries blog] +/// +/// [Dynamic Filters: Passing Information Between Operators During Execution for 25x Faster Queries blog]: https://datafusion.apache.org/blog/2025/09/10/dynamic-filters #[derive(Debug, Clone)] pub struct TopKDynamicFilters { /// The current *global* threshold for the dynamic filter. @@ -721,8 +724,8 @@ impl TopKHeap { let row = row.as_ref(); // Reuse storage for evicted item if possible - let new_top_k = if self.inner.len() == self.k { - let prev_min = self.inner.pop().unwrap(); + if self.inner.len() == self.k { + let mut prev_min = self.inner.peek_mut().unwrap(); // Update batch use if prev_min.batch_id == batch_entry.id { @@ -733,15 +736,16 @@ impl TopKHeap { // update memory accounting self.owned_bytes -= prev_min.owned_size(); - prev_min.with_new_row(row, batch_id, index) - } else { - TopKRow::new(row, batch_id, index) - }; - self.owned_bytes += new_top_k.owned_size(); + prev_min.replace_with(row, batch_id, index); - // put the new row into the heap - self.inner.push(new_top_k) + self.owned_bytes += prev_min.owned_size(); + } else { + let new_row = TopKRow::new(row, batch_id, index); + self.owned_bytes += new_row.owned_size(); + // put the new row into the heap + self.inner.push(new_row); + }; } /// Returns the values stored in this heap, from values low to @@ -908,26 +912,13 @@ impl TopKRow { } } - /// Create a new TopKRow reusing the existing allocation - fn with_new_row( - self, - new_row: impl AsRef<[u8]>, - batch_id: u32, - index: usize, - ) -> Self { - let Self { - mut row, - batch_id: _, - index: _, - } = self; - row.clear(); - row.extend_from_slice(new_row.as_ref()); + // Replace the existing row capacity with new values + fn replace_with(&mut self, new_row: impl AsRef<[u8]>, batch_id: u32, index: usize) { + self.row.clear(); + self.row.extend_from_slice(new_row.as_ref()); - Self { - row, - batch_id, - index, - } + self.batch_id = batch_id; + self.index = index; } /// Returns the number of bytes owned by this row in the heap (not diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index d27c81b96849..9fc02e730d02 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -32,11 +32,16 @@ use super::{ SendableRecordBatchStream, Statistics, metrics::{ExecutionPlanMetricsSet, MetricsSet}, }; +use crate::check_if_same_properties; use crate::execution_plan::{ InvariantLevel, boundedness_from_children, check_default_invariants, emission_type_from_children, }; -use crate::filter_pushdown::{FilterDescription, FilterPushdownPhase}; +use crate::filter::FilterExec; +use crate::filter_pushdown::{ + ChildPushdownResult, FilterDescription, FilterPushdownPhase, + FilterPushdownPropagation, PushedDown, +}; use crate::metrics::BaselineMetrics; use crate::projection::{ProjectionExec, make_with_child}; use crate::stream::ObservedStream; @@ -49,7 +54,9 @@ use datafusion_common::{ Result, assert_or_internal_err, exec_err, internal_datafusion_err, }; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr, calculate_union}; +use datafusion_physical_expr::{ + EquivalenceProperties, PhysicalExpr, calculate_union, conjunction, +}; use futures::Stream; use itertools::Itertools; @@ -100,7 +107,7 @@ pub struct UnionExec { /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, } impl UnionExec { @@ -118,7 +125,7 @@ impl UnionExec { UnionExec { inputs, metrics: ExecutionPlanMetricsSet::new(), - cache, + cache: Arc::new(cache), } } @@ -147,7 +154,7 @@ impl UnionExec { Ok(Arc::new(UnionExec { inputs, metrics: ExecutionPlanMetricsSet::new(), - cache, + cache: Arc::new(cache), })) } } @@ -183,6 +190,17 @@ impl UnionExec { boundedness_from_children(inputs), )) } + + fn with_new_children_and_same_properties( + &self, + children: Vec>, + ) -> Self { + Self { + inputs: children, + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for UnionExec { @@ -210,7 +228,7 @@ impl ExecutionPlan for UnionExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -259,6 +277,7 @@ impl ExecutionPlan for UnionExec { self: Arc, children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); UnionExec::try_new(children) } @@ -304,10 +323,6 @@ impl ExecutionPlan for UnionExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { if let Some(partition_idx) = partition { // For a specific partition, find which input it belongs to @@ -370,6 +385,83 @@ impl ExecutionPlan for UnionExec { ) -> Result { FilterDescription::from_children(parent_filters, &self.children()) } + + fn handle_child_pushdown_result( + &self, + phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + // Pre phase: handle heterogeneous pushdown by wrapping individual + // children with FilterExec and reporting all filters as handled. + // Post phase: use default behavior to let the filter creator decide how to handle + // filters that weren't fully pushed down. + if phase != FilterPushdownPhase::Pre { + return Ok(FilterPushdownPropagation::if_all(child_pushdown_result)); + } + + // UnionExec needs specialized filter pushdown handling when children have + // heterogeneous pushdown support. Without this, when some children support + // pushdown and others don't, the default behavior would leave FilterExec + // above UnionExec, re-applying filters to outputs of all children—including + // those that already applied the filters via pushdown. This specialized + // implementation adds FilterExec only to children that don't support + // pushdown, avoiding redundant filtering and improving performance. + // + // Example: Given Child1 (no pushdown support) and Child2 (has pushdown support) + // Default behavior: This implementation: + // FilterExec UnionExec + // UnionExec FilterExec + // Child1 Child1 + // Child2(filter) Child2(filter) + + // Collect unsupported filters for each child + let mut unsupported_filters_per_child = vec![Vec::new(); self.inputs.len()]; + for parent_filter_result in child_pushdown_result.parent_filters.iter() { + for (child_idx, &child_result) in + parent_filter_result.child_results.iter().enumerate() + { + if matches!(child_result, PushedDown::No) { + unsupported_filters_per_child[child_idx] + .push(Arc::clone(&parent_filter_result.filter)); + } + } + } + + // Wrap children that have unsupported filters with FilterExec + let mut new_children = self.inputs.clone(); + for (child_idx, unsupported_filters) in + unsupported_filters_per_child.iter().enumerate() + { + if !unsupported_filters.is_empty() { + let combined_filter = conjunction(unsupported_filters.clone()); + new_children[child_idx] = Arc::new(FilterExec::try_new( + combined_filter, + Arc::clone(&self.inputs[child_idx]), + )?); + } + } + + // Check if any children were modified + let children_modified = new_children + .iter() + .zip(self.inputs.iter()) + .any(|(new, old)| !Arc::ptr_eq(new, old)); + + let all_filters_pushed = + vec![PushedDown::Yes; child_pushdown_result.parent_filters.len()]; + let propagation = if children_modified { + let updated_node = UnionExec::try_new(new_children)?; + FilterPushdownPropagation::with_parent_pushdown_result(all_filters_pushed) + .with_updated_node(updated_node) + } else { + FilterPushdownPropagation::with_parent_pushdown_result(all_filters_pushed) + }; + + // Report all parent filters as supported since we've ensured they're applied + // on all children (either pushed down or via FilterExec) + Ok(propagation) + } } /// Combines multiple input streams by interleaving them. @@ -411,7 +503,7 @@ pub struct InterleaveExec { /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, } impl InterleaveExec { @@ -425,7 +517,7 @@ impl InterleaveExec { Ok(InterleaveExec { inputs, metrics: ExecutionPlanMetricsSet::new(), - cache, + cache: Arc::new(cache), }) } @@ -447,6 +539,17 @@ impl InterleaveExec { boundedness_from_children(inputs), )) } + + fn with_new_children_and_same_properties( + &self, + children: Vec>, + ) -> Self { + Self { + inputs: children, + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for InterleaveExec { @@ -474,7 +577,7 @@ impl ExecutionPlan for InterleaveExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -495,6 +598,7 @@ impl ExecutionPlan for InterleaveExec { can_interleave(children.iter()), "Can not create InterleaveExec: new children can not be interleaved" ); + check_if_same_properties!(self, children); Ok(Arc::new(InterleaveExec::try_new(children)?)) } @@ -545,10 +649,6 @@ impl ExecutionPlan for InterleaveExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { let stats = self .inputs @@ -593,8 +693,20 @@ fn union_schema(inputs: &[Arc]) -> Result { } let first_schema = inputs[0].schema(); + let first_field_count = first_schema.fields().len(); + + // validate that all inputs have the same number of fields + for (idx, input) in inputs.iter().enumerate().skip(1) { + let field_count = input.schema().fields().len(); + if field_count != first_field_count { + return exec_err!( + "UnionExec/InterleaveExec requires all inputs to have the same number of fields. \ + Input 0 has {first_field_count} fields, but input {idx} has {field_count} fields" + ); + } + } - let fields = (0..first_schema.fields().len()) + let fields = (0..first_field_count) .map(|i| { // We take the name from the left side of the union to match how names are coerced during logical planning, // which also uses the left side names. @@ -763,6 +875,18 @@ mod tests { Ok(schema) } + fn create_test_schema2() -> Result { + let a = Field::new("a", DataType::Int32, true); + let b = Field::new("b", DataType::Int32, true); + let c = Field::new("c", DataType::Int32, true); + let d = Field::new("d", DataType::Int32, true); + let e = Field::new("e", DataType::Int32, true); + let f = Field::new("f", DataType::Int32, true); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f])); + + Ok(schema) + } + #[tokio::test] async fn test_union_partitions() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); @@ -1052,4 +1176,23 @@ mod tests { Ok(()) } + + #[test] + fn test_union_schema_mismatch() { + // Test that UnionExec properly rejects inputs with different field counts + let schema = create_test_schema().unwrap(); + let schema2 = create_test_schema2().unwrap(); + let memory_exec1 = + Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None).unwrap()); + let memory_exec2 = + Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema2), None).unwrap()); + + let result = UnionExec::try_new(vec![memory_exec1, memory_exec2]); + assert!(result.is_err()); + assert!( + result.unwrap_err().to_string().contains( + "UnionExec/InterleaveExec requires all inputs to have the same number of fields" + ) + ); + } } diff --git a/datafusion/physical-plan/src/unnest.rs b/datafusion/physical-plan/src/unnest.rs index 5fef754e8078..422a9dd0d32b 100644 --- a/datafusion/physical-plan/src/unnest.rs +++ b/datafusion/physical-plan/src/unnest.rs @@ -28,7 +28,7 @@ use super::metrics::{ use super::{DisplayAs, ExecutionPlanProperties, PlanProperties}; use crate::{ DisplayFormatType, Distribution, ExecutionPlan, RecordBatchStream, - SendableRecordBatchStream, + SendableRecordBatchStream, check_if_same_properties, }; use arrow::array::{ @@ -74,7 +74,7 @@ pub struct UnnestExec { /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, } impl UnnestExec { @@ -100,7 +100,7 @@ impl UnnestExec { struct_column_indices, options, metrics: Default::default(), - cache, + cache: Arc::new(cache), }) } @@ -193,6 +193,17 @@ impl UnnestExec { pub fn options(&self) -> &UnnestOptions { &self.options } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for UnnestExec { @@ -221,7 +232,7 @@ impl ExecutionPlan for UnnestExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -231,10 +242,11 @@ impl ExecutionPlan for UnnestExec { fn with_new_children( self: Arc, - children: Vec>, + mut children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); Ok(Arc::new(UnnestExec::new( - Arc::clone(&children[0]), + children.swap_remove(0), self.list_column_indices.clone(), self.struct_column_indices.clone(), Arc::clone(&self.schema), diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 987a400ec369..a31268b9c685 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -36,7 +36,7 @@ use crate::windows::{ use crate::{ ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, InputOrderMode, PlanProperties, RecordBatchStream, - SendableRecordBatchStream, Statistics, WindowExpr, + SendableRecordBatchStream, Statistics, WindowExpr, check_if_same_properties, }; use arrow::compute::take_record_batch; @@ -93,7 +93,7 @@ pub struct BoundedWindowAggExec { // See `get_ordered_partition_by_indices` for more details. ordered_partition_by_indices: Vec, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, /// If `can_rerepartition` is false, partition_keys is always empty. can_repartition: bool, } @@ -134,7 +134,7 @@ impl BoundedWindowAggExec { metrics: ExecutionPlanMetricsSet::new(), input_order_mode, ordered_partition_by_indices, - cache, + cache: Arc::new(cache), can_repartition, }) } @@ -248,6 +248,17 @@ impl BoundedWindowAggExec { total_byte_size: Precision::Absent, }) } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for BoundedWindowAggExec { @@ -304,7 +315,7 @@ impl ExecutionPlan for BoundedWindowAggExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -339,6 +350,7 @@ impl ExecutionPlan for BoundedWindowAggExec { self: Arc, children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); Ok(Arc::new(BoundedWindowAggExec::try_new( self.window_expr.clone(), Arc::clone(&children[0]), @@ -368,10 +380,6 @@ impl ExecutionPlan for BoundedWindowAggExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { let input_stat = self.input.partition_statistics(partition)?; self.statistics_helper(input_stat) diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index d0e1eab09987..b72a65cf996b 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -226,6 +226,18 @@ impl WindowUDFExpr { pub fn fun(&self) -> &Arc { &self.fun } + + /// Returns all arguments passed to this window function. + /// + /// Unlike [`StandardWindowFunctionExpr::expressions`], which returns + /// only the expressions that need batch evaluation (and may filter out + /// literal offset/default args like those for `lead`/`lag`), this + /// method returns the complete, unfiltered argument list. This is + /// needed for serialization so that all arguments survive a + /// protobuf round-trip. + pub fn args(&self) -> &[Arc] { + &self.args + } } impl StandardWindowFunctionExpr for WindowUDFExpr { diff --git a/datafusion/physical-plan/src/windows/window_agg_exec.rs b/datafusion/physical-plan/src/windows/window_agg_exec.rs index aa99f4f49885..0a146d51d62d 100644 --- a/datafusion/physical-plan/src/windows/window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/window_agg_exec.rs @@ -32,7 +32,7 @@ use crate::windows::{ use crate::{ ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties, RecordBatchStream, - SendableRecordBatchStream, Statistics, WindowExpr, + SendableRecordBatchStream, Statistics, WindowExpr, check_if_same_properties, }; use arrow::array::ArrayRef; @@ -65,7 +65,7 @@ pub struct WindowAggExec { // see `get_ordered_partition_by_indices` for more details. ordered_partition_by_indices: Vec, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, /// If `can_partition` is false, partition_keys is always empty. can_repartition: bool, } @@ -89,7 +89,7 @@ impl WindowAggExec { schema, metrics: ExecutionPlanMetricsSet::new(), ordered_partition_by_indices, - cache, + cache: Arc::new(cache), can_repartition, }) } @@ -158,6 +158,17 @@ impl WindowAggExec { .unwrap_or_else(Vec::new) } } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for WindowAggExec { @@ -206,7 +217,7 @@ impl ExecutionPlan for WindowAggExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -242,11 +253,12 @@ impl ExecutionPlan for WindowAggExec { fn with_new_children( self: Arc, - children: Vec>, + mut children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); Ok(Arc::new(WindowAggExec::try_new( self.window_expr.clone(), - Arc::clone(&children[0]), + children.swap_remove(0), true, )?)) } @@ -272,10 +284,6 @@ impl ExecutionPlan for WindowAggExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { let input_stat = self.input.partition_statistics(partition)?; let win_cols = self.window_expr.len(); diff --git a/datafusion/physical-plan/src/work_table.rs b/datafusion/physical-plan/src/work_table.rs index f1b9e3e88d12..4c7f77e0ff98 100644 --- a/datafusion/physical-plan/src/work_table.rs +++ b/datafusion/physical-plan/src/work_table.rs @@ -109,7 +109,7 @@ pub struct WorkTableExec { /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, } impl WorkTableExec { @@ -129,7 +129,7 @@ impl WorkTableExec { projection, work_table: Arc::new(WorkTable::new(name)), metrics: ExecutionPlanMetricsSet::new(), - cache, + cache: Arc::new(cache), }) } @@ -181,7 +181,7 @@ impl ExecutionPlan for WorkTableExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -231,10 +231,6 @@ impl ExecutionPlan for WorkTableExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - Ok(Statistics::new_unknown(&self.schema())) - } - fn partition_statistics(&self, _partition: Option) -> Result { Ok(Statistics::new_unknown(&self.schema())) } @@ -263,7 +259,7 @@ impl ExecutionPlan for WorkTableExec { projection: self.projection.clone(), metrics: ExecutionPlanMetricsSet::new(), work_table, - cache: self.cache.clone(), + cache: Arc::clone(&self.cache), })) } } @@ -283,7 +279,7 @@ mod tests { assert!(work_table.take().is_err()); let pool = Arc::new(UnboundedMemoryPool::default()) as _; - let mut reservation = MemoryConsumer::new("test_work_table").register(&pool); + let reservation = MemoryConsumer::new("test_work_table").register(&pool); // Update batch to work_table let array: ArrayRef = Arc::new((0..5).collect::()); diff --git a/datafusion/proto-common/gen/Cargo.toml b/datafusion/proto-common/gen/Cargo.toml index 2d2557811d0d..f0e60819d42a 100644 --- a/datafusion/proto-common/gen/Cargo.toml +++ b/datafusion/proto-common/gen/Cargo.toml @@ -37,5 +37,5 @@ workspace = true [dependencies] # Pin these dependencies so that the generated output is deterministic -pbjson-build = "=0.8.0" -prost-build = "=0.14.1" +pbjson-build = "=0.9.0" +prost-build = "=0.14.3" diff --git a/datafusion/proto-common/proto/datafusion_common.proto b/datafusion/proto-common/proto/datafusion_common.proto index 08bb25bd715b..62c6bbe85612 100644 --- a/datafusion/proto-common/proto/datafusion_common.proto +++ b/datafusion/proto-common/proto/datafusion_common.proto @@ -183,6 +183,11 @@ message Map { bool keys_sorted = 2; } +message RunEndEncoded { + Field run_ends_field = 1; + Field values_field = 2; +} + enum UnionMode{ sparse = 0; dense = 1; @@ -236,6 +241,12 @@ message ScalarDictionaryValue { ScalarValue value = 2; } +message ScalarRunEndEncodedValue { + Field run_ends_field = 1; + Field values_field = 2; + ScalarValue value = 3; +} + message IntervalDayTimeValue { int32 days = 1; int32 milliseconds = 2; @@ -321,6 +332,8 @@ message ScalarValue{ IntervalMonthDayNanoValue interval_month_day_nano = 31; ScalarFixedSizeBinary fixed_size_binary_value = 34; UnionValue union_value = 42; + + ScalarRunEndEncodedValue run_end_encoded_value = 45; } } @@ -389,6 +402,7 @@ message ArrowType{ Union UNION = 29; Dictionary DICTIONARY = 30; Map MAP = 33; + RunEndEncoded RUN_END_ENCODED = 42; } } @@ -469,6 +483,7 @@ message JsonOptions { CompressionTypeVariant compression = 1; // Compression type optional uint64 schema_infer_max_rec = 2; // Optional max records for schema inference optional uint32 compression_level = 3; // Optional compression level + optional bool newline_delimited = 4; // Whether to read as newline-delimited JSON (default true). When false, expects JSON array format [{},...] } message TableParquetOptions { diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index e8e71c388458..ca8a269958d7 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -28,7 +28,12 @@ use arrow::datatypes::{ DataType, Field, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit, UnionFields, UnionMode, i256, }; -use arrow::ipc::{reader::read_record_batch, root_as_message}; +use arrow::ipc::{ + convert::fb_to_schema, + reader::{read_dictionary, read_record_batch}, + root_as_message, + writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions}, +}; use datafusion_common::{ Column, ColumnStatistics, Constraint, Constraints, DFSchema, DFSchemaRef, @@ -304,13 +309,16 @@ impl TryFrom<&protobuf::arrow_type::ArrowTypeEnum> for DataType { }; let union_fields = parse_proto_fields_to_fields(&union.union_types)?; - // Default to index based type ids if not provided - let type_ids: Vec<_> = match union.type_ids.is_empty() { - true => (0..union_fields.len() as i8).collect(), - false => union.type_ids.iter().map(|i| *i as i8).collect(), + // Default to index based type ids if not explicitly provided + let union_fields = if union.type_ids.is_empty() { + UnionFields::from_fields(union_fields) + } else { + let type_ids = union.type_ids.iter().map(|i| *i as i8); + UnionFields::try_new(type_ids, union_fields).map_err(|e| { + DataFusionError::from(e).context("Deserializing Union DataType") + })? }; - - DataType::Union(UnionFields::new(type_ids, union_fields), union_mode) + DataType::Union(union_fields, union_mode) } arrow_type::ArrowTypeEnum::Dictionary(dict) => { let key_datatype = dict.as_ref().key.as_deref().required("key")?; @@ -323,6 +331,19 @@ impl TryFrom<&protobuf::arrow_type::ArrowTypeEnum> for DataType { let keys_sorted = map.keys_sorted; DataType::Map(Arc::new(field), keys_sorted) } + arrow_type::ArrowTypeEnum::RunEndEncoded(run_end_encoded) => { + let run_ends_field: Field = run_end_encoded + .as_ref() + .run_ends_field + .as_deref() + .required("run_ends_field")?; + let value_field: Field = run_end_encoded + .as_ref() + .values_field + .as_deref() + .required("values_field")?; + DataType::RunEndEncoded(run_ends_field.into(), value_field.into()) + } }) } } @@ -381,7 +402,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Value::Float32Value(v) => Self::Float32(Some(*v)), Value::Float64Value(v) => Self::Float64(Some(*v)), Value::Date32Value(v) => Self::Date32(Some(*v)), - // ScalarValue::List is serialized using arrow IPC format + // Nested ScalarValue types are serialized using arrow IPC format Value::ListValue(v) | Value::FixedSizeListValue(v) | Value::LargeListValue(v) @@ -398,55 +419,83 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { schema_ref.try_into()? } else { return Err(Error::General( - "Invalid schema while deserializing ScalarValue::List" + "Invalid schema while deserializing nested ScalarValue" .to_string(), )); }; + // IPC dictionary batch IDs are assigned when encoding the schema, but our protobuf + // `Schema` doesn't preserve those IDs. Reconstruct them deterministically by + // round-tripping the schema through IPC. + let schema: Schema = { + let ipc_gen = IpcDataGenerator {}; + let write_options = IpcWriteOptions::default(); + let mut dict_tracker = DictionaryTracker::new(false); + let encoded_schema = ipc_gen.schema_to_bytes_with_dictionary_tracker( + &schema, + &mut dict_tracker, + &write_options, + ); + let message = + root_as_message(encoded_schema.ipc_message.as_slice()).map_err( + |e| { + Error::General(format!( + "Error IPC schema message while deserializing nested ScalarValue: {e}" + )) + }, + )?; + let ipc_schema = message.header_as_schema().ok_or_else(|| { + Error::General( + "Unexpected message type deserializing nested ScalarValue schema" + .to_string(), + ) + })?; + fb_to_schema(ipc_schema) + }; + let message = root_as_message(ipc_message.as_slice()).map_err(|e| { Error::General(format!( - "Error IPC message while deserializing ScalarValue::List: {e}" + "Error IPC message while deserializing nested ScalarValue: {e}" )) })?; let buffer = Buffer::from(arrow_data.as_slice()); let ipc_batch = message.header_as_record_batch().ok_or_else(|| { Error::General( - "Unexpected message type deserializing ScalarValue::List" + "Unexpected message type deserializing nested ScalarValue" .to_string(), ) })?; - let dict_by_id: HashMap = dictionaries.iter().map(|protobuf::scalar_nested_value::Dictionary { ipc_message, arrow_data }| { + let mut dict_by_id: HashMap = HashMap::new(); + for protobuf::scalar_nested_value::Dictionary { + ipc_message, + arrow_data, + } in dictionaries + { let message = root_as_message(ipc_message.as_slice()).map_err(|e| { Error::General(format!( - "Error IPC message while deserializing ScalarValue::List dictionary message: {e}" + "Error IPC message while deserializing nested ScalarValue dictionary message: {e}" )) })?; let buffer = Buffer::from(arrow_data.as_slice()); let dict_batch = message.header_as_dictionary_batch().ok_or_else(|| { Error::General( - "Unexpected message type deserializing ScalarValue::List dictionary message" + "Unexpected message type deserializing nested ScalarValue dictionary message" .to_string(), ) })?; - - let id = dict_batch.id(); - - let record_batch = read_record_batch( + read_dictionary( &buffer, - dict_batch.data().unwrap(), - Arc::new(schema.clone()), - &Default::default(), - None, + dict_batch, + &schema, + &mut dict_by_id, &message.version(), - )?; - - let values: ArrayRef = Arc::clone(record_batch.column(0)); - - Ok((id, values)) - }).collect::>>()?; + ) + .map_err(|e| arrow_datafusion_err!(e)) + .map_err(|e| e.context("Decoding nested ScalarValue dictionary"))?; + } let record_batch = read_record_batch( &buffer, @@ -457,7 +506,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { &message.version(), ) .map_err(|e| arrow_datafusion_err!(e)) - .map_err(|e| e.context("Decoding ScalarValue::List Value"))?; + .map_err(|e| e.context("Decoding nested ScalarValue value"))?; let arr = record_batch.column(0); match value { Value::ListValue(_) => { @@ -575,6 +624,32 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Self::Dictionary(Box::new(index_type), Box::new(value)) } + Value::RunEndEncodedValue(v) => { + let run_ends_field: Field = v + .run_ends_field + .as_ref() + .ok_or_else(|| Error::required("run_ends_field"))? + .try_into()?; + + let values_field: Field = v + .values_field + .as_ref() + .ok_or_else(|| Error::required("values_field"))? + .try_into()?; + + let value: Self = v + .value + .as_ref() + .ok_or_else(|| Error::required("value"))? + .as_ref() + .try_into()?; + + Self::RunEndEncoded( + run_ends_field.into(), + values_field.into(), + Box::new(value), + ) + } Value::BinaryValue(v) => Self::Binary(Some(v.clone())), Value::BinaryViewValue(v) => Self::BinaryView(Some(v.clone())), Value::LargeBinaryValue(v) => Self::LargeBinary(Some(v.clone())), @@ -602,7 +677,9 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { .collect::>>(); let fields = fields.ok_or_else(|| Error::required("UnionField"))?; let fields = parse_proto_fields_to_fields(&fields)?; - let fields = UnionFields::new(ids, fields); + let union_fields = UnionFields::try_new(ids, fields).map_err(|e| { + DataFusionError::from(e).context("Deserializing Union ScalarValue") + })?; let v_id = val.value_id as i8; let val = match &val.value { None => None, @@ -614,7 +691,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Some((v_id, Box::new(val))) } }; - Self::Union(val, fields, mode) + Self::Union(val, union_fields, mode) } Value::FixedSizeBinaryValue(v) => { Self::FixedSizeBinary(v.length, Some(v.clone().values)) @@ -1100,6 +1177,7 @@ impl TryFrom<&protobuf::JsonOptions> for JsonOptions { compression: compression.into(), compression_level: proto_opts.compression_level, schema_infer_max_rec: proto_opts.schema_infer_max_rec.map(|h| h as usize), + newline_delimited: proto_opts.newline_delimited.unwrap_or(true), }) } } diff --git a/datafusion/proto-common/src/generated/mod.rs b/datafusion/proto-common/src/generated/mod.rs index 08cd75b622db..9c2ca9385aa5 100644 --- a/datafusion/proto-common/src/generated/mod.rs +++ b/datafusion/proto-common/src/generated/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +// This code is generated so we don't want to fix any lint violations manually #[allow(clippy::allow_attributes)] #[allow(clippy::all)] #[rustfmt::skip] diff --git a/datafusion/proto-common/src/generated/pbjson.rs b/datafusion/proto-common/src/generated/pbjson.rs index d38cf86825d4..b00e7546bba2 100644 --- a/datafusion/proto-common/src/generated/pbjson.rs +++ b/datafusion/proto-common/src/generated/pbjson.rs @@ -29,7 +29,7 @@ impl<'de> serde::Deserialize<'de> for ArrowFormat { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -100,7 +100,7 @@ impl<'de> serde::Deserialize<'de> for ArrowOptions { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -276,6 +276,9 @@ impl serde::Serialize for ArrowType { arrow_type::ArrowTypeEnum::Map(v) => { struct_ser.serialize_field("MAP", v)?; } + arrow_type::ArrowTypeEnum::RunEndEncoded(v) => { + struct_ser.serialize_field("RUNENDENCODED", v)?; + } } } struct_ser.end() @@ -333,6 +336,8 @@ impl<'de> serde::Deserialize<'de> for ArrowType { "UNION", "DICTIONARY", "MAP", + "RUN_END_ENCODED", + "RUNENDENCODED", ]; #[allow(clippy::enum_variant_names)] @@ -375,6 +380,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType { Union, Dictionary, Map, + RunEndEncoded, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -383,7 +389,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -434,6 +440,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType { "UNION" => Ok(GeneratedField::Union), "DICTIONARY" => Ok(GeneratedField::Dictionary), "MAP" => Ok(GeneratedField::Map), + "RUNENDENCODED" | "RUN_END_ENCODED" => Ok(GeneratedField::RunEndEncoded), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -715,6 +722,13 @@ impl<'de> serde::Deserialize<'de> for ArrowType { return Err(serde::de::Error::duplicate_field("MAP")); } arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Map) +; + } + GeneratedField::RunEndEncoded => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("RUNENDENCODED")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::RunEndEncoded) ; } } @@ -758,7 +772,7 @@ impl<'de> serde::Deserialize<'de> for AvroFormat { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -829,7 +843,7 @@ impl<'de> serde::Deserialize<'de> for AvroOptions { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -916,7 +930,7 @@ impl<'de> serde::Deserialize<'de> for Column { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1016,7 +1030,7 @@ impl<'de> serde::Deserialize<'de> for ColumnRelation { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1153,7 +1167,7 @@ impl<'de> serde::Deserialize<'de> for ColumnStats { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1282,7 +1296,7 @@ impl<'de> serde::Deserialize<'de> for CompressionTypeVariant { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = CompressionTypeVariant; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1379,7 +1393,7 @@ impl<'de> serde::Deserialize<'de> for Constraint { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1479,7 +1493,7 @@ impl<'de> serde::Deserialize<'de> for Constraints { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1570,7 +1584,7 @@ impl<'de> serde::Deserialize<'de> for CsvFormat { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1840,7 +1854,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2204,7 +2218,7 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2407,7 +2421,7 @@ impl<'de> serde::Deserialize<'de> for Decimal128 { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2530,7 +2544,7 @@ impl<'de> serde::Deserialize<'de> for Decimal128Type { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2656,7 +2670,7 @@ impl<'de> serde::Deserialize<'de> for Decimal256 { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2779,7 +2793,7 @@ impl<'de> serde::Deserialize<'de> for Decimal256Type { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2905,7 +2919,7 @@ impl<'de> serde::Deserialize<'de> for Decimal32 { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3028,7 +3042,7 @@ impl<'de> serde::Deserialize<'de> for Decimal32Type { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3154,7 +3168,7 @@ impl<'de> serde::Deserialize<'de> for Decimal64 { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3277,7 +3291,7 @@ impl<'de> serde::Deserialize<'de> for Decimal64Type { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3389,7 +3403,7 @@ impl<'de> serde::Deserialize<'de> for DfField { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3497,7 +3511,7 @@ impl<'de> serde::Deserialize<'de> for DfSchema { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3607,7 +3621,7 @@ impl<'de> serde::Deserialize<'de> for Dictionary { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3699,7 +3713,7 @@ impl<'de> serde::Deserialize<'de> for EmptyMessage { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3811,7 +3825,7 @@ impl<'de> serde::Deserialize<'de> for Field { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3950,7 +3964,7 @@ impl<'de> serde::Deserialize<'de> for FixedSizeList { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4060,7 +4074,7 @@ impl<'de> serde::Deserialize<'de> for IntervalDayTimeValue { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4182,7 +4196,7 @@ impl<'de> serde::Deserialize<'de> for IntervalMonthDayNanoValue { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4286,7 +4300,7 @@ impl<'de> serde::Deserialize<'de> for IntervalUnit { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = IntervalUnit; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4358,7 +4372,7 @@ impl<'de> serde::Deserialize<'de> for JoinConstraint { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = JoinConstraint; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4431,7 +4445,7 @@ impl<'de> serde::Deserialize<'de> for JoinSide { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = JoinSide; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4519,7 +4533,7 @@ impl<'de> serde::Deserialize<'de> for JoinType { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = JoinType; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4589,6 +4603,9 @@ impl serde::Serialize for JsonOptions { if self.compression_level.is_some() { len += 1; } + if self.newline_delimited.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion_common.JsonOptions", len)?; if self.compression != 0 { let v = CompressionTypeVariant::try_from(self.compression) @@ -4603,6 +4620,9 @@ impl serde::Serialize for JsonOptions { if let Some(v) = self.compression_level.as_ref() { struct_ser.serialize_field("compressionLevel", v)?; } + if let Some(v) = self.newline_delimited.as_ref() { + struct_ser.serialize_field("newlineDelimited", v)?; + } struct_ser.end() } } @@ -4618,6 +4638,8 @@ impl<'de> serde::Deserialize<'de> for JsonOptions { "schemaInferMaxRec", "compression_level", "compressionLevel", + "newline_delimited", + "newlineDelimited", ]; #[allow(clippy::enum_variant_names)] @@ -4625,6 +4647,7 @@ impl<'de> serde::Deserialize<'de> for JsonOptions { Compression, SchemaInferMaxRec, CompressionLevel, + NewlineDelimited, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -4633,7 +4656,7 @@ impl<'de> serde::Deserialize<'de> for JsonOptions { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4649,6 +4672,7 @@ impl<'de> serde::Deserialize<'de> for JsonOptions { "compression" => Ok(GeneratedField::Compression), "schemaInferMaxRec" | "schema_infer_max_rec" => Ok(GeneratedField::SchemaInferMaxRec), "compressionLevel" | "compression_level" => Ok(GeneratedField::CompressionLevel), + "newlineDelimited" | "newline_delimited" => Ok(GeneratedField::NewlineDelimited), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -4671,6 +4695,7 @@ impl<'de> serde::Deserialize<'de> for JsonOptions { let mut compression__ = None; let mut schema_infer_max_rec__ = None; let mut compression_level__ = None; + let mut newline_delimited__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Compression => { @@ -4695,12 +4720,19 @@ impl<'de> serde::Deserialize<'de> for JsonOptions { map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| x.0) ; } + GeneratedField::NewlineDelimited => { + if newline_delimited__.is_some() { + return Err(serde::de::Error::duplicate_field("newlineDelimited")); + } + newline_delimited__ = map_.next_value()?; + } } } Ok(JsonOptions { compression: compression__.unwrap_or_default(), schema_infer_max_rec: schema_infer_max_rec__, compression_level: compression_level__, + newline_delimited: newline_delimited__, }) } } @@ -4748,7 +4780,7 @@ impl<'de> serde::Deserialize<'de> for JsonWriterOptions { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4840,7 +4872,7 @@ impl<'de> serde::Deserialize<'de> for List { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4941,7 +4973,7 @@ impl<'de> serde::Deserialize<'de> for Map { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5041,7 +5073,7 @@ impl<'de> serde::Deserialize<'de> for NdJsonFormat { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5119,7 +5151,7 @@ impl<'de> serde::Deserialize<'de> for NullEquality { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = NullEquality; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5286,7 +5318,7 @@ impl<'de> serde::Deserialize<'de> for ParquetColumnOptions { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5440,7 +5472,7 @@ impl<'de> serde::Deserialize<'de> for ParquetColumnSpecificOptions { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5540,7 +5572,7 @@ impl<'de> serde::Deserialize<'de> for ParquetFormat { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5976,7 +6008,7 @@ impl<'de> serde::Deserialize<'de> for ParquetOptions { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -6371,7 +6403,7 @@ impl<'de> serde::Deserialize<'de> for Precision { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -6460,7 +6492,7 @@ impl<'de> serde::Deserialize<'de> for PrecisionInfo { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = PrecisionInfo; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -6545,7 +6577,7 @@ impl<'de> serde::Deserialize<'de> for PrimaryKeyConstraint { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -6600,6 +6632,116 @@ impl<'de> serde::Deserialize<'de> for PrimaryKeyConstraint { deserializer.deserialize_struct("datafusion_common.PrimaryKeyConstraint", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for RunEndEncoded { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.run_ends_field.is_some() { + len += 1; + } + if self.values_field.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.RunEndEncoded", len)?; + if let Some(v) = self.run_ends_field.as_ref() { + struct_ser.serialize_field("runEndsField", v)?; + } + if let Some(v) = self.values_field.as_ref() { + struct_ser.serialize_field("valuesField", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for RunEndEncoded { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "run_ends_field", + "runEndsField", + "values_field", + "valuesField", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + RunEndsField, + ValuesField, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "runEndsField" | "run_ends_field" => Ok(GeneratedField::RunEndsField), + "valuesField" | "values_field" => Ok(GeneratedField::ValuesField), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = RunEndEncoded; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.RunEndEncoded") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut run_ends_field__ = None; + let mut values_field__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::RunEndsField => { + if run_ends_field__.is_some() { + return Err(serde::de::Error::duplicate_field("runEndsField")); + } + run_ends_field__ = map_.next_value()?; + } + GeneratedField::ValuesField => { + if values_field__.is_some() { + return Err(serde::de::Error::duplicate_field("valuesField")); + } + values_field__ = map_.next_value()?; + } + } + } + Ok(RunEndEncoded { + run_ends_field: run_ends_field__, + values_field: values_field__, + }) + } + } + deserializer.deserialize_struct("datafusion_common.RunEndEncoded", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ScalarDictionaryValue { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -6648,7 +6790,7 @@ impl<'de> serde::Deserialize<'de> for ScalarDictionaryValue { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -6758,7 +6900,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFixedSizeBinary { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -6892,7 +7034,7 @@ impl<'de> serde::Deserialize<'de> for ScalarNestedValue { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -7028,7 +7170,7 @@ impl<'de> serde::Deserialize<'de> for scalar_nested_value::Dictionary { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -7093,6 +7235,133 @@ impl<'de> serde::Deserialize<'de> for scalar_nested_value::Dictionary { deserializer.deserialize_struct("datafusion_common.ScalarNestedValue.Dictionary", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for ScalarRunEndEncodedValue { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.run_ends_field.is_some() { + len += 1; + } + if self.values_field.is_some() { + len += 1; + } + if self.value.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.ScalarRunEndEncodedValue", len)?; + if let Some(v) = self.run_ends_field.as_ref() { + struct_ser.serialize_field("runEndsField", v)?; + } + if let Some(v) = self.values_field.as_ref() { + struct_ser.serialize_field("valuesField", v)?; + } + if let Some(v) = self.value.as_ref() { + struct_ser.serialize_field("value", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ScalarRunEndEncodedValue { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "run_ends_field", + "runEndsField", + "values_field", + "valuesField", + "value", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + RunEndsField, + ValuesField, + Value, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "runEndsField" | "run_ends_field" => Ok(GeneratedField::RunEndsField), + "valuesField" | "values_field" => Ok(GeneratedField::ValuesField), + "value" => Ok(GeneratedField::Value), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ScalarRunEndEncodedValue; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.ScalarRunEndEncodedValue") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut run_ends_field__ = None; + let mut values_field__ = None; + let mut value__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::RunEndsField => { + if run_ends_field__.is_some() { + return Err(serde::de::Error::duplicate_field("runEndsField")); + } + run_ends_field__ = map_.next_value()?; + } + GeneratedField::ValuesField => { + if values_field__.is_some() { + return Err(serde::de::Error::duplicate_field("valuesField")); + } + values_field__ = map_.next_value()?; + } + GeneratedField::Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("value")); + } + value__ = map_.next_value()?; + } + } + } + Ok(ScalarRunEndEncodedValue { + run_ends_field: run_ends_field__, + values_field: values_field__, + value: value__, + }) + } + } + deserializer.deserialize_struct("datafusion_common.ScalarRunEndEncodedValue", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ScalarTime32Value { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -7143,7 +7412,7 @@ impl<'de> serde::Deserialize<'de> for ScalarTime32Value { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -7256,7 +7525,7 @@ impl<'de> serde::Deserialize<'de> for ScalarTime64Value { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -7393,7 +7662,7 @@ impl<'de> serde::Deserialize<'de> for ScalarTimestampValue { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -7635,6 +7904,9 @@ impl serde::Serialize for ScalarValue { scalar_value::Value::UnionValue(v) => { struct_ser.serialize_field("unionValue", v)?; } + scalar_value::Value::RunEndEncodedValue(v) => { + struct_ser.serialize_field("runEndEncodedValue", v)?; + } } } struct_ser.end() @@ -7731,6 +8003,8 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "fixedSizeBinaryValue", "union_value", "unionValue", + "run_end_encoded_value", + "runEndEncodedValue", ]; #[allow(clippy::enum_variant_names)] @@ -7777,6 +8051,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { IntervalMonthDayNano, FixedSizeBinaryValue, UnionValue, + RunEndEncodedValue, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7785,7 +8060,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -7840,6 +8115,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "intervalMonthDayNano" | "interval_month_day_nano" => Ok(GeneratedField::IntervalMonthDayNano), "fixedSizeBinaryValue" | "fixed_size_binary_value" => Ok(GeneratedField::FixedSizeBinaryValue), "unionValue" | "union_value" => Ok(GeneratedField::UnionValue), + "runEndEncodedValue" | "run_end_encoded_value" => Ok(GeneratedField::RunEndEncodedValue), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -8130,6 +8406,13 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { return Err(serde::de::Error::duplicate_field("unionValue")); } value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::UnionValue) +; + } + GeneratedField::RunEndEncodedValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("runEndEncodedValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::RunEndEncodedValue) ; } } @@ -8189,7 +8472,7 @@ impl<'de> serde::Deserialize<'de> for Schema { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8310,7 +8593,7 @@ impl<'de> serde::Deserialize<'de> for Statistics { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8420,7 +8703,7 @@ impl<'de> serde::Deserialize<'de> for Struct { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8529,7 +8812,7 @@ impl<'de> serde::Deserialize<'de> for TableParquetOptions { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8631,7 +8914,7 @@ impl<'de> serde::Deserialize<'de> for TimeUnit { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = TimeUnit; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8728,7 +9011,7 @@ impl<'de> serde::Deserialize<'de> for Timestamp { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8849,7 +9132,7 @@ impl<'de> serde::Deserialize<'de> for Union { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8970,7 +9253,7 @@ impl<'de> serde::Deserialize<'de> for UnionField { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -9059,7 +9342,7 @@ impl<'de> serde::Deserialize<'de> for UnionMode { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = UnionMode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -9170,7 +9453,7 @@ impl<'de> serde::Deserialize<'de> for UnionValue { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -9290,7 +9573,7 @@ impl<'de> serde::Deserialize<'de> for UniqueConstraint { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/datafusion/proto-common/src/generated/prost.rs b/datafusion/proto-common/src/generated/prost.rs index 16601dcf4697..a09826a29be5 100644 --- a/datafusion/proto-common/src/generated/prost.rs +++ b/datafusion/proto-common/src/generated/prost.rs @@ -176,6 +176,13 @@ pub struct Map { pub keys_sorted: bool, } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct RunEndEncoded { + #[prost(message, optional, boxed, tag = "1")] + pub run_ends_field: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "2")] + pub values_field: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct Union { #[prost(message, repeated, tag = "1")] pub union_types: ::prost::alloc::vec::Vec, @@ -264,6 +271,15 @@ pub struct ScalarDictionaryValue { #[prost(message, optional, boxed, tag = "2")] pub value: ::core::option::Option<::prost::alloc::boxed::Box>, } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ScalarRunEndEncodedValue { + #[prost(message, optional, tag = "1")] + pub run_ends_field: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub values_field: ::core::option::Option, + #[prost(message, optional, boxed, tag = "3")] + pub value: ::core::option::Option<::prost::alloc::boxed::Box>, +} #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct IntervalDayTimeValue { #[prost(int32, tag = "1")] @@ -311,7 +327,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 23, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 41, 43, 44, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 22, 30, 25, 31, 34, 42" + tags = "33, 1, 2, 3, 23, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 41, 43, 44, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 22, 30, 25, 31, 34, 42, 45" )] pub value: ::core::option::Option, } @@ -406,6 +422,8 @@ pub mod scalar_value { FixedSizeBinaryValue(super::ScalarFixedSizeBinary), #[prost(message, tag = "42")] UnionValue(::prost::alloc::boxed::Box), + #[prost(message, tag = "45")] + RunEndEncodedValue(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] @@ -449,7 +467,7 @@ pub struct Decimal256 { pub struct ArrowType { #[prost( oneof = "arrow_type::ArrowTypeEnum", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 40, 41, 24, 36, 25, 26, 27, 28, 29, 30, 33" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 40, 41, 24, 36, 25, 26, 27, 28, 29, 30, 33, 42" )] pub arrow_type_enum: ::core::option::Option, } @@ -538,6 +556,8 @@ pub mod arrow_type { Dictionary(::prost::alloc::boxed::Box), #[prost(message, tag = "33")] Map(::prost::alloc::boxed::Box), + #[prost(message, tag = "42")] + RunEndEncoded(::prost::alloc::boxed::Box), } } /// Useful for representing an empty enum variant in rust @@ -665,6 +685,9 @@ pub struct JsonOptions { /// Optional compression level #[prost(uint32, optional, tag = "3")] pub compression_level: ::core::option::Option, + /// Whether to read as newline-delimited JSON (default true). When false, expects JSON array format \[{},...\] + #[prost(bool, optional, tag = "4")] + pub newline_delimited: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct TableParquetOptions { diff --git a/datafusion/proto-common/src/lib.rs b/datafusion/proto-common/src/lib.rs index b7e1c906d90f..6f7fb7b89c0c 100644 --- a/datafusion/proto-common/src/lib.rs +++ b/datafusion/proto-common/src/lib.rs @@ -24,7 +24,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] -#![deny(clippy::allow_attributes)] //! Serialize / Deserialize DataFusion Primitive Types to bytes //! diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index fee365648200..79e3306a4df1 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -180,7 +180,9 @@ impl TryFrom<&DataType> for protobuf::arrow_type::ArrowTypeEnum { UnionMode::Dense => protobuf::UnionMode::Dense, }; Self::Union(protobuf::Union { - union_types: convert_arc_fields_to_proto_fields(fields.iter().map(|(_, item)|item))?, + union_types: convert_arc_fields_to_proto_fields( + fields.iter().map(|(_, item)| item), + )?, union_mode: union_mode.into(), type_ids: fields.iter().map(|(x, _)| x as i32).collect(), }) @@ -191,37 +193,44 @@ impl TryFrom<&DataType> for protobuf::arrow_type::ArrowTypeEnum { value: Some(Box::new(value_type.as_ref().try_into()?)), })) } - DataType::Decimal32(precision, scale) => Self::Decimal32(protobuf::Decimal32Type { - precision: *precision as u32, - scale: *scale as i32, - }), - DataType::Decimal64(precision, scale) => Self::Decimal64(protobuf::Decimal64Type { - precision: *precision as u32, - scale: *scale as i32, - }), - DataType::Decimal128(precision, scale) => Self::Decimal128(protobuf::Decimal128Type { - precision: *precision as u32, - scale: *scale as i32, - }), - DataType::Decimal256(precision, scale) => Self::Decimal256(protobuf::Decimal256Type { - precision: *precision as u32, - scale: *scale as i32, - }), - DataType::Map(field, sorted) => { - Self::Map(Box::new( - protobuf::Map { - field_type: Some(Box::new(field.as_ref().try_into()?)), - keys_sorted: *sorted, - } - )) - } - DataType::RunEndEncoded(_, _) => { - return Err(Error::General( - "Proto serialization error: The RunEndEncoded data type is not yet supported".to_owned() - )) + DataType::Decimal32(precision, scale) => { + Self::Decimal32(protobuf::Decimal32Type { + precision: *precision as u32, + scale: *scale as i32, + }) + } + DataType::Decimal64(precision, scale) => { + Self::Decimal64(protobuf::Decimal64Type { + precision: *precision as u32, + scale: *scale as i32, + }) + } + DataType::Decimal128(precision, scale) => { + Self::Decimal128(protobuf::Decimal128Type { + precision: *precision as u32, + scale: *scale as i32, + }) + } + DataType::Decimal256(precision, scale) => { + Self::Decimal256(protobuf::Decimal256Type { + precision: *precision as u32, + scale: *scale as i32, + }) + } + DataType::Map(field, sorted) => Self::Map(Box::new(protobuf::Map { + field_type: Some(Box::new(field.as_ref().try_into()?)), + keys_sorted: *sorted, + })), + DataType::RunEndEncoded(run_ends_field, values_field) => { + Self::RunEndEncoded(Box::new(protobuf::RunEndEncoded { + run_ends_field: Some(Box::new(run_ends_field.as_ref().try_into()?)), + values_field: Some(Box::new(values_field.as_ref().try_into()?)), + })) } DataType::ListView(_) | DataType::LargeListView(_) => { - return Err(Error::General(format!("Proto serialization error: {val} not yet supported"))) + return Err(Error::General(format!( + "Proto serialization error: {val} not yet supported" + ))); } }; @@ -680,6 +689,18 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { ))), }) } + + ScalarValue::RunEndEncoded(run_ends_field, values_field, val) => { + Ok(protobuf::ScalarValue { + value: Some(Value::RunEndEncodedValue(Box::new( + protobuf::ScalarRunEndEncodedValue { + run_ends_field: Some(run_ends_field.as_ref().try_into()?), + values_field: Some(values_field.as_ref().try_into()?), + value: Some(Box::new(val.as_ref().try_into()?)), + }, + ))), + }) + } } } } @@ -990,6 +1011,7 @@ impl TryFrom<&JsonOptions> for protobuf::JsonOptions { compression: compression.into(), schema_infer_max_rec: opts.schema_infer_max_rec.map(|h| h as u64), compression_level: opts.compression_level, + newline_delimited: Some(opts.newline_delimited), }) } } @@ -1010,7 +1032,7 @@ fn create_proto_scalar protobuf::scalar_value::Value>( Ok(protobuf::ScalarValue { value: Some(value) }) } -// ScalarValue::List / FixedSizeList / LargeList / Struct / Map are serialized using +// Nested ScalarValue types (List / FixedSizeList / LargeList / Struct / Map) are serialized using // Arrow IPC messages as a single column RecordBatch fn encode_scalar_nested_value( arr: ArrayRef, @@ -1018,13 +1040,20 @@ fn encode_scalar_nested_value( ) -> Result { let batch = RecordBatch::try_from_iter(vec![("field_name", arr)]).map_err(|e| { Error::General(format!( - "Error creating temporary batch while encoding ScalarValue::List: {e}" + "Error creating temporary batch while encoding nested ScalarValue: {e}" )) })?; let ipc_gen = IpcDataGenerator {}; let mut dict_tracker = DictionaryTracker::new(false); let write_options = IpcWriteOptions::default(); + // The IPC writer requires pre-allocated dictionary IDs (normally assigned when + // serializing the schema). Populate `dict_tracker` by encoding the schema first. + ipc_gen.schema_to_bytes_with_dictionary_tracker( + batch.schema().as_ref(), + &mut dict_tracker, + &write_options, + ); let mut compression_context = CompressionContext::default(); let (encoded_dictionaries, encoded_message) = ipc_gen .encode( @@ -1034,7 +1063,7 @@ fn encode_scalar_nested_value( &mut compression_context, ) .map_err(|e| { - Error::General(format!("Error encoding ScalarValue::List as IPC: {e}")) + Error::General(format!("Error encoding nested ScalarValue as IPC: {e}")) })?; let schema: protobuf::Schema = batch.schema().try_into()?; diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index b00bd0dcc6bf..3d17ed30d572 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -28,9 +28,6 @@ license = { workspace = true } authors = { workspace = true } rust-version = { workspace = true } -# Exclude proto files so crates.io consumers don't need protoc -exclude = ["*.proto"] - [package.metadata.docs.rs] all-features = true @@ -69,6 +66,7 @@ datafusion-proto-common = { workspace = true } object_store = { workspace = true } pbjson = { workspace = true, optional = true } prost = { workspace = true } +rand = { workspace = true } serde = { version = "1.0", optional = true } serde_json = { workspace = true, optional = true } diff --git a/datafusion/proto/gen/Cargo.toml b/datafusion/proto/gen/Cargo.toml index d446ab0d8974..8b48dfe70e6c 100644 --- a/datafusion/proto/gen/Cargo.toml +++ b/datafusion/proto/gen/Cargo.toml @@ -37,5 +37,5 @@ workspace = true [dependencies] # Pin these dependencies so that the generated output is deterministic -pbjson-build = "=0.8.0" -prost-build = "=0.14.1" +pbjson-build = "=0.9.0" +prost-build = "=0.14.3" diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index bd7dd3a6aff3..37b31a84deab 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -278,6 +278,7 @@ message DmlNode{ INSERT_APPEND = 3; INSERT_OVERWRITE = 4; INSERT_REPLACE = 5; + TRUNCATE = 6; } Type dml_type = 1; LogicalPlanNode input = 2; @@ -749,6 +750,8 @@ message PhysicalPlanNode { SortMergeJoinExecNode sort_merge_join = 34; MemoryScanExecNode memory_scan = 35; AsyncFuncExecNode async_func = 36; + BufferExecNode buffer = 37; + ArrowScanExecNode arrow_scan = 38; } } @@ -758,6 +761,16 @@ message PartitionColumn { } +// Determines how file sink output paths are interpreted. +enum FileOutputMode { + // Infer output mode from the URL (extension/trailing `/` heuristic). + FILE_OUTPUT_MODE_AUTOMATIC = 0; + // Write to a single file at the exact output path. + FILE_OUTPUT_MODE_SINGLE_FILE = 1; + // Write to a directory with generated filenames. + FILE_OUTPUT_MODE_DIRECTORY = 2; +} + message FileSinkConfig { reserved 6; // writer_mode reserved 8; // was `overwrite` which has been superseded by `insert_op` @@ -770,6 +783,8 @@ message FileSinkConfig { bool keep_partition_by_columns = 9; InsertOp insert_op = 10; string file_extension = 11; + // Determines how the output path is interpreted. + FileOutputMode file_output_mode = 12; } enum InsertOp { @@ -837,6 +852,14 @@ message PhysicalExprNode { // Was date_time_interval_expr reserved 17; + // Unique identifier for this expression to do deduplication during deserialization. + // When serializing, this is set to a unique identifier for each combination of + // expression, process and serialization run. + // When deserializing, if this ID has been seen before, the cached Arc is returned + // instead of creating a new one, enabling reconstruction of referential integrity + // across serde roundtrips. + optional uint64 expr_id = 30; + oneof ExprType { // column references PhysicalColumn column = 1; @@ -1006,6 +1029,8 @@ message FilterExecNode { PhysicalExprNode expr = 2; uint32 default_filter_selectivity = 3; repeated uint32 projection = 9; + uint32 batch_size = 10; + optional uint32 fetch = 11; } message FileGroup { @@ -1083,6 +1108,10 @@ message AvroScanExecNode { FileScanExecConf base_conf = 1; } +message ArrowScanExecNode { + FileScanExecConf base_conf = 1; +} + message MemoryScanExecNode { repeated bytes partitions = 1; datafusion_common.Schema schema = 2; @@ -1111,6 +1140,7 @@ message HashJoinExecNode { datafusion_common.NullEquality null_equality = 7; JoinFilter filter = 8; repeated uint32 projection = 9; + bool null_aware = 10; } enum StreamPartitionMode { @@ -1190,6 +1220,7 @@ enum AggregateMode { FINAL_PARTITIONED = 2; SINGLE = 3; SINGLE_PARTITIONED = 4; + PARTIAL_REDUCE = 5; } message PartiallySortedInputOrderMode { @@ -1219,6 +1250,8 @@ message MaybePhysicalSortExprs { message AggLimit { // wrap into a message to make it optional uint64 limit = 1; + // Optional ordering direction for TopK aggregation (true = descending, false = ascending) + optional bool descending = 2; } message AggregateExecNode { @@ -1412,3 +1445,8 @@ message AsyncFuncExecNode { repeated PhysicalExprNode async_exprs = 2; repeated string async_expr_names = 3; } + +message BufferExecNode { + PhysicalPlanNode input = 1; + uint64 capacity = 2; +} \ No newline at end of file diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index d95bdd388699..84b15ea9a892 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -21,7 +21,8 @@ use crate::logical_plan::{ self, AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec, }; use crate::physical_plan::{ - AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, + DefaultPhysicalExtensionCodec, DefaultPhysicalProtoConverter, PhysicalExtensionCodec, + PhysicalProtoConverterExtension, }; use crate::protobuf; use datafusion_common::{Result, plan_datafusion_err}; @@ -276,16 +277,18 @@ pub fn logical_plan_from_json_with_extension_codec( /// Serialize a PhysicalPlan as bytes pub fn physical_plan_to_bytes(plan: Arc) -> Result { let extension_codec = DefaultPhysicalExtensionCodec {}; - physical_plan_to_bytes_with_extension_codec(plan, &extension_codec) + let proto_converter = DefaultPhysicalProtoConverter {}; + physical_plan_to_bytes_with_proto_converter(plan, &extension_codec, &proto_converter) } /// Serialize a PhysicalPlan as JSON #[cfg(feature = "json")] pub fn physical_plan_to_json(plan: Arc) -> Result { let extension_codec = DefaultPhysicalExtensionCodec {}; - let protobuf = - protobuf::PhysicalPlanNode::try_from_physical_plan(plan, &extension_codec) - .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}"))?; + let proto_converter = DefaultPhysicalProtoConverter {}; + let protobuf = proto_converter + .execution_plan_to_proto(&plan, &extension_codec) + .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}"))?; serde_json::to_string(&protobuf) .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}")) } @@ -295,8 +298,18 @@ pub fn physical_plan_to_bytes_with_extension_codec( plan: Arc, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result { - let protobuf = - protobuf::PhysicalPlanNode::try_from_physical_plan(plan, extension_codec)?; + let proto_converter = DefaultPhysicalProtoConverter {}; + physical_plan_to_bytes_with_proto_converter(plan, extension_codec, &proto_converter) +} + +/// Serialize a PhysicalPlan as bytes, using the provided extension codec +/// and protobuf converter. +pub fn physical_plan_to_bytes_with_proto_converter( + plan: Arc, + extension_codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, +) -> Result { + let protobuf = proto_converter.execution_plan_to_proto(&plan, extension_codec)?; let mut buffer = BytesMut::new(); protobuf .encode(&mut buffer) @@ -313,7 +326,8 @@ pub fn physical_plan_from_json( let back: protobuf::PhysicalPlanNode = serde_json::from_str(json) .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}"))?; let extension_codec = DefaultPhysicalExtensionCodec {}; - back.try_into_physical_plan(ctx, &extension_codec) + let proto_converter = DefaultPhysicalProtoConverter {}; + proto_converter.proto_to_execution_plan(ctx, &extension_codec, &back) } /// Deserialize a PhysicalPlan from bytes @@ -322,7 +336,13 @@ pub fn physical_plan_from_bytes( ctx: &TaskContext, ) -> Result> { let extension_codec = DefaultPhysicalExtensionCodec {}; - physical_plan_from_bytes_with_extension_codec(bytes, ctx, &extension_codec) + let proto_converter = DefaultPhysicalProtoConverter {}; + physical_plan_from_bytes_with_proto_converter( + bytes, + ctx, + &extension_codec, + &proto_converter, + ) } /// Deserialize a PhysicalPlan from bytes @@ -330,8 +350,24 @@ pub fn physical_plan_from_bytes_with_extension_codec( bytes: &[u8], ctx: &TaskContext, extension_codec: &dyn PhysicalExtensionCodec, +) -> Result> { + let proto_converter = DefaultPhysicalProtoConverter {}; + physical_plan_from_bytes_with_proto_converter( + bytes, + ctx, + extension_codec, + &proto_converter, + ) +} + +/// Deserialize a PhysicalPlan from bytes +pub fn physical_plan_from_bytes_with_proto_converter( + bytes: &[u8], + ctx: &TaskContext, + extension_codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let protobuf = protobuf::PhysicalPlanNode::decode(bytes) .map_err(|e| plan_datafusion_err!("Error decoding expr as protobuf: {e}"))?; - protobuf.try_into_physical_plan(ctx, extension_codec) + proto_converter.proto_to_execution_plan(ctx, extension_codec, &protobuf) } diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs b/datafusion/proto/src/generated/datafusion_proto_common.rs index 16601dcf4697..a09826a29be5 100644 --- a/datafusion/proto/src/generated/datafusion_proto_common.rs +++ b/datafusion/proto/src/generated/datafusion_proto_common.rs @@ -176,6 +176,13 @@ pub struct Map { pub keys_sorted: bool, } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct RunEndEncoded { + #[prost(message, optional, boxed, tag = "1")] + pub run_ends_field: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "2")] + pub values_field: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct Union { #[prost(message, repeated, tag = "1")] pub union_types: ::prost::alloc::vec::Vec, @@ -264,6 +271,15 @@ pub struct ScalarDictionaryValue { #[prost(message, optional, boxed, tag = "2")] pub value: ::core::option::Option<::prost::alloc::boxed::Box>, } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ScalarRunEndEncodedValue { + #[prost(message, optional, tag = "1")] + pub run_ends_field: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub values_field: ::core::option::Option, + #[prost(message, optional, boxed, tag = "3")] + pub value: ::core::option::Option<::prost::alloc::boxed::Box>, +} #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct IntervalDayTimeValue { #[prost(int32, tag = "1")] @@ -311,7 +327,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 23, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 41, 43, 44, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 22, 30, 25, 31, 34, 42" + tags = "33, 1, 2, 3, 23, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 41, 43, 44, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 22, 30, 25, 31, 34, 42, 45" )] pub value: ::core::option::Option, } @@ -406,6 +422,8 @@ pub mod scalar_value { FixedSizeBinaryValue(super::ScalarFixedSizeBinary), #[prost(message, tag = "42")] UnionValue(::prost::alloc::boxed::Box), + #[prost(message, tag = "45")] + RunEndEncodedValue(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] @@ -449,7 +467,7 @@ pub struct Decimal256 { pub struct ArrowType { #[prost( oneof = "arrow_type::ArrowTypeEnum", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 40, 41, 24, 36, 25, 26, 27, 28, 29, 30, 33" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 40, 41, 24, 36, 25, 26, 27, 28, 29, 30, 33, 42" )] pub arrow_type_enum: ::core::option::Option, } @@ -538,6 +556,8 @@ pub mod arrow_type { Dictionary(::prost::alloc::boxed::Box), #[prost(message, tag = "33")] Map(::prost::alloc::boxed::Box), + #[prost(message, tag = "42")] + RunEndEncoded(::prost::alloc::boxed::Box), } } /// Useful for representing an empty enum variant in rust @@ -665,6 +685,9 @@ pub struct JsonOptions { /// Optional compression level #[prost(uint32, optional, tag = "3")] pub compression_level: ::core::option::Option, + /// Whether to read as newline-delimited JSON (default true). When false, expects JSON array format \[{},...\] + #[prost(bool, optional, tag = "4")] + pub newline_delimited: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct TableParquetOptions { diff --git a/datafusion/proto/src/generated/mod.rs b/datafusion/proto/src/generated/mod.rs index adf5125457c1..ca32b1500d57 100644 --- a/datafusion/proto/src/generated/mod.rs +++ b/datafusion/proto/src/generated/mod.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. -#![allow(clippy::allow_attributes)] - +// This code is generated so we don't want to fix any lint violations manually +#[allow(clippy::allow_attributes)] #[allow(clippy::all)] #[rustfmt::skip] pub mod datafusion { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index e269606d163a..419105c40c79 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -9,12 +9,18 @@ impl serde::Serialize for AggLimit { if self.limit != 0 { len += 1; } + if self.descending.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.AggLimit", len)?; if self.limit != 0 { #[allow(clippy::needless_borrow)] #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("limit", ToString::to_string(&self.limit).as_str())?; } + if let Some(v) = self.descending.as_ref() { + struct_ser.serialize_field("descending", v)?; + } struct_ser.end() } } @@ -26,11 +32,13 @@ impl<'de> serde::Deserialize<'de> for AggLimit { { const FIELDS: &[&str] = &[ "limit", + "descending", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Limit, + Descending, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -39,7 +47,7 @@ impl<'de> serde::Deserialize<'de> for AggLimit { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -53,6 +61,7 @@ impl<'de> serde::Deserialize<'de> for AggLimit { { match value { "limit" => Ok(GeneratedField::Limit), + "descending" => Ok(GeneratedField::Descending), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -73,6 +82,7 @@ impl<'de> serde::Deserialize<'de> for AggLimit { V: serde::de::MapAccess<'de>, { let mut limit__ = None; + let mut descending__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Limit => { @@ -83,10 +93,17 @@ impl<'de> serde::Deserialize<'de> for AggLimit { Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } + GeneratedField::Descending => { + if descending__.is_some() { + return Err(serde::de::Error::duplicate_field("descending")); + } + descending__ = map_.next_value()?; + } } } Ok(AggLimit { limit: limit__.unwrap_or_default(), + descending: descending__, }) } } @@ -230,7 +247,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -393,6 +410,7 @@ impl serde::Serialize for AggregateMode { Self::FinalPartitioned => "FINAL_PARTITIONED", Self::Single => "SINGLE", Self::SinglePartitioned => "SINGLE_PARTITIONED", + Self::PartialReduce => "PARTIAL_REDUCE", }; serializer.serialize_str(variant) } @@ -409,11 +427,12 @@ impl<'de> serde::Deserialize<'de> for AggregateMode { "FINAL_PARTITIONED", "SINGLE", "SINGLE_PARTITIONED", + "PARTIAL_REDUCE", ]; struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = AggregateMode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -454,6 +473,7 @@ impl<'de> serde::Deserialize<'de> for AggregateMode { "FINAL_PARTITIONED" => Ok(AggregateMode::FinalPartitioned), "SINGLE" => Ok(AggregateMode::Single), "SINGLE_PARTITIONED" => Ok(AggregateMode::SinglePartitioned), + "PARTIAL_REDUCE" => Ok(AggregateMode::PartialReduce), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } @@ -518,7 +538,7 @@ impl<'de> serde::Deserialize<'de> for AggregateNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -683,7 +703,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -854,7 +874,7 @@ impl<'de> serde::Deserialize<'de> for AliasNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -999,7 +1019,7 @@ impl<'de> serde::Deserialize<'de> for AnalyzeExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1125,7 +1145,7 @@ impl<'de> serde::Deserialize<'de> for AnalyzeNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1226,7 +1246,7 @@ impl<'de> serde::Deserialize<'de> for AnalyzedLogicalPlanType { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1278,6 +1298,98 @@ impl<'de> serde::Deserialize<'de> for AnalyzedLogicalPlanType { deserializer.deserialize_struct("datafusion.AnalyzedLogicalPlanType", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for ArrowScanExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.base_conf.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ArrowScanExecNode", len)?; + if let Some(v) = self.base_conf.as_ref() { + struct_ser.serialize_field("baseConf", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ArrowScanExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "base_conf", + "baseConf", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + BaseConf, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "baseConf" | "base_conf" => Ok(GeneratedField::BaseConf), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ArrowScanExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ArrowScanExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut base_conf__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::BaseConf => { + if base_conf__.is_some() { + return Err(serde::de::Error::duplicate_field("baseConf")); + } + base_conf__ = map_.next_value()?; + } + } + } + Ok(ArrowScanExecNode { + base_conf: base_conf__, + }) + } + } + deserializer.deserialize_struct("datafusion.ArrowScanExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for AsyncFuncExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -1335,7 +1447,7 @@ impl<'de> serde::Deserialize<'de> for AsyncFuncExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1445,7 +1557,7 @@ impl<'de> serde::Deserialize<'de> for AvroScanExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1536,7 +1648,7 @@ impl<'de> serde::Deserialize<'de> for BareTableReference { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1651,7 +1763,7 @@ impl<'de> serde::Deserialize<'de> for BetweenNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1777,7 +1889,7 @@ impl<'de> serde::Deserialize<'de> for BinaryExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1838,6 +1950,118 @@ impl<'de> serde::Deserialize<'de> for BinaryExprNode { deserializer.deserialize_struct("datafusion.BinaryExprNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for BufferExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if self.capacity != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.BufferExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if self.capacity != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("capacity", ToString::to_string(&self.capacity).as_str())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for BufferExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "capacity", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + Capacity, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "capacity" => Ok(GeneratedField::Capacity), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = BufferExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.BufferExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut capacity__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::Capacity => { + if capacity__.is_some() { + return Err(serde::de::Error::duplicate_field("capacity")); + } + capacity__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(BufferExecNode { + input: input__, + capacity: capacity__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.BufferExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for CaseNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -1895,7 +2119,7 @@ impl<'de> serde::Deserialize<'de> for CaseNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2013,7 +2237,7 @@ impl<'de> serde::Deserialize<'de> for CastNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2130,7 +2354,7 @@ impl<'de> serde::Deserialize<'de> for CoalesceBatchesExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2251,7 +2475,7 @@ impl<'de> serde::Deserialize<'de> for CoalescePartitionsExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2363,7 +2587,7 @@ impl<'de> serde::Deserialize<'de> for ColumnIndex { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2474,7 +2698,7 @@ impl<'de> serde::Deserialize<'de> for ColumnUnnestListItem { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2585,7 +2809,7 @@ impl<'de> serde::Deserialize<'de> for ColumnUnnestListRecursion { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2687,7 +2911,7 @@ impl<'de> serde::Deserialize<'de> for ColumnUnnestListRecursions { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2778,7 +3002,7 @@ impl<'de> serde::Deserialize<'de> for CooperativeExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2898,7 +3122,7 @@ impl<'de> serde::Deserialize<'de> for CopyToNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3036,7 +3260,7 @@ impl<'de> serde::Deserialize<'de> for CreateCatalogNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3163,7 +3387,7 @@ impl<'de> serde::Deserialize<'de> for CreateCatalogSchemaNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3382,7 +3606,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3627,7 +3851,7 @@ impl<'de> serde::Deserialize<'de> for CreateViewNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3762,7 +3986,7 @@ impl<'de> serde::Deserialize<'de> for CrossJoinExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3870,7 +4094,7 @@ impl<'de> serde::Deserialize<'de> for CrossJoinNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4038,7 +4262,7 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4201,7 +4425,7 @@ impl<'de> serde::Deserialize<'de> for CsvSink { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4327,7 +4551,7 @@ impl<'de> serde::Deserialize<'de> for CsvSinkExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4453,7 +4677,7 @@ impl<'de> serde::Deserialize<'de> for CteWorkTableScanNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4553,7 +4777,7 @@ impl<'de> serde::Deserialize<'de> for CubeNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4680,7 +4904,7 @@ impl<'de> serde::Deserialize<'de> for CustomTableScanNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4796,7 +5020,7 @@ impl<'de> serde::Deserialize<'de> for DateUnit { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = DateUnit; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4880,7 +5104,7 @@ impl<'de> serde::Deserialize<'de> for DistinctNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4998,7 +5222,7 @@ impl<'de> serde::Deserialize<'de> for DistinctOnNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5144,7 +5368,7 @@ impl<'de> serde::Deserialize<'de> for DmlNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5236,6 +5460,7 @@ impl serde::Serialize for dml_node::Type { Self::InsertAppend => "INSERT_APPEND", Self::InsertOverwrite => "INSERT_OVERWRITE", Self::InsertReplace => "INSERT_REPLACE", + Self::Truncate => "TRUNCATE", }; serializer.serialize_str(variant) } @@ -5253,11 +5478,12 @@ impl<'de> serde::Deserialize<'de> for dml_node::Type { "INSERT_APPEND", "INSERT_OVERWRITE", "INSERT_REPLACE", + "TRUNCATE", ]; struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = dml_node::Type; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5299,6 +5525,7 @@ impl<'de> serde::Deserialize<'de> for dml_node::Type { "INSERT_APPEND" => Ok(dml_node::Type::InsertAppend), "INSERT_OVERWRITE" => Ok(dml_node::Type::InsertOverwrite), "INSERT_REPLACE" => Ok(dml_node::Type::InsertReplace), + "TRUNCATE" => Ok(dml_node::Type::Truncate), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } @@ -5362,7 +5589,7 @@ impl<'de> serde::Deserialize<'de> for DropViewNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5471,7 +5698,7 @@ impl<'de> serde::Deserialize<'de> for EmptyExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5563,7 +5790,7 @@ impl<'de> serde::Deserialize<'de> for EmptyRelationNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5671,7 +5898,7 @@ impl<'de> serde::Deserialize<'de> for ExplainExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5788,7 +6015,7 @@ impl<'de> serde::Deserialize<'de> for ExplainNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5888,7 +6115,7 @@ impl<'de> serde::Deserialize<'de> for FileGroup { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5940,6 +6167,80 @@ impl<'de> serde::Deserialize<'de> for FileGroup { deserializer.deserialize_struct("datafusion.FileGroup", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for FileOutputMode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::Automatic => "FILE_OUTPUT_MODE_AUTOMATIC", + Self::SingleFile => "FILE_OUTPUT_MODE_SINGLE_FILE", + Self::Directory => "FILE_OUTPUT_MODE_DIRECTORY", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for FileOutputMode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "FILE_OUTPUT_MODE_AUTOMATIC", + "FILE_OUTPUT_MODE_SINGLE_FILE", + "FILE_OUTPUT_MODE_DIRECTORY", + ]; + + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = FileOutputMode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "FILE_OUTPUT_MODE_AUTOMATIC" => Ok(FileOutputMode::Automatic), + "FILE_OUTPUT_MODE_SINGLE_FILE" => Ok(FileOutputMode::SingleFile), + "FILE_OUTPUT_MODE_DIRECTORY" => Ok(FileOutputMode::Directory), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} impl serde::Serialize for FileRange { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -5991,7 +6292,7 @@ impl<'de> serde::Deserialize<'de> for FileRange { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -6183,7 +6484,7 @@ impl<'de> serde::Deserialize<'de> for FileScanExecConf { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -6362,6 +6663,9 @@ impl serde::Serialize for FileSinkConfig { if !self.file_extension.is_empty() { len += 1; } + if self.file_output_mode != 0 { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.FileSinkConfig", len)?; if !self.object_store_url.is_empty() { struct_ser.serialize_field("objectStoreUrl", &self.object_store_url)?; @@ -6389,6 +6693,11 @@ impl serde::Serialize for FileSinkConfig { if !self.file_extension.is_empty() { struct_ser.serialize_field("fileExtension", &self.file_extension)?; } + if self.file_output_mode != 0 { + let v = FileOutputMode::try_from(self.file_output_mode) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.file_output_mode)))?; + struct_ser.serialize_field("fileOutputMode", &v)?; + } struct_ser.end() } } @@ -6415,6 +6724,8 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { "insertOp", "file_extension", "fileExtension", + "file_output_mode", + "fileOutputMode", ]; #[allow(clippy::enum_variant_names)] @@ -6427,6 +6738,7 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { KeepPartitionByColumns, InsertOp, FileExtension, + FileOutputMode, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -6435,7 +6747,7 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -6456,6 +6768,7 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { "keepPartitionByColumns" | "keep_partition_by_columns" => Ok(GeneratedField::KeepPartitionByColumns), "insertOp" | "insert_op" => Ok(GeneratedField::InsertOp), "fileExtension" | "file_extension" => Ok(GeneratedField::FileExtension), + "fileOutputMode" | "file_output_mode" => Ok(GeneratedField::FileOutputMode), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -6483,6 +6796,7 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { let mut keep_partition_by_columns__ = None; let mut insert_op__ = None; let mut file_extension__ = None; + let mut file_output_mode__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::ObjectStoreUrl => { @@ -6533,6 +6847,12 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { } file_extension__ = Some(map_.next_value()?); } + GeneratedField::FileOutputMode => { + if file_output_mode__.is_some() { + return Err(serde::de::Error::duplicate_field("fileOutputMode")); + } + file_output_mode__ = Some(map_.next_value::()? as i32); + } } } Ok(FileSinkConfig { @@ -6544,6 +6864,7 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { keep_partition_by_columns: keep_partition_by_columns__.unwrap_or_default(), insert_op: insert_op__.unwrap_or_default(), file_extension: file_extension__.unwrap_or_default(), + file_output_mode: file_output_mode__.unwrap_or_default(), }) } } @@ -6570,6 +6891,12 @@ impl serde::Serialize for FilterExecNode { if !self.projection.is_empty() { len += 1; } + if self.batch_size != 0 { + len += 1; + } + if self.fetch.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.FilterExecNode", len)?; if let Some(v) = self.input.as_ref() { struct_ser.serialize_field("input", v)?; @@ -6583,6 +6910,12 @@ impl serde::Serialize for FilterExecNode { if !self.projection.is_empty() { struct_ser.serialize_field("projection", &self.projection)?; } + if self.batch_size != 0 { + struct_ser.serialize_field("batchSize", &self.batch_size)?; + } + if let Some(v) = self.fetch.as_ref() { + struct_ser.serialize_field("fetch", v)?; + } struct_ser.end() } } @@ -6598,6 +6931,9 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { "default_filter_selectivity", "defaultFilterSelectivity", "projection", + "batch_size", + "batchSize", + "fetch", ]; #[allow(clippy::enum_variant_names)] @@ -6606,6 +6942,8 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { Expr, DefaultFilterSelectivity, Projection, + BatchSize, + Fetch, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -6614,7 +6952,7 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -6631,6 +6969,8 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { "expr" => Ok(GeneratedField::Expr), "defaultFilterSelectivity" | "default_filter_selectivity" => Ok(GeneratedField::DefaultFilterSelectivity), "projection" => Ok(GeneratedField::Projection), + "batchSize" | "batch_size" => Ok(GeneratedField::BatchSize), + "fetch" => Ok(GeneratedField::Fetch), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -6654,6 +6994,8 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { let mut expr__ = None; let mut default_filter_selectivity__ = None; let mut projection__ = None; + let mut batch_size__ = None; + let mut fetch__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { @@ -6685,6 +7027,22 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { .into_iter().map(|x| x.0).collect()) ; } + GeneratedField::BatchSize => { + if batch_size__.is_some() { + return Err(serde::de::Error::duplicate_field("batchSize")); + } + batch_size__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Fetch => { + if fetch__.is_some() { + return Err(serde::de::Error::duplicate_field("fetch")); + } + fetch__ = + map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| x.0) + ; + } } } Ok(FilterExecNode { @@ -6692,6 +7050,8 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { expr: expr__, default_filter_selectivity: default_filter_selectivity__.unwrap_or_default(), projection: projection__.unwrap_or_default(), + batch_size: batch_size__.unwrap_or_default(), + fetch: fetch__, }) } } @@ -6737,7 +7097,7 @@ impl<'de> serde::Deserialize<'de> for FixedSizeBinary { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -6846,7 +7206,7 @@ impl<'de> serde::Deserialize<'de> for FullTableReference { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -6957,7 +7317,7 @@ impl<'de> serde::Deserialize<'de> for GenerateSeriesArgsContainsNull { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -7087,7 +7447,7 @@ impl<'de> serde::Deserialize<'de> for GenerateSeriesArgsDate { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -7259,7 +7619,7 @@ impl<'de> serde::Deserialize<'de> for GenerateSeriesArgsInt64 { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -7439,7 +7799,7 @@ impl<'de> serde::Deserialize<'de> for GenerateSeriesArgsTimestamp { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -7566,7 +7926,7 @@ impl<'de> serde::Deserialize<'de> for GenerateSeriesName { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GenerateSeriesName; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -7690,7 +8050,7 @@ impl<'de> serde::Deserialize<'de> for GenerateSeriesNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -7844,7 +8204,7 @@ impl<'de> serde::Deserialize<'de> for GlobalLimitExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -7957,7 +8317,7 @@ impl<'de> serde::Deserialize<'de> for GroupingSetNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8041,6 +8401,9 @@ impl serde::Serialize for HashJoinExecNode { if !self.projection.is_empty() { len += 1; } + if self.null_aware { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.HashJoinExecNode", len)?; if let Some(v) = self.left.as_ref() { struct_ser.serialize_field("left", v)?; @@ -8072,6 +8435,9 @@ impl serde::Serialize for HashJoinExecNode { if !self.projection.is_empty() { struct_ser.serialize_field("projection", &self.projection)?; } + if self.null_aware { + struct_ser.serialize_field("nullAware", &self.null_aware)?; + } struct_ser.end() } } @@ -8093,6 +8459,8 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { "nullEquality", "filter", "projection", + "null_aware", + "nullAware", ]; #[allow(clippy::enum_variant_names)] @@ -8105,6 +8473,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { NullEquality, Filter, Projection, + NullAware, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -8113,7 +8482,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8134,6 +8503,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { "nullEquality" | "null_equality" => Ok(GeneratedField::NullEquality), "filter" => Ok(GeneratedField::Filter), "projection" => Ok(GeneratedField::Projection), + "nullAware" | "null_aware" => Ok(GeneratedField::NullAware), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -8161,6 +8531,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { let mut null_equality__ = None; let mut filter__ = None; let mut projection__ = None; + let mut null_aware__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Left => { @@ -8214,6 +8585,12 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { .into_iter().map(|x| x.0).collect()) ; } + GeneratedField::NullAware => { + if null_aware__.is_some() { + return Err(serde::de::Error::duplicate_field("nullAware")); + } + null_aware__ = Some(map_.next_value()?); + } } } Ok(HashJoinExecNode { @@ -8225,6 +8602,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { null_equality: null_equality__.unwrap_or_default(), filter: filter__, projection: projection__.unwrap_or_default(), + null_aware: null_aware__.unwrap_or_default(), }) } } @@ -8282,7 +8660,7 @@ impl<'de> serde::Deserialize<'de> for HashRepartition { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8409,7 +8787,7 @@ impl<'de> serde::Deserialize<'de> for ILikeNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8543,7 +8921,7 @@ impl<'de> serde::Deserialize<'de> for InListNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8641,7 +9019,7 @@ impl<'de> serde::Deserialize<'de> for InsertOp { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = InsertOp; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8726,7 +9104,7 @@ impl<'de> serde::Deserialize<'de> for InterleaveExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8817,7 +9195,7 @@ impl<'de> serde::Deserialize<'de> for IsFalse { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8908,7 +9286,7 @@ impl<'de> serde::Deserialize<'de> for IsNotFalse { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8999,7 +9377,7 @@ impl<'de> serde::Deserialize<'de> for IsNotNull { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -9090,7 +9468,7 @@ impl<'de> serde::Deserialize<'de> for IsNotTrue { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -9181,7 +9559,7 @@ impl<'de> serde::Deserialize<'de> for IsNotUnknown { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -9272,7 +9650,7 @@ impl<'de> serde::Deserialize<'de> for IsNull { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -9363,7 +9741,7 @@ impl<'de> serde::Deserialize<'de> for IsTrue { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -9454,7 +9832,7 @@ impl<'de> serde::Deserialize<'de> for IsUnknown { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -9562,7 +9940,7 @@ impl<'de> serde::Deserialize<'de> for JoinFilter { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -9738,7 +10116,7 @@ impl<'de> serde::Deserialize<'de> for JoinNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -9900,7 +10278,7 @@ impl<'de> serde::Deserialize<'de> for JoinOn { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -10001,7 +10379,7 @@ impl<'de> serde::Deserialize<'de> for JsonScanExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -10101,7 +10479,7 @@ impl<'de> serde::Deserialize<'de> for JsonSink { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -10227,7 +10605,7 @@ impl<'de> serde::Deserialize<'de> for JsonSinkExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -10370,7 +10748,7 @@ impl<'de> serde::Deserialize<'de> for LikeNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -10508,7 +10886,7 @@ impl<'de> serde::Deserialize<'de> for LimitNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -10621,7 +10999,7 @@ impl<'de> serde::Deserialize<'de> for ListIndex { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -10728,7 +11106,7 @@ impl<'de> serde::Deserialize<'de> for ListRange { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -10846,7 +11224,7 @@ impl<'de> serde::Deserialize<'de> for ListUnnest { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -11060,7 +11438,7 @@ impl<'de> serde::Deserialize<'de> for ListingTableScanNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -11284,7 +11662,7 @@ impl<'de> serde::Deserialize<'de> for LocalLimitExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -11386,7 +11764,7 @@ impl<'de> serde::Deserialize<'de> for LogicalExprList { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -11649,7 +12027,7 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -11982,7 +12360,7 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNodeCollection { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -12083,7 +12461,7 @@ impl<'de> serde::Deserialize<'de> for LogicalExtensionNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -12359,7 +12737,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -12699,7 +13077,7 @@ impl<'de> serde::Deserialize<'de> for MaybeFilter { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -12791,7 +13169,7 @@ impl<'de> serde::Deserialize<'de> for MaybePhysicalSortExprs { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -12924,7 +13302,7 @@ impl<'de> serde::Deserialize<'de> for MemoryScanExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -13068,7 +13446,7 @@ impl<'de> serde::Deserialize<'de> for NamedStructField { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -13159,7 +13537,7 @@ impl<'de> serde::Deserialize<'de> for NegativeNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -13285,7 +13663,7 @@ impl<'de> serde::Deserialize<'de> for NestedLoopJoinExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -13415,7 +13793,7 @@ impl<'de> serde::Deserialize<'de> for Not { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -13493,7 +13871,7 @@ impl<'de> serde::Deserialize<'de> for NullTreatment { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = NullTreatment; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -13578,7 +13956,7 @@ impl<'de> serde::Deserialize<'de> for OptimizedLogicalPlanType { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -13670,7 +14048,7 @@ impl<'de> serde::Deserialize<'de> for OptimizedPhysicalPlanType { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -13779,7 +14157,7 @@ impl<'de> serde::Deserialize<'de> for ParquetScanExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -13897,7 +14275,7 @@ impl<'de> serde::Deserialize<'de> for ParquetSink { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -14023,7 +14401,7 @@ impl<'de> serde::Deserialize<'de> for ParquetSinkExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -14149,7 +14527,7 @@ impl<'de> serde::Deserialize<'de> for PartialTableReference { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -14249,7 +14627,7 @@ impl<'de> serde::Deserialize<'de> for PartiallySortedInputOrderMode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -14352,7 +14730,7 @@ impl<'de> serde::Deserialize<'de> for PartitionColumn { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -14441,7 +14819,7 @@ impl<'de> serde::Deserialize<'de> for PartitionMode { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = PartitionMode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -14560,7 +14938,7 @@ impl<'de> serde::Deserialize<'de> for PartitionStats { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -14730,7 +15108,7 @@ impl<'de> serde::Deserialize<'de> for PartitionedFile { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -14889,7 +15267,7 @@ impl<'de> serde::Deserialize<'de> for Partitioning { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -15054,7 +15432,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -15209,7 +15587,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalAliasNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -15325,7 +15703,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalBinaryExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -15452,7 +15830,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalCaseNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -15570,7 +15948,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalCastNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -15678,7 +16056,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalColumn { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -15796,7 +16174,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalDateTimeIntervalExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -15874,10 +16252,18 @@ impl serde::Serialize for PhysicalExprNode { { use serde::ser::SerializeStruct; let mut len = 0; + if self.expr_id.is_some() { + len += 1; + } if self.expr_type.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalExprNode", len)?; + if let Some(v) = self.expr_id.as_ref() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("exprId", ToString::to_string(&v).as_str())?; + } if let Some(v) = self.expr_type.as_ref() { match v { physical_expr_node::ExprType::Column(v) => { @@ -15949,6 +16335,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ + "expr_id", + "exprId", "column", "literal", "binary_expr", @@ -15985,6 +16373,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { #[allow(clippy::enum_variant_names)] enum GeneratedField { + ExprId, Column, Literal, BinaryExpr, @@ -16012,7 +16401,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -16025,6 +16414,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { E: serde::de::Error, { match value { + "exprId" | "expr_id" => Ok(GeneratedField::ExprId), "column" => Ok(GeneratedField::Column), "literal" => Ok(GeneratedField::Literal), "binaryExpr" | "binary_expr" => Ok(GeneratedField::BinaryExpr), @@ -16063,9 +16453,18 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { where V: serde::de::MapAccess<'de>, { + let mut expr_id__ = None; let mut expr_type__ = None; while let Some(k) = map_.next_key()? { match k { + GeneratedField::ExprId => { + if expr_id__.is_some() { + return Err(serde::de::Error::duplicate_field("exprId")); + } + expr_id__ = + map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| x.0) + ; + } GeneratedField::Column => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("column")); @@ -16202,6 +16601,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { } } Ok(PhysicalExprNode { + expr_id: expr_id__, expr_type: expr_type__, }) } @@ -16258,7 +16658,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExtensionExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -16370,7 +16770,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExtensionNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -16521,7 +16921,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalHashExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -16677,7 +17077,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalHashRepartition { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -16795,7 +17195,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalInListNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -16904,7 +17304,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalIsNotNull { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -16995,7 +17395,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalIsNull { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -17111,7 +17511,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalLikeExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -17229,7 +17629,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalNegativeNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -17320,7 +17720,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalNot { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -17491,6 +17891,12 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::AsyncFunc(v) => { struct_ser.serialize_field("asyncFunc", v)?; } + physical_plan_node::PhysicalPlanType::Buffer(v) => { + struct_ser.serialize_field("buffer", v)?; + } + physical_plan_node::PhysicalPlanType::ArrowScan(v) => { + struct_ser.serialize_field("arrowScan", v)?; + } } } struct_ser.end() @@ -17558,6 +17964,9 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "memoryScan", "async_func", "asyncFunc", + "buffer", + "arrow_scan", + "arrowScan", ]; #[allow(clippy::enum_variant_names)] @@ -17597,6 +18006,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { SortMergeJoin, MemoryScan, AsyncFunc, + Buffer, + ArrowScan, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -17605,7 +18016,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -17653,6 +18064,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "sortMergeJoin" | "sort_merge_join" => Ok(GeneratedField::SortMergeJoin), "memoryScan" | "memory_scan" => Ok(GeneratedField::MemoryScan), "asyncFunc" | "async_func" => Ok(GeneratedField::AsyncFunc), + "buffer" => Ok(GeneratedField::Buffer), + "arrowScan" | "arrow_scan" => Ok(GeneratedField::ArrowScan), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -17918,6 +18331,20 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("asyncFunc")); } physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::AsyncFunc) +; + } + GeneratedField::Buffer => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("buffer")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Buffer) +; + } + GeneratedField::ArrowScan => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("arrowScan")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::ArrowScan) ; } } @@ -18014,7 +18441,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -18169,7 +18596,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalSortExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -18279,7 +18706,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalSortExprNodeCollection { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -18379,7 +18806,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalTryCastNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -18489,7 +18916,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalWhenThen { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -18671,7 +19098,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -18868,7 +19295,7 @@ impl<'de> serde::Deserialize<'de> for PlaceholderNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -18988,7 +19415,7 @@ impl<'de> serde::Deserialize<'de> for PlaceholderRowExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -19143,7 +19570,7 @@ impl<'de> serde::Deserialize<'de> for PlanType { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -19356,7 +19783,7 @@ impl<'de> serde::Deserialize<'de> for PrepareNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -19474,7 +19901,7 @@ impl<'de> serde::Deserialize<'de> for ProjectionColumns { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -19582,7 +20009,7 @@ impl<'de> serde::Deserialize<'de> for ProjectionExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -19699,7 +20126,7 @@ impl<'de> serde::Deserialize<'de> for ProjectionExpr { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -19799,7 +20226,7 @@ impl<'de> serde::Deserialize<'de> for ProjectionExprs { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -19910,7 +20337,7 @@ impl<'de> serde::Deserialize<'de> for ProjectionNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -20037,7 +20464,7 @@ impl<'de> serde::Deserialize<'de> for RecursionUnnestOption { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -20175,7 +20602,7 @@ impl<'de> serde::Deserialize<'de> for RecursiveQueryNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -20301,7 +20728,7 @@ impl<'de> serde::Deserialize<'de> for RepartitionExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -20421,7 +20848,7 @@ impl<'de> serde::Deserialize<'de> for RepartitionNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -20529,7 +20956,7 @@ impl<'de> serde::Deserialize<'de> for RollupNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -20640,7 +21067,7 @@ impl<'de> serde::Deserialize<'de> for ScalarUdfExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -20751,7 +21178,7 @@ impl<'de> serde::Deserialize<'de> for ScanLimit { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -20844,7 +21271,7 @@ impl<'de> serde::Deserialize<'de> for SelectionExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -20943,7 +21370,7 @@ impl<'de> serde::Deserialize<'de> for SelectionNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -21068,7 +21495,7 @@ impl<'de> serde::Deserialize<'de> for SimilarToNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -21213,7 +21640,7 @@ impl<'de> serde::Deserialize<'de> for SortExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -21350,7 +21777,7 @@ impl<'de> serde::Deserialize<'de> for SortExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -21460,7 +21887,7 @@ impl<'de> serde::Deserialize<'de> for SortExprNodeCollection { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -21606,7 +22033,7 @@ impl<'de> serde::Deserialize<'de> for SortMergeJoinExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -21769,7 +22196,7 @@ impl<'de> serde::Deserialize<'de> for SortNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -21898,7 +22325,7 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -21996,7 +22423,7 @@ impl<'de> serde::Deserialize<'de> for StreamPartitionMode { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = StreamPartitionMode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -22089,7 +22516,7 @@ impl<'de> serde::Deserialize<'de> for StringifiedPlan { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -22197,7 +22624,7 @@ impl<'de> serde::Deserialize<'de> for SubqueryAliasNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -22372,7 +22799,7 @@ impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -22549,7 +22976,7 @@ impl<'de> serde::Deserialize<'de> for TableReference { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -22666,7 +23093,7 @@ impl<'de> serde::Deserialize<'de> for TryCastNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -22766,7 +23193,7 @@ impl<'de> serde::Deserialize<'de> for UnionExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -22857,7 +23284,7 @@ impl<'de> serde::Deserialize<'de> for UnionNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -22948,7 +23375,7 @@ impl<'de> serde::Deserialize<'de> for UnknownColumn { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -23039,7 +23466,7 @@ impl<'de> serde::Deserialize<'de> for Unnest { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -23164,7 +23591,7 @@ impl<'de> serde::Deserialize<'de> for UnnestExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -23346,7 +23773,7 @@ impl<'de> serde::Deserialize<'de> for UnnestNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -23506,7 +23933,7 @@ impl<'de> serde::Deserialize<'de> for UnnestOptions { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -23618,7 +24045,7 @@ impl<'de> serde::Deserialize<'de> for ValuesNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -23753,7 +24180,7 @@ impl<'de> serde::Deserialize<'de> for ViewTableScanNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -23890,7 +24317,7 @@ impl<'de> serde::Deserialize<'de> for WhenThen { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -23990,7 +24417,7 @@ impl<'de> serde::Deserialize<'de> for Wildcard { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -24122,7 +24549,7 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -24339,7 +24766,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -24535,7 +24962,7 @@ impl<'de> serde::Deserialize<'de> for WindowFrame { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -24657,7 +25084,7 @@ impl<'de> serde::Deserialize<'de> for WindowFrameBound { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -24746,7 +25173,7 @@ impl<'de> serde::Deserialize<'de> for WindowFrameBoundType { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = WindowFrameBoundType; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -24820,7 +25247,7 @@ impl<'de> serde::Deserialize<'de> for WindowFrameUnits { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = WindowFrameUnits; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -24914,7 +25341,7 @@ impl<'de> serde::Deserialize<'de> for WindowNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index cf343e0258d0..a0d4ef9e973c 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -444,6 +444,7 @@ pub mod dml_node { InsertAppend = 3, InsertOverwrite = 4, InsertReplace = 5, + Truncate = 6, } impl Type { /// String value of the enum field names used in the ProtoBuf definition. @@ -458,6 +459,7 @@ pub mod dml_node { Self::InsertAppend => "INSERT_APPEND", Self::InsertOverwrite => "INSERT_OVERWRITE", Self::InsertReplace => "INSERT_REPLACE", + Self::Truncate => "TRUNCATE", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -469,6 +471,7 @@ pub mod dml_node { "INSERT_APPEND" => Some(Self::InsertAppend), "INSERT_OVERWRITE" => Some(Self::InsertOverwrite), "INSERT_REPLACE" => Some(Self::InsertReplace), + "TRUNCATE" => Some(Self::Truncate), _ => None, } } @@ -1076,7 +1079,7 @@ pub mod table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38" )] pub physical_plan_type: ::core::option::Option, } @@ -1156,6 +1159,10 @@ pub mod physical_plan_node { MemoryScan(super::MemoryScanExecNode), #[prost(message, tag = "36")] AsyncFunc(::prost::alloc::boxed::Box), + #[prost(message, tag = "37")] + Buffer(::prost::alloc::boxed::Box), + #[prost(message, tag = "38")] + ArrowScan(super::ArrowScanExecNode), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -1183,6 +1190,9 @@ pub struct FileSinkConfig { pub insert_op: i32, #[prost(string, tag = "11")] pub file_extension: ::prost::alloc::string::String, + /// Determines how the output path is interpreted. + #[prost(enumeration = "FileOutputMode", tag = "12")] + pub file_output_mode: i32, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct JsonSink { @@ -1274,6 +1284,14 @@ pub struct PhysicalExtensionNode { /// physical expressions #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalExprNode { + /// Unique identifier for this expression to do deduplication during deserialization. + /// When serializing, this is set to a unique identifier for each combination of + /// expression, process and serialization run. + /// When deserializing, if this ID has been seen before, the cached Arc is returned + /// instead of creating a new one, enabling reconstruction of referential integrity + /// across serde roundtrips. + #[prost(uint64, optional, tag = "30")] + pub expr_id: ::core::option::Option, #[prost( oneof = "physical_expr_node::ExprType", tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19, 20, 21" @@ -1543,6 +1561,10 @@ pub struct FilterExecNode { pub default_filter_selectivity: u32, #[prost(uint32, repeated, tag = "9")] pub projection: ::prost::alloc::vec::Vec, + #[prost(uint32, tag = "10")] + pub batch_size: u32, + #[prost(uint32, optional, tag = "11")] + pub fetch: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct FileGroup { @@ -1651,6 +1673,11 @@ pub struct AvroScanExecNode { pub base_conf: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct ArrowScanExecNode { + #[prost(message, optional, tag = "1")] + pub base_conf: ::core::option::Option, +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct MemoryScanExecNode { #[prost(bytes = "vec", repeated, tag = "1")] pub partitions: ::prost::alloc::vec::Vec<::prost::alloc::vec::Vec>, @@ -1688,6 +1715,8 @@ pub struct HashJoinExecNode { pub filter: ::core::option::Option, #[prost(uint32, repeated, tag = "9")] pub projection: ::prost::alloc::vec::Vec, + #[prost(bool, tag = "10")] + pub null_aware: bool, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct SymmetricHashJoinExecNode { @@ -1830,6 +1859,9 @@ pub struct AggLimit { /// wrap into a message to make it optional #[prost(uint64, tag = "1")] pub limit: u64, + /// Optional ordering direction for TopK aggregation (true = descending, false = ascending) + #[prost(bool, optional, tag = "2")] + pub descending: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct AggregateExecNode { @@ -2134,6 +2166,13 @@ pub struct AsyncFuncExecNode { #[prost(string, repeated, tag = "3")] pub async_expr_names: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BufferExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(uint64, tag = "2")] + pub capacity: u64, +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum WindowFrameUnits { @@ -2244,6 +2283,39 @@ impl DateUnit { } } } +/// Determines how file sink output paths are interpreted. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum FileOutputMode { + /// Infer output mode from the URL (extension/trailing `/` heuristic). + Automatic = 0, + /// Write to a single file at the exact output path. + SingleFile = 1, + /// Write to a directory with generated filenames. + Directory = 2, +} +impl FileOutputMode { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Automatic => "FILE_OUTPUT_MODE_AUTOMATIC", + Self::SingleFile => "FILE_OUTPUT_MODE_SINGLE_FILE", + Self::Directory => "FILE_OUTPUT_MODE_DIRECTORY", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "FILE_OUTPUT_MODE_AUTOMATIC" => Some(Self::Automatic), + "FILE_OUTPUT_MODE_SINGLE_FILE" => Some(Self::SingleFile), + "FILE_OUTPUT_MODE_DIRECTORY" => Some(Self::Directory), + _ => None, + } + } +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum InsertOp { @@ -2336,6 +2408,7 @@ pub enum AggregateMode { FinalPartitioned = 2, Single = 3, SinglePartitioned = 4, + PartialReduce = 5, } impl AggregateMode { /// String value of the enum field names used in the ProtoBuf definition. @@ -2349,6 +2422,7 @@ impl AggregateMode { Self::FinalPartitioned => "FINAL_PARTITIONED", Self::Single => "SINGLE", Self::SinglePartitioned => "SINGLE_PARTITIONED", + Self::PartialReduce => "PARTIAL_REDUCE", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2359,6 +2433,7 @@ impl AggregateMode { "FINAL_PARTITIONED" => Some(Self::FinalPartitioned), "SINGLE" => Some(Self::Single), "SINGLE_PARTITIONED" => Some(Self::SinglePartitioned), + "PARTIAL_REDUCE" => Some(Self::PartialReduce), _ => None, } } diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index e30d2a22348c..7ddc930fa257 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -23,7 +23,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] -#![deny(clippy::allow_attributes)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] //! Serialize / Deserialize DataFusion Plans to bytes diff --git a/datafusion/proto/src/logical_plan/file_formats.rs b/datafusion/proto/src/logical_plan/file_formats.rs index 436a06493766..08f42b0af729 100644 --- a/datafusion/proto/src/logical_plan/file_formats.rs +++ b/datafusion/proto/src/logical_plan/file_formats.rs @@ -241,6 +241,7 @@ impl JsonOptionsProto { compression: options.compression as i32, schema_infer_max_rec: options.schema_infer_max_rec.map(|v| v as u64), compression_level: options.compression_level, + newline_delimited: Some(options.newline_delimited), } } else { JsonOptionsProto::default() @@ -260,6 +261,7 @@ impl From<&JsonOptionsProto> for JsonOptions { }, schema_infer_max_rec: proto.schema_infer_max_rec.map(|v| v as usize), compression_level: proto.compression_level, + newline_delimited: proto.newline_delimited.unwrap_or(true), } } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 179fe8bb7d7f..a653f517b727 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -239,6 +239,7 @@ impl From for WriteOp { } protobuf::dml_node::Type::InsertReplace => WriteOp::Insert(InsertOp::Replace), protobuf::dml_node::Type::Ctas => WriteOp::Ctas, + protobuf::dml_node::Type::Truncate => WriteOp::Truncate, } } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 6e4e5d0b6eea..fe63fce6ee26 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -578,6 +578,7 @@ pub fn serialize_expr( Expr::ScalarSubquery(_) | Expr::InSubquery(_) | Expr::Exists { .. } + | Expr::SetComparison(_) | Expr::OuterReferenceColumn { .. } => { // we would need to add logical plan operators to datafusion.proto to support this // see discussion in https://github.com/apache/datafusion/issues/2565 @@ -728,6 +729,7 @@ impl From<&WriteOp> for protobuf::dml_node::Type { WriteOp::Delete => protobuf::dml_node::Type::Delete, WriteOp::Update => protobuf::dml_node::Type::Update, WriteOp::Ctas => protobuf::dml_node::Type::Ctas, + WriteOp::Truncate => protobuf::dml_node::Type::Truncate, } } } diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 073fdd858cdd..e424be162648 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -21,14 +21,9 @@ use std::sync::Arc; use arrow::array::RecordBatch; use arrow::compute::SortOptions; -use arrow::datatypes::Field; +use arrow::datatypes::{Field, Schema}; use arrow::ipc::reader::StreamReader; use chrono::{TimeZone, Utc}; -use datafusion_expr::dml::InsertOp; -use object_store::ObjectMeta; -use object_store::path::Path; - -use arrow::datatypes::Schema; use datafusion_common::{DataFusionError, Result, internal_datafusion_err, not_impl_err}; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_groups::FileGroup; @@ -42,6 +37,7 @@ use datafusion_datasource_parquet::file_format::ParquetSink; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{FunctionRegistry, TaskContext}; use datafusion_expr::WindowFunctionDefinition; +use datafusion_expr::dml::InsertOp; use datafusion_physical_expr::projection::{ProjectionExpr, ProjectionExprs}; use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr, ScalarFunctionExpr}; use datafusion_physical_plan::expressions::{ @@ -52,13 +48,16 @@ use datafusion_physical_plan::joins::{HashExpr, SeededRandomState}; use datafusion_physical_plan::windows::{create_window_expr, schema_add_window_field}; use datafusion_physical_plan::{Partitioning, PhysicalExpr, WindowExpr}; use datafusion_proto_common::common::proto_error; +use object_store::ObjectMeta; +use object_store::path::Path; -use crate::convert_required; +use super::{ + DefaultPhysicalProtoConverter, PhysicalExtensionCodec, + PhysicalProtoConverterExtension, +}; use crate::logical_plan::{self}; -use crate::protobuf; use crate::protobuf::physical_expr_node::ExprType; - -use super::PhysicalExtensionCodec; +use crate::{convert_required, protobuf}; impl From<&protobuf::PhysicalColumn> for Column { fn from(c: &protobuf::PhysicalColumn) -> Column { @@ -80,9 +79,15 @@ pub fn parse_physical_sort_expr( ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { if let Some(expr) = &proto.expr { - let expr = parse_physical_expr(expr.as_ref(), ctx, input_schema, codec)?; + let expr = proto_converter.proto_to_physical_expr( + expr.as_ref(), + ctx, + input_schema, + codec, + )?; let options = SortOptions { descending: !proto.asc, nulls_first: proto.nulls_first, @@ -107,10 +112,13 @@ pub fn parse_physical_sort_exprs( ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { proto .iter() - .map(|sort_expr| parse_physical_sort_expr(sort_expr, ctx, input_schema, codec)) + .map(|sort_expr| { + parse_physical_sort_expr(sort_expr, ctx, input_schema, codec, proto_converter) + }) .collect() } @@ -129,12 +137,25 @@ pub fn parse_physical_window_expr( ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let window_node_expr = parse_physical_exprs(&proto.args, ctx, input_schema, codec)?; - let partition_by = - parse_physical_exprs(&proto.partition_by, ctx, input_schema, codec)?; - - let order_by = parse_physical_sort_exprs(&proto.order_by, ctx, input_schema, codec)?; + let window_node_expr = + parse_physical_exprs(&proto.args, ctx, input_schema, codec, proto_converter)?; + let partition_by = parse_physical_exprs( + &proto.partition_by, + ctx, + input_schema, + codec, + proto_converter, + )?; + + let order_by = parse_physical_sort_exprs( + &proto.order_by, + ctx, + input_schema, + codec, + proto_converter, + )?; let window_frame = proto .window_frame @@ -188,13 +209,14 @@ pub fn parse_physical_exprs<'a, I>( ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result>> where I: IntoIterator, { protos .into_iter() - .map(|p| parse_physical_expr(p, ctx, input_schema, codec)) + .map(|p| proto_converter.proto_to_physical_expr(p, ctx, input_schema, codec)) .collect::>>() } @@ -212,6 +234,32 @@ pub fn parse_physical_expr( ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, +) -> Result> { + parse_physical_expr_with_converter( + proto, + ctx, + input_schema, + codec, + &DefaultPhysicalProtoConverter {}, + ) +} + +/// Parses a physical expression from a protobuf. +/// +/// # Arguments +/// +/// * `proto` - Input proto with physical expression node +/// * `registry` - A registry knows how to build logical expressions out of user-defined function names +/// * `input_schema` - The Arrow schema for the input, used for determining expression data types +/// when performing type coercion. +/// * `codec` - An extension codec used to decode custom UDFs. +/// * `proto_converter` - Conversion functions for physical plans and expressions +pub fn parse_physical_expr_with_converter( + proto: &protobuf::PhysicalExprNode, + ctx: &TaskContext, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let expr_type = proto .expr_type @@ -232,6 +280,7 @@ pub fn parse_physical_expr( "left", input_schema, codec, + proto_converter, )?, logical_plan::from_proto::from_proto_binary_op(&binary_expr.op)?, parse_required_physical_expr( @@ -240,6 +289,7 @@ pub fn parse_physical_expr( "right", input_schema, codec, + proto_converter, )?, )), ExprType::AggregateExpr(_) => { @@ -262,6 +312,7 @@ pub fn parse_physical_expr( "expr", input_schema, codec, + proto_converter, )?)) } ExprType::IsNotNullExpr(e) => { @@ -271,6 +322,7 @@ pub fn parse_physical_expr( "expr", input_schema, codec, + proto_converter, )?)) } ExprType::NotExpr(e) => Arc::new(NotExpr::new(parse_required_physical_expr( @@ -279,6 +331,7 @@ pub fn parse_physical_expr( "expr", input_schema, codec, + proto_converter, )?)), ExprType::Negative(e) => { Arc::new(NegativeExpr::new(parse_required_physical_expr( @@ -287,6 +340,7 @@ pub fn parse_physical_expr( "expr", input_schema, codec, + proto_converter, )?)) } ExprType::InList(e) => in_list( @@ -296,15 +350,23 @@ pub fn parse_physical_expr( "expr", input_schema, codec, + proto_converter, )?, - parse_physical_exprs(&e.list, ctx, input_schema, codec)?, + parse_physical_exprs(&e.list, ctx, input_schema, codec, proto_converter)?, &e.negated, input_schema, )?, ExprType::Case(e) => Arc::new(CaseExpr::try_new( e.expr .as_ref() - .map(|e| parse_physical_expr(e.as_ref(), ctx, input_schema, codec)) + .map(|e| { + proto_converter.proto_to_physical_expr( + e.as_ref(), + ctx, + input_schema, + codec, + ) + }) .transpose()?, e.when_then_expr .iter() @@ -316,6 +378,7 @@ pub fn parse_physical_expr( "when_expr", input_schema, codec, + proto_converter, )?, parse_required_physical_expr( e.then_expr.as_ref(), @@ -323,13 +386,21 @@ pub fn parse_physical_expr( "then_expr", input_schema, codec, + proto_converter, )?, )) }) .collect::>>()?, e.else_expr .as_ref() - .map(|e| parse_physical_expr(e.as_ref(), ctx, input_schema, codec)) + .map(|e| { + proto_converter.proto_to_physical_expr( + e.as_ref(), + ctx, + input_schema, + codec, + ) + }) .transpose()?, )?), ExprType::Cast(e) => Arc::new(CastExpr::new( @@ -339,6 +410,7 @@ pub fn parse_physical_expr( "expr", input_schema, codec, + proto_converter, )?, convert_required!(e.arrow_type)?, None, @@ -350,6 +422,7 @@ pub fn parse_physical_expr( "expr", input_schema, codec, + proto_converter, )?, convert_required!(e.arrow_type)?, )), @@ -362,7 +435,8 @@ pub fn parse_physical_expr( }; let scalar_fun_def = Arc::clone(&udf); - let args = parse_physical_exprs(&e.args, ctx, input_schema, codec)?; + let args = + parse_physical_exprs(&e.args, ctx, input_schema, codec, proto_converter)?; let config_options = Arc::clone(ctx.session_config().options()); @@ -391,6 +465,7 @@ pub fn parse_physical_expr( "expr", input_schema, codec, + proto_converter, )?, parse_required_physical_expr( like_expr.pattern.as_deref(), @@ -398,11 +473,17 @@ pub fn parse_physical_expr( "pattern", input_schema, codec, + proto_converter, )?, )), ExprType::HashExpr(hash_expr) => { - let on_columns = - parse_physical_exprs(&hash_expr.on_columns, ctx, input_schema, codec)?; + let on_columns = parse_physical_exprs( + &hash_expr.on_columns, + ctx, + input_schema, + codec, + proto_converter, + )?; Arc::new(HashExpr::new( on_columns, SeededRandomState::with_seeds( @@ -418,9 +499,11 @@ pub fn parse_physical_expr( let inputs: Vec> = extension .inputs .iter() - .map(|e| parse_physical_expr(e, ctx, input_schema, codec)) + .map(|e| { + proto_converter.proto_to_physical_expr(e, ctx, input_schema, codec) + }) .collect::>()?; - (codec.try_decode_expr(extension.expr.as_slice(), &inputs)?) as _ + codec.try_decode_expr(extension.expr.as_slice(), &inputs)? as _ } }; @@ -433,8 +516,9 @@ fn parse_required_physical_expr( field: &str, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - expr.map(|e| parse_physical_expr(e, ctx, input_schema, codec)) + expr.map(|e| proto_converter.proto_to_physical_expr(e, ctx, input_schema, codec)) .transpose()? .ok_or_else(|| internal_datafusion_err!("Missing required field {field:?}")) } @@ -444,11 +528,17 @@ pub fn parse_protobuf_hash_partitioning( ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { match partitioning { Some(hash_part) => { - let expr = - parse_physical_exprs(&hash_part.hash_expr, ctx, input_schema, codec)?; + let expr = parse_physical_exprs( + &hash_part.hash_expr, + ctx, + input_schema, + codec, + proto_converter, + )?; Ok(Some(Partitioning::Hash( expr, @@ -464,6 +554,7 @@ pub fn parse_protobuf_partitioning( ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { match partitioning { Some(protobuf::Partitioning { partition_method }) => match partition_method { @@ -478,6 +569,7 @@ pub fn parse_protobuf_partitioning( ctx, input_schema, codec, + proto_converter, ) } Some(protobuf::partitioning::PartitionMethod::Unknown(partition_count)) => { @@ -532,6 +624,7 @@ pub fn parse_protobuf_file_scan_config( proto: &protobuf::FileScanExecConf, ctx: &TaskContext, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, file_source: Arc, ) -> Result { let schema: Arc = parse_protobuf_file_scan_schema(proto)?; @@ -557,6 +650,7 @@ pub fn parse_protobuf_file_scan_config( ctx, &schema, codec, + proto_converter, )?; output_ordering.extend(LexOrdering::new(sort_exprs)); } @@ -567,7 +661,7 @@ pub fn parse_protobuf_file_scan_config( .projections .iter() .map(|proto_expr| { - let expr = parse_physical_expr( + let expr = proto_converter.proto_to_physical_expr( proto_expr.expr.as_ref().ok_or_else(|| { internal_datafusion_err!("ProjectionExpr missing expr field") })?, @@ -616,30 +710,28 @@ impl TryFrom<&protobuf::PartitionedFile> for PartitionedFile { type Error = DataFusionError; fn try_from(val: &protobuf::PartitionedFile) -> Result { - Ok(PartitionedFile { - object_meta: ObjectMeta { - location: Path::parse(val.path.as_str()).map_err(|e| { - proto_error(format!("Invalid object_store path: {e}")) - })?, - last_modified: Utc.timestamp_nanos(val.last_modified_ns as i64), - size: val.size, - e_tag: None, - version: None, - }, - partition_values: val - .partition_values + let mut pf = PartitionedFile::new_from_meta(ObjectMeta { + location: Path::parse(val.path.as_str()) + .map_err(|e| proto_error(format!("Invalid object_store path: {e}")))?, + last_modified: Utc.timestamp_nanos(val.last_modified_ns as i64), + size: val.size, + e_tag: None, + version: None, + }) + .with_partition_values( + val.partition_values .iter() .map(|v| v.try_into()) .collect::, _>>()?, - range: val.range.as_ref().map(|v| v.try_into()).transpose()?, - statistics: val - .statistics - .as_ref() - .map(|v| v.try_into().map(Arc::new)) - .transpose()?, - extensions: None, - metadata_size_hint: None, - }) + ); + if let Some(range) = val.range.as_ref() { + let file_range: FileRange = range.try_into()?; + pf = pf.with_range(file_range.start, file_range.end); + } + if let Some(proto_stats) = val.statistics.as_ref() { + pf = pf.with_statistics(Arc::new(proto_stats.try_into()?)); + } + Ok(pf) } } @@ -729,6 +821,17 @@ impl TryFrom<&protobuf::FileSinkConfig> for FileSinkConfig { protobuf::InsertOp::Overwrite => InsertOp::Overwrite, protobuf::InsertOp::Replace => InsertOp::Replace, }; + let file_output_mode = match conf.file_output_mode() { + protobuf::FileOutputMode::Automatic => { + datafusion_datasource::file_sink_config::FileOutputMode::Automatic + } + protobuf::FileOutputMode::SingleFile => { + datafusion_datasource::file_sink_config::FileOutputMode::SingleFile + } + protobuf::FileOutputMode::Directory => { + datafusion_datasource::file_sink_config::FileOutputMode::Directory + } + }; Ok(Self { original_url: String::default(), object_store_url: ObjectStoreUrl::parse(&conf.object_store_url)?, @@ -739,35 +842,30 @@ impl TryFrom<&protobuf::FileSinkConfig> for FileSinkConfig { insert_op, keep_partition_by_columns: conf.keep_partition_by_columns, file_extension: conf.file_extension.clone(), + file_output_mode, }) } } #[cfg(test)] mod tests { - use super::*; use chrono::{TimeZone, Utc}; use datafusion_datasource::PartitionedFile; use object_store::ObjectMeta; use object_store::path::Path; + use super::*; + #[test] fn partitioned_file_path_roundtrip_percent_encoded() { let path_str = "foo/foo%2Fbar/baz%252Fqux"; - let pf = PartitionedFile { - object_meta: ObjectMeta { - location: Path::parse(path_str).unwrap(), - last_modified: Utc.timestamp_nanos(1_000), - size: 42, - e_tag: None, - version: None, - }, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }; + let pf = PartitionedFile::new_from_meta(ObjectMeta { + location: Path::parse(path_str).unwrap(), + last_modified: Utc.timestamp_nanos(1_000), + size: 42, + e_tag: None, + version: None, + }); let proto = protobuf::PartitionedFile::try_from(&pf).unwrap(); assert_eq!(proto.path, path_str); diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 4ff90b61eed9..85406e31da61 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -15,33 +15,14 @@ // specific language governing permissions and limitations // under the License. +use std::cell::RefCell; +use std::collections::HashMap; use std::fmt::Debug; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; -use self::from_proto::parse_protobuf_partitioning; -use self::to_proto::{serialize_partitioning, serialize_physical_expr}; -use crate::common::{byte_to_string, str_to_byte}; -use crate::physical_plan::from_proto::{ - parse_physical_expr, parse_physical_sort_expr, parse_physical_sort_exprs, - parse_physical_window_expr, parse_protobuf_file_scan_config, parse_record_batches, - parse_table_schema_from_proto, -}; -use crate::physical_plan::to_proto::{ - serialize_file_scan_config, serialize_maybe_filter, serialize_physical_aggr_expr, - serialize_physical_sort_exprs, serialize_physical_window_expr, - serialize_record_batches, -}; -use crate::protobuf::physical_aggregate_expr_node::AggregateFunction; -use crate::protobuf::physical_expr_node::ExprType; -use crate::protobuf::physical_plan_node::PhysicalPlanType; -use crate::protobuf::{ - self, ListUnnest as ProtoListUnnest, SortExprNode, SortMergeJoinExecNode, - proto_error, window_agg_exec_node, -}; -use crate::{convert_required, into_required}; - use arrow::compute::SortOptions; -use arrow::datatypes::{IntervalMonthDayNanoType, SchemaRef}; +use arrow::datatypes::{IntervalMonthDayNanoType, Schema, SchemaRef}; use datafusion_catalog::memory::MemorySourceConfig; use datafusion_common::config::CsvOptions; use datafusion_common::{ @@ -53,6 +34,7 @@ use datafusion_datasource::file_compression_type::FileCompressionType; use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; use datafusion_datasource::sink::DataSinkExec; use datafusion_datasource::source::{DataSource, DataSourceExec}; +use datafusion_datasource_arrow::source::ArrowSource; #[cfg(feature = "avro")] use datafusion_datasource_avro::source::AvroSource; use datafusion_datasource_csv::file_format::CsvSink; @@ -60,33 +42,40 @@ use datafusion_datasource_csv::source::CsvSource; use datafusion_datasource_json::file_format::JsonSink; use datafusion_datasource_json::source::JsonSource; #[cfg(feature = "parquet")] +use datafusion_datasource_parquet::CachedParquetFileReaderFactory; +#[cfg(feature = "parquet")] use datafusion_datasource_parquet::file_format::ParquetSink; #[cfg(feature = "parquet")] use datafusion_datasource_parquet::source::ParquetSource; +#[cfg(feature = "parquet")] +use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{FunctionRegistry, TaskContext}; use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; use datafusion_functions_table::generate_series::{ Empty, GenSeriesArgs, GenerateSeriesTable, GenericSeriesState, TimestampValue, }; -use datafusion_physical_expr::aggregate::AggregateExprBuilder; -use datafusion_physical_expr::aggregate::AggregateFunctionExpr; +use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; +use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr; use datafusion_physical_expr::{LexOrdering, LexRequirement, PhysicalExprRef}; -use datafusion_physical_plan::aggregates::AggregateMode; -use datafusion_physical_plan::aggregates::{AggregateExec, PhysicalGroupBy}; +use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateMode, LimitOptions, PhysicalGroupBy, +}; use datafusion_physical_plan::analyze::AnalyzeExec; +use datafusion_physical_plan::async_func::AsyncFuncExec; +use datafusion_physical_plan::buffer::BufferExec; +#[expect(deprecated)] use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::coop::CooperativeExec; use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::explain::ExplainExec; use datafusion_physical_plan::expressions::PhysicalSortExpr; -use datafusion_physical_plan::filter::FilterExec; +use datafusion_physical_plan::filter::{FilterExec, FilterExecBuilder}; use datafusion_physical_plan::joins::utils::{ColumnIndex, JoinFilter}; use datafusion_physical_plan::joins::{ - CrossJoinExec, NestedLoopJoinExec, SortMergeJoinExec, StreamJoinPartitionMode, - SymmetricHashJoinExec, + CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec, + StreamJoinPartitionMode, SymmetricHashJoinExec, }; -use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::memory::LazyMemoryExec; use datafusion_physical_plan::metrics::MetricType; @@ -99,12 +88,31 @@ use datafusion_physical_plan::union::{InterleaveExec, UnionExec}; use datafusion_physical_plan::unnest::{ListUnnest, UnnestExec}; use datafusion_physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use datafusion_physical_plan::{ExecutionPlan, InputOrderMode, PhysicalExpr, WindowExpr}; - -use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr; -use datafusion_physical_plan::async_func::AsyncFuncExec; use prost::Message; use prost::bytes::BufMut; +use self::from_proto::parse_protobuf_partitioning; +use self::to_proto::serialize_partitioning; +use crate::common::{byte_to_string, str_to_byte}; +use crate::physical_plan::from_proto::{ + parse_physical_expr_with_converter, parse_physical_sort_expr, + parse_physical_sort_exprs, parse_physical_window_expr, + parse_protobuf_file_scan_config, parse_record_batches, parse_table_schema_from_proto, +}; +use crate::physical_plan::to_proto::{ + serialize_file_scan_config, serialize_maybe_filter, serialize_physical_aggr_expr, + serialize_physical_expr_with_converter, serialize_physical_sort_exprs, + serialize_physical_window_expr, serialize_record_batches, +}; +use crate::protobuf::physical_aggregate_expr_node::AggregateFunction; +use crate::protobuf::physical_expr_node::ExprType; +use crate::protobuf::physical_plan_node::PhysicalPlanType; +use crate::protobuf::{ + self, ListUnnest as ProtoListUnnest, SortExprNode, SortMergeJoinExecNode, + proto_error, window_agg_exec_node, +}; +use crate::{convert_required, into_required}; + pub mod from_proto; pub mod to_proto; @@ -131,8 +139,37 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { fn try_into_physical_plan( &self, ctx: &TaskContext, + codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + self.try_into_physical_plan_with_converter( + ctx, + codec, + &DefaultPhysicalProtoConverter {}, + ) + } - extension_codec: &dyn PhysicalExtensionCodec, + fn try_from_physical_plan( + plan: Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result + where + Self: Sized, + { + Self::try_from_physical_plan_with_converter( + plan, + codec, + &DefaultPhysicalProtoConverter {}, + ) + } +} + +impl protobuf::PhysicalPlanNode { + pub fn try_into_physical_plan_with_converter( + &self, + ctx: &TaskContext, + + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let plan = self.physical_plan_type.as_ref().ok_or_else(|| { proto_error(format!( @@ -141,125 +178,155 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { })?; match plan { PhysicalPlanType::Explain(explain) => { - self.try_into_explain_physical_plan(explain, ctx, extension_codec) - } - PhysicalPlanType::Projection(projection) => { - self.try_into_projection_physical_plan(projection, ctx, extension_codec) + self.try_into_explain_physical_plan(explain, ctx, codec, proto_converter) } + PhysicalPlanType::Projection(projection) => self + .try_into_projection_physical_plan( + projection, + ctx, + codec, + proto_converter, + ), PhysicalPlanType::Filter(filter) => { - self.try_into_filter_physical_plan(filter, ctx, extension_codec) + self.try_into_filter_physical_plan(filter, ctx, codec, proto_converter) } PhysicalPlanType::CsvScan(scan) => { - self.try_into_csv_scan_physical_plan(scan, ctx, extension_codec) + self.try_into_csv_scan_physical_plan(scan, ctx, codec, proto_converter) } PhysicalPlanType::JsonScan(scan) => { - self.try_into_json_scan_physical_plan(scan, ctx, extension_codec) - } - PhysicalPlanType::ParquetScan(scan) => { - self.try_into_parquet_scan_physical_plan(scan, ctx, extension_codec) + self.try_into_json_scan_physical_plan(scan, ctx, codec, proto_converter) } + PhysicalPlanType::ParquetScan(scan) => self + .try_into_parquet_scan_physical_plan(scan, ctx, codec, proto_converter), PhysicalPlanType::AvroScan(scan) => { - self.try_into_avro_scan_physical_plan(scan, ctx, extension_codec) + self.try_into_avro_scan_physical_plan(scan, ctx, codec, proto_converter) } PhysicalPlanType::MemoryScan(scan) => { - self.try_into_memory_scan_physical_plan(scan, ctx, extension_codec) + self.try_into_memory_scan_physical_plan(scan, ctx, codec, proto_converter) + } + PhysicalPlanType::ArrowScan(scan) => { + self.try_into_arrow_scan_physical_plan(scan, ctx, codec, proto_converter) } PhysicalPlanType::CoalesceBatches(coalesce_batches) => self .try_into_coalesce_batches_physical_plan( coalesce_batches, ctx, - extension_codec, + codec, + proto_converter, ), PhysicalPlanType::Merge(merge) => { - self.try_into_merge_physical_plan(merge, ctx, extension_codec) - } - PhysicalPlanType::Repartition(repart) => { - self.try_into_repartition_physical_plan(repart, ctx, extension_codec) - } - PhysicalPlanType::GlobalLimit(limit) => { - self.try_into_global_limit_physical_plan(limit, ctx, extension_codec) - } - PhysicalPlanType::LocalLimit(limit) => { - self.try_into_local_limit_physical_plan(limit, ctx, extension_codec) - } - PhysicalPlanType::Window(window_agg) => { - self.try_into_window_physical_plan(window_agg, ctx, extension_codec) - } - PhysicalPlanType::Aggregate(hash_agg) => { - self.try_into_aggregate_physical_plan(hash_agg, ctx, extension_codec) - } - PhysicalPlanType::HashJoin(hashjoin) => { - self.try_into_hash_join_physical_plan(hashjoin, ctx, extension_codec) + self.try_into_merge_physical_plan(merge, ctx, codec, proto_converter) } + PhysicalPlanType::Repartition(repart) => self + .try_into_repartition_physical_plan(repart, ctx, codec, proto_converter), + PhysicalPlanType::GlobalLimit(limit) => self + .try_into_global_limit_physical_plan(limit, ctx, codec, proto_converter), + PhysicalPlanType::LocalLimit(limit) => self + .try_into_local_limit_physical_plan(limit, ctx, codec, proto_converter), + PhysicalPlanType::Window(window_agg) => self.try_into_window_physical_plan( + window_agg, + ctx, + codec, + proto_converter, + ), + PhysicalPlanType::Aggregate(hash_agg) => self + .try_into_aggregate_physical_plan(hash_agg, ctx, codec, proto_converter), + PhysicalPlanType::HashJoin(hashjoin) => self + .try_into_hash_join_physical_plan(hashjoin, ctx, codec, proto_converter), PhysicalPlanType::SymmetricHashJoin(sym_join) => self .try_into_symmetric_hash_join_physical_plan( sym_join, ctx, - extension_codec, + codec, + proto_converter, ), PhysicalPlanType::Union(union) => { - self.try_into_union_physical_plan(union, ctx, extension_codec) + self.try_into_union_physical_plan(union, ctx, codec, proto_converter) } - PhysicalPlanType::Interleave(interleave) => { - self.try_into_interleave_physical_plan(interleave, ctx, extension_codec) - } - PhysicalPlanType::CrossJoin(crossjoin) => { - self.try_into_cross_join_physical_plan(crossjoin, ctx, extension_codec) - } - PhysicalPlanType::Empty(empty) => { - self.try_into_empty_physical_plan(empty, ctx, extension_codec) - } - PhysicalPlanType::PlaceholderRow(placeholder) => self - .try_into_placeholder_row_physical_plan( - placeholder, + PhysicalPlanType::Interleave(interleave) => self + .try_into_interleave_physical_plan( + interleave, ctx, - extension_codec, + codec, + proto_converter, ), - PhysicalPlanType::Sort(sort) => { - self.try_into_sort_physical_plan(sort, ctx, extension_codec) + PhysicalPlanType::CrossJoin(crossjoin) => self + .try_into_cross_join_physical_plan( + crossjoin, + ctx, + codec, + proto_converter, + ), + PhysicalPlanType::Empty(empty) => { + self.try_into_empty_physical_plan(empty, ctx, codec, proto_converter) } - PhysicalPlanType::SortPreservingMerge(sort) => self - .try_into_sort_preserving_merge_physical_plan(sort, ctx, extension_codec), - PhysicalPlanType::Extension(extension) => { - self.try_into_extension_physical_plan(extension, ctx, extension_codec) + PhysicalPlanType::PlaceholderRow(placeholder) => { + self.try_into_placeholder_row_physical_plan(placeholder, ctx, codec) } - PhysicalPlanType::NestedLoopJoin(join) => { - self.try_into_nested_loop_join_physical_plan(join, ctx, extension_codec) + PhysicalPlanType::Sort(sort) => { + self.try_into_sort_physical_plan(sort, ctx, codec, proto_converter) } + PhysicalPlanType::SortPreservingMerge(sort) => self + .try_into_sort_preserving_merge_physical_plan( + sort, + ctx, + codec, + proto_converter, + ), + PhysicalPlanType::Extension(extension) => self + .try_into_extension_physical_plan(extension, ctx, codec, proto_converter), + PhysicalPlanType::NestedLoopJoin(join) => self + .try_into_nested_loop_join_physical_plan( + join, + ctx, + codec, + proto_converter, + ), PhysicalPlanType::Analyze(analyze) => { - self.try_into_analyze_physical_plan(analyze, ctx, extension_codec) + self.try_into_analyze_physical_plan(analyze, ctx, codec, proto_converter) } PhysicalPlanType::JsonSink(sink) => { - self.try_into_json_sink_physical_plan(sink, ctx, extension_codec) + self.try_into_json_sink_physical_plan(sink, ctx, codec, proto_converter) } PhysicalPlanType::CsvSink(sink) => { - self.try_into_csv_sink_physical_plan(sink, ctx, extension_codec) + self.try_into_csv_sink_physical_plan(sink, ctx, codec, proto_converter) } #[cfg_attr(not(feature = "parquet"), allow(unused_variables))] - PhysicalPlanType::ParquetSink(sink) => { - self.try_into_parquet_sink_physical_plan(sink, ctx, extension_codec) - } + PhysicalPlanType::ParquetSink(sink) => self + .try_into_parquet_sink_physical_plan(sink, ctx, codec, proto_converter), PhysicalPlanType::Unnest(unnest) => { - self.try_into_unnest_physical_plan(unnest, ctx, extension_codec) - } - PhysicalPlanType::Cooperative(cooperative) => { - self.try_into_cooperative_physical_plan(cooperative, ctx, extension_codec) + self.try_into_unnest_physical_plan(unnest, ctx, codec, proto_converter) } + PhysicalPlanType::Cooperative(cooperative) => self + .try_into_cooperative_physical_plan( + cooperative, + ctx, + codec, + proto_converter, + ), PhysicalPlanType::GenerateSeries(generate_series) => { self.try_into_generate_series_physical_plan(generate_series) } PhysicalPlanType::SortMergeJoin(sort_join) => { - self.try_into_sort_join(sort_join, ctx, extension_codec) + self.try_into_sort_join(sort_join, ctx, codec, proto_converter) } - PhysicalPlanType::AsyncFunc(async_func) => { - self.try_into_async_func_physical_plan(async_func, ctx, extension_codec) + PhysicalPlanType::AsyncFunc(async_func) => self + .try_into_async_func_physical_plan( + async_func, + ctx, + codec, + proto_converter, + ), + PhysicalPlanType::Buffer(buffer) => { + self.try_into_buffer_physical_plan(buffer, ctx, codec, proto_converter) } } } - fn try_from_physical_plan( + pub fn try_from_physical_plan_with_converter( plan: Arc, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result where Self: Sized, @@ -268,107 +335,113 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { let plan = plan.as_any(); if let Some(exec) = plan.downcast_ref::() { - return protobuf::PhysicalPlanNode::try_from_explain_exec( - exec, - extension_codec, - ); + return protobuf::PhysicalPlanNode::try_from_explain_exec(exec, codec); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_projection_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_analyze_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_filter_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(limit) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_global_limit_exec( limit, - extension_codec, + codec, + proto_converter, ); } if let Some(limit) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_local_limit_exec( limit, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_hash_join_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_symmetric_hash_join_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_sort_merge_join_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_cross_join_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_aggregate_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(empty) = plan.downcast_ref::() { - return protobuf::PhysicalPlanNode::try_from_empty_exec( - empty, - extension_codec, - ); + return protobuf::PhysicalPlanNode::try_from_empty_exec(empty, codec); } if let Some(empty) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_placeholder_row_exec( - empty, - extension_codec, + empty, codec, ); } + #[expect(deprecated)] if let Some(coalesce_batches) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_coalesce_batches_exec( coalesce_batches, - extension_codec, + codec, + proto_converter, ); } if let Some(data_source_exec) = plan.downcast_ref::() && let Some(node) = protobuf::PhysicalPlanNode::try_from_data_source_exec( data_source_exec, - extension_codec, + codec, + proto_converter, )? { return Ok(node); @@ -377,67 +450,80 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_coalesce_partitions_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_repartition_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { - return protobuf::PhysicalPlanNode::try_from_sort_exec(exec, extension_codec); + return protobuf::PhysicalPlanNode::try_from_sort_exec( + exec, + codec, + proto_converter, + ); } if let Some(union) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_union_exec( union, - extension_codec, + codec, + proto_converter, ); } if let Some(interleave) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_interleave_exec( interleave, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_sort_preserving_merge_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_nested_loop_join_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_window_agg_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_bounded_window_agg_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() && let Some(node) = protobuf::PhysicalPlanNode::try_from_data_sink_exec( exec, - extension_codec, + codec, + proto_converter, )? { return Ok(node); @@ -446,14 +532,16 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_unnest_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_cooperative_exec( exec, - extension_codec, + codec, + proto_converter, ); } @@ -467,21 +555,31 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_async_func_exec( exec, - extension_codec, + codec, + proto_converter, + ); + } + + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_buffer_exec( + exec, + codec, + proto_converter, ); } let mut buf: Vec = vec![]; - match extension_codec.try_encode(Arc::clone(&plan_clone), &mut buf) { + match codec.try_encode(Arc::clone(&plan_clone), &mut buf) { Ok(_) => { let inputs: Vec = plan_clone .children() .into_iter() .cloned() .map(|i| { - protobuf::PhysicalPlanNode::try_from_physical_plan( + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( i, - extension_codec, + codec, + proto_converter, ) }) .collect::>()?; @@ -505,7 +603,8 @@ impl protobuf::PhysicalPlanNode { explain: &protobuf::ExplainExecNode, _ctx: &TaskContext, - _extension_codec: &dyn PhysicalExtensionCodec, + _codec: &dyn PhysicalExtensionCodec, + _proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { Ok(Arc::new(ExplainExec::new( Arc::new(explain.schema.as_ref().unwrap().try_into()?), @@ -523,21 +622,22 @@ impl protobuf::PhysicalPlanNode { projection: &protobuf::ProjectionExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&projection.input, ctx, extension_codec)?; + into_physical_plan(&projection.input, ctx, codec, proto_converter)?; let exprs = projection .expr .iter() .zip(projection.expr_name.iter()) .map(|(expr, name)| { Ok(( - parse_physical_expr( + proto_converter.proto_to_physical_expr( expr, ctx, input.schema().as_ref(), - extension_codec, + codec, )?, name.to_string(), )) @@ -555,16 +655,22 @@ impl protobuf::PhysicalPlanNode { filter: &protobuf::FilterExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&filter.input, ctx, extension_codec)?; + into_physical_plan(&filter.input, ctx, codec, proto_converter)?; let predicate = filter .expr .as_ref() .map(|expr| { - parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec) + proto_converter.proto_to_physical_expr( + expr, + ctx, + input.schema().as_ref(), + codec, + ) }) .transpose()? .ok_or_else(|| { @@ -586,8 +692,11 @@ impl protobuf::PhysicalPlanNode { None }; - let filter = - FilterExec::try_new(predicate, input)?.with_projection(projection)?; + let filter = FilterExecBuilder::new(predicate, input) + .apply_projection(projection)? + .with_batch_size(filter.batch_size as usize) + .with_fetch(filter.fetch.map(|f| f as usize)) + .build()?; match filter_selectivity { Ok(filter_selectivity) => Ok(Arc::new( filter.with_default_selectivity(filter_selectivity)?, @@ -603,7 +712,8 @@ impl protobuf::PhysicalPlanNode { scan: &protobuf::CsvScanExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let escape = if let Some(protobuf::csv_scan_exec_node::OptionalEscape::Escape(escape)) = @@ -644,7 +754,8 @@ impl protobuf::PhysicalPlanNode { let conf = FileScanConfigBuilder::from(parse_protobuf_file_scan_config( scan.base_conf.as_ref().unwrap(), ctx, - extension_codec, + codec, + proto_converter, source, )?) .with_file_compression_type(FileCompressionType::UNCOMPRESSED) @@ -657,25 +768,49 @@ impl protobuf::PhysicalPlanNode { scan: &protobuf::JsonScanExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let base_conf = scan.base_conf.as_ref().unwrap(); let table_schema = parse_table_schema_from_proto(base_conf)?; let scan_conf = parse_protobuf_file_scan_config( base_conf, ctx, - extension_codec, + codec, + proto_converter, Arc::new(JsonSource::new(table_schema)), )?; Ok(DataSourceExec::from_data_source(scan_conf)) } + fn try_into_arrow_scan_physical_plan( + &self, + scan: &protobuf::ArrowScanExecNode, + ctx: &TaskContext, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result> { + let base_conf = scan.base_conf.as_ref().ok_or_else(|| { + internal_datafusion_err!("base_conf in ArrowScanExecNode is missing.") + })?; + let table_schema = parse_table_schema_from_proto(base_conf)?; + let scan_conf = parse_protobuf_file_scan_config( + base_conf, + ctx, + codec, + proto_converter, + Arc::new(ArrowSource::new_file_source(table_schema)), + )?; + Ok(DataSourceExec::from_data_source(scan_conf)) + } + #[cfg_attr(not(feature = "parquet"), expect(unused_variables))] fn try_into_parquet_scan_physical_plan( &self, scan: &protobuf::ParquetScanExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { #[cfg(feature = "parquet")] { @@ -692,7 +827,7 @@ impl protobuf::PhysicalPlanNode { .iter() .map(|&i| schema.field(i as usize).clone()) .collect(); - Arc::new(arrow::datatypes::Schema::new(projected_fields)) + Arc::new(Schema::new(projected_fields)) } else { schema }; @@ -701,11 +836,11 @@ impl protobuf::PhysicalPlanNode { .predicate .as_ref() .map(|expr| { - parse_physical_expr( + proto_converter.proto_to_physical_expr( expr, ctx, predicate_schema.as_ref(), - extension_codec, + codec, ) }) .transpose()?; @@ -717,9 +852,19 @@ impl protobuf::PhysicalPlanNode { // Parse table schema with partition columns let table_schema = parse_table_schema_from_proto(base_conf)?; + let object_store_url = match base_conf.object_store_url.is_empty() { + false => ObjectStoreUrl::parse(&base_conf.object_store_url)?, + true => ObjectStoreUrl::local_filesystem(), + }; + let store = ctx.runtime_env().object_store(object_store_url)?; + let metadata_cache = + ctx.runtime_env().cache_manager.get_file_metadata_cache(); + let reader_factory = + Arc::new(CachedParquetFileReaderFactory::new(store, metadata_cache)); - let mut source = - ParquetSource::new(table_schema).with_table_parquet_options(options); + let mut source = ParquetSource::new(table_schema) + .with_parquet_file_reader_factory(reader_factory) + .with_table_parquet_options(options); if let Some(predicate) = predicate { source = source.with_predicate(predicate); @@ -727,7 +872,8 @@ impl protobuf::PhysicalPlanNode { let base_config = parse_protobuf_file_scan_config( base_conf, ctx, - extension_codec, + codec, + proto_converter, Arc::new(source), )?; Ok(DataSourceExec::from_data_source(base_config)) @@ -743,7 +889,8 @@ impl protobuf::PhysicalPlanNode { &self, scan: &protobuf::AvroScanExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { #[cfg(feature = "avro")] { @@ -752,7 +899,8 @@ impl protobuf::PhysicalPlanNode { let conf = parse_protobuf_file_scan_config( scan.base_conf.as_ref().unwrap(), ctx, - extension_codec, + codec, + proto_converter, Arc::new(AvroSource::new(table_schema)), )?; Ok(DataSourceExec::from_data_source(conf)) @@ -767,7 +915,8 @@ impl protobuf::PhysicalPlanNode { scan: &protobuf::MemoryScanExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let partitions = scan .partitions @@ -797,7 +946,8 @@ impl protobuf::PhysicalPlanNode { &ordering.physical_sort_expr_nodes, ctx, &schema, - extension_codec, + codec, + proto_converter, )?; sort_information.extend(LexOrdering::new(sort_exprs)); } @@ -816,11 +966,13 @@ impl protobuf::PhysicalPlanNode { coalesce_batches: &protobuf::CoalesceBatchesExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&coalesce_batches.input, ctx, extension_codec)?; + into_physical_plan(&coalesce_batches.input, ctx, codec, proto_converter)?; Ok(Arc::new( + #[expect(deprecated)] CoalesceBatchesExec::new(input, coalesce_batches.target_batch_size as usize) .with_fetch(coalesce_batches.fetch.map(|f| f as usize)), )) @@ -831,10 +983,11 @@ impl protobuf::PhysicalPlanNode { merge: &protobuf::CoalescePartitionsExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&merge.input, ctx, extension_codec)?; + into_physical_plan(&merge.input, ctx, codec, proto_converter)?; Ok(Arc::new( CoalescePartitionsExec::new(input) .with_fetch(merge.fetch.map(|f| f as usize)), @@ -846,15 +999,17 @@ impl protobuf::PhysicalPlanNode { repart: &protobuf::RepartitionExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&repart.input, ctx, extension_codec)?; + into_physical_plan(&repart.input, ctx, codec, proto_converter)?; let partitioning = parse_protobuf_partitioning( repart.partitioning.as_ref(), ctx, input.schema().as_ref(), - extension_codec, + codec, + proto_converter, )?; Ok(Arc::new(RepartitionExec::try_new( input, @@ -867,10 +1022,11 @@ impl protobuf::PhysicalPlanNode { limit: &protobuf::GlobalLimitExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&limit.input, ctx, extension_codec)?; + into_physical_plan(&limit.input, ctx, codec, proto_converter)?; let fetch = if limit.fetch >= 0 { Some(limit.fetch as usize) } else { @@ -888,10 +1044,11 @@ impl protobuf::PhysicalPlanNode { limit: &protobuf::LocalLimitExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&limit.input, ctx, extension_codec)?; + into_physical_plan(&limit.input, ctx, codec, proto_converter)?; Ok(Arc::new(LocalLimitExec::new(input, limit.fetch as usize))) } @@ -900,10 +1057,11 @@ impl protobuf::PhysicalPlanNode { window_agg: &protobuf::WindowAggExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&window_agg.input, ctx, extension_codec)?; + into_physical_plan(&window_agg.input, ctx, codec, proto_converter)?; let input_schema = input.schema(); let physical_window_expr: Vec> = window_agg @@ -914,7 +1072,8 @@ impl protobuf::PhysicalPlanNode { window_expr, ctx, input_schema.as_ref(), - extension_codec, + codec, + proto_converter, ) }) .collect::, _>>()?; @@ -923,7 +1082,12 @@ impl protobuf::PhysicalPlanNode { .partition_keys .iter() .map(|expr| { - parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec) + proto_converter.proto_to_physical_expr( + expr, + ctx, + input.schema().as_ref(), + codec, + ) }) .collect::>>>()?; @@ -958,10 +1122,11 @@ impl protobuf::PhysicalPlanNode { hash_agg: &protobuf::AggregateExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&hash_agg.input, ctx, extension_codec)?; + into_physical_plan(&hash_agg.input, ctx, codec, proto_converter)?; let mode = protobuf::AggregateMode::try_from(hash_agg.mode).map_err(|_| { proto_error(format!( "Received a AggregateNode message with unknown AggregateMode {}", @@ -976,6 +1141,7 @@ impl protobuf::PhysicalPlanNode { protobuf::AggregateMode::SinglePartitioned => { AggregateMode::SinglePartitioned } + protobuf::AggregateMode::PartialReduce => AggregateMode::PartialReduce, }; let num_expr = hash_agg.group_expr.len(); @@ -985,7 +1151,8 @@ impl protobuf::PhysicalPlanNode { .iter() .zip(hash_agg.group_expr_name.iter()) .map(|(expr, name)| { - parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec) + proto_converter + .proto_to_physical_expr(expr, ctx, input.schema().as_ref(), codec) .map(|expr| (expr, name.to_string())) }) .collect::, _>>()?; @@ -995,7 +1162,8 @@ impl protobuf::PhysicalPlanNode { .iter() .zip(hash_agg.group_expr_name.iter()) .map(|(expr, name)| { - parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec) + proto_converter + .proto_to_physical_expr(expr, ctx, input.schema().as_ref(), codec) .map(|expr| (expr, name.to_string())) }) .collect::, _>>()?; @@ -1024,7 +1192,12 @@ impl protobuf::PhysicalPlanNode { expr.expr .as_ref() .map(|e| { - parse_physical_expr(e, ctx, &physical_schema, extension_codec) + proto_converter.proto_to_physical_expr( + e, + ctx, + &physical_schema, + codec, + ) }) .transpose() }) @@ -1045,11 +1218,11 @@ impl protobuf::PhysicalPlanNode { .expr .iter() .map(|e| { - parse_physical_expr( + proto_converter.proto_to_physical_expr( e, ctx, &physical_schema, - extension_codec, + codec, ) }) .collect::>>()?; @@ -1061,7 +1234,8 @@ impl protobuf::PhysicalPlanNode { e, ctx, &physical_schema, - extension_codec, + codec, + proto_converter, ) }) .collect::>()?; @@ -1071,11 +1245,11 @@ impl protobuf::PhysicalPlanNode { .map(|func| match func { AggregateFunction::UserDefinedAggrFunction(udaf_name) => { let agg_udf = match &agg_node.fun_definition { - Some(buf) => extension_codec - .try_decode_udaf(udaf_name, buf)?, + Some(buf) => { + codec.try_decode_udaf(udaf_name, buf)? + } None => ctx.udaf(udaf_name).or_else(|_| { - extension_codec - .try_decode_udaf(udaf_name, &[]) + codec.try_decode_udaf(udaf_name, &[]) })?, }; @@ -1102,11 +1276,6 @@ impl protobuf::PhysicalPlanNode { }) .collect::, _>>()?; - let limit = hash_agg - .limit - .as_ref() - .map(|lit_value| lit_value.limit as usize); - let agg = AggregateExec::try_new( agg_mode, PhysicalGroupBy::new(group_expr, null_expr, groups, has_grouping_set), @@ -1116,7 +1285,16 @@ impl protobuf::PhysicalPlanNode { physical_schema, )?; - let agg = agg.with_limit(limit); + let agg = if let Some(limit_proto) = &hash_agg.limit { + let limit = limit_proto.limit as usize; + let limit_options = match limit_proto.descending { + Some(descending) => LimitOptions::new_with_order(limit, descending), + None => LimitOptions::new(limit), + }; + agg.with_limit_options(Some(limit_options)) + } else { + agg + }; Ok(Arc::new(agg)) } @@ -1126,29 +1304,30 @@ impl protobuf::PhysicalPlanNode { hashjoin: &protobuf::HashJoinExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let left: Arc = - into_physical_plan(&hashjoin.left, ctx, extension_codec)?; + into_physical_plan(&hashjoin.left, ctx, codec, proto_converter)?; let right: Arc = - into_physical_plan(&hashjoin.right, ctx, extension_codec)?; + into_physical_plan(&hashjoin.right, ctx, codec, proto_converter)?; let left_schema = left.schema(); let right_schema = right.schema(); let on: Vec<(PhysicalExprRef, PhysicalExprRef)> = hashjoin .on .iter() .map(|col| { - let left = parse_physical_expr( + let left = proto_converter.proto_to_physical_expr( &col.left.clone().unwrap(), ctx, left_schema.as_ref(), - extension_codec, + codec, )?; - let right = parse_physical_expr( + let right = proto_converter.proto_to_physical_expr( &col.right.clone().unwrap(), ctx, right_schema.as_ref(), - extension_codec, + codec, )?; Ok((left, right)) }) @@ -1177,12 +1356,12 @@ impl protobuf::PhysicalPlanNode { .ok_or_else(|| proto_error("Missing JoinFilter schema"))? .try_into()?; - let expression = parse_physical_expr( + let expression = proto_converter.proto_to_physical_expr( f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, ctx, &schema, - extension_codec, + codec, )?; let column_indices = f.column_indices .iter() @@ -1236,6 +1415,7 @@ impl protobuf::PhysicalPlanNode { projection, partition_mode, null_equality.into(), + hashjoin.null_aware, )?)) } @@ -1244,27 +1424,28 @@ impl protobuf::PhysicalPlanNode { sym_join: &protobuf::SymmetricHashJoinExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let left = into_physical_plan(&sym_join.left, ctx, extension_codec)?; - let right = into_physical_plan(&sym_join.right, ctx, extension_codec)?; + let left = into_physical_plan(&sym_join.left, ctx, codec, proto_converter)?; + let right = into_physical_plan(&sym_join.right, ctx, codec, proto_converter)?; let left_schema = left.schema(); let right_schema = right.schema(); let on = sym_join .on .iter() .map(|col| { - let left = parse_physical_expr( + let left = proto_converter.proto_to_physical_expr( &col.left.clone().unwrap(), ctx, left_schema.as_ref(), - extension_codec, + codec, )?; - let right = parse_physical_expr( + let right = proto_converter.proto_to_physical_expr( &col.right.clone().unwrap(), ctx, right_schema.as_ref(), - extension_codec, + codec, )?; Ok((left, right)) }) @@ -1293,12 +1474,12 @@ impl protobuf::PhysicalPlanNode { .ok_or_else(|| proto_error("Missing JoinFilter schema"))? .try_into()?; - let expression = parse_physical_expr( + let expression = proto_converter.proto_to_physical_expr( f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, ctx, &schema, - extension_codec, + codec, )?; let column_indices = f.column_indices .iter() @@ -1324,7 +1505,8 @@ impl protobuf::PhysicalPlanNode { &sym_join.left_sort_exprs, ctx, &left_schema, - extension_codec, + codec, + proto_converter, )?; let left_sort_exprs = LexOrdering::new(left_sort_exprs); @@ -1332,7 +1514,8 @@ impl protobuf::PhysicalPlanNode { &sym_join.right_sort_exprs, ctx, &right_schema, - extension_codec, + codec, + proto_converter, )?; let right_sort_exprs = LexOrdering::new(right_sort_exprs); @@ -1372,11 +1555,12 @@ impl protobuf::PhysicalPlanNode { union: &protobuf::UnionExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let mut inputs: Vec> = vec![]; for input in &union.inputs { - inputs.push(input.try_into_physical_plan(ctx, extension_codec)?); + inputs.push(proto_converter.proto_to_execution_plan(ctx, codec, input)?); } UnionExec::try_new(inputs) } @@ -1386,11 +1570,12 @@ impl protobuf::PhysicalPlanNode { interleave: &protobuf::InterleaveExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let mut inputs: Vec> = vec![]; for input in &interleave.inputs { - inputs.push(input.try_into_physical_plan(ctx, extension_codec)?); + inputs.push(proto_converter.proto_to_execution_plan(ctx, codec, input)?); } Ok(Arc::new(InterleaveExec::try_new(inputs)?)) } @@ -1400,12 +1585,13 @@ impl protobuf::PhysicalPlanNode { crossjoin: &protobuf::CrossJoinExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let left: Arc = - into_physical_plan(&crossjoin.left, ctx, extension_codec)?; + into_physical_plan(&crossjoin.left, ctx, codec, proto_converter)?; let right: Arc = - into_physical_plan(&crossjoin.right, ctx, extension_codec)?; + into_physical_plan(&crossjoin.right, ctx, codec, proto_converter)?; Ok(Arc::new(CrossJoinExec::new(left, right))) } @@ -1414,7 +1600,8 @@ impl protobuf::PhysicalPlanNode { empty: &protobuf::EmptyExecNode, _ctx: &TaskContext, - _extension_codec: &dyn PhysicalExtensionCodec, + _codec: &dyn PhysicalExtensionCodec, + _proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let schema = Arc::new(convert_required!(empty.schema)?); Ok(Arc::new(EmptyExec::new(schema))) @@ -1425,7 +1612,7 @@ impl protobuf::PhysicalPlanNode { placeholder: &protobuf::PlaceholderRowExecNode, _ctx: &TaskContext, - _extension_codec: &dyn PhysicalExtensionCodec, + _codec: &dyn PhysicalExtensionCodec, ) -> Result> { let schema = Arc::new(convert_required!(placeholder.schema)?); Ok(Arc::new(PlaceholderRowExec::new(schema))) @@ -1436,9 +1623,10 @@ impl protobuf::PhysicalPlanNode { sort: &protobuf::SortExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&sort.input, ctx, extension_codec)?; + let input = into_physical_plan(&sort.input, ctx, codec, proto_converter)?; let exprs = sort .expr .iter() @@ -1459,7 +1647,7 @@ impl protobuf::PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec)?, + expr: proto_converter.proto_to_physical_expr(expr, ctx, input.schema().as_ref(), codec)?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -1488,9 +1676,10 @@ impl protobuf::PhysicalPlanNode { sort: &protobuf::SortPreservingMergeExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&sort.input, ctx, extension_codec)?; + let input = into_physical_plan(&sort.input, ctx, codec, proto_converter)?; let exprs = sort .expr .iter() @@ -1511,11 +1700,11 @@ impl protobuf::PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr( + expr: proto_converter.proto_to_physical_expr( expr, ctx, input.schema().as_ref(), - extension_codec, + codec, )?, options: SortOptions { descending: !sort_expr.asc, @@ -1541,16 +1730,16 @@ impl protobuf::PhysicalPlanNode { extension: &protobuf::PhysicalExtensionNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let inputs: Vec> = extension .inputs .iter() - .map(|i| i.try_into_physical_plan(ctx, extension_codec)) + .map(|i| proto_converter.proto_to_execution_plan(ctx, codec, i)) .collect::>()?; - let extension_node = - extension_codec.try_decode(extension.node.as_slice(), &inputs, ctx)?; + let extension_node = codec.try_decode(extension.node.as_slice(), &inputs, ctx)?; Ok(extension_node) } @@ -1560,12 +1749,13 @@ impl protobuf::PhysicalPlanNode { join: &protobuf::NestedLoopJoinExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let left: Arc = - into_physical_plan(&join.left, ctx, extension_codec)?; + into_physical_plan(&join.left, ctx, codec, proto_converter)?; let right: Arc = - into_physical_plan(&join.right, ctx, extension_codec)?; + into_physical_plan(&join.right, ctx, codec, proto_converter)?; let join_type = protobuf::JoinType::try_from(join.join_type).map_err(|_| { proto_error(format!( "Received a NestedLoopJoinExecNode message with unknown JoinType {}", @@ -1582,12 +1772,12 @@ impl protobuf::PhysicalPlanNode { .ok_or_else(|| proto_error("Missing JoinFilter schema"))? .try_into()?; - let expression = parse_physical_expr( + let expression = proto_converter.proto_to_physical_expr( f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, ctx, &schema, - extension_codec, + codec, )?; let column_indices = f.column_indices .iter() @@ -1634,10 +1824,11 @@ impl protobuf::PhysicalPlanNode { analyze: &protobuf::AnalyzeExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&analyze.input, ctx, extension_codec)?; + into_physical_plan(&analyze.input, ctx, codec, proto_converter)?; Ok(Arc::new(AnalyzeExec::new( analyze.verbose, analyze.show_statistics, @@ -1652,9 +1843,10 @@ impl protobuf::PhysicalPlanNode { sink: &protobuf::JsonSinkExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&sink.input, ctx, extension_codec)?; + let input = into_physical_plan(&sink.input, ctx, codec, proto_converter)?; let data_sink: JsonSink = sink .sink @@ -1670,7 +1862,8 @@ impl protobuf::PhysicalPlanNode { &collection.physical_sort_expr_nodes, ctx, &sink_schema, - extension_codec, + codec, + proto_converter, ) .map(|sort_exprs| { LexRequirement::new(sort_exprs.into_iter().map(Into::into)) @@ -1690,9 +1883,10 @@ impl protobuf::PhysicalPlanNode { sink: &protobuf::CsvSinkExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&sink.input, ctx, extension_codec)?; + let input = into_physical_plan(&sink.input, ctx, codec, proto_converter)?; let data_sink: CsvSink = sink .sink @@ -1708,7 +1902,8 @@ impl protobuf::PhysicalPlanNode { &collection.physical_sort_expr_nodes, ctx, &sink_schema, - extension_codec, + codec, + proto_converter, ) .map(|sort_exprs| { LexRequirement::new(sort_exprs.into_iter().map(Into::into)) @@ -1729,11 +1924,12 @@ impl protobuf::PhysicalPlanNode { sink: &protobuf::ParquetSinkExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { #[cfg(feature = "parquet")] { - let input = into_physical_plan(&sink.input, ctx, extension_codec)?; + let input = into_physical_plan(&sink.input, ctx, codec, proto_converter)?; let data_sink: ParquetSink = sink .sink @@ -1749,7 +1945,8 @@ impl protobuf::PhysicalPlanNode { &collection.physical_sort_expr_nodes, ctx, &sink_schema, - extension_codec, + codec, + proto_converter, ) .map(|sort_exprs| { LexRequirement::new(sort_exprs.into_iter().map(Into::into)) @@ -1772,9 +1969,10 @@ impl protobuf::PhysicalPlanNode { unnest: &protobuf::UnnestExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&unnest.input, ctx, extension_codec)?; + let input = into_physical_plan(&unnest.input, ctx, codec, proto_converter)?; Ok(Arc::new(UnnestExec::new( input, @@ -1803,11 +2001,12 @@ impl protobuf::PhysicalPlanNode { sort_join: &SortMergeJoinExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let left = into_physical_plan(&sort_join.left, ctx, extension_codec)?; + let left = into_physical_plan(&sort_join.left, ctx, codec, proto_converter)?; let left_schema = left.schema(); - let right = into_physical_plan(&sort_join.right, ctx, extension_codec)?; + let right = into_physical_plan(&sort_join.right, ctx, codec, proto_converter)?; let right_schema = right.schema(); let filter = sort_join @@ -1820,13 +2019,13 @@ impl protobuf::PhysicalPlanNode { .ok_or_else(|| proto_error("Missing JoinFilter schema"))? .try_into()?; - let expression = parse_physical_expr( + let expression = proto_converter.proto_to_physical_expr( f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, ctx, &schema, - extension_codec, + codec, )?; let column_indices = f .column_indices @@ -1883,17 +2082,17 @@ impl protobuf::PhysicalPlanNode { .on .iter() .map(|col| { - let left = parse_physical_expr( + let left = proto_converter.proto_to_physical_expr( &col.left.clone().unwrap(), ctx, left_schema.as_ref(), - extension_codec, + codec, )?; - let right = parse_physical_expr( + let right = proto_converter.proto_to_physical_expr( &col.right.clone().unwrap(), ctx, right_schema.as_ref(), - extension_codec, + codec, )?; Ok((left, right)) }) @@ -1980,9 +2179,10 @@ impl protobuf::PhysicalPlanNode { field_stream: &protobuf::CooperativeExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&field_stream.input, ctx, extension_codec)?; + let input = into_physical_plan(&field_stream.input, ctx, codec, proto_converter)?; Ok(Arc::new(CooperativeExec::new(input))) } @@ -1990,10 +2190,11 @@ impl protobuf::PhysicalPlanNode { &self, async_func: &protobuf::AsyncFuncExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&async_func.input, ctx, extension_codec)?; + into_physical_plan(&async_func.input, ctx, codec, proto_converter)?; if async_func.async_exprs.len() != async_func.async_expr_names.len() { return internal_err!( @@ -2006,11 +2207,11 @@ impl protobuf::PhysicalPlanNode { .iter() .zip(async_func.async_expr_names.iter()) .map(|(expr, name)| { - let physical_expr = parse_physical_expr( + let physical_expr = proto_converter.proto_to_physical_expr( expr, ctx, input.schema().as_ref(), - extension_codec, + codec, )?; Ok(Arc::new(AsyncFuncExpr::try_new( @@ -2024,9 +2225,22 @@ impl protobuf::PhysicalPlanNode { Ok(Arc::new(AsyncFuncExec::try_new(async_exprs, input)?)) } + fn try_into_buffer_physical_plan( + &self, + buffer: &protobuf::BufferExecNode, + ctx: &TaskContext, + extension_codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result> { + let input: Arc = + into_physical_plan(&buffer.input, ctx, extension_codec, proto_converter)?; + + Ok(Arc::new(BufferExec::new(input, buffer.capacity as usize))) + } + fn try_from_explain_exec( exec: &ExplainExec, - _extension_codec: &dyn PhysicalExtensionCodec, + _codec: &dyn PhysicalExtensionCodec, ) -> Result { Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Explain( @@ -2045,16 +2259,20 @@ impl protobuf::PhysicalPlanNode { fn try_from_projection_exec( exec: &ProjectionExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; let expr = exec .expr() .iter() - .map(|proj_expr| serialize_physical_expr(&proj_expr.expr, extension_codec)) + .map(|proj_expr| { + proto_converter.physical_expr_to_proto(&proj_expr.expr, codec) + }) .collect::>>()?; let expr_name = exec .expr() @@ -2074,11 +2292,13 @@ impl protobuf::PhysicalPlanNode { fn try_from_analyze_exec( exec: &AnalyzeExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Analyze(Box::new( @@ -2094,24 +2314,28 @@ impl protobuf::PhysicalPlanNode { fn try_from_filter_exec( exec: &FilterExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Filter(Box::new( protobuf::FilterExecNode { input: Some(Box::new(input)), - expr: Some(serialize_physical_expr( - exec.predicate(), - extension_codec, - )?), + expr: Some( + proto_converter + .physical_expr_to_proto(exec.predicate(), codec)?, + ), default_filter_selectivity: exec.default_selectivity() as u32, projection: exec.projection().as_ref().map_or_else(Vec::new, |v| { v.iter().map(|x| *x as u32).collect::>() }), + batch_size: exec.batch_size() as u32, + fetch: exec.fetch().map(|f| f as u32), }, ))), }) @@ -2119,11 +2343,13 @@ impl protobuf::PhysicalPlanNode { fn try_from_global_limit_exec( limit: &GlobalLimitExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( limit.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { @@ -2142,11 +2368,13 @@ impl protobuf::PhysicalPlanNode { fn try_from_local_limit_exec( limit: &LocalLimitExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( limit.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::LocalLimit(Box::new( @@ -2160,22 +2388,25 @@ impl protobuf::PhysicalPlanNode { fn try_from_hash_join_exec( exec: &HashJoinExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + let left = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.left().to_owned(), - extension_codec, + codec, + proto_converter, )?; - let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + let right = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.right().to_owned(), - extension_codec, + codec, + proto_converter, )?; let on: Vec = exec .on() .iter() .map(|tuple| { - let l = serialize_physical_expr(&tuple.0, extension_codec)?; - let r = serialize_physical_expr(&tuple.1, extension_codec)?; + let l = proto_converter.physical_expr_to_proto(&tuple.0, codec)?; + let r = proto_converter.physical_expr_to_proto(&tuple.1, codec)?; Ok::<_, DataFusionError>(protobuf::JoinOn { left: Some(l), right: Some(r), @@ -2189,7 +2420,7 @@ impl protobuf::PhysicalPlanNode { .as_ref() .map(|f| { let expression = - serialize_physical_expr(f.expression(), extension_codec)?; + proto_converter.physical_expr_to_proto(f.expression(), codec)?; let column_indices = f .column_indices() .iter() @@ -2229,6 +2460,7 @@ impl protobuf::PhysicalPlanNode { projection: exec.projection.as_ref().map_or_else(Vec::new, |v| { v.iter().map(|x| *x as u32).collect::>() }), + null_aware: exec.null_aware, }, ))), }) @@ -2236,22 +2468,25 @@ impl protobuf::PhysicalPlanNode { fn try_from_symmetric_hash_join_exec( exec: &SymmetricHashJoinExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + let left = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.left().to_owned(), - extension_codec, + codec, + proto_converter, )?; - let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + let right = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.right().to_owned(), - extension_codec, + codec, + proto_converter, )?; let on = exec .on() .iter() .map(|tuple| { - let l = serialize_physical_expr(&tuple.0, extension_codec)?; - let r = serialize_physical_expr(&tuple.1, extension_codec)?; + let l = proto_converter.physical_expr_to_proto(&tuple.0, codec)?; + let r = proto_converter.physical_expr_to_proto(&tuple.1, codec)?; Ok::<_, DataFusionError>(protobuf::JoinOn { left: Some(l), right: Some(r), @@ -2265,7 +2500,7 @@ impl protobuf::PhysicalPlanNode { .as_ref() .map(|f| { let expression = - serialize_physical_expr(f.expression(), extension_codec)?; + proto_converter.physical_expr_to_proto(f.expression(), codec)?; let column_indices = f .column_indices() .iter() @@ -2302,10 +2537,10 @@ impl protobuf::PhysicalPlanNode { .iter() .map(|expr| { Ok(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_physical_expr( - &expr.expr, - extension_codec, - )?)), + expr: Some(Box::new( + proto_converter + .physical_expr_to_proto(&expr.expr, codec)?, + )), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }) @@ -2322,10 +2557,10 @@ impl protobuf::PhysicalPlanNode { .iter() .map(|expr| { Ok(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_physical_expr( - &expr.expr, - extension_codec, - )?)), + expr: Some(Box::new( + proto_converter + .physical_expr_to_proto(&expr.expr, codec)?, + )), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }) @@ -2354,22 +2589,25 @@ impl protobuf::PhysicalPlanNode { fn try_from_sort_merge_join_exec( exec: &SortMergeJoinExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + let left = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.left().to_owned(), - extension_codec, + codec, + proto_converter, )?; - let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + let right = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.right().to_owned(), - extension_codec, + codec, + proto_converter, )?; let on = exec .on() .iter() .map(|tuple| { - let l = serialize_physical_expr(&tuple.0, extension_codec)?; - let r = serialize_physical_expr(&tuple.1, extension_codec)?; + let l = proto_converter.physical_expr_to_proto(&tuple.0, codec)?; + let r = proto_converter.physical_expr_to_proto(&tuple.1, codec)?; Ok::<_, DataFusionError>(protobuf::JoinOn { left: Some(l), right: Some(r), @@ -2383,7 +2621,7 @@ impl protobuf::PhysicalPlanNode { .as_ref() .map(|f| { let expression = - serialize_physical_expr(f.expression(), extension_codec)?; + proto_converter.physical_expr_to_proto(f.expression(), codec)?; let column_indices = f .column_indices() .iter() @@ -2423,7 +2661,7 @@ impl protobuf::PhysicalPlanNode { Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::SortMergeJoin(Box::new( - protobuf::SortMergeJoinExecNode { + SortMergeJoinExecNode { left: Some(Box::new(left)), right: Some(Box::new(right)), on, @@ -2438,15 +2676,18 @@ impl protobuf::PhysicalPlanNode { fn try_from_cross_join_exec( exec: &CrossJoinExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + let left = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.left().to_owned(), - extension_codec, + codec, + proto_converter, )?; - let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + let right = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.right().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::CrossJoin(Box::new( @@ -2460,7 +2701,8 @@ impl protobuf::PhysicalPlanNode { fn try_from_aggregate_exec( exec: &AggregateExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { let groups: Vec = exec .group_expr() @@ -2480,13 +2722,15 @@ impl protobuf::PhysicalPlanNode { let filter = exec .filter_expr() .iter() - .map(|expr| serialize_maybe_filter(expr.to_owned(), extension_codec)) + .map(|expr| serialize_maybe_filter(expr.to_owned(), codec, proto_converter)) .collect::>>()?; let agg = exec .aggr_expr() .iter() - .map(|expr| serialize_physical_aggr_expr(expr.to_owned(), extension_codec)) + .map(|expr| { + serialize_physical_aggr_expr(expr.to_owned(), codec, proto_converter) + }) .collect::>>()?; let agg_names = exec @@ -2503,29 +2747,32 @@ impl protobuf::PhysicalPlanNode { AggregateMode::SinglePartitioned => { protobuf::AggregateMode::SinglePartitioned } + AggregateMode::PartialReduce => protobuf::AggregateMode::PartialReduce, }; let input_schema = exec.input_schema(); - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; let null_expr = exec .group_expr() .null_expr() .iter() - .map(|expr| serialize_physical_expr(&expr.0, extension_codec)) + .map(|expr| proto_converter.physical_expr_to_proto(&expr.0, codec)) .collect::>>()?; let group_expr = exec .group_expr() .expr() .iter() - .map(|expr| serialize_physical_expr(&expr.0, extension_codec)) + .map(|expr| proto_converter.physical_expr_to_proto(&expr.0, codec)) .collect::>>()?; - let limit = exec.limit().map(|value| protobuf::AggLimit { - limit: value as u64, + let limit = exec.limit_options().map(|config| protobuf::AggLimit { + limit: config.limit() as u64, + descending: config.descending(), }); Ok(protobuf::PhysicalPlanNode { @@ -2550,7 +2797,7 @@ impl protobuf::PhysicalPlanNode { fn try_from_empty_exec( empty: &EmptyExec, - _extension_codec: &dyn PhysicalExtensionCodec, + _codec: &dyn PhysicalExtensionCodec, ) -> Result { let schema = empty.schema().as_ref().try_into()?; Ok(protobuf::PhysicalPlanNode { @@ -2562,7 +2809,7 @@ impl protobuf::PhysicalPlanNode { fn try_from_placeholder_row_exec( empty: &PlaceholderRowExec, - _extension_codec: &dyn PhysicalExtensionCodec, + _codec: &dyn PhysicalExtensionCodec, ) -> Result { let schema = empty.schema().as_ref().try_into()?; Ok(protobuf::PhysicalPlanNode { @@ -2574,13 +2821,16 @@ impl protobuf::PhysicalPlanNode { }) } + #[expect(deprecated)] fn try_from_coalesce_batches_exec( coalesce_batches: &CoalesceBatchesExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( coalesce_batches.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::CoalesceBatches(Box::new( @@ -2595,7 +2845,8 @@ impl protobuf::PhysicalPlanNode { fn try_from_data_source_exec( data_source_exec: &DataSourceExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let data_source = data_source_exec.data_source(); if let Some(maybe_csv) = data_source.as_any().downcast_ref::() { @@ -2606,7 +2857,8 @@ impl protobuf::PhysicalPlanNode { protobuf::CsvScanExecNode { base_conf: Some(serialize_file_scan_config( maybe_csv, - extension_codec, + codec, + proto_converter, )?), has_header: csv_config.has_header(), delimiter: byte_to_string( @@ -2647,7 +2899,25 @@ impl protobuf::PhysicalPlanNode { protobuf::JsonScanExecNode { base_conf: Some(serialize_file_scan_config( scan_conf, - extension_codec, + codec, + proto_converter, + )?), + }, + )), + })); + } + } + + if let Some(scan_conf) = data_source.as_any().downcast_ref::() { + let source = scan_conf.file_source(); + if let Some(_arrow_source) = source.as_any().downcast_ref::() { + return Ok(Some(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::ArrowScan( + protobuf::ArrowScanExecNode { + base_conf: Some(serialize_file_scan_config( + scan_conf, + codec, + proto_converter, )?), }, )), @@ -2661,14 +2931,15 @@ impl protobuf::PhysicalPlanNode { { let predicate = conf .filter() - .map(|pred| serialize_physical_expr(&pred, extension_codec)) + .map(|pred| proto_converter.physical_expr_to_proto(&pred, codec)) .transpose()?; return Ok(Some(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::ParquetScan( protobuf::ParquetScanExecNode { base_conf: Some(serialize_file_scan_config( maybe_parquet, - extension_codec, + codec, + proto_converter, )?), predicate, parquet_options: Some(conf.table_parquet_options().try_into()?), @@ -2686,7 +2957,8 @@ impl protobuf::PhysicalPlanNode { protobuf::AvroScanExecNode { base_conf: Some(serialize_file_scan_config( maybe_avro, - extension_codec, + codec, + proto_converter, )?), }, )), @@ -2719,7 +2991,8 @@ impl protobuf::PhysicalPlanNode { .map(|ordering| { let sort_exprs = serialize_physical_sort_exprs( ordering.to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok::<_, DataFusionError>(protobuf::PhysicalSortExprNodeCollection { physical_sort_expr_nodes: sort_exprs, @@ -2746,11 +3019,13 @@ impl protobuf::PhysicalPlanNode { fn try_from_coalesce_partitions_exec( exec: &CoalescePartitionsExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Merge(Box::new( @@ -2764,15 +3039,17 @@ impl protobuf::PhysicalPlanNode { fn try_from_repartition_exec( exec: &RepartitionExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; let pb_partitioning = - serialize_partitioning(exec.partitioning(), extension_codec)?; + serialize_partitioning(exec.partitioning(), codec, proto_converter)?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Repartition(Box::new( @@ -2786,25 +3063,23 @@ impl protobuf::PhysicalPlanNode { fn try_from_sort_exec( exec: &SortExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( - exec.input().to_owned(), - extension_codec, - )?; + let input = proto_converter.execution_plan_to_proto(exec.input(), codec)?; let expr = exec .expr() .iter() .map(|expr| { let sort_expr = Box::new(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_physical_expr( - &expr.expr, - extension_codec, - )?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(&expr.expr, codec)?, + )), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }); Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(ExprType::Sort(sort_expr)), }) }) @@ -2826,14 +3101,18 @@ impl protobuf::PhysicalPlanNode { fn try_from_union_exec( union: &UnionExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { let mut inputs: Vec = vec![]; for input in union.inputs() { - inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan( - input.to_owned(), - extension_codec, - )?); + inputs.push( + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + input.to_owned(), + codec, + proto_converter, + )?, + ); } Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Union(protobuf::UnionExecNode { @@ -2844,14 +3123,18 @@ impl protobuf::PhysicalPlanNode { fn try_from_interleave_exec( interleave: &InterleaveExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { let mut inputs: Vec = vec![]; for input in interleave.inputs() { - inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan( - input.to_owned(), - extension_codec, - )?); + inputs.push( + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + input.to_owned(), + codec, + proto_converter, + )?, + ); } Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Interleave( @@ -2862,25 +3145,27 @@ impl protobuf::PhysicalPlanNode { fn try_from_sort_preserving_merge_exec( exec: &SortPreservingMergeExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; let expr = exec .expr() .iter() .map(|expr| { let sort_expr = Box::new(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_physical_expr( - &expr.expr, - extension_codec, - )?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(&expr.expr, codec)?, + )), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }); Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(ExprType::Sort(sort_expr)), }) }) @@ -2898,15 +3183,18 @@ impl protobuf::PhysicalPlanNode { fn try_from_nested_loop_join_exec( exec: &NestedLoopJoinExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + let left = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.left().to_owned(), - extension_codec, + codec, + proto_converter, )?; - let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + let right = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.right().to_owned(), - extension_codec, + codec, + proto_converter, )?; let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); @@ -2915,7 +3203,7 @@ impl protobuf::PhysicalPlanNode { .as_ref() .map(|f| { let expression = - serialize_physical_expr(f.expression(), extension_codec)?; + proto_converter.physical_expr_to_proto(f.expression(), codec)?; let column_indices = f .column_indices() .iter() @@ -2943,7 +3231,7 @@ impl protobuf::PhysicalPlanNode { right: Some(Box::new(right)), join_type: join_type.into(), filter, - projection: exec.projection().map_or_else(Vec::new, |v| { + projection: exec.projection().as_ref().map_or_else(Vec::new, |v| { v.iter().map(|x| *x as u32).collect::>() }), }, @@ -2953,23 +3241,25 @@ impl protobuf::PhysicalPlanNode { fn try_from_window_agg_exec( exec: &WindowAggExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; let window_expr = exec .window_expr() .iter() - .map(|e| serialize_physical_window_expr(e, extension_codec)) + .map(|e| serialize_physical_window_expr(e, codec, proto_converter)) .collect::>>()?; let partition_keys = exec .partition_keys() .iter() - .map(|e| serialize_physical_expr(e, extension_codec)) + .map(|e| proto_converter.physical_expr_to_proto(e, codec)) .collect::>>()?; Ok(protobuf::PhysicalPlanNode { @@ -2986,23 +3276,25 @@ impl protobuf::PhysicalPlanNode { fn try_from_bounded_window_agg_exec( exec: &BoundedWindowAggExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; let window_expr = exec .window_expr() .iter() - .map(|e| serialize_physical_window_expr(e, extension_codec)) + .map(|e| serialize_physical_window_expr(e, codec, proto_converter)) .collect::>>()?; let partition_keys = exec .partition_keys() .iter() - .map(|e| serialize_physical_expr(e, extension_codec)) + .map(|e| proto_converter.physical_expr_to_proto(e, codec)) .collect::>>()?; let input_order_mode = match &exec.input_order_mode { @@ -3035,12 +3327,14 @@ impl protobuf::PhysicalPlanNode { fn try_from_data_sink_exec( exec: &DataSinkExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: protobuf::PhysicalPlanNode = - protobuf::PhysicalPlanNode::try_from_physical_plan( + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; let sort_order = match exec.sort_order() { Some(requirements) => { @@ -3049,10 +3343,10 @@ impl protobuf::PhysicalPlanNode { .map(|requirement| { let expr: PhysicalSortExpr = requirement.to_owned().into(); let sort_expr = protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_physical_expr( - &expr.expr, - extension_codec, - )?)), + expr: Some(Box::new( + proto_converter + .physical_expr_to_proto(&expr.expr, codec)?, + )), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }; @@ -3112,11 +3406,13 @@ impl protobuf::PhysicalPlanNode { fn try_from_unnest_exec( exec: &UnnestExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { @@ -3145,11 +3441,13 @@ impl protobuf::PhysicalPlanNode { fn try_from_cooperative_exec( exec: &CooperativeExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { @@ -3278,18 +3576,21 @@ impl protobuf::PhysicalPlanNode { fn try_from_async_func_exec( exec: &AsyncFuncExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( Arc::clone(exec.input()), - extension_codec, + codec, + proto_converter, )?; let mut async_exprs = vec![]; let mut async_expr_names = vec![]; for async_expr in exec.async_exprs() { - async_exprs.push(serialize_physical_expr(&async_expr.func, extension_codec)?); + async_exprs + .push(proto_converter.physical_expr_to_proto(&async_expr.func, codec)?); async_expr_names.push(async_expr.name.clone()) } @@ -3303,6 +3604,27 @@ impl protobuf::PhysicalPlanNode { ))), }) } + + fn try_from_buffer_exec( + exec: &BufferExec, + extension_codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(exec.input()), + extension_codec, + proto_converter, + )?; + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Buffer(Box::new( + protobuf::BufferExecNode { + input: Some(Box::new(input)), + capacity: exec.capacity() as u64, + }, + ))), + }) + } } pub trait AsExecutionPlan: Debug + Send + Sync + Clone { @@ -3319,12 +3641,12 @@ pub trait AsExecutionPlan: Debug + Send + Sync + Clone { &self, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, ) -> Result>; fn try_from_physical_plan( plan: Arc, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, ) -> Result where Self: Sized; @@ -3405,6 +3727,38 @@ impl PhysicalExtensionCodec for DefaultPhysicalExtensionCodec { } } +/// Controls the conversion of physical plans and expressions to and from their +/// Protobuf variants. Using this trait, users can perform optimizations on the +/// conversion process or collect performance metrics. +pub trait PhysicalProtoConverterExtension { + fn proto_to_execution_plan( + &self, + ctx: &TaskContext, + codec: &dyn PhysicalExtensionCodec, + proto: &protobuf::PhysicalPlanNode, + ) -> Result>; + + fn execution_plan_to_proto( + &self, + plan: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result; + + fn proto_to_physical_expr( + &self, + proto: &protobuf::PhysicalExprNode, + ctx: &TaskContext, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, + ) -> Result>; + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result; +} + /// DataEncoderTuple captures the position of the encoder /// in the codec list that was used to encode the data and actual encoded data #[derive(Clone, PartialEq, prost::Message)] @@ -3418,6 +3772,266 @@ struct DataEncoderTuple { pub blob: Vec, } +pub struct DefaultPhysicalProtoConverter; +impl PhysicalProtoConverterExtension for DefaultPhysicalProtoConverter { + fn proto_to_execution_plan( + &self, + ctx: &TaskContext, + codec: &dyn PhysicalExtensionCodec, + proto: &protobuf::PhysicalPlanNode, + ) -> Result> { + proto.try_into_physical_plan_with_converter(ctx, codec, self) + } + + fn execution_plan_to_proto( + &self, + plan: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result + where + Self: Sized, + { + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + codec, + self, + ) + } + + fn proto_to_physical_expr( + &self, + proto: &protobuf::PhysicalExprNode, + ctx: &TaskContext, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, + ) -> Result> + where + Self: Sized, + { + // Default implementation calls the free function + parse_physical_expr_with_converter(proto, ctx, input_schema, codec, self) + } + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result { + serialize_physical_expr_with_converter(expr, codec, self) + } +} + +/// Internal serializer that adds expr_id to expressions. +/// Created fresh for each serialization operation. +struct DeduplicatingSerializer { + /// Random salt combined with pointer addresses and process ID to create globally unique expr_ids. + session_id: u64, +} + +impl DeduplicatingSerializer { + fn new() -> Self { + Self { + session_id: rand::random(), + } + } +} + +impl PhysicalProtoConverterExtension for DeduplicatingSerializer { + fn proto_to_execution_plan( + &self, + _ctx: &TaskContext, + _codec: &dyn PhysicalExtensionCodec, + _proto: &protobuf::PhysicalPlanNode, + ) -> Result> { + internal_err!("DeduplicatingSerializer cannot deserialize execution plans") + } + + fn execution_plan_to_proto( + &self, + plan: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result + where + Self: Sized, + { + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + codec, + self, + ) + } + + fn proto_to_physical_expr( + &self, + _proto: &protobuf::PhysicalExprNode, + _ctx: &TaskContext, + _input_schema: &Schema, + _codec: &dyn PhysicalExtensionCodec, + ) -> Result> + where + Self: Sized, + { + internal_err!("DeduplicatingSerializer cannot deserialize physical expressions") + } + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let mut proto = serialize_physical_expr_with_converter(expr, codec, self)?; + + // Hash session_id, pointer address, and process ID together to create expr_id. + // - session_id: random per serializer, prevents collisions when merging serializations + // - ptr: unique address per Arc within a process + // - pid: prevents collisions if serializer is shared across processes + let mut hasher = DefaultHasher::new(); + self.session_id.hash(&mut hasher); + (Arc::as_ptr(expr) as *const () as u64).hash(&mut hasher); + std::process::id().hash(&mut hasher); + proto.expr_id = Some(hasher.finish()); + + Ok(proto) + } +} + +/// Internal deserializer that caches expressions by expr_id. +/// Created fresh for each deserialization operation. +#[derive(Default)] +struct DeduplicatingDeserializer { + /// Cache mapping expr_id to deserialized expressions. + cache: RefCell>>, +} + +impl PhysicalProtoConverterExtension for DeduplicatingDeserializer { + fn proto_to_execution_plan( + &self, + ctx: &TaskContext, + codec: &dyn PhysicalExtensionCodec, + proto: &protobuf::PhysicalPlanNode, + ) -> Result> { + proto.try_into_physical_plan_with_converter(ctx, codec, self) + } + + fn execution_plan_to_proto( + &self, + _plan: &Arc, + _codec: &dyn PhysicalExtensionCodec, + ) -> Result + where + Self: Sized, + { + internal_err!("DeduplicatingDeserializer cannot serialize execution plans") + } + + fn proto_to_physical_expr( + &self, + proto: &protobuf::PhysicalExprNode, + ctx: &TaskContext, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, + ) -> Result> + where + Self: Sized, + { + if let Some(expr_id) = proto.expr_id { + // Check cache first + if let Some(cached) = self.cache.borrow().get(&expr_id) { + return Ok(Arc::clone(cached)); + } + // Deserialize and cache + let expr = parse_physical_expr_with_converter( + proto, + ctx, + input_schema, + codec, + self, + )?; + self.cache.borrow_mut().insert(expr_id, Arc::clone(&expr)); + Ok(expr) + } else { + parse_physical_expr_with_converter(proto, ctx, input_schema, codec, self) + } + } + + fn physical_expr_to_proto( + &self, + _expr: &Arc, + _codec: &dyn PhysicalExtensionCodec, + ) -> Result { + internal_err!("DeduplicatingDeserializer cannot serialize physical expressions") + } +} + +/// A proto converter that adds expression deduplication during serialization +/// and deserialization. +/// +/// During serialization, each expression's Arc pointer address is XORed with a +/// random session_id to create a salted `expr_id`. This prevents cross-process +/// collisions when serialized plans are merged. +/// +/// During deserialization, expressions with the same `expr_id` share the same +/// Arc, reducing memory usage for plans with duplicate expressions (e.g., large +/// IN lists) and supporting correctly linking [`DynamicFilterPhysicalExpr`] instances. +/// +/// This converter is stateless - it creates internal serializers/deserializers +/// on demand for each operation. +/// +/// [`DynamicFilterPhysicalExpr`]: https://docs.rs/datafusion-physical-expr/latest/datafusion_physical_expr/expressions/struct.DynamicFilterPhysicalExpr.html +#[derive(Debug, Default, Clone, Copy)] +pub struct DeduplicatingProtoConverter {} + +impl PhysicalProtoConverterExtension for DeduplicatingProtoConverter { + fn proto_to_execution_plan( + &self, + ctx: &TaskContext, + codec: &dyn PhysicalExtensionCodec, + proto: &protobuf::PhysicalPlanNode, + ) -> Result> { + let deserializer = DeduplicatingDeserializer::default(); + proto.try_into_physical_plan_with_converter(ctx, codec, &deserializer) + } + + fn execution_plan_to_proto( + &self, + plan: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result + where + Self: Sized, + { + let serializer = DeduplicatingSerializer::new(); + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + codec, + &serializer, + ) + } + + fn proto_to_physical_expr( + &self, + proto: &protobuf::PhysicalExprNode, + ctx: &TaskContext, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, + ) -> Result> + where + Self: Sized, + { + let deserializer = DeduplicatingDeserializer::default(); + deserializer.proto_to_physical_expr(proto, ctx, input_schema, codec) + } + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let serializer = DeduplicatingSerializer::new(); + serializer.physical_expr_to_proto(expr, codec) + } +} + /// A PhysicalExtensionCodec that tries one of multiple inner codecs /// until one works #[derive(Debug)] @@ -3520,10 +4134,11 @@ impl PhysicalExtensionCodec for ComposedPhysicalExtensionCodec { fn into_physical_plan( node: &Option>, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { if let Some(field) = node { - field.try_into_physical_plan(ctx, extension_codec) + proto_converter.proto_to_execution_plan(ctx, codec, field) } else { Err(proto_error("Missing required field in protobuf")) } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 9558effb8a2a..de2f36e81e3b 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -24,8 +24,7 @@ use datafusion_common::{ DataFusionError, Result, internal_datafusion_err, internal_err, not_impl_err, }; use datafusion_datasource::file_scan_config::FileScanConfig; -use datafusion_datasource::file_sink_config::FileSink; -use datafusion_datasource::file_sink_config::FileSinkConfig; +use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig}; use datafusion_datasource::{FileRange, PartitionedFile}; use datafusion_datasource_csv::file_format::CsvSink; use datafusion_datasource_json::file_format::JsonSink; @@ -36,36 +35,43 @@ use datafusion_physical_expr::ScalarFunctionExpr; use datafusion_physical_expr::window::{SlidingAggregateWindowExpr, StandardWindowExpr}; use datafusion_physical_expr_common::physical_expr::snapshot_physical_expr; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; -use datafusion_physical_plan::expressions::LikeExpr; use datafusion_physical_plan::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, IsNullExpr, - Literal, NegativeExpr, NotExpr, TryCastExpr, UnKnownColumn, + LikeExpr, Literal, NegativeExpr, NotExpr, TryCastExpr, UnKnownColumn, }; use datafusion_physical_plan::joins::{HashExpr, HashTableLookupExpr}; use datafusion_physical_plan::udaf::AggregateFunctionExpr; use datafusion_physical_plan::windows::{PlainAggregateWindowExpr, WindowUDFExpr}; use datafusion_physical_plan::{Partitioning, PhysicalExpr, WindowExpr}; +use super::{ + DefaultPhysicalProtoConverter, PhysicalExtensionCodec, + PhysicalProtoConverterExtension, +}; use crate::protobuf::{ self, PhysicalSortExprNode, PhysicalSortExprNodeCollection, physical_aggregate_expr_node, physical_window_expr_node, }; -use super::PhysicalExtensionCodec; - #[expect(clippy::needless_pass_by_value)] pub fn serialize_physical_aggr_expr( aggr_expr: Arc, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let expressions = serialize_physical_exprs(&aggr_expr.expressions(), codec)?; - let order_bys = - serialize_physical_sort_exprs(aggr_expr.order_bys().iter().cloned(), codec)?; + let expressions = + serialize_physical_exprs(&aggr_expr.expressions(), codec, proto_converter)?; + let order_bys = serialize_physical_sort_exprs( + aggr_expr.order_bys().iter().cloned(), + codec, + proto_converter, + )?; let name = aggr_expr.fun().name().to_string(); let mut buf = Vec::new(); codec.try_encode_udaf(aggr_expr.fun(), &mut buf)?; Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( protobuf::PhysicalAggregateExprNode { aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(name)), @@ -100,9 +106,10 @@ fn serialize_physical_window_aggr_expr( pub fn serialize_physical_window_expr( window_expr: &Arc, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { let expr = window_expr.as_any(); - let args = window_expr.expressions().to_vec(); + let mut args = window_expr.expressions().to_vec(); let window_frame = window_expr.get_window_frame(); let (window_function, fun_definition, ignore_nulls, distinct) = @@ -138,6 +145,7 @@ pub fn serialize_physical_window_expr( { let mut buf = Vec::new(); codec.try_encode_udwf(expr.fun(), &mut buf)?; + args = expr.args().to_vec(); ( physical_window_expr_node::WindowFunction::UserDefinedWindowFunction( expr.fun().name().to_string(), @@ -155,9 +163,14 @@ pub fn serialize_physical_window_expr( return not_impl_err!("WindowExpr not supported: {window_expr:?}"); }; - let args = serialize_physical_exprs(&args, codec)?; - let partition_by = serialize_physical_exprs(window_expr.partition_by(), codec)?; - let order_by = serialize_physical_sort_exprs(window_expr.order_by().to_vec(), codec)?; + let args = serialize_physical_exprs(&args, codec, proto_converter)?; + let partition_by = + serialize_physical_exprs(window_expr.partition_by(), codec, proto_converter)?; + let order_by = serialize_physical_sort_exprs( + window_expr.order_by().to_vec(), + codec, + proto_converter, + )?; let window_frame: protobuf::WindowFrame = window_frame .as_ref() .try_into() @@ -179,22 +192,24 @@ pub fn serialize_physical_window_expr( pub fn serialize_physical_sort_exprs( sort_exprs: I, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> where I: IntoIterator, { sort_exprs .into_iter() - .map(|sort_expr| serialize_physical_sort_expr(sort_expr, codec)) + .map(|sort_expr| serialize_physical_sort_expr(sort_expr, codec, proto_converter)) .collect() } pub fn serialize_physical_sort_expr( sort_expr: PhysicalSortExpr, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { let PhysicalSortExpr { expr, options } = sort_expr; - let expr = serialize_physical_expr(&expr, codec)?; + let expr = proto_converter.physical_expr_to_proto(&expr, codec)?; Ok(PhysicalSortExprNode { expr: Some(Box::new(expr)), asc: !options.descending, @@ -205,13 +220,14 @@ pub fn serialize_physical_sort_expr( pub fn serialize_physical_exprs<'a, I>( values: I, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> where I: IntoIterator>, { values .into_iter() - .map(|value| serialize_physical_expr(value, codec)) + .map(|value| proto_converter.physical_expr_to_proto(value, codec)) .collect() } @@ -222,6 +238,24 @@ where pub fn serialize_physical_expr( value: &Arc, codec: &dyn PhysicalExtensionCodec, +) -> Result { + serialize_physical_expr_with_converter( + value, + codec, + &DefaultPhysicalProtoConverter {}, + ) +} + +/// Serialize a `PhysicalExpr` to default protobuf representation. +/// +/// If required, a [`PhysicalExtensionCodec`] can be provided which can handle +/// serialization of udfs requiring specialized serialization (see [`PhysicalExtensionCodec::try_encode_udf`]). +/// A [`PhysicalProtoConverterExtension`] can be provided to handle the +/// conversion process (see [`PhysicalProtoConverterExtension::physical_expr_to_proto`]). +pub fn serialize_physical_expr_with_converter( + value: &Arc, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { // Snapshot the expr in case it has dynamic predicate state so // it can be serialized @@ -248,12 +282,14 @@ pub fn serialize_physical_expr( )), }; return Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::Literal(value)), }); } if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::Column( protobuf::PhysicalColumn { name: expr.name().to_string(), @@ -263,6 +299,7 @@ pub fn serialize_physical_expr( }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::UnknownColumn( protobuf::UnknownColumn { name: expr.name().to_string(), @@ -271,18 +308,24 @@ pub fn serialize_physical_expr( }) } else if let Some(expr) = expr.downcast_ref::() { let binary_expr = Box::new(protobuf::PhysicalBinaryExprNode { - l: Some(Box::new(serialize_physical_expr(expr.left(), codec)?)), - r: Some(Box::new(serialize_physical_expr(expr.right(), codec)?)), + l: Some(Box::new( + proto_converter.physical_expr_to_proto(expr.left(), codec)?, + )), + r: Some(Box::new( + proto_converter.physical_expr_to_proto(expr.right(), codec)?, + )), op: format!("{:?}", expr.op()), }); Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::BinaryExpr( binary_expr, )), }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some( protobuf::physical_expr_node::ExprType::Case( Box::new( @@ -290,14 +333,21 @@ pub fn serialize_physical_expr( expr: expr .expr() .map(|exp| { - serialize_physical_expr(exp, codec).map(Box::new) + proto_converter + .physical_expr_to_proto(exp, codec) + .map(Box::new) }) .transpose()?, when_then_expr: expr .when_then_expr() .iter() .map(|(when_expr, then_expr)| { - serialize_when_then_expr(when_expr, then_expr, codec) + serialize_when_then_expr( + when_expr, + then_expr, + codec, + proto_converter, + ) }) .collect::, @@ -305,7 +355,11 @@ pub fn serialize_physical_expr( >>()?, else_expr: expr .else_expr() - .map(|a| serialize_physical_expr(a, codec).map(Box::new)) + .map(|a| { + proto_converter + .physical_expr_to_proto(a, codec) + .map(Box::new) + }) .transpose()?, }, ), @@ -314,66 +368,88 @@ pub fn serialize_physical_expr( }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::NotExpr(Box::new( protobuf::PhysicalNot { - expr: Some(Box::new(serialize_physical_expr(expr.arg(), codec)?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(expr.arg(), codec)?, + )), }, ))), }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::IsNullExpr( Box::new(protobuf::PhysicalIsNull { - expr: Some(Box::new(serialize_physical_expr(expr.arg(), codec)?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(expr.arg(), codec)?, + )), }), )), }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::IsNotNullExpr( Box::new(protobuf::PhysicalIsNotNull { - expr: Some(Box::new(serialize_physical_expr(expr.arg(), codec)?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(expr.arg(), codec)?, + )), }), )), }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::InList(Box::new( protobuf::PhysicalInListNode { - expr: Some(Box::new(serialize_physical_expr(expr.expr(), codec)?)), - list: serialize_physical_exprs(expr.list(), codec)?, + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(expr.expr(), codec)?, + )), + list: serialize_physical_exprs(expr.list(), codec, proto_converter)?, negated: expr.negated(), }, ))), }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::Negative(Box::new( protobuf::PhysicalNegativeNode { - expr: Some(Box::new(serialize_physical_expr(expr.arg(), codec)?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(expr.arg(), codec)?, + )), }, ))), }) } else if let Some(lit) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::Literal( lit.value().try_into()?, )), }) } else if let Some(cast) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::Cast(Box::new( protobuf::PhysicalCastNode { - expr: Some(Box::new(serialize_physical_expr(cast.expr(), codec)?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(cast.expr(), codec)?, + )), arrow_type: Some(cast.cast_type().try_into()?), }, ))), }) } else if let Some(cast) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast(Box::new( protobuf::PhysicalTryCastNode { - expr: Some(Box::new(serialize_physical_expr(cast.expr(), codec)?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(cast.expr(), codec)?, + )), arrow_type: Some(cast.cast_type().try_into()?), }, ))), @@ -382,10 +458,11 @@ pub fn serialize_physical_expr( let mut buf = Vec::new(); codec.try_encode_udf(expr.fun(), &mut buf)?; Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarUdf( protobuf::PhysicalScalarUdfNode { name: expr.name().to_string(), - args: serialize_physical_exprs(expr.args(), codec)?, + args: serialize_physical_exprs(expr.args(), codec, proto_converter)?, fun_definition: (!buf.is_empty()).then_some(buf), return_type: Some(expr.return_type().try_into()?), nullable: expr.nullable(), @@ -398,24 +475,31 @@ pub fn serialize_physical_expr( }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::LikeExpr(Box::new( protobuf::PhysicalLikeExprNode { negated: expr.negated(), case_insensitive: expr.case_insensitive(), - expr: Some(Box::new(serialize_physical_expr(expr.expr(), codec)?)), - pattern: Some(Box::new(serialize_physical_expr( - expr.pattern(), - codec, - )?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(expr.expr(), codec)?, + )), + pattern: Some(Box::new( + proto_converter.physical_expr_to_proto(expr.pattern(), codec)?, + )), }, ))), }) } else if let Some(expr) = expr.downcast_ref::() { let (s0, s1, s2, s3) = expr.seeds(); Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::HashExpr( protobuf::PhysicalHashExprNode { - on_columns: serialize_physical_exprs(expr.on_columns(), codec)?, + on_columns: serialize_physical_exprs( + expr.on_columns(), + codec, + proto_converter, + )?, seed0: s0, seed1: s1, seed2: s2, @@ -431,9 +515,10 @@ pub fn serialize_physical_expr( let inputs: Vec = value .children() .into_iter() - .map(|e| serialize_physical_expr(e, codec)) + .map(|e| proto_converter.physical_expr_to_proto(e, codec)) .collect::>()?; Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::Extension( protobuf::PhysicalExtensionExprNode { expr: buf, inputs }, )), @@ -449,6 +534,7 @@ pub fn serialize_physical_expr( pub fn serialize_partitioning( partitioning: &Partitioning, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { let serialized_partitioning = match partitioning { Partitioning::RoundRobinBatch(partition_count) => protobuf::Partitioning { @@ -457,7 +543,8 @@ pub fn serialize_partitioning( )), }, Partitioning::Hash(exprs, partition_count) => { - let serialized_exprs = serialize_physical_exprs(exprs, codec)?; + let serialized_exprs = + serialize_physical_exprs(exprs, codec, proto_converter)?; protobuf::Partitioning { partition_method: Some(protobuf::partitioning::PartitionMethod::Hash( protobuf::PhysicalHashRepartition { @@ -480,10 +567,11 @@ fn serialize_when_then_expr( when_expr: &Arc, then_expr: &Arc, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { Ok(protobuf::PhysicalWhenThen { - when_expr: Some(serialize_physical_expr(when_expr, codec)?), - then_expr: Some(serialize_physical_expr(then_expr, codec)?), + when_expr: Some(proto_converter.physical_expr_to_proto(when_expr, codec)?), + then_expr: Some(proto_converter.physical_expr_to_proto(then_expr, codec)?), }) } @@ -539,6 +627,7 @@ impl TryFrom<&[PartitionedFile]> for protobuf::FileGroup { pub fn serialize_file_scan_config( conf: &FileScanConfig, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { let file_groups = conf .file_groups @@ -548,7 +637,8 @@ pub fn serialize_file_scan_config( let mut output_orderings = vec![]; for order in &conf.output_ordering { - let ordering = serialize_physical_sort_exprs(order.to_vec(), codec)?; + let ordering = + serialize_physical_sort_exprs(order.to_vec(), codec, proto_converter)?; output_orderings.push(ordering) } @@ -563,8 +653,7 @@ pub fn serialize_file_scan_config( fields.extend(conf.table_partition_cols().iter().cloned()); let schema = Arc::new( - arrow::datatypes::Schema::new(fields.clone()) - .with_metadata(conf.file_schema().metadata.clone()), + Schema::new(fields.clone()).with_metadata(conf.file_schema().metadata.clone()), ); let projection_exprs = conf @@ -579,7 +668,10 @@ pub fn serialize_file_scan_config( .map(|expr| { Ok(protobuf::ProjectionExpr { alias: expr.alias.to_string(), - expr: Some(serialize_physical_expr(&expr.expr, codec)?), + expr: Some( + proto_converter + .physical_expr_to_proto(&expr.expr, codec)?, + ), }) }) .collect::>>()?, @@ -614,11 +706,12 @@ pub fn serialize_file_scan_config( pub fn serialize_maybe_filter( expr: Option>, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { match expr { None => Ok(protobuf::MaybeFilter { expr: None }), Some(expr) => Ok(protobuf::MaybeFilter { - expr: Some(serialize_physical_expr(&expr, codec)?), + expr: Some(proto_converter.physical_expr_to_proto(&expr, codec)?), }), } } @@ -695,6 +788,17 @@ impl TryFrom<&FileSinkConfig> for protobuf::FileSinkConfig { }) }) .collect::>>()?; + let file_output_mode = match conf.file_output_mode { + datafusion_datasource::file_sink_config::FileOutputMode::Automatic => { + protobuf::FileOutputMode::Automatic + } + datafusion_datasource::file_sink_config::FileOutputMode::SingleFile => { + protobuf::FileOutputMode::SingleFile + } + datafusion_datasource::file_sink_config::FileOutputMode::Directory => { + protobuf::FileOutputMode::Directory + } + }; Ok(Self { object_store_url: conf.object_store_url.to_string(), file_groups, @@ -704,6 +808,7 @@ impl TryFrom<&FileSinkConfig> for protobuf::FileSinkConfig { keep_partition_by_columns: conf.keep_partition_by_columns, insert_op: conf.insert_op as i32, file_extension: conf.file_extension.to_string(), + file_output_mode: file_output_mode.into(), }) } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index bcfda648b53e..9407cbf9a074 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -28,7 +28,7 @@ use datafusion::datasource::file_format::json::{JsonFormat, JsonFormatFactory}; use datafusion::datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, }; -use datafusion::execution::options::ArrowReadOptions; +use datafusion::execution::options::{ArrowReadOptions, JsonReadOptions}; use datafusion::optimizer::Optimizer; use datafusion::optimizer::optimize_unions::OptimizeUnions; use datafusion_common::parquet_config::DFParquetWriterVersion; @@ -413,6 +413,7 @@ async fn roundtrip_logical_plan_dml() -> Result<()> { "DELETE FROM T1", "UPDATE T1 SET a = 1", "CREATE TABLE T2 AS SELECT * FROM T1", + "TRUNCATE TABLE T1", ]; for query in queries { let plan = ctx.sql(query).await?.into_optimized_plan()?; @@ -754,7 +755,7 @@ async fn create_json_scan(ctx: &SessionContext) -> Result) -> Result<()> { let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; - roundtrip_test_and_return(exec_plan, &ctx, &codec)?; + let proto_converter = DefaultPhysicalProtoConverter {}; + roundtrip_test_and_return(exec_plan, &ctx, &codec, &proto_converter)?; Ok(()) } @@ -142,13 +151,19 @@ fn roundtrip_test_and_return( exec_plan: Arc, ctx: &SessionContext, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let proto: protobuf::PhysicalPlanNode = - protobuf::PhysicalPlanNode::try_from_physical_plan(exec_plan.clone(), codec) - .expect("to proto"); - let result_exec_plan: Arc = proto - .try_into_physical_plan(&ctx.task_ctx(), codec) - .expect("from proto"); + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&exec_plan), + codec, + proto_converter, + )?; + let result_exec_plan = physical_plan_from_bytes_with_proto_converter( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + codec, + proto_converter, + )?; pretty_assertions::assert_eq!( format!("{exec_plan:?}"), @@ -168,7 +183,8 @@ fn roundtrip_test_with_context( ctx: &SessionContext, ) -> Result<()> { let codec = DefaultPhysicalExtensionCodec {}; - roundtrip_test_and_return(exec_plan, ctx, &codec)?; + let proto_converter = DefaultPhysicalProtoConverter {}; + roundtrip_test_and_return(exec_plan, ctx, &codec, &proto_converter)?; Ok(()) } @@ -176,9 +192,10 @@ fn roundtrip_test_with_context( /// query results are identical. async fn roundtrip_test_sql_with_context(sql: &str, ctx: &SessionContext) -> Result<()> { let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DefaultPhysicalProtoConverter {}; let initial_plan = ctx.sql(sql).await?.create_physical_plan().await?; - roundtrip_test_and_return(initial_plan, ctx, &codec)?; + roundtrip_test_and_return(initial_plan, ctx, &codec, &proto_converter)?; Ok(()) } @@ -285,6 +302,7 @@ fn roundtrip_hash_join() -> Result<()> { None, *partition_mode, NullEquality::NullEqualsNothing, + false, )?))?; } } @@ -615,7 +633,7 @@ fn roundtrip_aggregate_with_limit() -> Result<()> { Arc::new(EmptyExec::new(schema.clone())), schema, )?; - let agg = agg.with_limit(Some(12)); + let agg = agg.with_limit_options(Some(LimitOptions::new_with_order(12, false))); roundtrip_test(Arc::new(agg)) } @@ -777,6 +795,19 @@ fn roundtrip_filter_with_not_and_in_list() -> Result<()> { )?)) } +#[test] +fn roundtrip_filter_with_fetch() -> Result<()> { + let field_a = Field::new("a", DataType::Boolean, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let predicate = col("a", &schema)?; + let filter = FilterExecBuilder::new(predicate, Arc::new(EmptyExec::new(schema))) + .with_fetch(Some(10)) + .build()?; + assert_eq!(filter.fetch(), Some(10)); + roundtrip_test(Arc::new(filter)) +} + #[test] fn roundtrip_sort() -> Result<()> { let field_a = Field::new("a", DataType::Boolean, false); @@ -845,11 +876,13 @@ fn roundtrip_coalesce_batches_with_fetch() -> Result<()> { let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); + #[expect(deprecated)] roundtrip_test(Arc::new(CoalesceBatchesExec::new( Arc::new(EmptyExec::new(schema.clone())), 8096, )))?; + #[expect(deprecated)] roundtrip_test(Arc::new( CoalesceBatchesExec::new(Arc::new(EmptyExec::new(schema)), 8096) .with_fetch(Some(10)), @@ -910,6 +943,83 @@ fn roundtrip_parquet_exec_with_pruning_predicate() -> Result<()> { roundtrip_test(DataSourceExec::from_data_source(scan_config)) } +#[test] +fn roundtrip_parquet_exec_attaches_cached_reader_factory_after_roundtrip() -> Result<()> { + let file_schema = + Arc::new(Schema::new(vec![Field::new("col", DataType::Utf8, false)])); + let file_source = Arc::new(ParquetSource::new(Arc::clone(&file_schema))); + let scan_config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), file_source) + .with_file_groups(vec![FileGroup::new(vec![PartitionedFile::new( + "/path/to/file.parquet".to_string(), + 1024, + )])]) + .with_statistics(Statistics { + num_rows: Precision::Inexact(100), + total_byte_size: Precision::Inexact(1024), + column_statistics: Statistics::unknown_column(&file_schema), + }) + .build(); + let exec_plan = DataSourceExec::from_data_source(scan_config); + + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DefaultPhysicalProtoConverter {}; + let roundtripped = + roundtrip_test_and_return(exec_plan, &ctx, &codec, &proto_converter)?; + + let data_source = roundtripped + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!("Expected DataSourceExec after roundtrip") + })?; + let file_scan = data_source + .data_source() + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!("Expected FileScanConfig after roundtrip") + })?; + let parquet_source = file_scan + .file_source() + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!("Expected ParquetSource after roundtrip") + })?; + + assert!( + parquet_source.parquet_file_reader_factory().is_some(), + "Parquet reader factory should be attached after decoding from protobuf" + ); + Ok(()) +} + +#[test] +fn roundtrip_arrow_scan() -> Result<()> { + let file_schema = + Arc::new(Schema::new(vec![Field::new("col", DataType::Utf8, false)])); + + let table_schema = TableSchema::new(file_schema.clone(), vec![]); + let file_source = Arc::new(ArrowSource::new_file_source(table_schema)); + + let scan_config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), file_source) + .with_file_groups(vec![FileGroup::new(vec![PartitionedFile::new( + "/path/to/file.arrow".to_string(), + 1024, + )])]) + .with_statistics(Statistics { + num_rows: Precision::Inexact(100), + total_byte_size: Precision::Inexact(1024), + column_statistics: Statistics::unknown_column(&file_schema), + }) + .build(); + + roundtrip_test(DataSourceExec::from_data_source(scan_config)) +} + #[tokio::test] async fn roundtrip_parquet_exec_with_table_partition_cols() -> Result<()> { let mut file_group = @@ -985,7 +1095,7 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { } impl Display for CustomPredicateExpr { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "CustomPredicateExpr") } } @@ -1078,7 +1188,12 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { let exec_plan = DataSourceExec::from_data_source(scan_config); let ctx = SessionContext::new(); - roundtrip_test_and_return(exec_plan, &ctx, &CustomPhysicalExtensionCodec {})?; + roundtrip_test_and_return( + exec_plan, + &ctx, + &CustomPhysicalExtensionCodec {}, + &DefaultPhysicalProtoConverter {}, + )?; Ok(()) } @@ -1284,7 +1399,8 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { )?); let ctx = SessionContext::new(); - roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec)?; + let proto_converter = DefaultPhysicalProtoConverter {}; + roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec, &proto_converter)?; Ok(()) } @@ -1331,7 +1447,8 @@ fn roundtrip_udwf_extension_codec() -> Result<()> { )?); let ctx = SessionContext::new(); - roundtrip_test_and_return(window, &ctx, &UDFExtensionCodec)?; + let proto_converter = DefaultPhysicalProtoConverter {}; + roundtrip_test_and_return(window, &ctx, &UDFExtensionCodec, &proto_converter)?; Ok(()) } @@ -1402,7 +1519,8 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { )?); let ctx = SessionContext::new(); - roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec)?; + let proto_converter = DefaultPhysicalProtoConverter {}; + roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec, &proto_converter)?; Ok(()) } @@ -1472,6 +1590,7 @@ fn roundtrip_json_sink() -> Result<()> { insert_op: InsertOp::Overwrite, keep_partition_by_columns: true, file_extension: "json".into(), + file_output_mode: FileOutputMode::SingleFile, }; let data_sink = Arc::new(JsonSink::new( file_sink_config, @@ -1510,6 +1629,7 @@ fn roundtrip_csv_sink() -> Result<()> { insert_op: InsertOp::Overwrite, keep_partition_by_columns: true, file_extension: "csv".into(), + file_output_mode: FileOutputMode::Directory, }; let data_sink = Arc::new(CsvSink::new( file_sink_config, @@ -1526,12 +1646,14 @@ fn roundtrip_csv_sink() -> Result<()> { let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DefaultPhysicalProtoConverter {}; + let roundtrip_plan = roundtrip_test_and_return( Arc::new(DataSinkExec::new(input, data_sink, Some(sort_order))), &ctx, &codec, - ) - .unwrap(); + &proto_converter, + )?; let roundtrip_plan = roundtrip_plan .as_any() @@ -1567,6 +1689,7 @@ fn roundtrip_parquet_sink() -> Result<()> { insert_op: InsertOp::Overwrite, keep_partition_by_columns: true, file_extension: "parquet".into(), + file_output_mode: FileOutputMode::Automatic, }; let data_sink = Arc::new(ParquetSink::new( file_sink_config, @@ -1818,11 +1941,12 @@ async fn roundtrip_projection_source() -> Result<()> { .build(); let filter = Arc::new( - FilterExec::try_new( + FilterExecBuilder::new( Arc::new(BinaryExpr::new(col("c", &schema)?, Operator::Eq, lit(1))), DataSourceExec::from_data_source(scan_config), - )? - .with_projection(Some(vec![0, 1]))?, + ) + .apply_projection(Some(vec![0, 1]))? + .build()?, ); roundtrip_test(filter) @@ -1972,6 +2096,7 @@ async fn test_serialize_deserialize_tpch_queries() -> Result<()> { // serialize the physical plan let codec = DefaultPhysicalExtensionCodec {}; + let proto = PhysicalPlanNode::try_from_physical_plan(physical_plan.clone(), &codec)?; @@ -2093,6 +2218,7 @@ async fn test_tpch_part_in_list_query_with_real_parquet_data() -> Result<()> { // Serialize the physical plan - bug may happen here already but not necessarily manifests let codec = DefaultPhysicalExtensionCodec {}; + let proto = PhysicalPlanNode::try_from_physical_plan(physical_plan.clone(), &codec)?; // This will fail with the bug, but should succeed when fixed @@ -2334,15 +2460,19 @@ async fn roundtrip_async_func_exec() -> Result<()> { /// it's a performance optimization filter, not a correctness requirement. #[test] fn roundtrip_hash_table_lookup_expr_to_lit() -> Result<()> { + use datafusion::physical_plan::joins::join_hash_map::JoinHashMapU32; + use datafusion::physical_plan::joins::{HashTableLookupExpr, Map}; + // Create a simple schema and input plan let schema = Arc::new(Schema::new(vec![Field::new("col", DataType::Int64, false)])); let input = Arc::new(EmptyExec::new(schema.clone())); // Create a HashTableLookupExpr - it will be replaced with lit(true) during serialization - let hash_map = Arc::new(JoinHashMapU32::with_capacity(0)); - let hash_expr: Arc = Arc::new(Column::new("col", 0)); + let hash_map = Arc::new(Map::HashMap(Box::new(JoinHashMapU32::with_capacity(0)))); + let on_columns = vec![datafusion::physical_plan::expressions::col("col", &schema)?]; let lookup_expr: Arc = Arc::new(HashTableLookupExpr::new( - hash_expr, + on_columns, + datafusion::physical_plan::joins::SeededRandomState::with_seeds(0, 0, 0, 0), hash_map, "test_lookup".to_string(), )); @@ -2353,8 +2483,9 @@ fn roundtrip_hash_table_lookup_expr_to_lit() -> Result<()> { // Serialize let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; - let proto: protobuf::PhysicalPlanNode = - protobuf::PhysicalPlanNode::try_from_physical_plan(filter.clone(), &codec) + + let proto: PhysicalPlanNode = + PhysicalPlanNode::try_from_physical_plan(filter.clone(), &codec) .expect("serialization should succeed"); // Deserialize @@ -2404,3 +2535,635 @@ fn roundtrip_hash_expr() -> Result<()> { ); roundtrip_test(filter) } + +#[test] +fn custom_proto_converter_intercepts() -> Result<()> { + #[derive(Default)] + struct CustomConverterInterceptor { + num_proto_plans: RwLock, + num_physical_plans: RwLock, + num_proto_exprs: RwLock, + num_physical_exprs: RwLock, + } + + impl PhysicalProtoConverterExtension for CustomConverterInterceptor { + fn proto_to_execution_plan( + &self, + ctx: &TaskContext, + codec: &dyn PhysicalExtensionCodec, + proto: &protobuf::PhysicalPlanNode, + ) -> Result> { + { + let mut counter = self + .num_proto_plans + .write() + .map_err(|err| exec_datafusion_err!("{err}"))?; + *counter += 1; + } + proto.try_into_physical_plan_with_converter(ctx, codec, self) + } + + fn execution_plan_to_proto( + &self, + plan: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result + where + Self: Sized, + { + { + let mut counter = self + .num_physical_plans + .write() + .map_err(|err| exec_datafusion_err!("{err}"))?; + *counter += 1; + } + PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + codec, + self, + ) + } + + fn proto_to_physical_expr( + &self, + proto: &PhysicalExprNode, + ctx: &TaskContext, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, + ) -> Result> + where + Self: Sized, + { + { + let mut counter = self + .num_proto_exprs + .write() + .map_err(|err| exec_datafusion_err!("{err}"))?; + *counter += 1; + } + parse_physical_expr_with_converter(proto, ctx, input_schema, codec, self) + } + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result { + { + let mut counter = self + .num_physical_exprs + .write() + .map_err(|err| exec_datafusion_err!("{err}"))?; + *counter += 1; + } + serialize_physical_expr_with_converter(expr, codec, self) + } + } + + let field_a = Field::new("a", DataType::Boolean, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let sort_exprs = [ + PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions { + descending: true, + nulls_first: false, + }, + }, + PhysicalSortExpr { + expr: col("b", &schema)?, + options: SortOptions { + descending: false, + nulls_first: true, + }, + }, + ] + .into(); + + let exec_plan = Arc::new(SortExec::new(sort_exprs, Arc::new(EmptyExec::new(schema)))); + + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = CustomConverterInterceptor::default(); + roundtrip_test_and_return(exec_plan, &ctx, &codec, &proto_converter)?; + + assert_eq!(*proto_converter.num_proto_exprs.read().unwrap(), 2); + assert_eq!(*proto_converter.num_physical_exprs.read().unwrap(), 2); + assert_eq!(*proto_converter.num_proto_plans.read().unwrap(), 2); + assert_eq!(*proto_converter.num_physical_plans.read().unwrap(), 2); + + Ok(()) +} + +#[test] +fn roundtrip_call_null_scalar_struct_dict() -> Result<()> { + let data_type = DataType::Struct(Fields::from(vec![Field::new( + "item", + DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)), + true, + )])); + + let schema = Arc::new(Schema::new(vec![Field::new("a", data_type.clone(), true)])); + let scan = Arc::new(EmptyExec::new(Arc::clone(&schema))); + let scalar = lit(ScalarValue::try_from(data_type)?); + let filter = Arc::new(FilterExec::try_new( + Arc::new(BinaryExpr::new(scalar, Operator::Eq, col("a", &schema)?)), + scan, + )?); + + roundtrip_test(filter) +} + +/// Test that expression deduplication works during deserialization. +/// When the same expression Arc is serialized multiple times, it should be +/// deduplicated on deserialization (sharing the same Arc). +#[test] +fn test_expression_deduplication() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a])); + + // Create a shared expression that will be used multiple times + let shared_col: Arc = Arc::new(Column::new("a", 0)); + + // Create an InList expression that uses the same column Arc multiple times + // This simulates a real-world scenario where expressions are shared + let in_list_expr = in_list( + Arc::clone(&shared_col), + vec![lit(1i64), lit(2i64), lit(3i64)], + &false, + &schema, + )?; + + // Create a binary expression that uses the shared column and the in_list result + let binary_expr: Arc = Arc::new(BinaryExpr::new( + Arc::clone(&shared_col), + Operator::Eq, + lit(42i64), + )); + + // Create a plan that has both expressions (they share the `shared_col` Arc) + let input = Arc::new(EmptyExec::new(schema.clone())); + let filter = FilterExecBuilder::new(in_list_expr, input).build()?; + let projection_exprs = vec![ProjectionExpr { + expr: binary_expr, + alias: "result".to_string(), + }]; + let exec_plan = + Arc::new(ProjectionExec::try_new(projection_exprs, Arc::new(filter))?); + + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DeduplicatingProtoConverter {}; + + // Perform roundtrip + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&exec_plan) as Arc, + &codec, + &proto_converter, + )?; + + // Create a new converter for deserialization (fresh cache) + let deser_converter = DeduplicatingProtoConverter {}; + let result_plan = physical_plan_from_bytes_with_proto_converter( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &deser_converter, + )?; + + // Verify the plan structure is correct + pretty_assertions::assert_eq!(format!("{exec_plan:?}"), format!("{result_plan:?}")); + + Ok(()) +} + +/// Test that expression deduplication correctly shares Arcs for identical expressions. +/// This test verifies the core deduplication behavior. +#[test] +fn test_expression_deduplication_arc_sharing() -> Result<()> { + use datafusion_proto::bytes::{ + physical_plan_from_bytes_with_proto_converter, + physical_plan_to_bytes_with_proto_converter, + }; + + let field_a = Field::new("a", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a])); + + // Create a column expression + let col_expr: Arc = Arc::new(Column::new("a", 0)); + + // Create a projection that uses the SAME Arc twice + // After roundtrip, both should point to the same Arc + let projection_exprs = vec![ + ProjectionExpr { + expr: Arc::clone(&col_expr), + alias: "a1".to_string(), + }, + ProjectionExpr { + expr: Arc::clone(&col_expr), // Same Arc! + alias: "a2".to_string(), + }, + ]; + + let input = Arc::new(EmptyExec::new(schema)); + let exec_plan = Arc::new(ProjectionExec::try_new(projection_exprs, input)?); + + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DeduplicatingProtoConverter {}; + + // Serialize + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&exec_plan) as Arc, + &codec, + &proto_converter, + )?; + + // Deserialize with a fresh converter + let deser_converter = DeduplicatingProtoConverter {}; + let result_plan = physical_plan_from_bytes_with_proto_converter( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &deser_converter, + )?; + + // Get the projection from the result + let projection = result_plan + .as_any() + .downcast_ref::() + .expect("Expected ProjectionExec"); + + let exprs: Vec<_> = projection.expr().iter().collect(); + assert_eq!(exprs.len(), 2); + + // The key test: both expressions should point to the same Arc after deduplication + // This is because they were the same Arc before serialization + assert!( + Arc::ptr_eq(&exprs[0].expr, &exprs[1].expr), + "Expected both expressions to share the same Arc after deduplication" + ); + + Ok(()) +} + +/// Test backward compatibility: protos without expr_id should still deserialize correctly. +#[test] +fn test_backward_compatibility_no_expr_id() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a])); + + // Manually create a proto without expr_id set + let proto = PhysicalExprNode { + expr_id: None, // Simulating old proto without this field + expr_type: Some( + datafusion_proto::protobuf::physical_expr_node::ExprType::Column( + datafusion_proto::protobuf::PhysicalColumn { + name: "a".to_string(), + index: 0, + }, + ), + ), + }; + + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DefaultPhysicalProtoConverter {}; + + // Should deserialize without error + let result = proto_converter.proto_to_physical_expr( + &proto, + ctx.task_ctx().as_ref(), + &schema, + &codec, + )?; + + // Verify the result is correct + let col = result + .as_any() + .downcast_ref::() + .expect("Expected Column"); + assert_eq!(col.name(), "a"); + assert_eq!(col.index(), 0); + + Ok(()) +} + +/// Test that deduplication works within a single plan deserialization and that +/// separate deserializations produce independent expressions (no cross-operation sharing). +#[test] +fn test_deduplication_within_plan_deserialization() -> Result<()> { + use datafusion_proto::bytes::{ + physical_plan_from_bytes_with_proto_converter, + physical_plan_to_bytes_with_proto_converter, + }; + + let field_a = Field::new("a", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a])); + + // Create a plan with expressions that will be deduplicated + let col_expr: Arc = Arc::new(Column::new("a", 0)); + let projection_exprs = vec![ + ProjectionExpr { + expr: Arc::clone(&col_expr), + alias: "a1".to_string(), + }, + ProjectionExpr { + expr: Arc::clone(&col_expr), // Same Arc - will be deduplicated + alias: "a2".to_string(), + }, + ]; + let exec_plan = Arc::new(ProjectionExec::try_new( + projection_exprs, + Arc::new(EmptyExec::new(schema)), + )?); + + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DeduplicatingProtoConverter {}; + + // Serialize + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&exec_plan) as Arc, + &codec, + &proto_converter, + )?; + + // First deserialization + let plan1 = physical_plan_from_bytes_with_proto_converter( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &proto_converter, + )?; + + // Check that the plan was deserialized correctly with deduplication + let projection1 = plan1 + .as_any() + .downcast_ref::() + .expect("Expected ProjectionExec"); + let exprs1: Vec<_> = projection1.expr().iter().collect(); + assert_eq!(exprs1.len(), 2); + assert!( + Arc::ptr_eq(&exprs1[0].expr, &exprs1[1].expr), + "Expected both expressions to share the same Arc after deduplication" + ); + + // Second deserialization + let plan2 = physical_plan_from_bytes_with_proto_converter( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &proto_converter, + )?; + + // Check that the second plan was also deserialized correctly + let projection2 = plan2 + .as_any() + .downcast_ref::() + .expect("Expected ProjectionExec"); + let exprs2: Vec<_> = projection2.expr().iter().collect(); + assert_eq!(exprs2.len(), 2); + assert!( + Arc::ptr_eq(&exprs2[0].expr, &exprs2[1].expr), + "Expected both expressions to share the same Arc after deduplication" + ); + + // Check that there was no deduplication across deserializations + assert!( + !Arc::ptr_eq(&exprs1[0].expr, &exprs2[0].expr), + "Expected expressions from different deserializations to be different Arcs" + ); + assert!( + !Arc::ptr_eq(&exprs1[1].expr, &exprs2[1].expr), + "Expected expressions from different deserializations to be different Arcs" + ); + + Ok(()) +} + +/// Test that deduplication works within direct expression deserialization and that +/// separate deserializations produce independent expressions (no cross-operation sharing). +#[test] +fn test_deduplication_within_expr_deserialization() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a])); + + // Create a binary expression where both sides are the same Arc + // This allows us to test deduplication within a single deserialization + let col_expr: Arc = Arc::new(Column::new("a", 0)); + let binary_expr: Arc = Arc::new(BinaryExpr::new( + Arc::clone(&col_expr), + Operator::Plus, + Arc::clone(&col_expr), // Same Arc - will be deduplicated + )); + + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DeduplicatingProtoConverter {}; + + // Serialize the expression + let proto = proto_converter.physical_expr_to_proto(&binary_expr, &codec)?; + + // First expression deserialization + let expr1 = proto_converter.proto_to_physical_expr( + &proto, + ctx.task_ctx().as_ref(), + &schema, + &codec, + )?; + + // Check that deduplication worked within the deserialization + let binary1 = expr1 + .as_any() + .downcast_ref::() + .expect("Expected BinaryExpr"); + assert!( + Arc::ptr_eq(binary1.left(), binary1.right()), + "Expected both sides to share the same Arc after deduplication" + ); + + // Second expression deserialization + let expr2 = proto_converter.proto_to_physical_expr( + &proto, + ctx.task_ctx().as_ref(), + &schema, + &codec, + )?; + + // Check that the second expression was also deserialized correctly + let binary2 = expr2 + .as_any() + .downcast_ref::() + .expect("Expected BinaryExpr"); + assert!( + Arc::ptr_eq(binary2.left(), binary2.right()), + "Expected both sides to share the same Arc after deduplication" + ); + + // Check that there was no deduplication across deserializations + assert!( + !Arc::ptr_eq(binary1.left(), binary2.left()), + "Expected expressions from different deserializations to be different Arcs" + ); + assert!( + !Arc::ptr_eq(binary1.right(), binary2.right()), + "Expected expressions from different deserializations to be different Arcs" + ); + + Ok(()) +} + +/// Test that session_id rotates between top-level serialization operations. +/// This verifies that each top-level serialization gets a fresh session_id, +/// which prevents cross-process collisions when serialized plans are merged. +#[test] +fn test_session_id_rotation_between_serializations() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let _schema = Arc::new(Schema::new(vec![field_a])); + + // Create a simple expression + let col_expr: Arc = Arc::new(Column::new("a", 0)); + + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DeduplicatingProtoConverter {}; + + // First serialization + let proto1 = proto_converter.physical_expr_to_proto(&col_expr, &codec)?; + let expr_id1 = proto1.expr_id.expect("Expected expr_id to be set"); + + // Second serialization with the same converter + // The session_id should have rotated, so the expr_id should be different + // even though we're serializing the same expression (same pointer address) + let proto2 = proto_converter.physical_expr_to_proto(&col_expr, &codec)?; + let expr_id2 = proto2.expr_id.expect("Expected expr_id to be set"); + + // The expr_ids should be different because session_id rotated + assert_ne!( + expr_id1, expr_id2, + "Expected different expr_ids due to session_id rotation between serializations" + ); + + // Also test that serializing the same expression multiple times within + // the same top-level operation would give the same expr_id (not testable + // here directly since each physical_expr_to_proto is a top-level operation, + // but the deduplication tests verify this indirectly) + + Ok(()) +} + +/// Test that session_id rotation works correctly with execution plans. +/// This verifies the end-to-end behavior with plan serialization. +#[test] +fn test_session_id_rotation_with_execution_plans() -> Result<()> { + use datafusion_proto::bytes::physical_plan_to_bytes_with_proto_converter; + + let field_a = Field::new("a", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a])); + + // Create a simple plan + let col_expr: Arc = Arc::new(Column::new("a", 0)); + let projection_exprs = vec![ProjectionExpr { + expr: Arc::clone(&col_expr), + alias: "a1".to_string(), + }]; + let exec_plan = Arc::new(ProjectionExec::try_new( + projection_exprs.clone(), + Arc::new(EmptyExec::new(Arc::clone(&schema))), + )?); + + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DeduplicatingProtoConverter {}; + + // First serialization + let bytes1 = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&exec_plan) as Arc, + &codec, + &proto_converter, + )?; + + // Second serialization with the same converter + let bytes2 = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&exec_plan) as Arc, + &codec, + &proto_converter, + )?; + + // The serialized bytes should be different due to different session_ids + // (specifically, the expr_id values embedded in the protobuf will differ) + assert_ne!( + bytes1.as_ref(), + bytes2.as_ref(), + "Expected different serialized bytes due to session_id rotation" + ); + + // But both should deserialize correctly + let ctx = SessionContext::new(); + let deser_converter = DeduplicatingProtoConverter {}; + + let plan1 = datafusion_proto::bytes::physical_plan_from_bytes_with_proto_converter( + bytes1.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &deser_converter, + )?; + + let plan2 = datafusion_proto::bytes::physical_plan_from_bytes_with_proto_converter( + bytes2.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &deser_converter, + )?; + + // Verify both plans have the expected structure + assert_eq!(plan1.schema(), plan2.schema()); + + Ok(()) +} + +/// Tests that `lead` window function with offset and default value args +/// survives a protobuf round-trip. This is a regression test for a bug +/// where `expressions()` (used during serialization) returns only the +/// column expression for lead/lag, silently dropping the offset and +/// default value literal args. +#[test] +fn roundtrip_lead_with_default_value() -> Result<()> { + use datafusion::functions_window::lead_lag::lead_udwf; + + let field_a = Field::new("a", DataType::Int64, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + // lead(a, 2, 42) — column a, offset 2, default value 42 + let lead_window = create_udwf_window_expr( + &lead_udwf(), + &[col("a", &schema)?, lit(2i64), lit(42i64)], + schema.as_ref(), + "test lead with default".to_string(), + false, + )?; + + let udwf_expr = Arc::new(StandardWindowExpr::new( + lead_window, + &[col("b", &schema)?], + &[PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }], + Arc::new(WindowFrame::new(None)), + )); + + let input = Arc::new(EmptyExec::new(schema.clone())); + + roundtrip_test(Arc::new(BoundedWindowAggExec::try_new( + vec![udwf_expr], + input, + InputOrderMode::Sorted, + true, + )?)) +} diff --git a/datafusion/pruning/LICENSE.txt b/datafusion/pruning/LICENSE.txt new file mode 120000 index 000000000000..1ef648f64b34 --- /dev/null +++ b/datafusion/pruning/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/pruning/NOTICE.txt b/datafusion/pruning/NOTICE.txt new file mode 120000 index 000000000000..fb051c92b10b --- /dev/null +++ b/datafusion/pruning/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/pruning/src/lib.rs b/datafusion/pruning/src/lib.rs index 9f8142447ba6..be17f29eaafa 100644 --- a/datafusion/pruning/src/lib.rs +++ b/datafusion/pruning/src/lib.rs @@ -16,7 +16,6 @@ // under the License. #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -#![deny(clippy::allow_attributes)] mod file_pruner; mod pruning_predicate; diff --git a/datafusion/pruning/src/pruning_predicate.rs b/datafusion/pruning/src/pruning_predicate.rs index b5b8267d7f93..6f6b00e80abc 100644 --- a/datafusion/pruning/src/pruning_predicate.rs +++ b/datafusion/pruning/src/pruning_predicate.rs @@ -492,7 +492,6 @@ impl PruningPredicate { // Simplify the newly created predicate to get rid of redundant casts, comparisons, etc. let predicate_expr = PhysicalExprSimplifier::new(&predicate_schema).simplify(predicate_expr)?; - let literal_guarantees = LiteralGuarantee::analyze(&expr); Ok(Self { @@ -1206,13 +1205,6 @@ fn is_compare_op(op: Operator) -> bool { ) } -fn is_string_type(data_type: &DataType) -> bool { - matches!( - data_type, - DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View - ) -} - // The pruning logic is based on the comparing the min/max bounds. // Must make sure the two type has order. // For example, casts from string to numbers is not correct. @@ -1234,7 +1226,7 @@ fn verify_support_type_for_prune(from_type: &DataType, to_type: &DataType) -> Re // If both types are strings or both are not strings (number, timestamp, etc) // then we can compare them. // PruningPredicate does not support casting of strings to numbers and such. - if is_string_type(from_type) == is_string_type(to_type) { + if from_type.is_string() == to_type.is_string() { Ok(()) } else { plan_err!( @@ -1281,7 +1273,7 @@ fn build_single_column_expr( ) -> Option> { let field = schema.field_with_name(column.name()).ok()?; - if matches!(field.data_type(), &DataType::Boolean) { + if *field.data_type() == DataType::Boolean { let col_ref = Arc::new(column.clone()) as _; let min = required_columns @@ -4682,7 +4674,7 @@ mod tests { true, // s1 ["AB", "A\u{10ffff}\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep) true, - // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> no row match. (min, max) maybe truncate + // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> no row match. (min, max) maybe truncate // original (min, max) maybe ("A\u{10ffff}\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}\u{10ffff}\u{10ffff}") true, ]; diff --git a/datafusion/session/src/lib.rs b/datafusion/session/src/lib.rs index 3d3cb541b5a5..11f734e75745 100644 --- a/datafusion/session/src/lib.rs +++ b/datafusion/session/src/lib.rs @@ -16,7 +16,6 @@ // under the License. #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -#![deny(clippy::allow_attributes)] //! Session management for DataFusion query execution environment //! diff --git a/datafusion/spark/Cargo.toml b/datafusion/spark/Cargo.toml index 673b62c5c348..162b6d814e80 100644 --- a/datafusion/spark/Cargo.toml +++ b/datafusion/spark/Cargo.toml @@ -29,6 +29,10 @@ edition = { workspace = true } [package.metadata.docs.rs] all-features = true +[features] +default = [] +core = ["datafusion"] + # Note: add additional linter rules in lib.rs. # Rust does not support workspace + new linter rules in subcrates yet # https://github.com/rust-lang/cargo/issues/13157 @@ -43,20 +47,28 @@ arrow = { workspace = true } bigdecimal = { workspace = true } chrono = { workspace = true } crc32fast = "1.4" +# Optional dependency for SessionStateBuilderSpark extension trait +datafusion = { workspace = true, optional = true, default-features = false } datafusion-catalog = { workspace = true } datafusion-common = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions = { workspace = true, features = ["crypto_expressions"] } +datafusion-functions-aggregate = { workspace = true } datafusion-functions-nested = { workspace = true } log = { workspace = true } percent-encoding = "2.3.2" rand = { workspace = true } +serde_json = { workspace = true } sha1 = "0.10" +sha2 = { workspace = true } url = { workspace = true } [dev-dependencies] +arrow = { workspace = true, features = ["test_utils"] } criterion = { workspace = true } +# for SessionStateBuilderSpark tests +datafusion = { workspace = true, default-features = false } [[bench]] harness = false @@ -65,3 +77,23 @@ name = "char" [[bench]] harness = false name = "space" + +[[bench]] +harness = false +name = "hex" + +[[bench]] +harness = false +name = "slice" + +[[bench]] +harness = false +name = "substring" + +[[bench]] +harness = false +name = "unhex" + +[[bench]] +harness = false +name = "sha2" diff --git a/datafusion/spark/benches/char.rs b/datafusion/spark/benches/char.rs index b5f87857ae9c..38d9ebdeb4f5 100644 --- a/datafusion/spark/benches/char.rs +++ b/datafusion/spark/benches/char.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::{DataType, Field}; use arrow::{array::PrimitiveArray, datatypes::Int64Type}; use criterion::{Criterion, criterion_group, criterion_main}; diff --git a/datafusion/spark/benches/hex.rs b/datafusion/spark/benches/hex.rs new file mode 100644 index 000000000000..9785371cc582 --- /dev/null +++ b/datafusion/spark/benches/hex.rs @@ -0,0 +1,150 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::*; +use arrow::datatypes::*; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_spark::function::math::hex::SparkHex; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +fn seedable_rng() -> StdRng { + StdRng::seed_from_u64(42) +} + +fn generate_int64_data(size: usize, null_density: f32) -> PrimitiveArray { + let mut rng = seedable_rng(); + (0..size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(rng.random_range::(-999_999_999_999..999_999_999_999)) + } + }) + .collect() +} + +fn generate_utf8_data(size: usize, null_density: f32) -> StringArray { + let mut rng = seedable_rng(); + let mut builder = StringBuilder::new(); + for _ in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + let len = rng.random_range::(1..=100); + let s: String = + std::iter::repeat_with(|| rng.random_range(b'a'..=b'z') as char) + .take(len) + .collect(); + builder.append_value(&s); + } + } + builder.finish() +} + +fn generate_binary_data(size: usize, null_density: f32) -> BinaryArray { + let mut rng = seedable_rng(); + let mut builder = BinaryBuilder::new(); + for _ in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + let len = rng.random_range::(1..=100); + let bytes: Vec = (0..len).map(|_| rng.random()).collect(); + builder.append_value(&bytes); + } + } + builder.finish() +} + +fn generate_int64_dict_data( + size: usize, + null_density: f32, +) -> DictionaryArray { + let mut rng = seedable_rng(); + let mut builder = PrimitiveDictionaryBuilder::::new(); + for _ in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + builder.append_value( + rng.random_range::(-999_999_999_999..999_999_999_999), + ); + } + } + builder.finish() +} + +fn run_benchmark(c: &mut Criterion, name: &str, size: usize, array: Arc) { + let hex_func = SparkHex::new(); + let args = vec![ColumnarValue::Array(array)]; + let arg_fields: Vec<_> = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function(&format!("{name}/size={size}"), |b| { + b.iter(|| { + black_box( + hex_func + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Arc::new(Field::new("f", DataType::Utf8, true)), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + let sizes = vec![1024, 4096, 8192]; + let null_density = 0.1; + + for &size in &sizes { + let data = generate_int64_data(size, null_density); + run_benchmark(c, "hex_int64", size, Arc::new(data)); + } + + for &size in &sizes { + let data = generate_utf8_data(size, null_density); + run_benchmark(c, "hex_utf8", size, Arc::new(data)); + } + + for &size in &sizes { + let data = generate_binary_data(size, null_density); + run_benchmark(c, "hex_binary", size, Arc::new(data)); + } + + for &size in &sizes { + let data = generate_int64_dict_data(size, null_density); + run_benchmark(c, "hex_int64_dict", size, Arc::new(data)); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/spark/benches/sha2.rs b/datafusion/spark/benches/sha2.rs new file mode 100644 index 000000000000..6e835984703f --- /dev/null +++ b/datafusion/spark/benches/sha2.rs @@ -0,0 +1,105 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::*; +use arrow::datatypes::*; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_spark::function::hash::sha2::SparkSha2; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +fn seedable_rng() -> StdRng { + StdRng::seed_from_u64(42) +} + +fn generate_binary_data(size: usize, null_density: f32) -> BinaryArray { + let mut rng = seedable_rng(); + let mut builder = BinaryBuilder::new(); + for _ in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + let len = rng.random_range::(1..=100); + let bytes: Vec = (0..len).map(|_| rng.random()).collect(); + builder.append_value(&bytes); + } + } + builder.finish() +} + +fn run_benchmark(c: &mut Criterion, name: &str, size: usize, args: &[ColumnarValue]) { + let sha2_func = SparkSha2::new(); + let arg_fields: Vec<_> = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function(&format!("{name}/size={size}"), |b| { + b.iter(|| { + black_box( + sha2_func + .invoke_with_args(ScalarFunctionArgs { + args: args.to_vec(), + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Arc::new(Field::new("f", DataType::Utf8, true)), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + // Scalar benchmark (avoid array expansion) + let scalar_args = vec![ + ColumnarValue::Scalar(ScalarValue::Binary(Some(b"Spark".to_vec()))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(256))), + ]; + run_benchmark(c, "sha2/scalar", 1, &scalar_args); + + let sizes = vec![1024, 4096, 8192]; + let null_density = 0.1; + + for &size in &sizes { + let values: ArrayRef = Arc::new(generate_binary_data(size, null_density)); + let bit_lengths: ArrayRef = Arc::new(Int32Array::from(vec![256; size])); + + let array_args = vec![ + ColumnarValue::Array(Arc::clone(&values)), + ColumnarValue::Array(Arc::clone(&bit_lengths)), + ]; + run_benchmark(c, "sha2/array_binary_256", size, &array_args); + + let array_scalar_args = vec![ + ColumnarValue::Array(Arc::clone(&values)), + ColumnarValue::Scalar(ScalarValue::Int32(Some(256))), + ]; + run_benchmark(c, "sha2/array_scalar_binary_256", size, &array_scalar_args); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/spark/benches/slice.rs b/datafusion/spark/benches/slice.rs new file mode 100644 index 000000000000..da392dc042f9 --- /dev/null +++ b/datafusion/spark/benches/slice.rs @@ -0,0 +1,185 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Int64Array, ListArray, ListViewArray, NullBufferBuilder, PrimitiveArray, +}; +use arrow::buffer::{OffsetBuffer, ScalarBuffer}; +use arrow::datatypes::{DataType, Field, Int64Type}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_spark::function::array::slice; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +fn create_inputs( + rng: &mut StdRng, + size: usize, + child_array_size: usize, + null_density: f32, +) -> (ListArray, ListViewArray) { + let mut nulls_builder = NullBufferBuilder::new(size); + let mut sizes = Vec::with_capacity(size); + + for _ in 0..size { + if rng.random::() < null_density { + nulls_builder.append_null(); + } else { + nulls_builder.append_non_null(); + } + sizes.push(rng.random_range(1..child_array_size)); + } + let nulls = nulls_builder.finish(); + + let length = sizes.iter().sum(); + let values: PrimitiveArray = + (0..length).map(|_| Some(rng.random())).collect(); + let values = Arc::new(values); + + let offsets = OffsetBuffer::from_lengths(sizes.clone()); + let list_array = ListArray::new( + Arc::new(Field::new_list_field(DataType::Int64, true)), + offsets.clone(), + values.clone(), + nulls.clone(), + ); + + let offsets = ScalarBuffer::from(offsets.slice(0, size - 1)); + let sizes = ScalarBuffer::from_iter(sizes.into_iter().map(|v| v as i32)); + let list_view_array = ListViewArray::new( + Arc::new(Field::new_list_field(DataType::Int64, true)), + offsets, + sizes, + values, + nulls, + ); + + (list_array, list_view_array) +} + +fn random_from_to( + rng: &mut StdRng, + size: i64, + null_density: f32, +) -> (Option, Option) { + let from = if rng.random::() < null_density { + None + } else { + Some(rng.random_range(1..=size)) + }; + + let to = if rng.random::() < null_density { + None + } else { + match from { + Some(from) => Some(rng.random_range(from..=size)), + None => Some(rng.random_range(1..=size)), + } + }; + + (from, to) +} + +fn array_slice_benchmark( + name: &str, + input: ColumnarValue, + mut args: Vec, + c: &mut Criterion, + size: usize, +) { + args.insert(0, input); + + let array_slice = slice(); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + >::from(Field::new(format!("arg_{idx}"), arg.data_type(), true)) + }) + .collect::>(); + c.bench_function(name, |b| { + b.iter(|| { + black_box( + array_slice + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new_list_field(args[0].data_type(), true) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + let rng = &mut StdRng::seed_from_u64(42); + let size = 1_000_000; + let child_array_size = 100; + let null_density = 0.1; + + let (list_array, list_view_array) = + create_inputs(rng, size, child_array_size, null_density); + + let mut array_from = Vec::with_capacity(size); + let mut array_to = Vec::with_capacity(size); + for child_array_size in list_array.offsets().lengths() { + let (from, to) = random_from_to(rng, child_array_size as i64, null_density); + array_from.push(from); + array_to.push(to); + } + + // input + let list_array = ColumnarValue::Array(Arc::new(list_array)); + let list_view_array = ColumnarValue::Array(Arc::new(list_view_array)); + + // args + let array_from = ColumnarValue::Array(Arc::new(Int64Array::from(array_from))); + let array_to = ColumnarValue::Array(Arc::new(Int64Array::from(array_to))); + let scalar_from = ColumnarValue::Scalar(ScalarValue::from(1i64)); + let scalar_to = ColumnarValue::Scalar(ScalarValue::from(child_array_size as i64 / 2)); + + for input in [list_array, list_view_array] { + let input_type = input.data_type().to_string(); + + array_slice_benchmark( + &format!("slice: input {input_type}, array args, no stride"), + input.clone(), + vec![array_from.clone(), array_to.clone()], + c, + size, + ); + + array_slice_benchmark( + &format!("slice: input {input_type}, scalar args, no stride"), + input.clone(), + vec![scalar_from.clone(), scalar_to.clone()], + c, + size, + ); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/spark/benches/space.rs b/datafusion/spark/benches/space.rs index 8ace7219a1dc..bd9d370ca37f 100644 --- a/datafusion/spark/benches/space.rs +++ b/datafusion/spark/benches/space.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::PrimitiveArray; use arrow::datatypes::{DataType, Field, Int32Type}; use criterion::{Criterion, criterion_group, criterion_main}; diff --git a/datafusion/spark/benches/substring.rs b/datafusion/spark/benches/substring.rs new file mode 100644 index 000000000000..d6eac817c322 --- /dev/null +++ b/datafusion/spark/benches/substring.rs @@ -0,0 +1,205 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, Int64Array, OffsetSizeTrait}; +use arrow::datatypes::{DataType, Field}; +use arrow::util::bench_util::{ + create_string_array_with_len, create_string_view_array_with_len, +}; +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; +use datafusion_common::DataFusionError; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_spark::function::string::substring; +use std::hint::black_box; +use std::sync::Arc; + +fn create_args_without_count( + size: usize, + str_len: usize, + start_half_way: bool, + force_view_types: bool, +) -> Vec { + let start_array = Arc::new(Int64Array::from( + (0..size) + .map(|_| { + if start_half_way { + (str_len / 2) as i64 + } else { + 1i64 + } + }) + .collect::>(), + )); + + if force_view_types { + let string_array = + Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false)); + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(start_array), + ] + } else { + let string_array = + Arc::new(create_string_array_with_len::(size, 0.1, str_len)); + + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(Arc::clone(&start_array) as ArrayRef), + ] + } +} + +fn create_args_with_count( + size: usize, + str_len: usize, + count_max: usize, + force_view_types: bool, +) -> Vec { + let start_array = + Arc::new(Int64Array::from((0..size).map(|_| 1).collect::>())); + let count = count_max.min(str_len) as i64; + let count_array = Arc::new(Int64Array::from( + (0..size).map(|_| count).collect::>(), + )); + + if force_view_types { + let string_array = + Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false)); + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(start_array), + ColumnarValue::Array(count_array), + ] + } else { + let string_array = + Arc::new(create_string_array_with_len::(size, 0.1, str_len)); + + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(Arc::clone(&start_array) as ArrayRef), + ColumnarValue::Array(Arc::clone(&count_array) as ArrayRef), + ] + } +} + +#[expect(clippy::needless_pass_by_value)] +fn invoke_substr_with_args( + args: Vec, + number_rows: usize, +) -> Result { + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + substring().invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields, + number_rows, + return_field: Field::new("f", DataType::Utf8View, true).into(), + config_options: Arc::clone(&config_options), + }) +} + +fn criterion_benchmark(c: &mut Criterion) { + for size in [1024, 4096] { + // string_len = 12, substring_len=6 (see `create_args_without_count`) + let len = 12; + let mut group = c.benchmark_group("SHORTER THAN 12"); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + + let args = create_args_without_count::(size, len, true, true); + group.bench_function( + format!("substr_string_view [size={size}, strlen={len}]"), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + let args = create_args_without_count::(size, len, false, false); + group.bench_function(format!("substr_string [size={size}, strlen={len}]"), |b| { + b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))) + }); + + let args = create_args_without_count::(size, len, true, false); + group.bench_function( + format!("substr_large_string [size={size}, strlen={len}]"), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + group.finish(); + + // string_len = 128, start=1, count=64, substring_len=64 + let len = 128; + let count = 64; + let mut group = c.benchmark_group("LONGER THAN 12"); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + + let args = create_args_with_count::(size, len, count, true); + group.bench_function( + format!("substr_string_view [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + let args = create_args_with_count::(size, len, count, false); + group.bench_function( + format!("substr_string [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + let args = create_args_with_count::(size, len, count, false); + group.bench_function( + format!("substr_large_string [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + group.finish(); + + // string_len = 128, start=1, count=6, substring_len=6 + let len = 128; + let count = 6; + let mut group = c.benchmark_group("SRC_LEN > 12, SUB_LEN < 12"); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + + let args = create_args_with_count::(size, len, count, true); + group.bench_function( + format!("substr_string_view [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + let args = create_args_with_count::(size, len, count, false); + group.bench_function( + format!("substr_string [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + let args = create_args_with_count::(size, len, count, false); + group.bench_function( + format!("substr_large_string [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + group.finish(); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/spark/benches/unhex.rs b/datafusion/spark/benches/unhex.rs new file mode 100644 index 000000000000..7dce683485bc --- /dev/null +++ b/datafusion/spark/benches/unhex.rs @@ -0,0 +1,146 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, LargeStringArray, LargeStringBuilder, StringArray, StringBuilder, + StringViewArray, StringViewBuilder, +}; +use arrow::datatypes::{DataType, Field}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_spark::function::math::unhex::SparkUnhex; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +fn generate_hex_string_data(size: usize, null_density: f32) -> StringArray { + let mut rng = StdRng::seed_from_u64(42); + let mut builder = StringBuilder::with_capacity(size, 0); + let hex_chars = b"0123456789abcdefABCDEF"; + + for _ in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + let len = rng.random_range::(2..=100); + let s: String = std::iter::repeat_with(|| { + hex_chars[rng.random_range(0..hex_chars.len())] as char + }) + .take(len) + .collect(); + builder.append_value(&s); + } + } + builder.finish() +} + +fn generate_hex_large_string_data(size: usize, null_density: f32) -> LargeStringArray { + let mut rng = StdRng::seed_from_u64(42); + let mut builder = LargeStringBuilder::with_capacity(size, 0); + let hex_chars = b"0123456789abcdefABCDEF"; + + for _ in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + let len = rng.random_range::(2..=100); + let s: String = std::iter::repeat_with(|| { + hex_chars[rng.random_range(0..hex_chars.len())] as char + }) + .take(len) + .collect(); + builder.append_value(&s); + } + } + builder.finish() +} + +fn generate_hex_utf8view_data(size: usize, null_density: f32) -> StringViewArray { + let mut rng = StdRng::seed_from_u64(42); + let mut builder = StringViewBuilder::with_capacity(size); + let hex_chars = b"0123456789abcdefABCDEF"; + + for _ in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + let len = rng.random_range::(2..=100); + let s: String = std::iter::repeat_with(|| { + hex_chars[rng.random_range(0..hex_chars.len())] as char + }) + .take(len) + .collect(); + builder.append_value(&s); + } + } + builder.finish() +} + +fn run_benchmark(c: &mut Criterion, name: &str, size: usize, array: Arc) { + let unhex_func = SparkUnhex::new(); + let args = vec![ColumnarValue::Array(array)]; + let arg_fields: Vec<_> = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function(&format!("{name}/size={size}"), |b| { + b.iter(|| { + black_box( + unhex_func + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Arc::new(Field::new("f", DataType::Binary, true)), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + let sizes = vec![1024, 4096, 8192]; + let null_density = 0.1; + + // Benchmark with hex string + for &size in &sizes { + let data = generate_hex_string_data(size, null_density); + run_benchmark(c, "unhex_utf8", size, Arc::new(data)); + } + + // Benchmark with hex large string + for &size in &sizes { + let data = generate_hex_large_string_data(size, null_density); + run_benchmark(c, "unhex_large_utf8", size, Arc::new(data)); + } + + // Benchmark with hex Utf8View + for &size in &sizes { + let data = generate_hex_utf8view_data(size, null_density); + run_benchmark(c, "unhex_utf8view", size, Arc::new(data)); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/spark/src/function/aggregate/collect.rs b/datafusion/spark/src/function/aggregate/collect.rs new file mode 100644 index 000000000000..50497e282638 --- /dev/null +++ b/datafusion/spark/src/function/aggregate/collect.rs @@ -0,0 +1,200 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::utils::SingleRowListArrayBuilder; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; +use datafusion_functions_aggregate::array_agg::{ + ArrayAggAccumulator, DistinctArrayAggAccumulator, +}; +use std::{any::Any, sync::Arc}; + +// Spark implementation of collect_list/collect_set aggregate function. +// Differs from DataFusion ArrayAgg in the following ways: +// - ignores NULL inputs +// - returns an empty list when all inputs are NULL +// - does not support ordering + +// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkCollectList { + signature: Signature, +} + +impl Default for SparkCollectList { + fn default() -> Self { + Self::new() + } +} + +impl SparkCollectList { + pub fn new() -> Self { + Self { + signature: Signature::any(1, Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for SparkCollectList { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "collect_list" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(DataType::List(Arc::new(Field::new_list_field( + arg_types[0].clone(), + true, + )))) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + Ok(vec![ + Field::new_list( + format_state_name(args.name, "collect_list"), + Field::new_list_field(args.input_fields[0].data_type().clone(), true), + true, + ) + .into(), + ]) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let field = &acc_args.expr_fields[0]; + let data_type = field.data_type().clone(); + let ignore_nulls = true; + Ok(Box::new(NullToEmptyListAccumulator::new( + ArrayAggAccumulator::try_new(&data_type, ignore_nulls)?, + data_type, + ))) + } +} + +// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkCollectSet { + signature: Signature, +} + +impl Default for SparkCollectSet { + fn default() -> Self { + Self::new() + } +} + +impl SparkCollectSet { + pub fn new() -> Self { + Self { + signature: Signature::any(1, Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for SparkCollectSet { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "collect_set" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(DataType::List(Arc::new(Field::new_list_field( + arg_types[0].clone(), + true, + )))) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + Ok(vec![ + Field::new_list( + format_state_name(args.name, "collect_set"), + Field::new_list_field(args.input_fields[0].data_type().clone(), true), + true, + ) + .into(), + ]) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let field = &acc_args.expr_fields[0]; + let data_type = field.data_type().clone(); + let ignore_nulls = true; + Ok(Box::new(NullToEmptyListAccumulator::new( + DistinctArrayAggAccumulator::try_new(&data_type, None, ignore_nulls)?, + data_type, + ))) + } +} + +/// Wrapper accumulator that returns an empty list instead of NULL when all inputs are NULL. +/// This implements Spark's behavior for collect_list and collect_set. +#[derive(Debug)] +struct NullToEmptyListAccumulator { + inner: T, + data_type: DataType, +} + +impl NullToEmptyListAccumulator { + pub fn new(inner: T, data_type: DataType) -> Self { + Self { inner, data_type } + } +} + +impl Accumulator for NullToEmptyListAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.inner.update_batch(values) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.inner.merge_batch(states) + } + + fn state(&mut self) -> Result> { + self.inner.state() + } + + fn evaluate(&mut self) -> Result { + let result = self.inner.evaluate()?; + if result.is_null() { + let empty_array = arrow::array::new_empty_array(&self.data_type); + Ok(SingleRowListArrayBuilder::new(empty_array).build_list_scalar()) + } else { + Ok(result) + } + } + + fn size(&self) -> usize { + self.inner.size() + self.data_type.size() + } +} diff --git a/datafusion/spark/src/function/aggregate/mod.rs b/datafusion/spark/src/function/aggregate/mod.rs index 3db72669d42b..d6a2fe7a8503 100644 --- a/datafusion/spark/src/function/aggregate/mod.rs +++ b/datafusion/spark/src/function/aggregate/mod.rs @@ -19,6 +19,7 @@ use datafusion_expr::AggregateUDF; use std::sync::Arc; pub mod avg; +pub mod collect; pub mod try_sum; pub mod expr_fn { @@ -30,6 +31,16 @@ pub mod expr_fn { "Returns the sum of values for a column, or NULL if overflow occurs", arg1 )); + export_functions!(( + collect_list, + "Returns a list created from the values in a column", + arg1 + )); + export_functions!(( + collect_set, + "Returns a set created from the values in a column", + arg1 + )); } // TODO: try use something like datafusion_functions_aggregate::create_func!() @@ -39,7 +50,13 @@ pub fn avg() -> Arc { pub fn try_sum() -> Arc { Arc::new(AggregateUDF::new_from_impl(try_sum::SparkTrySum::new())) } +pub fn collect_list() -> Arc { + Arc::new(AggregateUDF::new_from_impl(collect::SparkCollectList::new())) +} +pub fn collect_set() -> Arc { + Arc::new(AggregateUDF::new_from_impl(collect::SparkCollectSet::new())) +} pub fn functions() -> Vec> { - vec![avg(), try_sum()] + vec![avg(), try_sum(), collect_list(), collect_set()] } diff --git a/datafusion/spark/src/function/array/array_contains.rs b/datafusion/spark/src/function/array/array_contains.rs new file mode 100644 index 000000000000..2bc5d64d8bff --- /dev/null +++ b/datafusion/spark/src/function/array/array_contains.rs @@ -0,0 +1,168 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, AsArray, BooleanArray, BooleanBufferBuilder, GenericListArray, OffsetSizeTrait, +}; +use arrow::buffer::{BooleanBuffer, NullBuffer}; +use arrow::datatypes::DataType; +use datafusion_common::{Result, exec_err}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions_nested::array_has::array_has_udf; +use std::any::Any; +use std::sync::Arc; + +/// Spark-compatible `array_contains` function. +/// +/// Calls DataFusion's `array_has` and then applies Spark's null semantics: +/// - If the result from `array_has` is `true`, return `true`. +/// - If the result is `false` and the input array row contains any null elements, +/// return `null` (because the element might have been the null). +/// - If the result is `false` and the input array row has no null elements, +/// return `false`. +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkArrayContains { + signature: Signature, +} + +impl Default for SparkArrayContains { + fn default() -> Self { + Self::new() + } +} + +impl SparkArrayContains { + pub fn new() -> Self { + Self { + signature: Signature::array_and_element(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkArrayContains { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_contains" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let haystack = args.args[0].clone(); + let array_has_result = array_has_udf().invoke_with_args(args)?; + + let result_array = array_has_result.to_array(1)?; + let patched = apply_spark_null_semantics(result_array.as_boolean(), &haystack)?; + Ok(ColumnarValue::Array(Arc::new(patched))) + } +} + +/// For each row where `array_has` returned `false`, set the output to null +/// if that row's input array contains any null elements. +fn apply_spark_null_semantics( + result: &BooleanArray, + haystack_arg: &ColumnarValue, +) -> Result { + // happy path + if result.false_count() == 0 || haystack_arg.data_type() == DataType::Null { + return Ok(result.clone()); + } + + let haystack = haystack_arg.to_array_of_size(result.len())?; + + let row_has_nulls = compute_row_has_nulls(&haystack)?; + + // A row keeps its validity when result is true OR the row has no nulls. + let keep_mask = result.values() | &!&row_has_nulls; + let new_validity = match result.nulls() { + Some(n) => n.inner() & &keep_mask, + None => keep_mask, + }; + + Ok(BooleanArray::new( + result.values().clone(), + Some(NullBuffer::new(new_validity)), + )) +} + +/// Returns a per-row bitmap where bit i is set if row i's list contains any null element. +fn compute_row_has_nulls(haystack: &dyn Array) -> Result { + match haystack.data_type() { + DataType::List(_) => generic_list_row_has_nulls(haystack.as_list::()), + DataType::LargeList(_) => generic_list_row_has_nulls(haystack.as_list::()), + DataType::FixedSizeList(_, _) => { + let list = haystack.as_fixed_size_list(); + let buf = match list.values().nulls() { + Some(nulls) => { + let validity = nulls.inner(); + let vl = list.value_length() as usize; + let mut builder = BooleanBufferBuilder::new(list.len()); + for i in 0..list.len() { + builder.append(validity.slice(i * vl, vl).count_set_bits() < vl); + } + builder.finish() + } + None => BooleanBuffer::new_unset(list.len()), + }; + Ok(mask_with_list_nulls(buf, list.nulls())) + } + dt => exec_err!("compute_row_has_nulls: unsupported data type {dt}"), + } +} + +/// Computes per-row null presence for `List` and `LargeList` arrays. +fn generic_list_row_has_nulls( + list: &GenericListArray, +) -> Result { + let buf = match list.values().nulls() { + Some(nulls) => { + let validity = nulls.inner(); + let offsets = list.offsets(); + let mut builder = BooleanBufferBuilder::new(list.len()); + for i in 0..list.len() { + let s = offsets[i].as_usize(); + let len = offsets[i + 1].as_usize() - s; + builder.append(validity.slice(s, len).count_set_bits() < len); + } + builder.finish() + } + None => BooleanBuffer::new_unset(list.len()), + }; + Ok(mask_with_list_nulls(buf, list.nulls())) +} + +/// Rows where the list itself is null should not be marked as "has nulls". +fn mask_with_list_nulls( + buf: BooleanBuffer, + list_nulls: Option<&NullBuffer>, +) -> BooleanBuffer { + match list_nulls { + Some(n) => &buf & n.inner(), + None => buf, + } +} diff --git a/datafusion/spark/src/function/array/mod.rs b/datafusion/spark/src/function/array/mod.rs index 01056ba95298..6c16e0536164 100644 --- a/datafusion/spark/src/function/array/mod.rs +++ b/datafusion/spark/src/function/array/mod.rs @@ -15,27 +15,54 @@ // specific language governing permissions and limitations // under the License. +pub mod array_contains; +pub mod repeat; pub mod shuffle; +pub mod slice; pub mod spark_array; use datafusion_expr::ScalarUDF; use datafusion_functions::make_udf_function; use std::sync::Arc; +make_udf_function!(array_contains::SparkArrayContains, spark_array_contains); make_udf_function!(spark_array::SparkArray, array); make_udf_function!(shuffle::SparkShuffle, shuffle); +make_udf_function!(repeat::SparkArrayRepeat, array_repeat); +make_udf_function!(slice::SparkSlice, slice); pub mod expr_fn { use datafusion_functions::export_functions; + export_functions!(( + spark_array_contains, + "Returns true if the array contains the element (Spark semantics).", + array element + )); export_functions!((array, "Returns an array with the given elements.", args)); export_functions!(( shuffle, "Returns a random permutation of the given array.", args )); + export_functions!(( + array_repeat, + "returns an array containing element count times.", + element count + )); + export_functions!(( + slice, + "Returns a slice of the array from the start index with the given length.", + array start length + )); } pub fn functions() -> Vec> { - vec![array(), shuffle()] + vec![ + spark_array_contains(), + array(), + shuffle(), + array_repeat(), + slice(), + ] } diff --git a/datafusion/spark/src/function/array/repeat.rs b/datafusion/spark/src/function/array/repeat.rs new file mode 100644 index 000000000000..7543300a9107 --- /dev/null +++ b/datafusion/spark/src/function/array/repeat.rs @@ -0,0 +1,128 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions_nested::repeat::ArrayRepeat; +use std::any::Any; +use std::sync::Arc; + +use crate::function::null_utils::{ + NullMaskResolution, apply_null_mask, compute_null_mask, +}; + +/// Spark-compatible `array_repeat` expression. The difference with DataFusion's `array_repeat` is the handling of NULL inputs: in spark if any input is NULL, the result is NULL. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkArrayRepeat { + signature: Signature, +} + +impl Default for SparkArrayRepeat { + fn default() -> Self { + Self::new() + } +} + +impl SparkArrayRepeat { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkArrayRepeat { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_repeat" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(DataType::List(Arc::new(Field::new_list_field( + arg_types[0].clone(), + true, + )))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_array_repeat(args) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [first_type, second_type] = take_function_args(self.name(), arg_types)?; + + // Coerce the second argument to Int64/UInt64 if it's a numeric type + let second = match second_type { + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + DataType::Int64 + } + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { + DataType::UInt64 + } + _ => return exec_err!("count must be an integer type"), + }; + + Ok(vec![first_type.clone(), second]) + } +} + +/// This is a Spark-specific wrapper around DataFusion's array_repeat that returns NULL +/// if any argument is NULL (Spark behavior), whereas DataFusion's array_repeat ignores NULLs. +fn spark_array_repeat(args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { + args: arg_values, + arg_fields, + number_rows, + return_field, + config_options, + } = args; + let return_type = return_field.data_type().clone(); + + // Step 1: Check for NULL mask in incoming args + let null_mask = compute_null_mask(&arg_values, number_rows)?; + + // If any argument is null then return NULL immediately + if matches!(null_mask, NullMaskResolution::ReturnNull) { + return Ok(ColumnarValue::Scalar(ScalarValue::try_from(return_type)?)); + } + + // Step 2: Delegate to DataFusion's array_repeat + let array_repeat_func = ArrayRepeat::new(); + let func_args = ScalarFunctionArgs { + args: arg_values, + arg_fields, + number_rows, + return_field, + config_options, + }; + let result = array_repeat_func.invoke_with_args(func_args)?; + + // Step 3: Apply NULL mask to result + apply_null_mask(result, null_mask, &return_type) +} diff --git a/datafusion/spark/src/function/array/shuffle.rs b/datafusion/spark/src/function/array/shuffle.rs index eaeff6538c32..8051825acc74 100644 --- a/datafusion/spark/src/function/array/shuffle.rs +++ b/datafusion/spark/src/function/array/shuffle.rs @@ -105,11 +105,8 @@ impl ScalarUDFImpl for SparkShuffle { &self, args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - if args.args.is_empty() { - return exec_err!("shuffle expects at least 1 argument"); - } - if args.args.len() > 2 { - return exec_err!("shuffle expects at most 2 arguments"); + if args.args.is_empty() || args.args.len() > 2 { + return exec_err!("shuffle expects 1 or 2 argument(s)"); } // Extract seed from second argument if present @@ -131,10 +128,10 @@ fn extract_seed(seed_arg: &ColumnarValue) -> Result> { ColumnarValue::Scalar(scalar) => { let seed = match scalar { ScalarValue::Int64(Some(v)) => Some(*v as u64), - ScalarValue::Null => None, + ScalarValue::Null | ScalarValue::Int64(None) => None, _ => { return exec_err!( - "shuffle seed must be Int64 type, got '{}'", + "shuffle seed must be Int64 type but got '{}'", scalar.data_type() ); } @@ -164,7 +161,10 @@ fn array_shuffle_with_seed(arg: &[ArrayRef], seed: Option) -> Result Ok(Arc::clone(input_array)), - array_type => exec_err!("shuffle does not support type '{array_type}'."), + array_type => exec_err!( + "shuffle does not support type '{array_type}'; \ + expected types: List, LargeList, FixedSizeList or Null." + ), } } diff --git a/datafusion/spark/src/function/array/slice.rs b/datafusion/spark/src/function/array/slice.rs new file mode 100644 index 000000000000..6c168a4f491b --- /dev/null +++ b/datafusion/spark/src/function/array/slice.rs @@ -0,0 +1,172 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, Int64Builder}; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::cast::{as_int64_array, as_list_array}; +use datafusion_common::utils::ListCoercion; +use datafusion_common::{Result, exec_err, internal_err, utils::take_function_args}; +use datafusion_expr::{ + ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ReturnFieldArgs, + ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use datafusion_functions_nested::extract::array_slice_udf; +use std::any::Any; +use std::sync::Arc; + +/// Spark slice function implementation +/// Main difference from DataFusion's array_slice is that the third argument is the length of the slice and not the end index. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkSlice { + signature: Signature, +} + +impl Default for SparkSlice { + fn default() -> Self { + Self::new() + } +} + +impl SparkSlice { + pub fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Index, + ArrayFunctionArgument::Index, + ], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, + ), + volatility: Volatility::Immutable, + parameter_names: None, + }, + } + } +} + +impl ScalarUDFImpl for SparkSlice { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "slice" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + + Ok(Arc::new(Field::new( + "slice", + args.arg_fields[0].data_type().clone(), + nullable, + ))) + } + + fn invoke_with_args( + &self, + mut func_args: ScalarFunctionArgs, + ) -> Result { + let array_len = func_args + .args + .iter() + .find_map(|arg| match arg { + ColumnarValue::Array(array) => Some(array.len()), + _ => None, + }) + .unwrap_or(func_args.number_rows); + + let arrays = func_args + .args + .iter() + .map(|arg| match arg { + ColumnarValue::Array(array) => Ok(Arc::clone(array)), + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(array_len), + }) + .collect::>>()?; + + let (start, end) = calculate_start_end(&arrays)?; + + array_slice_udf().invoke_with_args(ScalarFunctionArgs { + args: vec![ + func_args.args.swap_remove(0), + ColumnarValue::Array(start), + ColumnarValue::Array(end), + ], + arg_fields: func_args.arg_fields, + number_rows: func_args.number_rows, + return_field: func_args.return_field, + config_options: func_args.config_options, + }) + } +} + +fn calculate_start_end(args: &[ArrayRef]) -> Result<(ArrayRef, ArrayRef)> { + let [values, start, length] = take_function_args("slice", args)?; + + let values_len = values.len(); + + let start = as_int64_array(&start)?; + let length = as_int64_array(&length)?; + + let values = as_list_array(values)?; + + let mut adjusted_start = Int64Builder::with_capacity(values_len); + let mut end = Int64Builder::with_capacity(values_len); + + for row in 0..values_len { + if values.is_null(row) || start.is_null(row) || length.is_null(row) { + adjusted_start.append_null(); + end.append_null(); + continue; + } + let start = start.value(row); + let length = length.value(row); + let value_length = values.value(row).len() as i64; + + if start == 0 { + return exec_err!("Start index must not be zero"); + } + if length < 0 { + return exec_err!("Length must be non-negative, but got {}", length); + } + + let adjusted_start_value = if start < 0 { + start + value_length + 1 + } else { + start + }; + + adjusted_start.append_value(adjusted_start_value); + end.append_value(adjusted_start_value + (length - 1)); + } + + Ok((Arc::new(adjusted_start.finish()), Arc::new(end.finish()))) +} diff --git a/datafusion/spark/src/function/array/spark_array.rs b/datafusion/spark/src/function/array/spark_array.rs index 6d9f9a1695e1..1ad0a394b8ca 100644 --- a/datafusion/spark/src/function/array/spark_array.rs +++ b/datafusion/spark/src/function/array/spark_array.rs @@ -23,7 +23,7 @@ use datafusion_common::utils::SingleRowListArrayBuilder; use datafusion_common::{Result, internal_err}; use datafusion_expr::{ ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, - TypeSignature, Volatility, + Volatility, }; use datafusion_functions_nested::make_array::{array_array, coerce_types_inner}; @@ -45,10 +45,7 @@ impl Default for SparkArray { impl SparkArray { pub fn new() -> Self { Self { - signature: Signature::one_of( - vec![TypeSignature::UserDefined, TypeSignature::Nullary], - Volatility::Immutable, - ), + signature: Signature::user_defined(Volatility::Immutable), } } } @@ -104,12 +101,12 @@ impl ScalarUDFImpl for SparkArray { make_scalar_function(make_array_inner)(args.as_slice()) } - fn aliases(&self) -> &[String] { - &[] - } - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - coerce_types_inner(arg_types, self.name()) + if arg_types.is_empty() { + Ok(vec![]) + } else { + coerce_types_inner(arg_types, self.name()) + } } } diff --git a/datafusion/spark/src/function/bitmap/bitmap_bit_position.rs b/datafusion/spark/src/function/bitmap/bitmap_bit_position.rs new file mode 100644 index 000000000000..262dc07f2704 --- /dev/null +++ b/datafusion/spark/src/function/bitmap/bitmap_bit_position.rs @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, AsArray, Int64Array}; +use arrow::datatypes::Field; +use arrow::datatypes::{DataType, FieldRef, Int8Type, Int16Type, Int32Type, Int64Type}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; +use datafusion_functions::utils::make_scalar_function; +use std::any::Any; +use std::sync::Arc; + +/// Spark-compatible `bitmap_bit_position` expression +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct BitmapBitPosition { + signature: Signature, +} + +impl Default for BitmapBitPosition { + fn default() -> Self { + Self::new() + } +} + +impl BitmapBitPosition { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Int8]), + TypeSignature::Exact(vec![DataType::Int16]), + TypeSignature::Exact(vec![DataType::Int32]), + TypeSignature::Exact(vec![DataType::Int64]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for BitmapBitPosition { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "bitmap_bit_position" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args( + &self, + args: datafusion_expr::ReturnFieldArgs, + ) -> Result { + Ok(Arc::new(Field::new( + self.name(), + DataType::Int64, + args.arg_fields[0].is_nullable(), + ))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(bitmap_bit_position_inner, vec![])(&args.args) + } +} + +pub fn bitmap_bit_position_inner(arg: &[ArrayRef]) -> Result { + let [array] = take_function_args("bitmap_bit_position", arg)?; + match &array.data_type() { + DataType::Int8 => { + let result: Int64Array = array + .as_primitive::() + .iter() + .map(|opt| opt.map(|value| bitmap_bit_position(value.into()))) + .collect(); + Ok(Arc::new(result)) + } + DataType::Int16 => { + let result: Int64Array = array + .as_primitive::() + .iter() + .map(|opt| opt.map(|value| bitmap_bit_position(value.into()))) + .collect(); + Ok(Arc::new(result)) + } + DataType::Int32 => { + let result: Int64Array = array + .as_primitive::() + .iter() + .map(|opt| opt.map(|value| bitmap_bit_position(value.into()))) + .collect(); + Ok(Arc::new(result)) + } + DataType::Int64 => { + let result: Int64Array = array + .as_primitive::() + .iter() + .map(|opt| opt.map(bitmap_bit_position)) + .collect(); + Ok(Arc::new(result)) + } + data_type => { + internal_err!("bitmap_bit_position does not support {data_type}") + } + } +} + +const NUM_BYTES: i64 = 4 * 1024; +const NUM_BITS: i64 = NUM_BYTES * 8; + +fn bitmap_bit_position(value: i64) -> i64 { + if value > 0 { + (value - 1) % NUM_BITS + } else { + (value.wrapping_neg()) % NUM_BITS + } +} diff --git a/datafusion/spark/src/function/bitmap/bitmap_bucket_number.rs b/datafusion/spark/src/function/bitmap/bitmap_bucket_number.rs new file mode 100644 index 000000000000..9686d1acd883 --- /dev/null +++ b/datafusion/spark/src/function/bitmap/bitmap_bucket_number.rs @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, AsArray, Int64Array}; +use arrow::datatypes::Field; +use arrow::datatypes::{DataType, FieldRef, Int8Type, Int16Type, Int32Type, Int64Type}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; +use datafusion_functions::utils::make_scalar_function; +use std::any::Any; +use std::sync::Arc; + +/// Spark-compatible `bitmap_bucket_number` expression +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct BitmapBucketNumber { + signature: Signature, +} + +impl Default for BitmapBucketNumber { + fn default() -> Self { + Self::new() + } +} + +impl BitmapBucketNumber { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Int8]), + TypeSignature::Exact(vec![DataType::Int16]), + TypeSignature::Exact(vec![DataType::Int32]), + TypeSignature::Exact(vec![DataType::Int64]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for BitmapBucketNumber { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "bitmap_bucket_number" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args( + &self, + args: datafusion_expr::ReturnFieldArgs, + ) -> Result { + Ok(Arc::new(Field::new( + self.name(), + DataType::Int64, + args.arg_fields[0].is_nullable(), + ))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(bitmap_bucket_number_inner, vec![])(&args.args) + } +} + +pub fn bitmap_bucket_number_inner(arg: &[ArrayRef]) -> Result { + let [array] = take_function_args("bitmap_bucket_number", arg)?; + match &array.data_type() { + DataType::Int8 => { + let result: Int64Array = array + .as_primitive::() + .iter() + .map(|opt| opt.map(|value| bitmap_bucket_number(value.into()))) + .collect(); + Ok(Arc::new(result)) + } + DataType::Int16 => { + let result: Int64Array = array + .as_primitive::() + .iter() + .map(|opt| opt.map(|value| bitmap_bucket_number(value.into()))) + .collect(); + Ok(Arc::new(result)) + } + DataType::Int32 => { + let result: Int64Array = array + .as_primitive::() + .iter() + .map(|opt| opt.map(|value| bitmap_bucket_number(value.into()))) + .collect(); + Ok(Arc::new(result)) + } + DataType::Int64 => { + let result: Int64Array = array + .as_primitive::() + .iter() + .map(|opt| opt.map(bitmap_bucket_number)) + .collect(); + Ok(Arc::new(result)) + } + data_type => { + internal_err!("bitmap_bucket_number does not support {data_type}") + } + } +} + +const NUM_BYTES: i64 = 4 * 1024; +const NUM_BITS: i64 = NUM_BYTES * 8; + +fn bitmap_bucket_number(value: i64) -> i64 { + if value > 0 { + 1 + (value - 1) / NUM_BITS + } else { + value / NUM_BITS + } +} diff --git a/datafusion/spark/src/function/bitmap/mod.rs b/datafusion/spark/src/function/bitmap/mod.rs index 8532c32ac9c5..4992992aeae8 100644 --- a/datafusion/spark/src/function/bitmap/mod.rs +++ b/datafusion/spark/src/function/bitmap/mod.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +pub mod bitmap_bit_position; +pub mod bitmap_bucket_number; pub mod bitmap_count; use datafusion_expr::ScalarUDF; @@ -22,6 +24,11 @@ use datafusion_functions::make_udf_function; use std::sync::Arc; make_udf_function!(bitmap_count::BitmapCount, bitmap_count); +make_udf_function!(bitmap_bit_position::BitmapBitPosition, bitmap_bit_position); +make_udf_function!( + bitmap_bucket_number::BitmapBucketNumber, + bitmap_bucket_number +); pub mod expr_fn { use datafusion_functions::export_functions; @@ -31,8 +38,22 @@ pub mod expr_fn { "Returns the number of set bits in the input bitmap.", arg )); + export_functions!(( + bitmap_bit_position, + "Returns the bit position for the given input child expression.", + arg + )); + export_functions!(( + bitmap_bucket_number, + "Returns the bucket number for the given input child expression.", + arg + )); } pub fn functions() -> Vec> { - vec![bitmap_count()] + vec![ + bitmap_count(), + bitmap_bit_position(), + bitmap_bucket_number(), + ] } diff --git a/datafusion/spark/src/function/bitwise/bitwise_not.rs b/datafusion/spark/src/function/bitwise/bitwise_not.rs index 5f8cf36911f4..e7285d480495 100644 --- a/datafusion/spark/src/function/bitwise/bitwise_not.rs +++ b/datafusion/spark/src/function/bitwise/bitwise_not.rs @@ -73,25 +73,11 @@ impl ScalarUDFImpl for SparkBitwiseNot { } fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { - if args.arg_fields.len() != 1 { - return plan_err!("bitwise_not expects exactly 1 argument"); - } - - let input_field = &args.arg_fields[0]; - - let out_dt = input_field.data_type().clone(); - let mut out_nullable = input_field.is_nullable(); - - let scalar_null_present = args - .scalar_arguments - .iter() - .any(|opt_s| opt_s.is_some_and(|sv| sv.is_null())); - - if scalar_null_present { - out_nullable = true; - } - - Ok(Arc::new(Field::new(self.name(), out_dt, out_nullable))) + Ok(Arc::new(Field::new( + self.name(), + args.arg_fields[0].data_type().clone(), + args.arg_fields[0].is_nullable(), + ))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -196,32 +182,4 @@ mod tests { assert!(out_i64_null.is_nullable()); assert_eq!(out_i64_null.data_type(), &DataType::Int64); } - - #[test] - fn test_bitwise_not_nullability_with_null_scalar() -> Result<()> { - use arrow::datatypes::{DataType, Field}; - use datafusion_common::ScalarValue; - use std::sync::Arc; - - let func = SparkBitwiseNot::new(); - - let non_nullable: FieldRef = Arc::new(Field::new("col", DataType::Int32, false)); - - let out = func.return_field_from_args(ReturnFieldArgs { - arg_fields: &[Arc::clone(&non_nullable)], - scalar_arguments: &[None], - })?; - assert!(!out.is_nullable()); - assert_eq!(out.data_type(), &DataType::Int32); - - let null_scalar = ScalarValue::Int32(None); - let out_with_null_scalar = func.return_field_from_args(ReturnFieldArgs { - arg_fields: &[Arc::clone(&non_nullable)], - scalar_arguments: &[Some(&null_scalar)], - })?; - assert!(out_with_null_scalar.is_nullable()); - assert_eq!(out_with_null_scalar.data_type(), &DataType::Int32); - - Ok(()) - } } diff --git a/datafusion/spark/src/function/collection/mod.rs b/datafusion/spark/src/function/collection/mod.rs index a87df9a2c87a..6871e3aba646 100644 --- a/datafusion/spark/src/function/collection/mod.rs +++ b/datafusion/spark/src/function/collection/mod.rs @@ -15,11 +15,20 @@ // specific language governing permissions and limitations // under the License. +pub mod size; + use datafusion_expr::ScalarUDF; +use datafusion_functions::make_udf_function; use std::sync::Arc; -pub mod expr_fn {} +make_udf_function!(size::SparkSize, size); + +pub mod expr_fn { + use datafusion_functions::export_functions; + + export_functions!((size, "Return the size of an array or map.", arg)); +} pub fn functions() -> Vec> { - vec![] + vec![size()] } diff --git a/datafusion/spark/src/function/collection/size.rs b/datafusion/spark/src/function/collection/size.rs new file mode 100644 index 000000000000..05b8ba315675 --- /dev/null +++ b/datafusion/spark/src/function/collection/size.rs @@ -0,0 +1,162 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, AsArray, Int32Array}; +use arrow::compute::kernels::length::length as arrow_length; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::{Result, plan_err}; +use datafusion_expr::{ + ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ReturnFieldArgs, + ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; +use std::any::Any; +use std::sync::Arc; + +/// Spark-compatible `size` function. +/// +/// Returns the number of elements in an array or the number of key-value pairs in a map. +/// Returns -1 for null input (Spark behavior). +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkSize { + signature: Signature, +} + +impl Default for SparkSize { + fn default() -> Self { + Self::new() + } +} + +impl SparkSize { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + // Array Type + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array], + array_coercion: None, + }), + // Map Type + TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkSize { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "size" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { + // nullable=false for legacy behavior (NULL -> -1); set to input nullability for null-on-null + Ok(Arc::new(Field::new(self.name(), DataType::Int32, false))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_size_inner, vec![])(&args.args) + } +} + +fn spark_size_inner(args: &[ArrayRef]) -> Result { + let array = &args[0]; + + match array.data_type() { + DataType::List(_) => { + if array.null_count() == 0 { + Ok(arrow_length(array)?) + } else { + let list_array = array.as_list::(); + let lengths: Vec = list_array + .offsets() + .lengths() + .enumerate() + .map(|(i, len)| if array.is_null(i) { -1 } else { len as i32 }) + .collect(); + Ok(Arc::new(Int32Array::from(lengths))) + } + } + DataType::FixedSizeList(_, size) => { + if array.null_count() == 0 { + Ok(arrow_length(array)?) + } else { + let length: Vec = (0..array.len()) + .map(|i| if array.is_null(i) { -1 } else { *size }) + .collect(); + Ok(Arc::new(Int32Array::from(length))) + } + } + DataType::LargeList(_) => { + // Arrow length kernel returns Int64 for LargeList + let list_array = array.as_list::(); + if array.null_count() == 0 { + let lengths: Vec = list_array + .offsets() + .lengths() + .map(|len| len as i32) + .collect(); + Ok(Arc::new(Int32Array::from(lengths))) + } else { + let lengths: Vec = list_array + .offsets() + .lengths() + .enumerate() + .map(|(i, len)| if array.is_null(i) { -1 } else { len as i32 }) + .collect(); + Ok(Arc::new(Int32Array::from(lengths))) + } + } + DataType::Map(_, _) => { + let map_array = array.as_map(); + let length: Vec = if array.null_count() == 0 { + map_array + .offsets() + .lengths() + .map(|len| len as i32) + .collect() + } else { + map_array + .offsets() + .lengths() + .enumerate() + .map(|(i, len)| if array.is_null(i) { -1 } else { len as i32 }) + .collect() + }; + Ok(Arc::new(Int32Array::from(length))) + } + DataType::Null => Ok(Arc::new(Int32Array::from(vec![-1; array.len()]))), + dt => { + plan_err!("size function does not support type: {}", dt) + } + } +} diff --git a/datafusion/spark/src/function/conditional/if.rs b/datafusion/spark/src/function/conditional/if.rs index 906b0bc312f2..e423f8264ecc 100644 --- a/datafusion/spark/src/function/conditional/if.rs +++ b/datafusion/spark/src/function/conditional/if.rs @@ -86,7 +86,7 @@ impl ScalarUDFImpl for SparkIf { fn simplify( &self, args: Vec, - _info: &dyn datafusion_expr::simplify::SimplifyInfo, + _info: &datafusion_expr::simplify::SimplifyContext, ) -> Result { let condition = args[0].clone(); let then_expr = args[1].clone(); diff --git a/datafusion/spark/src/function/datetime/add_months.rs b/datafusion/spark/src/function/datetime/add_months.rs new file mode 100644 index 000000000000..fa9f6fa8db94 --- /dev/null +++ b/datafusion/spark/src/function/datetime/add_months.rs @@ -0,0 +1,95 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::ops::Add; +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, FieldRef, IntervalUnit}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + ColumnarValue, Expr, ExprSchemable, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, Volatility, +}; + +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkAddMonths { + signature: Signature, +} + +impl Default for SparkAddMonths { + fn default() -> Self { + Self::new() + } +} + +impl SparkAddMonths { + pub fn new() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Date32, DataType::Int32], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkAddMonths { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "add_months" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + + Ok(Arc::new(Field::new( + self.name(), + DataType::Date32, + nullable, + ))) + } + + fn simplify( + &self, + args: Vec, + info: &SimplifyContext, + ) -> Result { + let [date_arg, months_arg] = take_function_args("add_months", args)?; + let interval = months_arg + .cast_to(&DataType::Interval(IntervalUnit::YearMonth), info.schema())?; + Ok(ExprSimplifyResult::Simplified(date_arg.add(interval))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("invoke should not be called on a simplified add_months() function") + } +} diff --git a/datafusion/spark/src/function/datetime/date_add.rs b/datafusion/spark/src/function/datetime/date_add.rs index 78b9c904cee3..3745f77969f2 100644 --- a/datafusion/spark/src/function/datetime/date_add.rs +++ b/datafusion/spark/src/function/datetime/date_add.rs @@ -82,12 +82,7 @@ impl ScalarUDFImpl for SparkDateAdd { } fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { - let nullable = args.arg_fields.iter().any(|f| f.is_nullable()) - || args - .scalar_arguments - .iter() - .any(|arg| matches!(arg, Some(sv) if sv.is_null())); - + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); Ok(Arc::new(Field::new( self.name(), DataType::Date32, @@ -142,7 +137,6 @@ fn spark_date_add(args: &[ArrayRef]) -> Result { mod tests { use super::*; use arrow::datatypes::Field; - use datafusion_common::ScalarValue; #[test] fn test_date_add_non_nullable_inputs() { @@ -181,25 +175,4 @@ mod tests { assert_eq!(ret_field.data_type(), &DataType::Date32); assert!(ret_field.is_nullable()); } - - #[test] - fn test_date_add_null_scalar() { - let func = SparkDateAdd::new(); - let args = &[ - Arc::new(Field::new("date", DataType::Date32, false)), - Arc::new(Field::new("num", DataType::Int32, false)), - ]; - - let null_scalar = ScalarValue::Int32(None); - - let ret_field = func - .return_field_from_args(ReturnFieldArgs { - arg_fields: args, - scalar_arguments: &[None, Some(&null_scalar)], - }) - .unwrap(); - - assert_eq!(ret_field.data_type(), &DataType::Date32); - assert!(ret_field.is_nullable()); - } } diff --git a/datafusion/spark/src/function/datetime/date_diff.rs b/datafusion/spark/src/function/datetime/date_diff.rs new file mode 100644 index 000000000000..094c35eec56b --- /dev/null +++ b/datafusion/spark/src/function/datetime/date_diff.rs @@ -0,0 +1,119 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::types::{NativeType, logical_date, logical_string}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Coercion, ColumnarValue, Expr, ExprSchemable, Operator, ReturnFieldArgs, + ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, + binary_expr, +}; + +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkDateDiff { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkDateDiff { + fn default() -> Self { + Self::new() + } +} + +impl SparkDateDiff { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![ + Coercion::new_implicit( + TypeSignatureClass::Native(logical_date()), + vec![ + TypeSignatureClass::Native(logical_string()), + TypeSignatureClass::Timestamp, + ], + NativeType::Date, + ), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_date()), + vec![ + TypeSignatureClass::Native(logical_string()), + TypeSignatureClass::Timestamp, + ], + NativeType::Date, + ), + ], + Volatility::Immutable, + ), + aliases: vec!["datediff".to_string()], + } + } +} + +impl ScalarUDFImpl for SparkDateDiff { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "date_diff" + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + Ok(Arc::new(Field::new(self.name(), DataType::Int32, nullable))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!( + "Apache Spark `date_diff` should have been simplified to standard subtraction" + ) + } + + fn simplify( + &self, + args: Vec, + info: &SimplifyContext, + ) -> Result { + let [end, start] = take_function_args(self.name(), args)?; + let end = end.cast_to(&DataType::Date32, info.schema())?; + let start = start.cast_to(&DataType::Date32, info.schema())?; + Ok(ExprSimplifyResult::Simplified( + binary_expr(end, Operator::Minus, start) + .cast_to(&DataType::Int32, info.schema())?, + )) + } +} diff --git a/datafusion/spark/src/function/datetime/date_part.rs b/datafusion/spark/src/function/datetime/date_part.rs new file mode 100644 index 000000000000..e30a162ef42d --- /dev/null +++ b/datafusion/spark/src/function/datetime/date_part.rs @@ -0,0 +1,139 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::types::logical_date; +use datafusion_common::{ + Result, ScalarValue, internal_err, types::logical_string, utils::take_function_args, +}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Coercion, ColumnarValue, Expr, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, + Signature, TypeSignature, TypeSignatureClass, Volatility, +}; +use std::{any::Any, sync::Arc}; + +/// Wrapper around datafusion date_part function to handle +/// Spark behavior returning day of the week 1-indexed instead of 0-indexed and different part aliases. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkDatePart { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkDatePart { + fn default() -> Self { + Self::new() + } +} + +impl SparkDatePart { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Timestamp), + ]), + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_date())), + ]), + ], + Volatility::Immutable, + ), + aliases: vec![String::from("datepart")], + } + } +} + +impl ScalarUDFImpl for SparkDatePart { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "date_part" + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("Use return_field_from_args in this case instead.") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + + Ok(Arc::new(Field::new(self.name(), DataType::Int32, nullable))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("spark date_part should have been simplified to standard date_part") + } + + fn simplify( + &self, + args: Vec, + _info: &SimplifyContext, + ) -> Result { + let [part_expr, date_expr] = take_function_args(self.name(), args)?; + + let part = match part_expr.as_literal() { + Some(ScalarValue::Utf8(Some(v))) + | Some(ScalarValue::Utf8View(Some(v))) + | Some(ScalarValue::LargeUtf8(Some(v))) => v.to_lowercase(), + _ => { + return internal_err!( + "First argument of `DATE_PART` must be non-null scalar Utf8" + ); + } + }; + + // Map Spark-specific date part aliases to datafusion ones + let part = match part.as_str() { + "yearofweek" | "year_iso" => "isoyear", + "dayofweek" => "dow", + "dayofweek_iso" | "dow_iso" => "isodow", + other => other, + }; + + let part_expr = Expr::Literal(ScalarValue::new_utf8(part), None); + + let date_part_expr = Expr::ScalarFunction(ScalarFunction::new_udf( + datafusion_functions::datetime::date_part(), + vec![part_expr, date_expr], + )); + + match part { + // Add 1 for day-of-week parts to convert 0-indexed to 1-indexed + "dow" | "isodow" => Ok(ExprSimplifyResult::Simplified( + date_part_expr + Expr::Literal(ScalarValue::Int32(Some(1)), None), + )), + _ => Ok(ExprSimplifyResult::Simplified(date_part_expr)), + } + } +} diff --git a/datafusion/spark/src/function/datetime/date_sub.rs b/datafusion/spark/src/function/datetime/date_sub.rs index 34894317f67d..af1b8d5a4e91 100644 --- a/datafusion/spark/src/function/datetime/date_sub.rs +++ b/datafusion/spark/src/function/datetime/date_sub.rs @@ -75,12 +75,7 @@ impl ScalarUDFImpl for SparkDateSub { } fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { - let nullable = args.arg_fields.iter().any(|f| f.is_nullable()) - || args - .scalar_arguments - .iter() - .any(|arg| matches!(arg, Some(sv) if sv.is_null())); - + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); Ok(Arc::new(Field::new( self.name(), DataType::Date32, @@ -139,7 +134,6 @@ fn spark_date_sub(args: &[ArrayRef]) -> Result { #[cfg(test)] mod tests { use super::*; - use datafusion_common::ScalarValue; #[test] fn test_date_sub_nullability_non_nullable_args() { @@ -174,22 +168,4 @@ mod tests { assert!(result.is_nullable()); assert_eq!(result.data_type(), &DataType::Date32); } - - #[test] - fn test_date_sub_nullability_scalar_null_argument() { - let udf = SparkDateSub::new(); - let date_field = Arc::new(Field::new("d", DataType::Date32, false)); - let days_field = Arc::new(Field::new("n", DataType::Int32, false)); - let null_scalar = ScalarValue::Int32(None); - - let result = udf - .return_field_from_args(ReturnFieldArgs { - arg_fields: &[date_field, days_field], - scalar_arguments: &[None, Some(&null_scalar)], - }) - .unwrap(); - - assert!(result.is_nullable()); - assert_eq!(result.data_type(), &DataType::Date32); - } } diff --git a/datafusion/spark/src/function/datetime/date_trunc.rs b/datafusion/spark/src/function/datetime/date_trunc.rs new file mode 100644 index 000000000000..2199c90703b3 --- /dev/null +++ b/datafusion/spark/src/function/datetime/date_trunc.rs @@ -0,0 +1,172 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, FieldRef, TimeUnit}; +use datafusion_common::types::{NativeType, logical_string}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, internal_err, plan_err}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Coercion, ColumnarValue, Expr, ExprSchemable, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, +}; + +/// Spark date_trunc supports extra format aliases. +/// It also handles timestamps with timezones by converting to session timezone first. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkDateTrunc { + signature: Signature, +} + +impl Default for SparkDateTrunc { + fn default() -> Self { + Self::new() + } +} + +impl SparkDateTrunc { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_implicit( + TypeSignatureClass::Timestamp, + vec![TypeSignatureClass::Native(logical_string())], + NativeType::Timestamp(TimeUnit::Microsecond, None), + ), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkDateTrunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "date_trunc" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + + Ok(Arc::new(Field::new( + self.name(), + args.arg_fields[1].data_type().clone(), + nullable, + ))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!( + "spark date_trunc should have been simplified to standard date_trunc" + ) + } + + fn simplify( + &self, + args: Vec, + info: &SimplifyContext, + ) -> Result { + let [fmt_expr, ts_expr] = take_function_args(self.name(), args)?; + + let fmt = match fmt_expr.as_literal() { + Some(ScalarValue::Utf8(Some(v))) + | Some(ScalarValue::Utf8View(Some(v))) + | Some(ScalarValue::LargeUtf8(Some(v))) => v.to_lowercase(), + _ => { + return plan_err!( + "First argument of `DATE_TRUNC` must be non-null scalar Utf8" + ); + } + }; + + // Map Spark-specific fmt aliases to datafusion ones + let fmt = match fmt.as_str() { + "yy" | "yyyy" => "year", + "mm" | "mon" => "month", + "dd" => "day", + other => other, + }; + + let session_tz = info.config_options().execution.time_zone.clone(); + let ts_type = ts_expr.get_type(info.schema())?; + + // Spark interprets timestamps in the session timezone before truncating, + // then returns a timestamp at microsecond precision. + // See: https://github.com/apache/spark/blob/f310f4fcc95580a6824bc7d22b76006f79b8804a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala#L492 + // + // For sub-second truncations (second, millisecond, microsecond), timezone + // adjustment is unnecessary since timezone offsets are whole seconds. + let ts_expr = match (&ts_type, fmt) { + // Sub-second truncations don't need timezone adjustment + (_, "second" | "millisecond" | "microsecond") => ts_expr, + + // convert to session timezone, strip timezone and convert back to original timezone + (DataType::Timestamp(unit, tz), _) => { + let ts_expr = match &session_tz { + Some(session_tz) => ts_expr.cast_to( + &DataType::Timestamp( + TimeUnit::Microsecond, + Some(Arc::from(session_tz.as_str())), + ), + info.schema(), + )?, + None => ts_expr, + }; + Expr::ScalarFunction(ScalarFunction::new_udf( + datafusion_functions::datetime::to_local_time(), + vec![ts_expr], + )) + .cast_to(&DataType::Timestamp(*unit, tz.clone()), info.schema())? + } + + _ => { + return plan_err!( + "Second argument of `DATE_TRUNC` must be Timestamp, got {}", + ts_type + ); + } + }; + + let fmt_expr = Expr::Literal(ScalarValue::new_utf8(fmt), None); + + Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction( + ScalarFunction::new_udf( + datafusion_functions::datetime::date_trunc(), + vec![fmt_expr, ts_expr], + ), + ))) + } +} diff --git a/datafusion/spark/src/function/datetime/from_utc_timestamp.rs b/datafusion/spark/src/function/datetime/from_utc_timestamp.rs new file mode 100644 index 000000000000..77cc66da5f37 --- /dev/null +++ b/datafusion/spark/src/function/datetime/from_utc_timestamp.rs @@ -0,0 +1,195 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::timezone::Tz; +use arrow::array::{Array, ArrayRef, AsArray, PrimitiveBuilder, StringArrayType}; +use arrow::datatypes::TimeUnit; +use arrow::datatypes::{ + ArrowTimestampType, DataType, Field, FieldRef, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, +}; +use datafusion_common::types::{NativeType, logical_string}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, exec_datafusion_err, exec_err, internal_err}; +use datafusion_expr::{ + Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, + Signature, TypeSignatureClass, Volatility, +}; +use datafusion_functions::datetime::to_local_time::adjust_to_local_time; +use datafusion_functions::utils::make_scalar_function; + +/// Apache Spark `from_utc_timestamp` function. +/// +/// Interprets the given timestamp as UTC and converts it to the given timezone. +/// +/// Timestamp in Apache Spark represents number of microseconds from the Unix epoch, which is not +/// timezone-agnostic. So in Apache Spark this function just shift the timestamp value from UTC timezone to +/// the given timezone. +/// +/// See +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkFromUtcTimestamp { + signature: Signature, +} + +impl Default for SparkFromUtcTimestamp { + fn default() -> Self { + Self::new() + } +} + +impl SparkFromUtcTimestamp { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![ + Coercion::new_implicit( + TypeSignatureClass::Timestamp, + vec![TypeSignatureClass::Native(logical_string())], + NativeType::Timestamp(TimeUnit::Microsecond, None), + ), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkFromUtcTimestamp { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "from_utc_timestamp" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + + Ok(Arc::new(Field::new( + self.name(), + args.arg_fields[0].data_type().clone(), + nullable, + ))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_from_utc_timestamp, vec![])(&args.args) + } +} + +fn spark_from_utc_timestamp(args: &[ArrayRef]) -> Result { + let [timestamp, timezone] = take_function_args("from_utc_timestamp", args)?; + + match timestamp.data_type() { + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { + process_timestamp_with_tz_array::( + timestamp, + timezone, + tz_opt.clone(), + ) + } + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { + process_timestamp_with_tz_array::( + timestamp, + timezone, + tz_opt.clone(), + ) + } + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { + process_timestamp_with_tz_array::( + timestamp, + timezone, + tz_opt.clone(), + ) + } + DataType::Timestamp(TimeUnit::Second, tz_opt) => { + process_timestamp_with_tz_array::( + timestamp, + timezone, + tz_opt.clone(), + ) + } + ts_type => { + exec_err!("`from_utc_timestamp`: unsupported argument types: {ts_type}") + } + } +} + +fn process_timestamp_with_tz_array( + ts_array: &ArrayRef, + tz_array: &ArrayRef, + tz_opt: Option>, +) -> Result { + match tz_array.data_type() { + DataType::Utf8 => { + process_arrays::(tz_opt, ts_array, tz_array.as_string::()) + } + DataType::LargeUtf8 => { + process_arrays::(tz_opt, ts_array, tz_array.as_string::()) + } + DataType::Utf8View => { + process_arrays::(tz_opt, ts_array, tz_array.as_string_view()) + } + other => { + exec_err!("`from_utc_timestamp`: timezone must be a string type, got {other}") + } + } +} + +fn process_arrays<'a, T: ArrowTimestampType, S>( + return_tz_opt: Option>, + ts_array: &ArrayRef, + tz_array: &'a S, +) -> Result +where + &'a S: StringArrayType<'a>, +{ + let ts_primitive = ts_array.as_primitive::(); + let mut builder = PrimitiveBuilder::::with_capacity(ts_array.len()); + + for (ts_opt, tz_opt) in ts_primitive.iter().zip(tz_array.iter()) { + match (ts_opt, tz_opt) { + (Some(ts), Some(tz_str)) => { + let tz: Tz = tz_str.parse().map_err(|e| { + exec_datafusion_err!( + "`from_utc_timestamp`: invalid timezone '{tz_str}': {e}" + ) + })?; + let val = adjust_to_local_time::(ts, tz)?; + builder.append_value(val); + } + _ => builder.append_null(), + } + } + + builder = builder.with_timezone_opt(return_tz_opt); + Ok(Arc::new(builder.finish())) +} diff --git a/datafusion/spark/src/function/datetime/last_day.rs b/datafusion/spark/src/function/datetime/last_day.rs index 40834ec345df..4c6f731db18a 100644 --- a/datafusion/spark/src/function/datetime/last_day.rs +++ b/datafusion/spark/src/function/datetime/last_day.rs @@ -114,7 +114,11 @@ impl ScalarUDFImpl for SparkLastDay { } fn spark_last_day(days: i32) -> Result { - let date = Date32Type::to_naive_date(days); + let date = Date32Type::to_naive_date_opt(days).ok_or_else(|| { + exec_datafusion_err!( + "Spark `last_day`: Unable to convert days value {days} to date" + ) + })?; let (year, month) = (date.year(), date.month()); let (next_year, next_month) = if month == 12 { diff --git a/datafusion/spark/src/function/datetime/mod.rs b/datafusion/spark/src/function/datetime/mod.rs index 849aa2089599..3133ed7337f2 100644 --- a/datafusion/spark/src/function/datetime/mod.rs +++ b/datafusion/spark/src/function/datetime/mod.rs @@ -15,20 +15,37 @@ // specific language governing permissions and limitations // under the License. +pub mod add_months; pub mod date_add; +pub mod date_diff; +pub mod date_part; pub mod date_sub; +pub mod date_trunc; pub mod extract; +pub mod from_utc_timestamp; pub mod last_day; pub mod make_dt_interval; pub mod make_interval; pub mod next_day; +pub mod time_trunc; +pub mod to_utc_timestamp; +pub mod trunc; +pub mod unix; use datafusion_expr::ScalarUDF; use datafusion_functions::make_udf_function; use std::sync::Arc; +make_udf_function!(add_months::SparkAddMonths, add_months); make_udf_function!(date_add::SparkDateAdd, date_add); +make_udf_function!(date_diff::SparkDateDiff, date_diff); +make_udf_function!(date_part::SparkDatePart, date_part); make_udf_function!(date_sub::SparkDateSub, date_sub); +make_udf_function!(date_trunc::SparkDateTrunc, date_trunc); +make_udf_function!( + from_utc_timestamp::SparkFromUtcTimestamp, + from_utc_timestamp +); make_udf_function!(extract::SparkHour, hour); make_udf_function!(extract::SparkMinute, minute); make_udf_function!(extract::SparkSecond, second); @@ -36,10 +53,34 @@ make_udf_function!(last_day::SparkLastDay, last_day); make_udf_function!(make_dt_interval::SparkMakeDtInterval, make_dt_interval); make_udf_function!(make_interval::SparkMakeInterval, make_interval); make_udf_function!(next_day::SparkNextDay, next_day); +make_udf_function!(time_trunc::SparkTimeTrunc, time_trunc); +make_udf_function!(to_utc_timestamp::SparkToUtcTimestamp, to_utc_timestamp); +make_udf_function!(trunc::SparkTrunc, trunc); +make_udf_function!(unix::SparkUnixDate, unix_date); +make_udf_function!( + unix::SparkUnixTimestamp, + unix_micros, + unix::SparkUnixTimestamp::microseconds +); +make_udf_function!( + unix::SparkUnixTimestamp, + unix_millis, + unix::SparkUnixTimestamp::milliseconds +); +make_udf_function!( + unix::SparkUnixTimestamp, + unix_seconds, + unix::SparkUnixTimestamp::seconds +); pub mod expr_fn { use datafusion_functions::export_functions; + export_functions!(( + add_months, + "Returns the date that is months months after start. The function returns NULL if at least one of the input parameters is NULL.", + arg1 arg2 + )); export_functions!(( date_add, "Returns the date that is days days after start. The function returns NULL if at least one of the input parameters is NULL.", @@ -83,18 +124,85 @@ pub mod expr_fn { "Returns the first date which is later than start_date and named as indicated. The function returns NULL if at least one of the input parameters is NULL.", arg1 arg2 )); + export_functions!(( + date_diff, + "Returns the number of days from start `start` to end `end`.", + end start + )); + export_functions!(( + date_trunc, + "Truncates a timestamp `ts` to the unit specified by the format `fmt`.", + fmt ts + )); + export_functions!(( + time_trunc, + "Truncates a time `t` to the unit specified by the format `fmt`.", + fmt t + )); + export_functions!(( + trunc, + "Truncates a date `dt` to the unit specified by the format `fmt`.", + dt fmt + )); + export_functions!(( + date_part, + "Extracts a part of the date or time from a date, time, or timestamp expression.", + arg1 arg2 + )); + export_functions!(( + from_utc_timestamp, + "Interpret a given timestamp `ts` in UTC timezone and then convert it to timezone `tz`.", + ts tz + )); + export_functions!(( + to_utc_timestamp, + "Interpret a given timestamp `ts` in timezone `tz` and then convert it to UTC timezone.", + ts tz + )); + export_functions!(( + unix_date, + "Returns the number of days since epoch (1970-01-01) for the given date `dt`.", + dt + )); + export_functions!(( + unix_micros, + "Returns the number of microseconds since epoch (1970-01-01 00:00:00 UTC) for the given timestamp `ts`.", + ts + )); + export_functions!(( + unix_millis, + "Returns the number of milliseconds since epoch (1970-01-01 00:00:00 UTC) for the given timestamp `ts`.", + ts + )); + export_functions!(( + unix_seconds, + "Returns the number of seconds since epoch (1970-01-01 00:00:00 UTC) for the given timestamp `ts`.", + ts + )); } pub fn functions() -> Vec> { vec![ + add_months(), date_add(), + date_diff(), + date_part(), date_sub(), + date_trunc(), + from_utc_timestamp(), hour(), - minute(), - second(), last_day(), make_dt_interval(), make_interval(), + minute(), next_day(), + second(), + time_trunc(), + to_utc_timestamp(), + trunc(), + unix_date(), + unix_micros(), + unix_millis(), + unix_seconds(), ] } diff --git a/datafusion/spark/src/function/datetime/next_day.rs b/datafusion/spark/src/function/datetime/next_day.rs index 2acd295f8f14..a456a7831597 100644 --- a/datafusion/spark/src/function/datetime/next_day.rs +++ b/datafusion/spark/src/function/datetime/next_day.rs @@ -18,7 +18,7 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, AsArray, Date32Array, StringArrayType, new_null_array}; +use arrow::array::{ArrayRef, AsArray, Date32Array, StringArrayType}; use arrow::datatypes::{DataType, Date32Type, Field, FieldRef}; use chrono::{Datelike, Duration, Weekday}; use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; @@ -129,10 +129,7 @@ impl ScalarUDFImpl for SparkNextDay { } else { // TODO: if spark.sql.ansi.enabled is false, // returns NULL instead of an error for a malformed dayOfWeek. - Ok(ColumnarValue::Array(Arc::new(new_null_array( - &DataType::Date32, - date_array.len(), - )))) + Ok(ColumnarValue::Scalar(ScalarValue::Date32(None))) } } _ => exec_err!( @@ -216,7 +213,7 @@ where } fn spark_next_day(days: i32, day_of_week: &str) -> Option { - let date = Date32Type::to_naive_date(days); + let date = Date32Type::to_naive_date_opt(days)?; let day_of_week = day_of_week.trim().to_uppercase(); let day_of_week = match day_of_week.as_str() { diff --git a/datafusion/spark/src/function/datetime/time_trunc.rs b/datafusion/spark/src/function/datetime/time_trunc.rs new file mode 100644 index 000000000000..718502a05ee6 --- /dev/null +++ b/datafusion/spark/src/function/datetime/time_trunc.rs @@ -0,0 +1,122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::types::logical_string; +use datafusion_common::{Result, ScalarValue, internal_err, plan_err}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Coercion, ColumnarValue, Expr, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, + Signature, TypeSignatureClass, Volatility, +}; + +/// Spark time_trunc function only handles time inputs. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkTimeTrunc { + signature: Signature, +} + +impl Default for SparkTimeTrunc { + fn default() -> Self { + Self::new() + } +} + +impl SparkTimeTrunc { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Time), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkTimeTrunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "time_trunc" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + + Ok(Arc::new(Field::new( + self.name(), + args.arg_fields[1].data_type().clone(), + nullable, + ))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!( + "spark time_trunc should have been simplified to standard date_trunc" + ) + } + + fn simplify( + &self, + args: Vec, + _info: &SimplifyContext, + ) -> Result { + let fmt_expr = &args[0]; + + let fmt = match fmt_expr.as_literal() { + Some(ScalarValue::Utf8(Some(v))) + | Some(ScalarValue::Utf8View(Some(v))) + | Some(ScalarValue::LargeUtf8(Some(v))) => v.to_lowercase(), + _ => { + return plan_err!( + "First argument of `TIME_TRUNC` must be non-null scalar Utf8" + ); + } + }; + + if !matches!( + fmt.as_str(), + "hour" | "minute" | "second" | "millisecond" | "microsecond" + ) { + return plan_err!( + "The format argument of `TIME_TRUNC` must be one of: hour, minute, second, millisecond, microsecond" + ); + } + + Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction( + ScalarFunction::new_udf(datafusion_functions::datetime::date_trunc(), args), + ))) + } +} diff --git a/datafusion/spark/src/function/datetime/to_utc_timestamp.rs b/datafusion/spark/src/function/datetime/to_utc_timestamp.rs new file mode 100644 index 000000000000..0e8c267a390e --- /dev/null +++ b/datafusion/spark/src/function/datetime/to_utc_timestamp.rs @@ -0,0 +1,225 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::timezone::Tz; +use arrow::array::{Array, ArrayRef, AsArray, PrimitiveBuilder, StringArrayType}; +use arrow::datatypes::TimeUnit; +use arrow::datatypes::{ + ArrowTimestampType, DataType, Field, FieldRef, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, +}; +use chrono::{DateTime, Offset, TimeZone}; +use datafusion_common::types::{NativeType, logical_string}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{ + Result, exec_datafusion_err, exec_err, internal_datafusion_err, internal_err, +}; +use datafusion_expr::{ + Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, + Signature, TypeSignatureClass, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; + +/// Apache Spark `to_utc_timestamp` function. +/// +/// Interprets the given timestamp in the provided timezone and then converts it to UTC. +/// +/// Timestamp in Apache Spark represents number of microseconds from the Unix epoch, which is not +/// timezone-agnostic. So in Apache Spark this function just shift the timestamp value from the given +/// timezone to UTC timezone. +/// +/// See +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkToUtcTimestamp { + signature: Signature, +} + +impl Default for SparkToUtcTimestamp { + fn default() -> Self { + Self::new() + } +} + +impl SparkToUtcTimestamp { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![ + Coercion::new_implicit( + TypeSignatureClass::Timestamp, + vec![TypeSignatureClass::Native(logical_string())], + NativeType::Timestamp(TimeUnit::Microsecond, None), + ), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkToUtcTimestamp { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "to_utc_timestamp" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + + Ok(Arc::new(Field::new( + self.name(), + args.arg_fields[0].data_type().clone(), + nullable, + ))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(to_utc_timestamp, vec![])(&args.args) + } +} + +fn to_utc_timestamp(args: &[ArrayRef]) -> Result { + let [timestamp, timezone] = take_function_args("to_utc_timestamp", args)?; + + match timestamp.data_type() { + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { + process_timestamp_with_tz_array::( + timestamp, + timezone, + tz_opt.clone(), + ) + } + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { + process_timestamp_with_tz_array::( + timestamp, + timezone, + tz_opt.clone(), + ) + } + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { + process_timestamp_with_tz_array::( + timestamp, + timezone, + tz_opt.clone(), + ) + } + DataType::Timestamp(TimeUnit::Second, tz_opt) => { + process_timestamp_with_tz_array::( + timestamp, + timezone, + tz_opt.clone(), + ) + } + ts_type => { + exec_err!("`to_utc_timestamp`: unsupported argument types: {ts_type}") + } + } +} + +fn process_timestamp_with_tz_array( + ts_array: &ArrayRef, + tz_array: &ArrayRef, + tz_opt: Option>, +) -> Result { + match tz_array.data_type() { + DataType::Utf8 => { + process_arrays::(tz_opt, ts_array, tz_array.as_string::()) + } + DataType::LargeUtf8 => { + process_arrays::(tz_opt, ts_array, tz_array.as_string::()) + } + DataType::Utf8View => { + process_arrays::(tz_opt, ts_array, tz_array.as_string_view()) + } + other => { + exec_err!("`to_utc_timestamp`: timezone must be a string type, got {other}") + } + } +} + +fn process_arrays<'a, T: ArrowTimestampType, S>( + return_tz_opt: Option>, + ts_array: &ArrayRef, + tz_array: &'a S, +) -> Result +where + &'a S: StringArrayType<'a>, +{ + let ts_primitive = ts_array.as_primitive::(); + let mut builder = PrimitiveBuilder::::with_capacity(ts_array.len()); + + for (ts_opt, tz_opt) in ts_primitive.iter().zip(tz_array.iter()) { + match (ts_opt, tz_opt) { + (Some(ts), Some(tz_str)) => { + let tz: Tz = tz_str.parse().map_err(|e| { + exec_datafusion_err!( + "`to_utc_timestamp`: invalid timezone '{tz_str}': {e}" + ) + })?; + let val = adjust_to_utc_time::(ts, tz)?; + builder.append_value(val); + } + _ => builder.append_null(), + } + } + + builder = builder.with_timezone_opt(return_tz_opt); + Ok(Arc::new(builder.finish())) +} + +fn adjust_to_utc_time(ts: i64, tz: Tz) -> Result { + let dt = match T::UNIT { + TimeUnit::Nanosecond => Some(DateTime::from_timestamp_nanos(ts)), + TimeUnit::Microsecond => DateTime::from_timestamp_micros(ts), + TimeUnit::Millisecond => DateTime::from_timestamp_millis(ts), + TimeUnit::Second => DateTime::from_timestamp(ts, 0), + } + .ok_or_else(|| internal_datafusion_err!("Invalid timestamp"))?; + let naive_dt = dt.naive_utc(); + + let offset_seconds = tz + .offset_from_utc_datetime(&naive_dt) + .fix() + .local_minus_utc() as i64; + + let offset_in_unit = match T::UNIT { + TimeUnit::Nanosecond => offset_seconds.checked_mul(1_000_000_000), + TimeUnit::Microsecond => offset_seconds.checked_mul(1_000_000), + TimeUnit::Millisecond => offset_seconds.checked_mul(1_000), + TimeUnit::Second => Some(offset_seconds), + } + .ok_or_else(|| internal_datafusion_err!("Offset overflow"))?; + + ts.checked_sub(offset_in_unit).ok_or_else(|| { + internal_datafusion_err!("Timestamp overflow during timezone adjustment") + }) +} diff --git a/datafusion/spark/src/function/datetime/trunc.rs b/datafusion/spark/src/function/datetime/trunc.rs new file mode 100644 index 000000000000..b584cc9a70d4 --- /dev/null +++ b/datafusion/spark/src/function/datetime/trunc.rs @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, FieldRef, TimeUnit}; +use datafusion_common::types::{NativeType, logical_date, logical_string}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, internal_err, plan_err}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Coercion, ColumnarValue, Expr, ExprSchemable, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, +}; + +/// Spark trunc supports date inputs only and extra format aliases. +/// Also spark trunc's argument order is (date, format). +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkTrunc { + signature: Signature, +} + +impl Default for SparkTrunc { + fn default() -> Self { + Self::new() + } +} + +impl SparkTrunc { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![ + Coercion::new_implicit( + TypeSignatureClass::Native(logical_date()), + vec![TypeSignatureClass::Native(logical_string())], + NativeType::Date, + ), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkTrunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "trunc" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + + Ok(Arc::new(Field::new( + self.name(), + args.arg_fields[0].data_type().clone(), + nullable, + ))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("spark trunc should have been simplified to standard date_trunc") + } + + fn simplify( + &self, + args: Vec, + info: &SimplifyContext, + ) -> Result { + let [dt_expr, fmt_expr] = take_function_args(self.name(), args)?; + + let fmt = match fmt_expr.as_literal() { + Some(ScalarValue::Utf8(Some(v))) + | Some(ScalarValue::Utf8View(Some(v))) + | Some(ScalarValue::LargeUtf8(Some(v))) => v.to_lowercase(), + _ => { + return plan_err!( + "Second argument of `TRUNC` must be non-null scalar Utf8" + ); + } + }; + + // Map Spark-specific fmt aliases to datafusion ones + let fmt = match fmt.as_str() { + "yy" | "yyyy" => "year", + "mm" | "mon" => "month", + "year" | "month" | "day" | "week" | "quarter" => fmt.as_str(), + _ => { + return plan_err!( + "The format argument of `TRUNC` must be one of: year, yy, yyyy, month, mm, mon, day, week, quarter." + ); + } + }; + let return_type = dt_expr.get_type(info.schema())?; + + let fmt_expr = Expr::Literal(ScalarValue::new_utf8(fmt), None); + + // Spark uses Dates so we need to cast to timestamp and back to work with datafusion's date_trunc + Ok(ExprSimplifyResult::Simplified( + Expr::ScalarFunction(ScalarFunction::new_udf( + datafusion_functions::datetime::date_trunc(), + vec![ + fmt_expr, + dt_expr.cast_to( + &DataType::Timestamp(TimeUnit::Nanosecond, None), + info.schema(), + )?, + ], + )) + .cast_to(&return_type, info.schema())?, + )) + } +} diff --git a/datafusion/spark/src/function/datetime/unix.rs b/datafusion/spark/src/function/datetime/unix.rs new file mode 100644 index 000000000000..4254b2ed85d5 --- /dev/null +++ b/datafusion/spark/src/function/datetime/unix.rs @@ -0,0 +1,174 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, FieldRef, TimeUnit}; +use datafusion_common::types::logical_date; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Coercion, ColumnarValue, Expr, ExprSchemable, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, +}; + +/// Returns the number of days since epoch (1970-01-01) for the given date. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkUnixDate { + signature: Signature, +} + +impl Default for SparkUnixDate { + fn default() -> Self { + Self::new() + } +} + +impl SparkUnixDate { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![Coercion::new_exact(TypeSignatureClass::Native( + logical_date(), + ))], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkUnixDate { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "unix_date" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields[0].is_nullable(); + Ok(Arc::new(Field::new(self.name(), DataType::Int32, nullable))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("invoke_with_args should not be called on SparkUnixDate") + } + + fn simplify( + &self, + args: Vec, + info: &SimplifyContext, + ) -> Result { + let [date] = take_function_args(self.name(), args)?; + Ok(ExprSimplifyResult::Simplified( + date.cast_to(&DataType::Date32, info.schema())? + .cast_to(&DataType::Int32, info.schema())?, + )) + } +} + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkUnixTimestamp { + time_unit: TimeUnit, + signature: Signature, + name: &'static str, +} + +impl SparkUnixTimestamp { + pub fn new(name: &'static str, time_unit: TimeUnit) -> Self { + Self { + signature: Signature::coercible( + vec![Coercion::new_exact(TypeSignatureClass::Timestamp)], + Volatility::Immutable, + ), + time_unit, + name, + } + } + + /// Returns the number of microseconds since epoch (1970-01-01 00:00:00 UTC) for the given timestamp. + /// + pub fn microseconds() -> Self { + Self::new("unix_micros", TimeUnit::Microsecond) + } + + /// Returns the number of milliseconds since epoch (1970-01-01 00:00:00 UTC) for the given timestamp. + /// + pub fn milliseconds() -> Self { + Self::new("unix_millis", TimeUnit::Millisecond) + } + + /// Returns the number of seconds since epoch (1970-01-01 00:00:00 UTC) for the given timestamp. + /// + pub fn seconds() -> Self { + Self::new("unix_seconds", TimeUnit::Second) + } +} + +impl ScalarUDFImpl for SparkUnixTimestamp { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields[0].is_nullable(); + Ok(Arc::new(Field::new(self.name(), DataType::Int64, nullable))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("invoke_with_args should not be called on `{}`", self.name()) + } + + fn simplify( + &self, + args: Vec, + info: &SimplifyContext, + ) -> Result { + let [ts] = take_function_args(self.name(), args)?; + Ok(ExprSimplifyResult::Simplified( + ts.cast_to( + &DataType::Timestamp(self.time_unit, Some("UTC".into())), + info.schema(), + )? + .cast_to(&DataType::Int64, info.schema())?, + )) + } +} diff --git a/datafusion/spark/src/function/hash/sha2.rs b/datafusion/spark/src/function/hash/sha2.rs index 1f1727506277..3fa41aba71b5 100644 --- a/datafusion/spark/src/function/hash/sha2.rs +++ b/datafusion/spark/src/function/hash/sha2.rs @@ -15,26 +15,29 @@ // specific language governing permissions and limitations // under the License. -extern crate datafusion_functions; - -use crate::function::error_utils::{ - invalid_arg_count_exec_err, unsupported_data_type_exec_err, -}; -use crate::function::math::hex::spark_sha2_hex; -use arrow::array::{ArrayRef, AsArray, StringArray}; +use arrow::array::{ArrayRef, AsArray, BinaryArrayType, Int32Array, StringArray}; use arrow::datatypes::{DataType, Int32Type}; -use datafusion_common::{Result, ScalarValue, exec_err, internal_datafusion_err}; -use datafusion_expr::Signature; -use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility}; -pub use datafusion_functions::crypto::basic::{sha224, sha256, sha384, sha512}; +use datafusion_common::types::{ + NativeType, logical_binary, logical_int32, logical_string, +}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, internal_err}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignatureClass, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; +use sha2::{self, Digest}; use std::any::Any; use std::sync::Arc; +/// Differs from DataFusion version in allowing array input for bit lengths, and +/// also hex encoding the output. +/// /// #[derive(Debug, PartialEq, Eq, Hash)] pub struct SparkSha2 { signature: Signature, - aliases: Vec, } impl Default for SparkSha2 { @@ -46,8 +49,21 @@ impl Default for SparkSha2 { impl SparkSha2 { pub fn new() -> Self { Self { - signature: Signature::user_defined(Volatility::Immutable), - aliases: vec![], + signature: Signature::coercible( + vec![ + Coercion::new_implicit( + TypeSignatureClass::Native(logical_binary()), + vec![TypeSignatureClass::Native(logical_string())], + NativeType::Binary, + ), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_int32()), + vec![TypeSignatureClass::Integer], + NativeType::Int32, + ), + ], + Volatility::Immutable, + ), } } } @@ -65,163 +81,188 @@ impl ScalarUDFImpl for SparkSha2 { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types[1].is_null() { - return Ok(DataType::Null); - } - Ok(match arg_types[0] { - DataType::Utf8View - | DataType::LargeUtf8 - | DataType::Utf8 - | DataType::Binary - | DataType::BinaryView - | DataType::LargeBinary => DataType::Utf8, - DataType::Null => DataType::Null, - _ => { - return exec_err!( - "{} function can only accept strings or binary arrays.", - self.name() - ); - } - }) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let args: [ColumnarValue; 2] = args.args.try_into().map_err(|_| { - internal_datafusion_err!("Expected 2 arguments for function sha2") - })?; + let [values, bit_lengths] = take_function_args(self.name(), args.args.iter())?; - sha2(args) - } + match (values, bit_lengths) { + ( + ColumnarValue::Scalar(value_scalar), + ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length))), + ) => { + if value_scalar.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } - fn aliases(&self) -> &[String] { - &self.aliases - } - - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 2 { - return Err(invalid_arg_count_exec_err( - self.name(), - (2, 2), - arg_types.len(), - )); - } - let expr_type = match &arg_types[0] { - DataType::Utf8View - | DataType::LargeUtf8 - | DataType::Utf8 - | DataType::Binary - | DataType::BinaryView - | DataType::LargeBinary - | DataType::Null => Ok(arg_types[0].clone()), - _ => Err(unsupported_data_type_exec_err( - self.name(), - "String, Binary", - &arg_types[0], - )), - }?; - let bit_length_type = if arg_types[1].is_numeric() { - Ok(DataType::Int32) - } else if arg_types[1].is_null() { - Ok(DataType::Null) - } else { - Err(unsupported_data_type_exec_err( - self.name(), - "Numeric Type", - &arg_types[1], - )) - }?; - - Ok(vec![expr_type, bit_length_type]) - } -} + // Accept both Binary and Utf8 scalars (depending on coercion) + let bytes = match value_scalar { + ScalarValue::Binary(Some(b)) => b.as_slice(), + ScalarValue::LargeBinary(Some(b)) => b.as_slice(), + ScalarValue::BinaryView(Some(b)) => b.as_slice(), + ScalarValue::Utf8(Some(s)) + | ScalarValue::LargeUtf8(Some(s)) + | ScalarValue::Utf8View(Some(s)) => s.as_bytes(), + other => { + return internal_err!( + "Unsupported scalar datatype for sha2: {}", + other.data_type() + ); + } + }; -pub fn sha2(args: [ColumnarValue; 2]) -> Result { - match args { - [ - ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), - ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg))), - ] => compute_sha2( - bit_length_arg, - &[ColumnarValue::from(ScalarValue::Utf8(expr_arg))], - ), - [ - ColumnarValue::Array(expr_arg), - ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg))), - ] => compute_sha2(bit_length_arg, &[ColumnarValue::from(expr_arg)]), - [ - ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), - ColumnarValue::Array(bit_length_arg), - ] => { - let arr: StringArray = bit_length_arg - .as_primitive::() - .iter() - .map(|bit_length| { - match sha2([ - ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg.clone())), - ColumnarValue::Scalar(ScalarValue::Int32(bit_length)), - ]) - .unwrap() - { - ColumnarValue::Scalar(ScalarValue::Utf8(str)) => str, - ColumnarValue::Array(arr) => arr - .as_string::() - .iter() - .map(|str| str.unwrap().to_string()) - .next(), // first element - _ => unreachable!(), + let out = match bit_length { + 224 => { + let mut digest = sha2::Sha224::default(); + digest.update(bytes); + Some(hex_encode(digest.finalize())) } - }) - .collect(); - Ok(ColumnarValue::Array(Arc::new(arr) as ArrayRef)) - } - [ - ColumnarValue::Array(expr_arg), - ColumnarValue::Array(bit_length_arg), - ] => { - let expr_iter = expr_arg.as_string::().iter(); - let bit_length_iter = bit_length_arg.as_primitive::().iter(); - let arr: StringArray = expr_iter - .zip(bit_length_iter) - .map(|(expr, bit_length)| { - match sha2([ - ColumnarValue::Scalar(ScalarValue::Utf8(Some( - expr.unwrap().to_string(), - ))), - ColumnarValue::Scalar(ScalarValue::Int32(bit_length)), - ]) - .unwrap() - { - ColumnarValue::Scalar(ScalarValue::Utf8(str)) => str, - ColumnarValue::Array(arr) => arr - .as_string::() - .iter() - .map(|str| str.unwrap().to_string()) - .next(), // first element - _ => unreachable!(), + 0 | 256 => { + let mut digest = sha2::Sha256::default(); + digest.update(bytes); + Some(hex_encode(digest.finalize())) + } + 384 => { + let mut digest = sha2::Sha384::default(); + digest.update(bytes); + Some(hex_encode(digest.finalize())) } - }) - .collect(); - Ok(ColumnarValue::Array(Arc::new(arr) as ArrayRef)) + 512 => { + let mut digest = sha2::Sha512::default(); + digest.update(bytes); + Some(hex_encode(digest.finalize())) + } + _ => None, + }; + + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(out))) + } + // Array values + scalar bit length (common case: sha2(col, 256)) + ( + ColumnarValue::Array(values_array), + ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length))), + ) => { + let output: ArrayRef = match values_array.data_type() { + DataType::Binary => sha2_binary_scalar_bitlen( + &values_array.as_binary::(), + *bit_length, + ), + DataType::LargeBinary => sha2_binary_scalar_bitlen( + &values_array.as_binary::(), + *bit_length, + ), + DataType::BinaryView => sha2_binary_scalar_bitlen( + &values_array.as_binary_view(), + *bit_length, + ), + dt => return internal_err!("Unsupported datatype for sha2: {dt}"), + }; + Ok(ColumnarValue::Array(output)) + } + ( + ColumnarValue::Scalar(_), + ColumnarValue::Scalar(ScalarValue::Int32(None)), + ) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), + ( + ColumnarValue::Array(_), + ColumnarValue::Scalar(ScalarValue::Int32(None)), + ) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), + _ => { + // Fallback to existing behavior for any array/mixed cases + make_scalar_function(sha2_impl, vec![])(&args.args) + } } - _ => exec_err!("Unsupported argument types for sha2 function"), } } -fn compute_sha2( - bit_length_arg: i32, - expr_arg: &[ColumnarValue], -) -> Result { - match bit_length_arg { - 0 | 256 => sha256(expr_arg), - 224 => sha224(expr_arg), - 384 => sha384(expr_arg), - 512 => sha512(expr_arg), - _ => { - // Return null for unsupported bit lengths instead of error, because spark sha2 does not - // error out for this. - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); +fn sha2_impl(args: &[ArrayRef]) -> Result { + let [values, bit_lengths] = take_function_args("sha2", args)?; + + let bit_lengths = bit_lengths.as_primitive::(); + let output = match values.data_type() { + DataType::Binary => sha2_binary_impl(&values.as_binary::(), bit_lengths), + DataType::LargeBinary => { + sha2_binary_impl(&values.as_binary::(), bit_lengths) } + DataType::BinaryView => sha2_binary_impl(&values.as_binary_view(), bit_lengths), + dt => return internal_err!("Unsupported datatype for sha2: {dt}"), + }; + Ok(output) +} + +fn sha2_binary_impl<'a, BinaryArrType>( + values: &BinaryArrType, + bit_lengths: &Int32Array, +) -> ArrayRef +where + BinaryArrType: BinaryArrayType<'a>, +{ + sha2_binary_bitlen_iter(values, bit_lengths.iter()) +} + +fn sha2_binary_scalar_bitlen<'a, BinaryArrType>( + values: &BinaryArrType, + bit_length: i32, +) -> ArrayRef +where + BinaryArrType: BinaryArrayType<'a>, +{ + sha2_binary_bitlen_iter(values, std::iter::repeat(Some(bit_length))) +} + +fn sha2_binary_bitlen_iter<'a, BinaryArrType, I>( + values: &BinaryArrType, + bit_lengths: I, +) -> ArrayRef +where + BinaryArrType: BinaryArrayType<'a>, + I: Iterator>, +{ + let array = values + .iter() + .zip(bit_lengths) + .map(|(value, bit_length)| match (value, bit_length) { + (Some(value), Some(224)) => { + let mut digest = sha2::Sha224::default(); + digest.update(value); + Some(hex_encode(digest.finalize())) + } + (Some(value), Some(0 | 256)) => { + let mut digest = sha2::Sha256::default(); + digest.update(value); + Some(hex_encode(digest.finalize())) + } + (Some(value), Some(384)) => { + let mut digest = sha2::Sha384::default(); + digest.update(value); + Some(hex_encode(digest.finalize())) + } + (Some(value), Some(512)) => { + let mut digest = sha2::Sha512::default(); + digest.update(value); + Some(hex_encode(digest.finalize())) + } + // Unknown bit-lengths go to null, same as in Spark + _ => None, + }) + .collect::(); + Arc::new(array) +} + +const HEX_CHARS: [u8; 16] = *b"0123456789abcdef"; + +#[inline] +fn hex_encode>(data: T) -> String { + let bytes = data.as_ref(); + let mut out = Vec::with_capacity(bytes.len() * 2); + for &b in bytes { + let hi = b >> 4; + let lo = b & 0x0F; + out.push(HEX_CHARS[hi as usize]); + out.push(HEX_CHARS[lo as usize]); } - .map(|hashed| spark_sha2_hex(&[hashed]).unwrap()) + // SAFETY: out contains only ASCII + unsafe { String::from_utf8_unchecked(out) } } diff --git a/datafusion/spark/src/function/json/json_tuple.rs b/datafusion/spark/src/function/json/json_tuple.rs new file mode 100644 index 000000000000..f3ba7e91ac3d --- /dev/null +++ b/datafusion/spark/src/function/json/json_tuple.rs @@ -0,0 +1,244 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, NullBufferBuilder, StringBuilder, StructArray}; +use arrow::datatypes::{DataType, Field, FieldRef, Fields}; +use datafusion_common::cast::as_string_array; +use datafusion_common::{Result, exec_err, internal_err}; +use datafusion_expr::{ + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; + +/// Spark-compatible `json_tuple` expression +/// +/// +/// +/// Extracts top-level fields from a JSON string and returns them as a struct. +/// +/// `json_tuple(json_string, field1, field2, ...) -> Struct` +/// +/// Note: In Spark, `json_tuple` is a Generator that produces multiple columns directly. +/// In DataFusion, a ScalarUDF can only return one value per row, so the result is wrapped +/// in a Struct. The caller (e.g. Comet) is expected to destructure the struct fields. +/// +/// - Returns NULL for each field that is missing from the JSON object +/// - Returns NULL for all fields if the input is NULL or not valid JSON +/// - Non-string JSON values are converted to their JSON string representation +/// - JSON `null` values are returned as NULL (not the string "null") +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct JsonTuple { + signature: Signature, +} + +impl Default for JsonTuple { + fn default() -> Self { + Self::new() + } +} + +impl JsonTuple { + pub fn new() -> Self { + Self { + signature: Signature::variadic(vec![DataType::Utf8], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for JsonTuple { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "json_tuple" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + if args.arg_fields.len() < 2 { + return exec_err!( + "json_tuple requires at least 2 arguments (json_string, field1), got {}", + args.arg_fields.len() + ); + } + + let num_fields = args.arg_fields.len() - 1; + let fields: Fields = (0..num_fields) + .map(|i| Field::new(format!("c{i}"), DataType::Utf8, true)) + .collect::>() + .into(); + + Ok(Arc::new(Field::new( + self.name(), + DataType::Struct(fields), + true, + ))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { + args: arg_values, + return_field, + .. + } = args; + let arrays = ColumnarValue::values_to_arrays(&arg_values)?; + let result = json_tuple_inner(&arrays, return_field.data_type())?; + + Ok(ColumnarValue::Array(result)) + } +} + +fn json_tuple_inner(args: &[ArrayRef], return_type: &DataType) -> Result { + let num_rows = args[0].len(); + let num_fields = args.len() - 1; + + let json_array = as_string_array(&args[0])?; + + let field_arrays = args[1..] + .iter() + .map(|arg| as_string_array(arg)) + .collect::>>()?; + + let mut builders: Vec = + (0..num_fields).map(|_| StringBuilder::new()).collect(); + + let mut null_buffer = NullBufferBuilder::new(num_rows); + + for row_idx in 0..num_rows { + if json_array.is_null(row_idx) { + for builder in &mut builders { + builder.append_null(); + } + null_buffer.append_null(); + continue; + } + + let json_str = json_array.value(row_idx); + match serde_json::from_str::(json_str) { + Ok(serde_json::Value::Object(map)) => { + null_buffer.append_non_null(); + for (field_idx, builder) in builders.iter_mut().enumerate() { + if field_arrays[field_idx].is_null(row_idx) { + builder.append_null(); + continue; + } + let field_name = field_arrays[field_idx].value(row_idx); + match map.get(field_name) { + Some(serde_json::Value::Null) => { + builder.append_null(); + } + Some(serde_json::Value::String(s)) => { + builder.append_value(s); + } + Some(other) => { + builder.append_value(other.to_string()); + } + None => { + builder.append_null(); + } + } + } + } + _ => { + for builder in &mut builders { + builder.append_null(); + } + null_buffer.append_null(); + } + } + } + + let struct_fields = match return_type { + DataType::Struct(fields) => fields.clone(), + _ => { + return internal_err!( + "json_tuple requires a Struct return type, got {:?}", + return_type + ); + } + }; + + let arrays: Vec = builders + .into_iter() + .map(|mut builder| Arc::new(builder.finish()) as ArrayRef) + .collect(); + + let struct_array = StructArray::try_new(struct_fields, arrays, null_buffer.finish())?; + + Ok(Arc::new(struct_array)) +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_expr::ReturnFieldArgs; + + #[test] + fn test_return_field_shape() { + let func = JsonTuple::new(); + let fields = vec![ + Arc::new(Field::new("json", DataType::Utf8, false)), + Arc::new(Field::new("f1", DataType::Utf8, false)), + Arc::new(Field::new("f2", DataType::Utf8, false)), + ]; + let result = func + .return_field_from_args(ReturnFieldArgs { + arg_fields: &fields, + scalar_arguments: &[None, None, None], + }) + .unwrap(); + + match result.data_type() { + DataType::Struct(inner) => { + assert_eq!(inner.len(), 2); + assert_eq!(inner[0].name(), "c0"); + assert_eq!(inner[1].name(), "c1"); + assert_eq!(inner[0].data_type(), &DataType::Utf8); + assert!(inner[0].is_nullable()); + } + other => panic!("Expected Struct, got {other:?}"), + } + } + + #[test] + fn test_too_few_args() { + let func = JsonTuple::new(); + let fields = vec![Arc::new(Field::new("json", DataType::Utf8, false))]; + let result = func.return_field_from_args(ReturnFieldArgs { + arg_fields: &fields, + scalar_arguments: &[None], + }); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("at least 2 arguments") + ); + } +} diff --git a/datafusion/spark/src/function/json/mod.rs b/datafusion/spark/src/function/json/mod.rs index a87df9a2c87a..01378235d7c6 100644 --- a/datafusion/spark/src/function/json/mod.rs +++ b/datafusion/spark/src/function/json/mod.rs @@ -15,11 +15,24 @@ // specific language governing permissions and limitations // under the License. +pub mod json_tuple; + use datafusion_expr::ScalarUDF; +use datafusion_functions::make_udf_function; use std::sync::Arc; -pub mod expr_fn {} +make_udf_function!(json_tuple::JsonTuple, json_tuple); + +pub mod expr_fn { + use datafusion_functions::export_functions; + + export_functions!(( + json_tuple, + "Extracts top-level fields from a JSON string and returns them as a struct.", + args, + )); +} pub fn functions() -> Vec> { - vec![] + vec![json_tuple()] } diff --git a/datafusion/spark/src/function/map/map_from_arrays.rs b/datafusion/spark/src/function/map/map_from_arrays.rs index f6ca02e2fe86..429ed272d772 100644 --- a/datafusion/spark/src/function/map/map_from_arrays.rs +++ b/datafusion/spark/src/function/map/map_from_arrays.rs @@ -96,9 +96,7 @@ impl ScalarUDFImpl for MapFromArrays { fn map_from_arrays_inner(args: &[ArrayRef]) -> Result { let [keys, values] = take_function_args("map_from_arrays", args)?; - if matches!(keys.data_type(), DataType::Null) - || matches!(values.data_type(), DataType::Null) - { + if *keys.data_type() == DataType::Null || *values.data_type() == DataType::Null { return Ok(cast( &NullArray::new(keys.len()), &map_type_from_key_value_types( diff --git a/datafusion/spark/src/function/map/mod.rs b/datafusion/spark/src/function/map/mod.rs index 2f596b19b422..c9ebed6f612e 100644 --- a/datafusion/spark/src/function/map/mod.rs +++ b/datafusion/spark/src/function/map/mod.rs @@ -17,6 +17,7 @@ pub mod map_from_arrays; pub mod map_from_entries; +pub mod str_to_map; mod utils; use datafusion_expr::ScalarUDF; @@ -25,6 +26,7 @@ use std::sync::Arc; make_udf_function!(map_from_arrays::MapFromArrays, map_from_arrays); make_udf_function!(map_from_entries::MapFromEntries, map_from_entries); +make_udf_function!(str_to_map::SparkStrToMap, str_to_map); pub mod expr_fn { use datafusion_functions::export_functions; @@ -40,8 +42,14 @@ pub mod expr_fn { "Creates a map from array>.", arg1 )); + + export_functions!(( + str_to_map, + "Creates a map after splitting the text into key/value pairs using delimiters.", + text pair_delim key_value_delim + )); } pub fn functions() -> Vec> { - vec![map_from_arrays(), map_from_entries()] + vec![map_from_arrays(), map_from_entries(), str_to_map()] } diff --git a/datafusion/spark/src/function/map/str_to_map.rs b/datafusion/spark/src/function/map/str_to_map.rs new file mode 100644 index 000000000000..b722fb7abd6b --- /dev/null +++ b/datafusion/spark/src/function/map/str_to_map.rs @@ -0,0 +1,266 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::collections::HashSet; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayRef, MapBuilder, MapFieldNames, StringArrayType, StringBuilder, +}; +use arrow::buffer::NullBuffer; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::cast::{ + as_large_string_array, as_string_array, as_string_view_array, +}; +use datafusion_common::{Result, exec_err, internal_err}; +use datafusion_expr::{ + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature, Volatility, +}; + +use crate::function::map::utils::map_type_from_key_value_types; + +const DEFAULT_PAIR_DELIM: &str = ","; +const DEFAULT_KV_DELIM: &str = ":"; + +/// Spark-compatible `str_to_map` expression +/// +/// +/// Creates a map from a string by splitting on delimiters. +/// str_to_map(text[, pairDelim[, keyValueDelim]]) -> Map +/// +/// - text: The input string +/// - pairDelim: Delimiter between key-value pairs (default: ',') +/// - keyValueDelim: Delimiter between key and value (default: ':') +/// +/// # Duplicate Key Handling +/// Uses EXCEPTION behavior (Spark 3.0+ default): errors on duplicate keys. +/// See `spark.sql.mapKeyDedupPolicy`: +/// +/// +/// TODO: Support configurable `spark.sql.mapKeyDedupPolicy` (LAST_WIN) in a follow-up PR. +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkStrToMap { + signature: Signature, +} + +impl Default for SparkStrToMap { + fn default() -> Self { + Self::new() + } +} + +impl SparkStrToMap { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + // str_to_map(text) + TypeSignature::String(1), + // str_to_map(text, pairDelim) + TypeSignature::String(2), + // str_to_map(text, pairDelim, keyValueDelim) + TypeSignature::String(3), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkStrToMap { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "str_to_map" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + let map_type = map_type_from_key_value_types(&DataType::Utf8, &DataType::Utf8); + Ok(Arc::new(Field::new(self.name(), map_type, nullable))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let arrays: Vec = ColumnarValue::values_to_arrays(&args.args)?; + let result = str_to_map_inner(&arrays)?; + Ok(ColumnarValue::Array(result)) + } +} + +fn str_to_map_inner(args: &[ArrayRef]) -> Result { + match args.len() { + 1 => match args[0].data_type() { + DataType::Utf8 => str_to_map_impl(as_string_array(&args[0])?, None, None), + DataType::LargeUtf8 => { + str_to_map_impl(as_large_string_array(&args[0])?, None, None) + } + DataType::Utf8View => { + str_to_map_impl(as_string_view_array(&args[0])?, None, None) + } + other => exec_err!( + "Unsupported data type {other:?} for str_to_map, \ + expected Utf8, LargeUtf8, or Utf8View" + ), + }, + 2 => match (args[0].data_type(), args[1].data_type()) { + (DataType::Utf8, DataType::Utf8) => str_to_map_impl( + as_string_array(&args[0])?, + Some(as_string_array(&args[1])?), + None, + ), + (DataType::LargeUtf8, DataType::LargeUtf8) => str_to_map_impl( + as_large_string_array(&args[0])?, + Some(as_large_string_array(&args[1])?), + None, + ), + (DataType::Utf8View, DataType::Utf8View) => str_to_map_impl( + as_string_view_array(&args[0])?, + Some(as_string_view_array(&args[1])?), + None, + ), + (t1, t2) => exec_err!( + "Unsupported data types ({t1:?}, {t2:?}) for str_to_map, \ + expected matching Utf8, LargeUtf8, or Utf8View" + ), + }, + 3 => match ( + args[0].data_type(), + args[1].data_type(), + args[2].data_type(), + ) { + (DataType::Utf8, DataType::Utf8, DataType::Utf8) => str_to_map_impl( + as_string_array(&args[0])?, + Some(as_string_array(&args[1])?), + Some(as_string_array(&args[2])?), + ), + (DataType::LargeUtf8, DataType::LargeUtf8, DataType::LargeUtf8) => { + str_to_map_impl( + as_large_string_array(&args[0])?, + Some(as_large_string_array(&args[1])?), + Some(as_large_string_array(&args[2])?), + ) + } + (DataType::Utf8View, DataType::Utf8View, DataType::Utf8View) => { + str_to_map_impl( + as_string_view_array(&args[0])?, + Some(as_string_view_array(&args[1])?), + Some(as_string_view_array(&args[2])?), + ) + } + (t1, t2, t3) => exec_err!( + "Unsupported data types ({t1:?}, {t2:?}, {t3:?}) for str_to_map, \ + expected matching Utf8, LargeUtf8, or Utf8View" + ), + }, + n => exec_err!("str_to_map expects 1-3 arguments, got {n}"), + } +} + +fn str_to_map_impl<'a, V: StringArrayType<'a> + Copy>( + text_array: V, + pair_delim_array: Option, + kv_delim_array: Option, +) -> Result { + let num_rows = text_array.len(); + + // Precompute combined null buffer from all input arrays. + // NullBuffer::union performs a bitmap-level AND, which is more efficient + // than checking per-row nullability inline. + let text_nulls = text_array.nulls().cloned(); + let pair_nulls = pair_delim_array.and_then(|a| a.nulls().cloned()); + let kv_nulls = kv_delim_array.and_then(|a| a.nulls().cloned()); + let combined_nulls = [text_nulls.as_ref(), pair_nulls.as_ref(), kv_nulls.as_ref()] + .into_iter() + .fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls)); + + // Use field names matching map_type_from_key_value_types: "key" and "value" + let field_names = MapFieldNames { + entry: "entries".to_string(), + key: "key".to_string(), + value: "value".to_string(), + }; + let mut map_builder = MapBuilder::new( + Some(field_names), + StringBuilder::new(), + StringBuilder::new(), + ); + + let mut seen_keys = HashSet::new(); + for row_idx in 0..num_rows { + if combined_nulls.as_ref().is_some_and(|n| n.is_null(row_idx)) { + map_builder.append(false)?; + continue; + } + + // Per-row delimiter extraction + let pair_delim = + pair_delim_array.map_or(DEFAULT_PAIR_DELIM, |a| a.value(row_idx)); + let kv_delim = kv_delim_array.map_or(DEFAULT_KV_DELIM, |a| a.value(row_idx)); + + let text = text_array.value(row_idx); + if text.is_empty() { + // Empty string -> map with empty key and NULL value (Spark behavior) + map_builder.keys().append_value(""); + map_builder.values().append_null(); + map_builder.append(true)?; + continue; + } + + seen_keys.clear(); + for pair in text.split(pair_delim) { + if pair.is_empty() { + continue; + } + + let mut kv_iter = pair.splitn(2, kv_delim); + let key = kv_iter.next().unwrap_or(""); + let value = kv_iter.next(); + + // TODO: Support LAST_WIN policy via spark.sql.mapKeyDedupPolicy config + // EXCEPTION policy: error on duplicate keys (Spark 3.0+ default) + if !seen_keys.insert(key) { + return exec_err!( + "Duplicate map key '{key}' was found, please check the input data. \ + If you want to remove the duplicated keys, you can set \ + spark.sql.mapKeyDedupPolicy to \"LAST_WIN\" so that the key \ + inserted at last takes precedence." + ); + } + + map_builder.keys().append_value(key); + match value { + Some(v) => map_builder.values().append_value(v), + None => map_builder.values().append_null(), + } + } + map_builder.append(true)?; + } + + Ok(Arc::new(map_builder.finish())) +} diff --git a/datafusion/spark/src/function/map/utils.rs b/datafusion/spark/src/function/map/utils.rs index 1a25ffb29568..28fa3227fd62 100644 --- a/datafusion/spark/src/function/map/utils.rs +++ b/datafusion/spark/src/function/map/utils.rs @@ -181,8 +181,8 @@ fn map_deduplicate_keys( let num_keys_entries = *next_keys_offset as usize - cur_keys_offset; let num_values_entries = *next_values_offset as usize - cur_values_offset; - let mut keys_mask_one = [false].repeat(num_keys_entries); - let mut values_mask_one = [false].repeat(num_values_entries); + let mut keys_mask_one = vec![false; num_keys_entries]; + let mut values_mask_one = vec![false; num_values_entries]; let key_is_valid = keys_nulls.is_none_or(|buf| buf.is_valid(row_idx)); let value_is_valid = values_nulls.is_none_or(|buf| buf.is_valid(row_idx)); diff --git a/datafusion/spark/src/function/math/abs.rs b/datafusion/spark/src/function/math/abs.rs index 101291ac5f66..5edb40ae8ae9 100644 --- a/datafusion/spark/src/function/math/abs.rs +++ b/datafusion/spark/src/function/math/abs.rs @@ -17,13 +17,15 @@ use arrow::array::*; use arrow::datatypes::{DataType, Field, FieldRef}; +use arrow::error::ArrowError; use datafusion_common::{DataFusionError, Result, ScalarValue, internal_err}; use datafusion_expr::{ ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; use datafusion_functions::{ - downcast_named_arg, make_abs_function, make_wrapping_abs_function, + downcast_named_arg, make_abs_function, make_try_abs_function, + make_wrapping_abs_function, }; use std::any::Any; use std::sync::Arc; @@ -34,8 +36,10 @@ use std::sync::Arc; /// Returns the absolute value of input /// Returns NULL if input is NULL, returns NaN if input is NaN. /// -/// TODOs: +/// Differences with DataFusion abs: /// - Spark's ANSI-compliant dialect, when off (i.e. `spark.sql.ansi.enabled=false`), taking absolute value on the minimal value of a signed integer returns the value as is. DataFusion's abs throws "DataFusion error: Arrow error: Compute error" on arithmetic overflow +/// +/// TODOs: /// - Spark's abs also supports ANSI interval types: YearMonthIntervalType and DayTimeIntervalType. DataFusion's abs doesn't. /// #[derive(Debug, PartialEq, Eq, Hash)] @@ -85,19 +89,39 @@ impl ScalarUDFImpl for SparkAbs { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - spark_abs(&args.args) + spark_abs(&args.args, args.config_options.execution.enable_ansi_mode) } } macro_rules! scalar_compute_op { - ($INPUT:ident, $SCALAR_TYPE:ident) => {{ - let result = $INPUT.wrapping_abs(); + ($ENABLE_ANSI_MODE:expr, $INPUT:ident, $SCALAR_TYPE:ident) => {{ + let result = if $ENABLE_ANSI_MODE { + $INPUT.checked_abs().ok_or_else(|| { + ArrowError::ComputeError(format!( + "{} overflow on abs({:?})", + stringify!($SCALAR_TYPE), + $INPUT + )) + })? + } else { + $INPUT.wrapping_abs() + }; Ok(ColumnarValue::Scalar(ScalarValue::$SCALAR_TYPE(Some( result, )))) }}; - ($INPUT:ident, $PRECISION:expr, $SCALE:expr, $SCALAR_TYPE:ident) => {{ - let result = $INPUT.wrapping_abs(); + ($ENABLE_ANSI_MODE:expr, $INPUT:ident, $PRECISION:expr, $SCALE:expr, $SCALAR_TYPE:ident) => {{ + let result = if $ENABLE_ANSI_MODE { + $INPUT.checked_abs().ok_or_else(|| { + ArrowError::ComputeError(format!( + "{} overflow on abs({:?})", + stringify!($SCALAR_TYPE), + $INPUT + )) + })? + } else { + $INPUT.wrapping_abs() + }; Ok(ColumnarValue::Scalar(ScalarValue::$SCALAR_TYPE( Some(result), $PRECISION, @@ -106,7 +130,10 @@ macro_rules! scalar_compute_op { }}; } -pub fn spark_abs(args: &[ColumnarValue]) -> Result { +pub fn spark_abs( + args: &[ColumnarValue], + enable_ansi_mode: bool, +) -> Result { if args.len() != 1 { return internal_err!("abs takes exactly 1 argument, but got: {}", args.len()); } @@ -119,19 +146,35 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result Ok(args[0].clone()), DataType::Int8 => { - let abs_fun = make_wrapping_abs_function!(Int8Array); + let abs_fun = if enable_ansi_mode { + make_try_abs_function!(Int8Array) + } else { + make_wrapping_abs_function!(Int8Array) + }; abs_fun(array).map(ColumnarValue::Array) } DataType::Int16 => { - let abs_fun = make_wrapping_abs_function!(Int16Array); + let abs_fun = if enable_ansi_mode { + make_try_abs_function!(Int16Array) + } else { + make_wrapping_abs_function!(Int16Array) + }; abs_fun(array).map(ColumnarValue::Array) } DataType::Int32 => { - let abs_fun = make_wrapping_abs_function!(Int32Array); + let abs_fun = if enable_ansi_mode { + make_try_abs_function!(Int32Array) + } else { + make_wrapping_abs_function!(Int32Array) + }; abs_fun(array).map(ColumnarValue::Array) } DataType::Int64 => { - let abs_fun = make_wrapping_abs_function!(Int64Array); + let abs_fun = if enable_ansi_mode { + make_try_abs_function!(Int64Array) + } else { + make_wrapping_abs_function!(Int64Array) + }; abs_fun(array).map(ColumnarValue::Array) } DataType::Float32 => { @@ -143,11 +186,19 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result { - let abs_fun = make_wrapping_abs_function!(Decimal128Array); + let abs_fun = if enable_ansi_mode { + make_try_abs_function!(Decimal128Array) + } else { + make_wrapping_abs_function!(Decimal128Array) + }; abs_fun(array).map(ColumnarValue::Array) } DataType::Decimal256(_, _) => { - let abs_fun = make_wrapping_abs_function!(Decimal256Array); + let abs_fun = if enable_ansi_mode { + make_try_abs_function!(Decimal256Array) + } else { + make_wrapping_abs_function!(Decimal256Array) + }; abs_fun(array).map(ColumnarValue::Array) } dt => internal_err!("Not supported datatype for Spark ABS: {dt}"), @@ -159,10 +210,10 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result Ok(args[0].clone()), sv if sv.is_null() => Ok(args[0].clone()), - ScalarValue::Int8(Some(v)) => scalar_compute_op!(v, Int8), - ScalarValue::Int16(Some(v)) => scalar_compute_op!(v, Int16), - ScalarValue::Int32(Some(v)) => scalar_compute_op!(v, Int32), - ScalarValue::Int64(Some(v)) => scalar_compute_op!(v, Int64), + ScalarValue::Int8(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int8), + ScalarValue::Int16(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int16), + ScalarValue::Int32(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int32), + ScalarValue::Int64(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int64), ScalarValue::Float32(Some(v)) => { Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(v.abs())))) } @@ -170,10 +221,10 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result { - scalar_compute_op!(v, *precision, *scale, Decimal128) + scalar_compute_op!(enable_ansi_mode, v, *precision, *scale, Decimal128) } ScalarValue::Decimal256(Some(v), precision, scale) => { - scalar_compute_op!(v, *precision, *scale, Decimal256) + scalar_compute_op!(enable_ansi_mode, v, *precision, *scale, Decimal256) } dt => internal_err!("Not supported datatype for Spark ABS: {dt}"), }, @@ -185,100 +236,12 @@ mod tests { use super::*; use arrow::datatypes::i256; - macro_rules! eval_legacy_mode { - ($TYPE:ident, $VAL:expr) => {{ - let args = ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL))); - match spark_abs(&[args]) { - Ok(ColumnarValue::Scalar(ScalarValue::$TYPE(Some(result)))) => { - assert_eq!(result, $VAL); - } - _ => unreachable!(), - } - }}; - ($TYPE:ident, $VAL:expr, $RESULT:expr) => {{ - let args = ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL))); - match spark_abs(&[args]) { - Ok(ColumnarValue::Scalar(ScalarValue::$TYPE(Some(result)))) => { - assert_eq!(result, $RESULT); - } - _ => unreachable!(), - } - }}; - ($TYPE:ident, $VAL:expr, $PRECISION:expr, $SCALE:expr) => {{ - let args = - ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL), $PRECISION, $SCALE)); - match spark_abs(&[args]) { - Ok(ColumnarValue::Scalar(ScalarValue::$TYPE( - Some(result), - precision, - scale, - ))) => { - assert_eq!(result, $VAL); - assert_eq!(precision, $PRECISION); - assert_eq!(scale, $SCALE); - } - _ => unreachable!(), - } - }}; - ($TYPE:ident, $VAL:expr, $PRECISION:expr, $SCALE:expr, $RESULT:expr) => {{ - let args = - ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL), $PRECISION, $SCALE)); - match spark_abs(&[args]) { - Ok(ColumnarValue::Scalar(ScalarValue::$TYPE( - Some(result), - precision, - scale, - ))) => { - assert_eq!(result, $RESULT); - assert_eq!(precision, $PRECISION); - assert_eq!(scale, $SCALE); - } - _ => unreachable!(), - } - }}; - } - - #[test] - fn test_abs_scalar_legacy_mode() { - // NumericType MIN - eval_legacy_mode!(UInt8, u8::MIN); - eval_legacy_mode!(UInt16, u16::MIN); - eval_legacy_mode!(UInt32, u32::MIN); - eval_legacy_mode!(UInt64, u64::MIN); - eval_legacy_mode!(Int8, i8::MIN); - eval_legacy_mode!(Int16, i16::MIN); - eval_legacy_mode!(Int32, i32::MIN); - eval_legacy_mode!(Int64, i64::MIN); - eval_legacy_mode!(Float32, f32::MIN, f32::MAX); - eval_legacy_mode!(Float64, f64::MIN, f64::MAX); - eval_legacy_mode!(Decimal128, i128::MIN, 18, 10); - eval_legacy_mode!(Decimal256, i256::MIN, 10, 2); - - // NumericType not MIN - eval_legacy_mode!(Int8, -1i8, 1i8); - eval_legacy_mode!(Int16, -1i16, 1i16); - eval_legacy_mode!(Int32, -1i32, 1i32); - eval_legacy_mode!(Int64, -1i64, 1i64); - eval_legacy_mode!(Decimal128, -1i128, 18, 10, 1i128); - eval_legacy_mode!(Decimal256, i256::from(-1i8), 10, 2, i256::from(1i8)); - - // Float32, Float64 - eval_legacy_mode!(Float32, f32::NEG_INFINITY, f32::INFINITY); - eval_legacy_mode!(Float32, f32::INFINITY, f32::INFINITY); - eval_legacy_mode!(Float32, 0.0f32, 0.0f32); - eval_legacy_mode!(Float32, -0.0f32, 0.0f32); - eval_legacy_mode!(Float64, f64::NEG_INFINITY, f64::INFINITY); - eval_legacy_mode!(Float64, f64::INFINITY, f64::INFINITY); - eval_legacy_mode!(Float64, 0.0f64, 0.0f64); - eval_legacy_mode!(Float64, -0.0f64, 0.0f64); - } - macro_rules! eval_array_legacy_mode { ($INPUT:expr, $OUTPUT:expr, $FUNC:ident) => {{ let input = $INPUT; let args = ColumnarValue::Array(Arc::new(input)); let expected = $OUTPUT; - match spark_abs(&[args]) { + match spark_abs(&[args], false) { Ok(ColumnarValue::Array(result)) => { let actual = datafusion_common::cast::$FUNC(&result).unwrap(); assert_eq!(actual, &expected); @@ -367,24 +330,187 @@ mod tests { ); eval_array_legacy_mode!( - Decimal128Array::from(vec![Some(i128::MIN), None]) + Decimal128Array::from(vec![Some(i128::MIN), Some(i128::MIN + 1), None]) .with_precision_and_scale(38, 37) .unwrap(), - Decimal128Array::from(vec![Some(i128::MIN), None]) + Decimal128Array::from(vec![Some(i128::MIN), Some(i128::MAX), None]) .with_precision_and_scale(38, 37) .unwrap(), as_decimal128_array ); eval_array_legacy_mode!( - Decimal256Array::from(vec![Some(i256::MIN), None]) - .with_precision_and_scale(5, 2) + Decimal256Array::from(vec![ + Some(i256::MIN), + Some(i256::MINUS_ONE), + Some(i256::MIN + i256::from(1)), + None + ]) + .with_precision_and_scale(5, 2) + .unwrap(), + Decimal256Array::from(vec![ + Some(i256::MIN), + Some(i256::ONE), + Some(i256::MAX), + None + ]) + .with_precision_and_scale(5, 2) + .unwrap(), + as_decimal256_array + ); + } + + macro_rules! eval_array_ansi_mode { + ($INPUT:expr) => {{ + let input = $INPUT; + let args = ColumnarValue::Array(Arc::new(input)); + match spark_abs(&[args], true) { + Err(e) => { + assert!( + e.to_string().contains("overflow on abs"), + "Error message did not match. Actual message: {e}" + ); + } + _ => unreachable!(), + } + }}; + ($INPUT:expr, $OUTPUT:expr, $FUNC:ident) => {{ + let input = $INPUT; + let args = ColumnarValue::Array(Arc::new(input)); + let expected = $OUTPUT; + match spark_abs(&[args], true) { + Ok(ColumnarValue::Array(result)) => { + let actual = datafusion_common::cast::$FUNC(&result).unwrap(); + assert_eq!(actual, &expected); + } + _ => unreachable!(), + } + }}; + } + #[test] + fn test_abs_array_ansi_mode() { + eval_array_ansi_mode!( + UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]), + UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]), + as_uint64_array + ); + + eval_array_ansi_mode!(Int8Array::from(vec![ + Some(-1), + Some(i8::MIN), + Some(i8::MAX), + None + ])); + eval_array_ansi_mode!(Int16Array::from(vec![ + Some(-1), + Some(i16::MIN), + Some(i16::MAX), + None + ])); + eval_array_ansi_mode!(Int32Array::from(vec![ + Some(-1), + Some(i32::MIN), + Some(i32::MAX), + None + ])); + eval_array_ansi_mode!(Int64Array::from(vec![ + Some(-1), + Some(i64::MIN), + Some(i64::MAX), + None + ])); + eval_array_ansi_mode!( + Float32Array::from(vec![ + Some(-1f32), + Some(f32::MIN), + Some(f32::MAX), + None, + Some(f32::NAN), + Some(f32::INFINITY), + Some(f32::NEG_INFINITY), + Some(0.0), + Some(-0.0), + ]), + Float32Array::from(vec![ + Some(1f32), + Some(f32::MAX), + Some(f32::MAX), + None, + Some(f32::NAN), + Some(f32::INFINITY), + Some(f32::INFINITY), + Some(0.0), + Some(0.0), + ]), + as_float32_array + ); + + eval_array_ansi_mode!( + Float64Array::from(vec![ + Some(-1f64), + Some(f64::MIN), + Some(f64::MAX), + None, + Some(f64::NAN), + Some(f64::INFINITY), + Some(f64::NEG_INFINITY), + Some(0.0), + Some(-0.0), + ]), + Float64Array::from(vec![ + Some(1f64), + Some(f64::MAX), + Some(f64::MAX), + None, + Some(f64::NAN), + Some(f64::INFINITY), + Some(f64::INFINITY), + Some(0.0), + Some(0.0), + ]), + as_float64_array + ); + + // decimal: no arithmetic overflow + eval_array_ansi_mode!( + Decimal128Array::from(vec![Some(-1), Some(-2), Some(i128::MIN + 1)]) + .with_precision_and_scale(38, 37) .unwrap(), - Decimal256Array::from(vec![Some(i256::MIN), None]) - .with_precision_and_scale(5, 2) + Decimal128Array::from(vec![Some(1), Some(2), Some(i128::MAX)]) + .with_precision_and_scale(38, 37) .unwrap(), + as_decimal128_array + ); + + eval_array_ansi_mode!( + Decimal256Array::from(vec![ + Some(i256::MINUS_ONE), + Some(i256::from(-2)), + Some(i256::MIN + i256::from(1)) + ]) + .with_precision_and_scale(18, 7) + .unwrap(), + Decimal256Array::from(vec![ + Some(i256::ONE), + Some(i256::from(2)), + Some(i256::MAX) + ]) + .with_precision_and_scale(18, 7) + .unwrap(), as_decimal256_array ); + + // decimal: arithmetic overflow + eval_array_ansi_mode!( + Decimal128Array::from(vec![Some(i128::MIN), None]) + .with_precision_and_scale(38, 37) + .unwrap() + ); + eval_array_ansi_mode!( + Decimal256Array::from(vec![Some(i256::MIN), None]) + .with_precision_and_scale(5, 2) + .unwrap() + ); } #[test] diff --git a/datafusion/spark/src/function/math/bin.rs b/datafusion/spark/src/function/math/bin.rs new file mode 100644 index 000000000000..6822025b782d --- /dev/null +++ b/datafusion/spark/src/function/math/bin.rs @@ -0,0 +1,111 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, AsArray, StringArray}; +use arrow::datatypes::{DataType, Field, FieldRef, Int64Type}; +use datafusion_common::types::{NativeType, logical_int64}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + TypeSignatureClass, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; +use std::any::Any; +use std::sync::Arc; + +/// Spark-compatible `bin` expression +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkBin { + signature: Signature, +} + +impl Default for SparkBin { + fn default() -> Self { + Self::new() + } +} + +impl SparkBin { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![TypeSignature::Coercible(vec![Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![TypeSignatureClass::Numeric], + NativeType::Int64, + )])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkBin { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "bin" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args( + &self, + args: datafusion_expr::ReturnFieldArgs, + ) -> Result { + Ok(Arc::new(Field::new( + self.name(), + DataType::Utf8, + args.arg_fields[0].is_nullable(), + ))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_bin_inner, vec![])(&args.args) + } +} + +fn spark_bin_inner(arg: &[ArrayRef]) -> Result { + let [array] = take_function_args("bin", arg)?; + match &array.data_type() { + DataType::Int64 => { + let result: StringArray = array + .as_primitive::() + .iter() + .map(|opt| opt.map(spark_bin)) + .collect(); + Ok(Arc::new(result)) + } + data_type => { + internal_err!("bin does not support: {data_type}") + } + } +} + +fn spark_bin(value: i64) -> String { + format!("{value:b}") +} diff --git a/datafusion/spark/src/function/math/hex.rs b/datafusion/spark/src/function/math/hex.rs index ef62b08fb03d..06c77f37021b 100644 --- a/datafusion/spark/src/function/math/hex.rs +++ b/datafusion/spark/src/function/math/hex.rs @@ -16,9 +16,10 @@ // under the License. use std::any::Any; +use std::str::from_utf8_unchecked; use std::sync::Arc; -use arrow::array::{Array, StringArray}; +use arrow::array::{Array, ArrayRef, StringBuilder}; use arrow::datatypes::DataType; use arrow::{ array::{as_dictionary_array, as_largestring_array, as_string_array}, @@ -91,11 +92,13 @@ impl ScalarUDFImpl for SparkHex { &self.signature } - fn return_type( - &self, - _arg_types: &[DataType], - ) -> datafusion_common::Result { - Ok(DataType::Utf8) + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + Ok(match &arg_types[0] { + DataType::Dictionary(key_type, _) => { + DataType::Dictionary(key_type.clone(), Box::new(DataType::Utf8)) + } + _ => DataType::Utf8, + }) } fn invoke_with_args( @@ -110,37 +113,85 @@ impl ScalarUDFImpl for SparkHex { } } -fn hex_int64(num: i64) -> String { - format!("{num:X}") -} - /// Hex encoding lookup tables for fast byte-to-hex conversion const HEX_CHARS_LOWER: &[u8; 16] = b"0123456789abcdef"; const HEX_CHARS_UPPER: &[u8; 16] = b"0123456789ABCDEF"; #[inline] -fn hex_encode>(data: T, lower_case: bool) -> String { - let bytes = data.as_ref(); - let mut s = String::with_capacity(bytes.len() * 2); - let hex_chars = if lower_case { +fn hex_int64(num: i64, buffer: &mut [u8; 16]) -> &[u8] { + if num == 0 { + return b"0"; + } + + let mut n = num as u64; + let mut i = 16; + while n != 0 { + i -= 1; + buffer[i] = HEX_CHARS_UPPER[(n & 0xF) as usize]; + n >>= 4; + } + &buffer[i..] +} + +/// Generic hex encoding for byte array types +fn hex_encode_bytes<'a, I, T>( + iter: I, + lowercase: bool, + len: usize, +) -> Result +where + I: Iterator>, + T: AsRef<[u8]> + 'a, +{ + let mut builder = StringBuilder::with_capacity(len, len * 64); + let mut buffer = Vec::with_capacity(64); + let hex_chars = if lowercase { HEX_CHARS_LOWER } else { HEX_CHARS_UPPER }; - for &b in bytes { - s.push(hex_chars[(b >> 4) as usize] as char); - s.push(hex_chars[(b & 0x0f) as usize] as char); + + for v in iter { + if let Some(b) = v { + buffer.clear(); + let bytes = b.as_ref(); + for &byte in bytes { + buffer.push(hex_chars[(byte >> 4) as usize]); + buffer.push(hex_chars[(byte & 0x0f) as usize]); + } + // SAFETY: buffer contains only ASCII hex digests, which are valid UTF-8 + unsafe { + builder.append_value(from_utf8_unchecked(&buffer)); + } + } else { + builder.append_null(); + } } - s + + Ok(Arc::new(builder.finish())) } -#[inline(always)] -fn hex_bytes>( - bytes: T, - lowercase: bool, -) -> Result { - let hex_string = hex_encode(bytes, lowercase); - Ok(hex_string) +/// Generic hex encoding for int64 type +fn hex_encode_int64( + iter: impl Iterator>, + len: usize, +) -> Result { + let mut builder = StringBuilder::with_capacity(len, len * 16); + + for v in iter { + if let Some(num) = v { + let mut temp = [0u8; 16]; + let slice = hex_int64(num, &mut temp); + // SAFETY: slice contains only ASCII hex digests, which are valid UTF-8 + unsafe { + builder.append_value(from_utf8_unchecked(slice)); + } + } else { + builder.append_null(); + } + } + + Ok(Arc::new(builder.finish())) } /// Spark-compatible `hex` function @@ -166,103 +217,109 @@ pub fn compute_hex( ColumnarValue::Array(array) => match array.data_type() { DataType::Int64 => { let array = as_int64_array(array)?; - - let hexed_array: StringArray = - array.iter().map(|v| v.map(hex_int64)).collect(); - - Ok(ColumnarValue::Array(Arc::new(hexed_array))) + Ok(ColumnarValue::Array(hex_encode_int64( + array.iter(), + array.len(), + )?)) } DataType::Utf8 => { let array = as_string_array(array); - - let hexed: StringArray = array - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?; - - Ok(ColumnarValue::Array(Arc::new(hexed))) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } DataType::Utf8View => { let array = as_string_view_array(array)?; - - let hexed: StringArray = array - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?; - - Ok(ColumnarValue::Array(Arc::new(hexed))) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } DataType::LargeUtf8 => { let array = as_largestring_array(array); - - let hexed: StringArray = array - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?; - - Ok(ColumnarValue::Array(Arc::new(hexed))) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } DataType::Binary => { let array = as_binary_array(array)?; - - let hexed: StringArray = array - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?; - - Ok(ColumnarValue::Array(Arc::new(hexed))) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } DataType::LargeBinary => { let array = as_large_binary_array(array)?; - - let hexed: StringArray = array - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?; - - Ok(ColumnarValue::Array(Arc::new(hexed))) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } DataType::FixedSizeBinary(_) => { let array = as_fixed_size_binary_array(array)?; - - let hexed: StringArray = array - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?; - - Ok(ColumnarValue::Array(Arc::new(hexed))) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } - DataType::Dictionary(_, value_type) => { - let dict = as_dictionary_array::(&array); + DataType::Dictionary(key_type, _) => { + if **key_type != DataType::Int32 { + return exec_err!( + "hex only supports Int32 dictionary keys, get: {}", + key_type + ); + } - let values = match **value_type { - DataType::Int64 => as_int64_array(dict.values())? - .iter() - .map(|v| v.map(hex_int64)) - .collect::>(), - DataType::Utf8 => as_string_array(dict.values()) - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?, - DataType::Binary => as_binary_array(dict.values())? - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?, - _ => exec_err!( - "hex got an unexpected argument type: {}", - array.data_type() - )?, + let dict = as_dictionary_array::(&array); + let dict_values = dict.values(); + + let encoded_values = match dict_values.data_type() { + DataType::Int64 => { + let arr = as_int64_array(dict_values)?; + hex_encode_int64(arr.iter(), arr.len())? + } + DataType::Utf8 => { + let arr = as_string_array(dict_values); + hex_encode_bytes(arr.iter(), lowercase, arr.len())? + } + DataType::LargeUtf8 => { + let arr = as_largestring_array(dict_values); + hex_encode_bytes(arr.iter(), lowercase, arr.len())? + } + DataType::Utf8View => { + let arr = as_string_view_array(dict_values)?; + hex_encode_bytes(arr.iter(), lowercase, arr.len())? + } + DataType::Binary => { + let arr = as_binary_array(dict_values)?; + hex_encode_bytes(arr.iter(), lowercase, arr.len())? + } + DataType::LargeBinary => { + let arr = as_large_binary_array(dict_values)?; + hex_encode_bytes(arr.iter(), lowercase, arr.len())? + } + DataType::FixedSizeBinary(_) => { + let arr = as_fixed_size_binary_array(dict_values)?; + hex_encode_bytes(arr.iter(), lowercase, arr.len())? + } + _ => { + return exec_err!( + "hex got an unexpected argument type: {}", + dict_values.data_type() + ); + } }; - let new_values: Vec> = dict - .keys() - .iter() - .map(|key| key.map(|k| values[k as usize].clone()).unwrap_or(None)) - .collect(); - - let string_array_values = StringArray::from(new_values); - - Ok(ColumnarValue::Array(Arc::new(string_array_values))) + let new_dict = dict.with_values(encoded_values); + Ok(ColumnarValue::Array(Arc::new(new_dict))) } _ => exec_err!("hex got an unexpected argument type: {}", array.data_type()), }, @@ -272,16 +329,18 @@ pub fn compute_hex( #[cfg(test)] mod test { + use std::str::from_utf8_unchecked; use std::sync::Arc; - use arrow::array::{Int64Array, StringArray}; + use arrow::array::{DictionaryArray, Int32Array, Int64Array, StringArray}; use arrow::{ array::{ - BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringBuilder, - StringDictionaryBuilder, as_string_array, + BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringDictionaryBuilder, + as_string_array, }, datatypes::{Int32Type, Int64Type}, }; + use datafusion_common::cast::as_dictionary_array; use datafusion_expr::ColumnarValue; #[test] @@ -293,12 +352,12 @@ mod test { input_builder.append_value("rust"); let input = input_builder.finish(); - let mut string_builder = StringBuilder::new(); - string_builder.append_value("6869"); - string_builder.append_value("627965"); - string_builder.append_null(); - string_builder.append_value("72757374"); - let expected = string_builder.finish(); + let mut expected_builder = StringDictionaryBuilder::::new(); + expected_builder.append_value("6869"); + expected_builder.append_value("627965"); + expected_builder.append_null(); + expected_builder.append_value("72757374"); + let expected = expected_builder.finish(); let columnar_value = ColumnarValue::Array(Arc::new(input)); let result = super::spark_hex(&[columnar_value]).unwrap(); @@ -308,7 +367,7 @@ mod test { _ => panic!("Expected array"), }; - let result = as_string_array(&result); + let result = as_dictionary_array(&result).unwrap(); assert_eq!(result, &expected); } @@ -322,12 +381,12 @@ mod test { input_builder.append_value(3); let input = input_builder.finish(); - let mut string_builder = StringBuilder::new(); - string_builder.append_value("1"); - string_builder.append_value("2"); - string_builder.append_null(); - string_builder.append_value("3"); - let expected = string_builder.finish(); + let mut expected_builder = StringDictionaryBuilder::::new(); + expected_builder.append_value("1"); + expected_builder.append_value("2"); + expected_builder.append_null(); + expected_builder.append_value("3"); + let expected = expected_builder.finish(); let columnar_value = ColumnarValue::Array(Arc::new(input)); let result = super::spark_hex(&[columnar_value]).unwrap(); @@ -337,7 +396,7 @@ mod test { _ => panic!("Expected array"), }; - let result = as_string_array(&result); + let result = as_dictionary_array(&result).unwrap(); assert_eq!(result, &expected); } @@ -351,7 +410,7 @@ mod test { input_builder.append_value("3"); let input = input_builder.finish(); - let mut expected_builder = StringBuilder::new(); + let mut expected_builder = StringDictionaryBuilder::::new(); expected_builder.append_value("31"); expected_builder.append_value("6A"); expected_builder.append_null(); @@ -366,20 +425,24 @@ mod test { _ => panic!("Expected array"), }; - let result = as_string_array(&result); + let result = as_dictionary_array(&result).unwrap(); assert_eq!(result, &expected); } #[test] fn test_hex_int64() { - let num = 1234; - let hexed = super::hex_int64(num); - assert_eq!(hexed, "4D2".to_string()); + let test_cases = vec![(1234, "4D2"), (-1, "FFFFFFFFFFFFFFFF")]; + + for (num, expected) in test_cases { + let mut cache = [0u8; 16]; + let slice = super::hex_int64(num, &mut cache); - let num = -1; - let hexed = super::hex_int64(num); - assert_eq!(hexed, "FFFFFFFFFFFFFFFF".to_string()); + unsafe { + let result = from_utf8_unchecked(slice); + assert_eq!(expected, result); + } + } } #[test] @@ -403,4 +466,28 @@ mod test { assert_eq!(string_array, &expected_array); } + + #[test] + fn test_dict_values_null() { + let keys = Int32Array::from(vec![Some(0), None, Some(1)]); + let vals = Int64Array::from(vec![Some(32), None]); + // [32, null, null] + let dict = DictionaryArray::new(keys, Arc::new(vals)); + + let columnar_value = ColumnarValue::Array(Arc::new(dict)); + let result = super::spark_hex(&[columnar_value]).unwrap(); + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let result = as_dictionary_array(&result).unwrap(); + + let keys = Int32Array::from(vec![Some(0), None, Some(1)]); + let vals = StringArray::from(vec![Some("20"), None]); + let expected = DictionaryArray::new(keys, Arc::new(vals)); + + assert_eq!(&expected, result); + } } diff --git a/datafusion/spark/src/function/math/mod.rs b/datafusion/spark/src/function/math/mod.rs index 1422eb250d93..7f7d04e06b0b 100644 --- a/datafusion/spark/src/function/math/mod.rs +++ b/datafusion/spark/src/function/math/mod.rs @@ -16,12 +16,15 @@ // under the License. pub mod abs; +pub mod bin; pub mod expm1; pub mod factorial; pub mod hex; pub mod modulus; +pub mod negative; pub mod rint; pub mod trigonometry; +pub mod unhex; pub mod width_bucket; use datafusion_expr::ScalarUDF; @@ -35,9 +38,12 @@ make_udf_function!(hex::SparkHex, hex); make_udf_function!(modulus::SparkMod, modulus); make_udf_function!(modulus::SparkPmod, pmod); make_udf_function!(rint::SparkRint, rint); +make_udf_function!(unhex::SparkUnhex, unhex); make_udf_function!(width_bucket::SparkWidthBucket, width_bucket); make_udf_function!(trigonometry::SparkCsc, csc); make_udf_function!(trigonometry::SparkSec, sec); +make_udf_function!(negative::SparkNegative, negative); +make_udf_function!(bin::SparkBin, bin); pub mod expr_fn { use datafusion_functions::export_functions; @@ -57,9 +63,20 @@ pub mod expr_fn { "Returns the double value that is closest in value to the argument and is equal to a mathematical integer.", arg1 )); + export_functions!((unhex, "Converts hexadecimal string to binary.", arg1)); export_functions!((width_bucket, "Returns the bucket number into which the value of this expression would fall after being evaluated.", arg1 arg2 arg3 arg4)); export_functions!((csc, "Returns the cosecant of expr.", arg1)); export_functions!((sec, "Returns the secant of expr.", arg1)); + export_functions!(( + negative, + "Returns the negation of expr (unary minus).", + arg1 + )); + export_functions!(( + bin, + "Returns the string representation of the long value represented in binary.", + arg1 + )); } pub fn functions() -> Vec> { @@ -71,8 +88,11 @@ pub fn functions() -> Vec> { modulus(), pmod(), rint(), + unhex(), width_bucket(), csc(), sec(), + negative(), + bin(), ] } diff --git a/datafusion/spark/src/function/math/modulus.rs b/datafusion/spark/src/function/math/modulus.rs index 49657e2cb8ce..7a21aabbdf85 100644 --- a/datafusion/spark/src/function/math/modulus.rs +++ b/datafusion/spark/src/function/math/modulus.rs @@ -15,8 +15,13 @@ // specific language governing permissions and limitations // under the License. +use arrow::array::{Scalar, new_null_array}; use arrow::compute::kernels::numeric::add; -use arrow::compute::kernels::{cmp::lt, numeric::rem, zip::zip}; +use arrow::compute::kernels::{ + cmp::{eq, lt}, + numeric::rem, + zip::zip, +}; use arrow::datatypes::DataType; use datafusion_common::{Result, ScalarValue, assert_eq_or_internal_err}; use datafusion_expr::{ @@ -24,28 +29,61 @@ use datafusion_expr::{ }; use std::any::Any; +/// Attempts `rem(left, right)` with per-element divide-by-zero handling. +/// In ANSI mode, any zero divisor causes an error. +/// In legacy mode (ANSI off), positions where the divisor is zero return NULL +/// while other positions compute normally. +fn try_rem( + left: &arrow::array::ArrayRef, + right: &arrow::array::ArrayRef, + enable_ansi_mode: bool, +) -> Result { + match rem(left, right) { + Ok(result) => Ok(result), + Err(arrow::error::ArrowError::DivideByZero) if !enable_ansi_mode => { + // Integer rem fails when ANY divisor element is zero. + // Handle per-element: null out zero divisors + let zero = ScalarValue::new_zero(right.data_type())?.to_array()?; + let zero = Scalar::new(zero); + let null = Scalar::new(new_null_array(right.data_type(), 1)); + let is_zero = eq(right, &zero)?; + let safe_right = zip(&is_zero, &null, right)?; + Ok(rem(left, &safe_right)?) + } + Err(e) => Err(e.into()), + } +} + /// Spark-compatible `mod` function -/// This function directly uses Arrow's arithmetic_op function for modulo operations -pub fn spark_mod(args: &[ColumnarValue]) -> Result { +/// In ANSI mode, division by zero throws an error. +/// In legacy mode, division by zero returns NULL (Spark behavior). +pub fn spark_mod( + args: &[ColumnarValue], + enable_ansi_mode: bool, +) -> Result { assert_eq_or_internal_err!(args.len(), 2, "mod expects exactly two arguments"); let args = ColumnarValue::values_to_arrays(args)?; - let result = rem(&args[0], &args[1])?; + let result = try_rem(&args[0], &args[1], enable_ansi_mode)?; Ok(ColumnarValue::Array(result)) } /// Spark-compatible `pmod` function -/// This function directly uses Arrow's arithmetic_op function for modulo operations -pub fn spark_pmod(args: &[ColumnarValue]) -> Result { +/// In ANSI mode, division by zero throws an error. +/// In legacy mode, division by zero returns NULL (Spark behavior). +pub fn spark_pmod( + args: &[ColumnarValue], + enable_ansi_mode: bool, +) -> Result { assert_eq_or_internal_err!(args.len(), 2, "pmod expects exactly two arguments"); let args = ColumnarValue::values_to_arrays(args)?; let left = &args[0]; let right = &args[1]; let zero = ScalarValue::new_zero(left.data_type())?.to_array_of_size(left.len())?; - let result = rem(left, right)?; + let result = try_rem(left, right, enable_ansi_mode)?; let neg = lt(&result, &zero)?; let plus = zip(&neg, right, &zero)?; let result = add(&plus, &result)?; - let result = rem(&result, right)?; + let result = try_rem(&result, right, enable_ansi_mode)?; Ok(ColumnarValue::Array(result)) } @@ -95,7 +133,7 @@ impl ScalarUDFImpl for SparkMod { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - spark_mod(&args.args) + spark_mod(&args.args, args.config_options.execution.enable_ansi_mode) } } @@ -145,7 +183,7 @@ impl ScalarUDFImpl for SparkPmod { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - spark_pmod(&args.args) + spark_pmod(&args.args, args.config_options.execution.enable_ansi_mode) } } @@ -165,7 +203,7 @@ mod test { let left_value = ColumnarValue::Array(Arc::new(left)); let right_value = ColumnarValue::Array(Arc::new(right)); - let result = spark_mod(&[left_value, right_value]).unwrap(); + let result = spark_mod(&[left_value, right_value], false).unwrap(); if let ColumnarValue::Array(result_array) = result { let result_int32 = @@ -187,7 +225,7 @@ mod test { let left_value = ColumnarValue::Array(Arc::new(left)); let right_value = ColumnarValue::Array(Arc::new(right)); - let result = spark_mod(&[left_value, right_value]).unwrap(); + let result = spark_mod(&[left_value, right_value], false).unwrap(); if let ColumnarValue::Array(result_array) = result { let result_int64 = @@ -228,7 +266,7 @@ mod test { let left_value = ColumnarValue::Array(Arc::new(left)); let right_value = ColumnarValue::Array(Arc::new(right)); - let result = spark_mod(&[left_value, right_value]).unwrap(); + let result = spark_mod(&[left_value, right_value], false).unwrap(); if let ColumnarValue::Array(result_array) = result { let result_float64 = result_array @@ -284,7 +322,7 @@ mod test { let left_value = ColumnarValue::Array(Arc::new(left)); let right_value = ColumnarValue::Array(Arc::new(right)); - let result = spark_mod(&[left_value, right_value]).unwrap(); + let result = spark_mod(&[left_value, right_value], false).unwrap(); if let ColumnarValue::Array(result_array) = result { let result_float32 = result_array @@ -319,7 +357,7 @@ mod test { let left_value = ColumnarValue::Array(Arc::new(left)); - let result = spark_mod(&[left_value, right_value]).unwrap(); + let result = spark_mod(&[left_value, right_value], false).unwrap(); if let ColumnarValue::Array(result_array) = result { let result_int32 = @@ -337,20 +375,43 @@ mod test { let left = Int32Array::from(vec![Some(10)]); let left_value = ColumnarValue::Array(Arc::new(left)); - let result = spark_mod(&[left_value]); + let result = spark_mod(&[left_value], false); assert!(result.is_err()); } #[test] - fn test_mod_zero_division() { + fn test_mod_zero_division_legacy() { + // In legacy mode (ANSI off), division by zero returns NULL per-element + let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]); + let right = Int32Array::from(vec![Some(0), Some(2), Some(4)]); + + let left_value = ColumnarValue::Array(Arc::new(left)); + let right_value = ColumnarValue::Array(Arc::new(right)); + + let result = spark_mod(&[left_value, right_value], false).unwrap(); + + if let ColumnarValue::Array(result_array) = result { + let result_int32 = + result_array.as_any().downcast_ref::().unwrap(); + assert!(result_int32.is_null(0)); // 10 % 0 = NULL + assert_eq!(result_int32.value(1), 1); // 7 % 2 = 1 + assert_eq!(result_int32.value(2), 3); // 15 % 4 = 3 + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_mod_zero_division_ansi() { + // In ANSI mode, division by zero should error let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]); let right = Int32Array::from(vec![Some(0), Some(2), Some(4)]); let left_value = ColumnarValue::Array(Arc::new(left)); let right_value = ColumnarValue::Array(Arc::new(right)); - let result = spark_mod(&[left_value, right_value]); - assert!(result.is_err()); // Division by zero should error + let result = spark_mod(&[left_value, right_value], true); + assert!(result.is_err()); } // PMOD tests @@ -362,7 +423,7 @@ mod test { let left_value = ColumnarValue::Array(Arc::new(left)); let right_value = ColumnarValue::Array(Arc::new(right)); - let result = spark_pmod(&[left_value, right_value]).unwrap(); + let result = spark_pmod(&[left_value, right_value], false).unwrap(); if let ColumnarValue::Array(result_array) = result { let result_int32 = @@ -385,7 +446,7 @@ mod test { let left_value = ColumnarValue::Array(Arc::new(left)); let right_value = ColumnarValue::Array(Arc::new(right)); - let result = spark_pmod(&[left_value, right_value]).unwrap(); + let result = spark_pmod(&[left_value, right_value], false).unwrap(); if let ColumnarValue::Array(result_array) = result { let result_int64 = @@ -425,7 +486,7 @@ mod test { let left_value = ColumnarValue::Array(Arc::new(left)); let right_value = ColumnarValue::Array(Arc::new(right)); - let result = spark_pmod(&[left_value, right_value]).unwrap(); + let result = spark_pmod(&[left_value, right_value], false).unwrap(); if let ColumnarValue::Array(result_array) = result { let result_float64 = result_array @@ -476,7 +537,7 @@ mod test { let left_value = ColumnarValue::Array(Arc::new(left)); let right_value = ColumnarValue::Array(Arc::new(right)); - let result = spark_pmod(&[left_value, right_value]).unwrap(); + let result = spark_pmod(&[left_value, right_value], false).unwrap(); if let ColumnarValue::Array(result_array) = result { let result_float32 = result_array @@ -508,7 +569,7 @@ mod test { let left_value = ColumnarValue::Array(Arc::new(left)); - let result = spark_pmod(&[left_value, right_value]).unwrap(); + let result = spark_pmod(&[left_value, right_value], false).unwrap(); if let ColumnarValue::Array(result_array) = result { let result_int32 = @@ -527,20 +588,43 @@ mod test { let left = Int32Array::from(vec![Some(10)]); let left_value = ColumnarValue::Array(Arc::new(left)); - let result = spark_pmod(&[left_value]); + let result = spark_pmod(&[left_value], false); assert!(result.is_err()); } #[test] - fn test_pmod_zero_division() { + fn test_pmod_zero_division_legacy() { + // In legacy mode (ANSI off), division by zero returns NULL per-element let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]); let right = Int32Array::from(vec![Some(0), Some(0), Some(4)]); let left_value = ColumnarValue::Array(Arc::new(left)); let right_value = ColumnarValue::Array(Arc::new(right)); - let result = spark_pmod(&[left_value, right_value]); - assert!(result.is_err()); // Division by zero should error + let result = spark_pmod(&[left_value, right_value], false).unwrap(); + + if let ColumnarValue::Array(result_array) = result { + let result_int32 = + result_array.as_any().downcast_ref::().unwrap(); + assert!(result_int32.is_null(0)); // 10 pmod 0 = NULL + assert!(result_int32.is_null(1)); // -7 pmod 0 = NULL + assert_eq!(result_int32.value(2), 3); // 15 pmod 4 = 3 + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_pmod_zero_division_ansi() { + // In ANSI mode, division by zero should error + let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]); + let right = Int32Array::from(vec![Some(0), Some(0), Some(4)]); + + let left_value = ColumnarValue::Array(Arc::new(left)); + let right_value = ColumnarValue::Array(Arc::new(right)); + + let result = spark_pmod(&[left_value, right_value], true); + assert!(result.is_err()); } #[test] @@ -552,7 +636,7 @@ mod test { let left_value = ColumnarValue::Array(Arc::new(left)); let right_value = ColumnarValue::Array(Arc::new(right)); - let result = spark_pmod(&[left_value, right_value]).unwrap(); + let result = spark_pmod(&[left_value, right_value], false).unwrap(); if let ColumnarValue::Array(result_array) = result { let result_int32 = @@ -590,7 +674,7 @@ mod test { let left_value = ColumnarValue::Array(Arc::new(left)); let right_value = ColumnarValue::Array(Arc::new(right)); - let result = spark_pmod(&[left_value, right_value]).unwrap(); + let result = spark_pmod(&[left_value, right_value], false).unwrap(); if let ColumnarValue::Array(result_array) = result { let result_int32 = diff --git a/datafusion/spark/src/function/math/negative.rs b/datafusion/spark/src/function/math/negative.rs new file mode 100644 index 000000000000..2df71b709d8c --- /dev/null +++ b/datafusion/spark/src/function/math/negative.rs @@ -0,0 +1,477 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::types::*; +use arrow::array::*; +use arrow::datatypes::{DataType, IntervalDayTime, IntervalMonthDayNano, IntervalUnit}; +use bigdecimal::num_traits::WrappingNeg; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; +use std::any::Any; +use std::sync::Arc; + +/// Spark-compatible `negative` expression +/// +/// +/// Returns the negation of input (equivalent to unary minus) +/// Returns NULL if input is NULL, returns NaN if input is NaN. +/// +/// ANSI mode support: +/// - When ANSI mode is disabled (`spark.sql.ansi.enabled=false`), negating the minimal +/// value of a signed integer wraps around. For example: negative(i32::MIN) returns +/// i32::MIN (wraps instead of error). +/// - When ANSI mode is enabled (`spark.sql.ansi.enabled=true`), overflow conditions +/// throw an ARITHMETIC_OVERFLOW error instead of wrapping. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkNegative { + signature: Signature, +} + +impl Default for SparkNegative { + fn default() -> Self { + Self::new() + } +} + +impl SparkNegative { + pub fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::OneOf(vec![ + // Numeric types: signed integers, float, decimals + TypeSignature::Numeric(1), + // Interval types: YearMonth, DayTime, MonthDayNano + TypeSignature::Uniform( + 1, + vec![ + DataType::Interval(IntervalUnit::YearMonth), + DataType::Interval(IntervalUnit::DayTime), + DataType::Interval(IntervalUnit::MonthDayNano), + ], + ), + ]), + volatility: Volatility::Immutable, + parameter_names: None, + }, + } + } +} + +impl ScalarUDFImpl for SparkNegative { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "negative" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_negative(&args.args, args.config_options.execution.enable_ansi_mode) + } +} + +/// Macro to implement negation for integer array types +macro_rules! impl_integer_array_negative { + ($array:expr, $type:ty, $type_name:expr, $enable_ansi_mode:expr) => {{ + let array = $array.as_primitive::<$type>(); + let result: PrimitiveArray<$type> = if $enable_ansi_mode { + array.try_unary(|x| { + x.checked_neg().ok_or_else(|| { + (exec_err!("{} overflow on negative({x})", $type_name) + as Result<(), _>) + .unwrap_err() + }) + })? + } else { + array.unary(|x| x.wrapping_neg()) + }; + Ok(ColumnarValue::Array(Arc::new(result))) + }}; +} + +/// Macro to implement negation for float array types +macro_rules! impl_float_array_negative { + ($array:expr, $type:ty) => {{ + let array = $array.as_primitive::<$type>(); + let result: PrimitiveArray<$type> = array.unary(|x| -x); + Ok(ColumnarValue::Array(Arc::new(result))) + }}; +} + +/// Macro to implement negation for decimal array types +macro_rules! impl_decimal_array_negative { + ($array:expr, $type:ty, $type_name:expr, $enable_ansi_mode:expr) => {{ + let array = $array.as_primitive::<$type>(); + let result: PrimitiveArray<$type> = if $enable_ansi_mode { + array + .try_unary(|x| { + x.checked_neg().ok_or_else(|| { + (exec_err!("{} overflow on negative({x})", $type_name) + as Result<(), _>) + .unwrap_err() + }) + })? + .with_data_type(array.data_type().clone()) + } else { + array.unary(|x| x.wrapping_neg()) + }; + Ok(ColumnarValue::Array(Arc::new(result))) + }}; +} + +/// Macro to implement negation for integer scalar types +macro_rules! impl_integer_scalar_negative { + ($v:expr, $type_name:expr, $variant:ident, $enable_ansi_mode:expr) => {{ + let result = if $enable_ansi_mode { + $v.checked_neg().ok_or_else(|| { + (exec_err!("{} overflow on negative({})", $type_name, $v) + as Result<(), _>) + .unwrap_err() + })? + } else { + $v.wrapping_neg() + }; + Ok(ColumnarValue::Scalar(ScalarValue::$variant(Some(result)))) + }}; +} + +/// Macro to implement negation for decimal scalar types +macro_rules! impl_decimal_scalar_negative { + ($v:expr, $precision:expr, $scale:expr, $type_name:expr, $variant:ident, $enable_ansi_mode:expr) => {{ + let result = if $enable_ansi_mode { + $v.checked_neg().ok_or_else(|| { + (exec_err!("{} overflow on negative({})", $type_name, $v) + as Result<(), _>) + .unwrap_err() + })? + } else { + $v.wrapping_neg() + }; + Ok(ColumnarValue::Scalar(ScalarValue::$variant( + Some(result), + *$precision, + *$scale, + ))) + }}; +} + +/// Core implementation of Spark's negative function +fn spark_negative( + args: &[ColumnarValue], + enable_ansi_mode: bool, +) -> Result { + let [arg] = take_function_args("negative", args)?; + + match arg { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Null => Ok(arg.clone()), + + // Signed integers - use checked negation in ANSI mode, wrapping in legacy mode + DataType::Int8 => { + impl_integer_array_negative!(array, Int8Type, "Int8", enable_ansi_mode) + } + DataType::Int16 => { + impl_integer_array_negative!(array, Int16Type, "Int16", enable_ansi_mode) + } + DataType::Int32 => { + impl_integer_array_negative!(array, Int32Type, "Int32", enable_ansi_mode) + } + DataType::Int64 => { + impl_integer_array_negative!(array, Int64Type, "Int64", enable_ansi_mode) + } + + // Floating point - simple negation (no overflow possible) + DataType::Float16 => impl_float_array_negative!(array, Float16Type), + DataType::Float32 => impl_float_array_negative!(array, Float32Type), + DataType::Float64 => impl_float_array_negative!(array, Float64Type), + + // Decimal types - use checked negation in ANSI mode, wrapping in legacy mode + DataType::Decimal32(_, _) => impl_decimal_array_negative!( + array, + Decimal32Type, + "Decimal32", + enable_ansi_mode + ), + DataType::Decimal64(_, _) => impl_decimal_array_negative!( + array, + Decimal64Type, + "Decimal64", + enable_ansi_mode + ), + DataType::Decimal128(_, _) => impl_decimal_array_negative!( + array, + Decimal128Type, + "Decimal128", + enable_ansi_mode + ), + DataType::Decimal256(_, _) => impl_decimal_array_negative!( + array, + Decimal256Type, + "Decimal256", + enable_ansi_mode + ), + + // interval type - use checked negation in ANSI mode, wrapping in legacy mode + DataType::Interval(IntervalUnit::YearMonth) => { + impl_integer_array_negative!( + array, + IntervalYearMonthType, + "IntervalYearMonth", + enable_ansi_mode + ) + } + DataType::Interval(IntervalUnit::DayTime) => { + let array = array.as_primitive::(); + let result: PrimitiveArray = if enable_ansi_mode { + array.try_unary(|x| { + let days = x.days.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalDayTime overflow on negative (days: {})", + x.days + ) as Result<(), _>) + .unwrap_err() + })?; + let milliseconds = + x.milliseconds.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalDayTime overflow on negative (milliseconds: {})", + x.milliseconds + ) as Result<(), _>) + .unwrap_err() + })?; + Ok::<_, arrow::error::ArrowError>(IntervalDayTime { + days, + milliseconds, + }) + })? + } else { + array.unary(|x| IntervalDayTime { + days: x.days.wrapping_neg(), + milliseconds: x.milliseconds.wrapping_neg(), + }) + }; + Ok(ColumnarValue::Array(Arc::new(result))) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + let array = array.as_primitive::(); + let result: PrimitiveArray = if enable_ansi_mode + { + array.try_unary(|x| { + let months = x.months.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalMonthDayNano overflow on negative (months: {})", + x.months + ) as Result<(), _>) + .unwrap_err() + })?; + let days = x.days.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalMonthDayNano overflow on negative (days: {})", + x.days + ) as Result<(), _>) + .unwrap_err() + })?; + let nanoseconds = x.nanoseconds.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalMonthDayNano overflow on negative (nanoseconds: {})", + x.nanoseconds + ) as Result<(), _>) + .unwrap_err() + })?; + Ok::<_, arrow::error::ArrowError>(IntervalMonthDayNano { + months, + days, + nanoseconds, + }) + })? + } else { + array.unary(|x| IntervalMonthDayNano { + months: x.months.wrapping_neg(), + days: x.days.wrapping_neg(), + nanoseconds: x.nanoseconds.wrapping_neg(), + }) + }; + Ok(ColumnarValue::Array(Arc::new(result))) + } + + dt => not_impl_err!("Not supported datatype for Spark negative(): {dt}"), + }, + ColumnarValue::Scalar(sv) => match sv { + ScalarValue::Null => Ok(arg.clone()), + _ if sv.is_null() => Ok(arg.clone()), + + // Signed integers - use checked negation in ANSI mode, wrapping in legacy mode + ScalarValue::Int8(Some(v)) => { + impl_integer_scalar_negative!(v, "Int8", Int8, enable_ansi_mode) + } + ScalarValue::Int16(Some(v)) => { + impl_integer_scalar_negative!(v, "Int16", Int16, enable_ansi_mode) + } + ScalarValue::Int32(Some(v)) => { + impl_integer_scalar_negative!(v, "Int32", Int32, enable_ansi_mode) + } + ScalarValue::Int64(Some(v)) => { + impl_integer_scalar_negative!(v, "Int64", Int64, enable_ansi_mode) + } + + // Floating point - simple negation + ScalarValue::Float16(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Float16(Some(-v)))) + } + ScalarValue::Float32(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(-v)))) + } + ScalarValue::Float64(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(-v)))) + } + + // Decimal types - use checked negation in ANSI mode, wrapping in legacy mode + ScalarValue::Decimal32(Some(v), precision, scale) => { + impl_decimal_scalar_negative!( + v, + precision, + scale, + "Decimal32", + Decimal32, + enable_ansi_mode + ) + } + ScalarValue::Decimal64(Some(v), precision, scale) => { + impl_decimal_scalar_negative!( + v, + precision, + scale, + "Decimal64", + Decimal64, + enable_ansi_mode + ) + } + ScalarValue::Decimal128(Some(v), precision, scale) => { + impl_decimal_scalar_negative!( + v, + precision, + scale, + "Decimal128", + Decimal128, + enable_ansi_mode + ) + } + ScalarValue::Decimal256(Some(v), precision, scale) => { + impl_decimal_scalar_negative!( + v, + precision, + scale, + "Decimal256", + Decimal256, + enable_ansi_mode + ) + } + + //interval type - use checked negation in ANSI mode, wrapping in legacy mode + ScalarValue::IntervalYearMonth(Some(v)) => { + impl_integer_scalar_negative!( + v, + "IntervalYearMonth", + IntervalYearMonth, + enable_ansi_mode + ) + } + ScalarValue::IntervalDayTime(Some(v)) => { + let result = if enable_ansi_mode { + let days = v.days.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalDayTime overflow on negative (days: {})", + v.days + ) as Result<(), _>) + .unwrap_err() + })?; + let milliseconds = v.milliseconds.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalDayTime overflow on negative (milliseconds: {})", + v.milliseconds + ) as Result<(), _>) + .unwrap_err() + })?; + IntervalDayTime { days, milliseconds } + } else { + IntervalDayTime { + days: v.days.wrapping_neg(), + milliseconds: v.milliseconds.wrapping_neg(), + } + }; + Ok(ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( + result, + )))) + } + ScalarValue::IntervalMonthDayNano(Some(v)) => { + let result = if enable_ansi_mode { + let months = v.months.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalMonthDayNano overflow on negative (months: {})", + v.months + ) as Result<(), _>) + .unwrap_err() + })?; + let days = v.days.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalMonthDayNano overflow on negative (days: {})", + v.days + ) as Result<(), _>) + .unwrap_err() + })?; + let nanoseconds = v.nanoseconds.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalMonthDayNano overflow on negative (nanoseconds: {})", + v.nanoseconds + ) as Result<(), _>) + .unwrap_err() + })?; + IntervalMonthDayNano { + months, + days, + nanoseconds, + } + } else { + IntervalMonthDayNano { + months: v.months.wrapping_neg(), + days: v.days.wrapping_neg(), + nanoseconds: v.nanoseconds.wrapping_neg(), + } + }; + Ok(ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano( + Some(result), + ))) + } + + dt => not_impl_err!("Not supported datatype for Spark negative(): {dt}"), + }, + } +} diff --git a/datafusion/spark/src/function/math/unhex.rs b/datafusion/spark/src/function/math/unhex.rs new file mode 100644 index 000000000000..dee532d818f8 --- /dev/null +++ b/datafusion/spark/src/function/math/unhex.rs @@ -0,0 +1,214 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, BinaryBuilder}; +use arrow::datatypes::DataType; +use datafusion_common::cast::{ + as_large_string_array, as_string_array, as_string_view_array, +}; +use datafusion_common::types::logical_string; +use datafusion_common::utils::take_function_args; +use datafusion_common::{DataFusionError, Result, ScalarValue, exec_err}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignatureClass, Volatility, +}; +use std::any::Any; +use std::sync::Arc; + +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkUnhex { + signature: Signature, +} + +impl Default for SparkUnhex { + fn default() -> Self { + Self::new() + } +} + +impl SparkUnhex { + pub fn new() -> Self { + let string = Coercion::new_exact(TypeSignatureClass::Native(logical_string())); + + Self { + signature: Signature::coercible(vec![string], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkUnhex { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "unhex" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Binary) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_unhex(&args.args) + } +} + +#[inline] +fn hex_nibble(c: u8) -> Option { + match c { + b'0'..=b'9' => Some(c - b'0'), + b'a'..=b'f' => Some(c - b'a' + 10), + b'A'..=b'F' => Some(c - b'A' + 10), + _ => None, + } +} + +/// Decodes a hex-encoded byte slice into binary data. +/// Returns `true` if decoding succeeded, `false` if the input contains invalid hex characters. +fn unhex_common(bytes: &[u8], out: &mut Vec) -> bool { + if bytes.is_empty() { + return true; + } + + let mut i = 0usize; + + // If the hex string length is odd, implicitly left-pad with '0'. + if (bytes.len() & 1) == 1 { + match hex_nibble(bytes[0]) { + // Equivalent to (0 << 4) | lo + Some(lo) => out.push(lo), + None => return false, + } + i = 1; + } + + while i + 1 < bytes.len() { + match (hex_nibble(bytes[i]), hex_nibble(bytes[i + 1])) { + (Some(hi), Some(lo)) => out.push((hi << 4) | lo), + _ => return false, + } + i += 2; + } + + true +} + +/// Converts an iterator of hex strings to a binary array. +fn unhex_array( + iter: I, + len: usize, + capacity: usize, +) -> Result +where + I: Iterator>, + T: AsRef, +{ + let mut builder = BinaryBuilder::with_capacity(len, capacity); + let mut buffer = Vec::new(); + + for v in iter { + if let Some(s) = v { + buffer.clear(); + buffer.reserve(s.as_ref().len().div_ceil(2)); + if unhex_common(s.as_ref().as_bytes(), &mut buffer) { + builder.append_value(&buffer); + } else { + builder.append_null(); + } + } else { + builder.append_null(); + } + } + + Ok(Arc::new(builder.finish())) +} + +/// Convert a single hex string to binary +fn unhex_scalar(s: &str) -> Option> { + let mut buffer = Vec::with_capacity(s.len().div_ceil(2)); + if unhex_common(s.as_bytes(), &mut buffer) { + Some(buffer) + } else { + None + } +} + +fn spark_unhex(args: &[ColumnarValue]) -> Result { + let [args] = take_function_args("unhex", args)?; + + match args { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Utf8 => { + let array = as_string_array(array)?; + let capacity = array.values().len().div_ceil(2); + Ok(ColumnarValue::Array(unhex_array( + array.iter(), + array.len(), + capacity, + )?)) + } + DataType::Utf8View => { + let array = as_string_view_array(array)?; + // Estimate capacity since StringViewArray data can be scattered or inlined. + let capacity = array.len() * 32; + Ok(ColumnarValue::Array(unhex_array( + array.iter(), + array.len(), + capacity, + )?)) + } + DataType::LargeUtf8 => { + let array = as_large_string_array(array)?; + let capacity = array.values().len().div_ceil(2); + Ok(ColumnarValue::Array(unhex_array( + array.iter(), + array.len(), + capacity, + )?)) + } + _ => exec_err!( + "unhex only supports string argument, but got: {}", + array.data_type() + ), + }, + ColumnarValue::Scalar(sv) => match sv { + ScalarValue::Utf8(None) + | ScalarValue::Utf8View(None) + | ScalarValue::LargeUtf8(None) => { + Ok(ColumnarValue::Scalar(ScalarValue::Binary(None))) + } + ScalarValue::Utf8(Some(s)) + | ScalarValue::Utf8View(Some(s)) + | ScalarValue::LargeUtf8(Some(s)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Binary(unhex_scalar(s)))) + } + _ => { + exec_err!( + "unhex only supports string argument, but got: {}", + sv.data_type() + ) + } + }, + } +} diff --git a/datafusion/spark/src/function/math/width_bucket.rs b/datafusion/spark/src/function/math/width_bucket.rs index 8d748439ad80..905c10819790 100644 --- a/datafusion/spark/src/function/math/width_bucket.rs +++ b/datafusion/spark/src/function/math/width_bucket.rs @@ -26,11 +26,11 @@ use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Duration, Float64, Int32, Interval}; use arrow::datatypes::IntervalUnit::{MonthDayNano, YearMonth}; use datafusion_common::cast::{ - as_duration_microsecond_array, as_float64_array, as_int32_array, + as_duration_microsecond_array, as_float64_array, as_int64_array, as_interval_mdn_array, as_interval_ym_array, }; use datafusion_common::types::{ - NativeType, logical_duration_microsecond, logical_float64, logical_int32, + NativeType, logical_duration_microsecond, logical_float64, logical_int64, logical_interval_mdn, logical_interval_year_month, }; use datafusion_common::{Result, exec_err, internal_err}; @@ -41,7 +41,7 @@ use datafusion_expr::{ }; use datafusion_functions::utils::make_scalar_function; -use arrow::array::{Int32Array, Int32Builder}; +use arrow::array::{Int32Array, Int32Builder, Int64Array}; use arrow::datatypes::TimeUnit::Microsecond; use datafusion_expr::Coercion; use datafusion_expr::Volatility::Immutable; @@ -75,9 +75,9 @@ impl SparkWidthBucket { let interval_mdn = Coercion::new_exact(TypeSignatureClass::Native(logical_interval_mdn())); let bucket = Coercion::new_implicit( - TypeSignatureClass::Native(logical_int32()), + TypeSignatureClass::Native(logical_int64()), vec![TypeSignatureClass::Integer], - NativeType::Int32, + NativeType::Int64, ); let type_signature = Signature::one_of( vec![ @@ -160,28 +160,28 @@ fn width_bucket_kern(args: &[ArrayRef]) -> Result { let v = as_float64_array(v)?; let min = as_float64_array(minv)?; let max = as_float64_array(maxv)?; - let n_bucket = as_int32_array(nb)?; + let n_bucket = as_int64_array(nb)?; Ok(Arc::new(width_bucket_float64(v, min, max, n_bucket))) } Duration(Microsecond) => { let v = as_duration_microsecond_array(v)?; let min = as_duration_microsecond_array(minv)?; let max = as_duration_microsecond_array(maxv)?; - let n_bucket = as_int32_array(nb)?; + let n_bucket = as_int64_array(nb)?; Ok(Arc::new(width_bucket_i64_as_float(v, min, max, n_bucket))) } Interval(YearMonth) => { let v = as_interval_ym_array(v)?; let min = as_interval_ym_array(minv)?; let max = as_interval_ym_array(maxv)?; - let n_bucket = as_int32_array(nb)?; + let n_bucket = as_int64_array(nb)?; Ok(Arc::new(width_bucket_i32_as_float(v, min, max, n_bucket))) } Interval(MonthDayNano) => { let v = as_interval_mdn_array(v)?; let min = as_interval_mdn_array(minv)?; let max = as_interval_mdn_array(maxv)?; - let n_bucket = as_int32_array(nb)?; + let n_bucket = as_int64_array(nb)?; Ok(Arc::new(width_bucket_interval_mdn_exact( v, min, max, n_bucket, ))) @@ -203,7 +203,7 @@ macro_rules! width_bucket_kernel_impl { v: &$arr_ty, min: &$arr_ty, max: &$arr_ty, - n_bucket: &Int32Array, + n_bucket: &Int64Array, ) -> Int32Array { let len = v.len(); let mut b = Int32Builder::with_capacity(len); @@ -223,6 +223,7 @@ macro_rules! width_bucket_kernel_impl { b.append_null(); continue; } + let next_bucket = (buckets + 1) as i32; if $check_nan { if !x.is_finite() || !l.is_finite() || !h.is_finite() { b.append_null(); @@ -237,11 +238,11 @@ macro_rules! width_bucket_kernel_impl { continue; } }; - if matches!(ord, std::cmp::Ordering::Equal) { + if ord == std::cmp::Ordering::Equal { b.append_null(); continue; } - let asc = matches!(ord, std::cmp::Ordering::Less); + let asc = ord == std::cmp::Ordering::Less; if asc { if x < l { @@ -249,7 +250,7 @@ macro_rules! width_bucket_kernel_impl { continue; } if x >= h { - b.append_value(buckets + 1); + b.append_value(next_bucket); continue; } } else { @@ -258,7 +259,7 @@ macro_rules! width_bucket_kernel_impl { continue; } if x <= h { - b.append_value(buckets + 1); + b.append_value(next_bucket); continue; } } @@ -272,8 +273,8 @@ macro_rules! width_bucket_kernel_impl { if bucket < 1 { bucket = 1; } - if bucket > buckets + 1 { - bucket = buckets + 1; + if bucket > next_bucket { + bucket = next_bucket; } b.append_value(bucket); @@ -309,7 +310,7 @@ pub(crate) fn width_bucket_interval_mdn_exact( v: &IntervalMonthDayNanoArray, lo: &IntervalMonthDayNanoArray, hi: &IntervalMonthDayNanoArray, - n: &Int32Array, + n: &Int64Array, ) -> Int32Array { let len = v.len(); let mut b = Int32Builder::with_capacity(len); @@ -324,6 +325,7 @@ pub(crate) fn width_bucket_interval_mdn_exact( b.append_null(); continue; } + let next_bucket = (buckets + 1) as i32; let x = v.value(i); let l = lo.value(i); @@ -349,7 +351,7 @@ pub(crate) fn width_bucket_interval_mdn_exact( continue; } if x_m >= h_m { - b.append_value(buckets + 1); + b.append_value(next_bucket); continue; } } else { @@ -358,7 +360,7 @@ pub(crate) fn width_bucket_interval_mdn_exact( continue; } if x_m <= h_m { - b.append_value(buckets + 1); + b.append_value(next_bucket); continue; } } @@ -373,8 +375,8 @@ pub(crate) fn width_bucket_interval_mdn_exact( if bucket < 1 { bucket = 1; } - if bucket > buckets + 1 { - bucket = buckets + 1; + if bucket > next_bucket { + bucket = next_bucket; } b.append_value(bucket); continue; @@ -400,7 +402,7 @@ pub(crate) fn width_bucket_interval_mdn_exact( continue; } if x_f >= h_f { - b.append_value(buckets + 1); + b.append_value(next_bucket); continue; } } else { @@ -409,7 +411,7 @@ pub(crate) fn width_bucket_interval_mdn_exact( continue; } if x_f <= h_f { - b.append_value(buckets + 1); + b.append_value(next_bucket); continue; } } @@ -424,8 +426,8 @@ pub(crate) fn width_bucket_interval_mdn_exact( if bucket < 1 { bucket = 1; } - if bucket > buckets + 1 { - bucket = buckets + 1; + if bucket > next_bucket { + bucket = next_bucket; } b.append_value(bucket); continue; @@ -443,15 +445,15 @@ mod tests { use std::sync::Arc; use arrow::array::{ - ArrayRef, DurationMicrosecondArray, Float64Array, Int32Array, + ArrayRef, DurationMicrosecondArray, Float64Array, Int32Array, Int64Array, IntervalYearMonthArray, }; use arrow::datatypes::IntervalMonthDayNano; // --- Helpers ------------------------------------------------------------- - fn i32_array_all(len: usize, val: i32) -> Arc { - Arc::new(Int32Array::from(vec![val; len])) + fn i64_array_all(len: usize, val: i64) -> Arc { + Arc::new(Int64Array::from(vec![val; len])) } fn f64_array(vals: &[f64]) -> Arc { @@ -489,7 +491,7 @@ mod tests { let v = f64_array(&[0.5, 1.0, 9.9, -1.0, 10.0]); let lo = f64_array(&[0.0, 0.0, 0.0, 0.0, 0.0]); let hi = f64_array(&[10.0, 10.0, 10.0, 10.0, 10.0]); - let n = i32_array_all(5, 10); + let n = i64_array_all(5, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -501,7 +503,7 @@ mod tests { let v = f64_array(&[9.9, 10.0, 0.0, -0.1, 10.1]); let lo = f64_array(&[10.0; 5]); let hi = f64_array(&[0.0; 5]); - let n = i32_array_all(5, 10); + let n = i64_array_all(5, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -513,7 +515,7 @@ mod tests { let v = f64_array(&[0.0, 9.999999999, 10.0]); let lo = f64_array(&[0.0; 3]); let hi = f64_array(&[10.0; 3]); - let n = i32_array_all(3, 10); + let n = i64_array_all(3, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -525,7 +527,7 @@ mod tests { let v = f64_array(&[10.0, 0.0, -0.000001]); let lo = f64_array(&[10.0; 3]); let hi = f64_array(&[0.0; 3]); - let n = i32_array_all(3, 10); + let n = i64_array_all(3, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -537,7 +539,7 @@ mod tests { let v = f64_array(&[1.0, 5.0, 9.0]); let lo = f64_array(&[0.0, 0.0, 0.0]); let hi = f64_array(&[10.0, 10.0, 10.0]); - let n = Arc::new(Int32Array::from(vec![0, -1, 10])); + let n = Arc::new(Int64Array::from(vec![0, -1, 10])); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); assert!(out.is_null(0)); @@ -547,7 +549,7 @@ mod tests { let v = f64_array(&[1.0]); let lo = f64_array(&[5.0]); let hi = f64_array(&[5.0]); - let n = i32_array_all(1, 10); + let n = i64_array_all(1, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); assert!(out.is_null(0)); @@ -555,7 +557,7 @@ mod tests { let v = f64_array_opt(&[Some(f64::NAN)]); let lo = f64_array(&[0.0]); let hi = f64_array(&[10.0]); - let n = i32_array_all(1, 10); + let n = i64_array_all(1, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); assert!(out.is_null(0)); @@ -566,7 +568,7 @@ mod tests { let v = f64_array_opt(&[None, Some(1.0), Some(2.0), Some(3.0)]); let lo = f64_array(&[0.0; 4]); let hi = f64_array(&[10.0; 4]); - let n = i32_array_all(4, 10); + let n = i64_array_all(4, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -578,7 +580,7 @@ mod tests { let v = f64_array(&[1.0]); let lo = f64_array_opt(&[None]); let hi = f64_array(&[10.0]); - let n = i32_array_all(1, 10); + let n = i64_array_all(1, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); assert!(out.is_null(0)); @@ -591,7 +593,7 @@ mod tests { let v = dur_us_array(&[1_000_000, 0, -1]); let lo = dur_us_array(&[0, 0, 0]); let hi = dur_us_array(&[2_000_000, 2_000_000, 2_000_000]); - let n = i32_array_all(3, 2); + let n = i64_array_all(3, 2); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -603,7 +605,7 @@ mod tests { let v = dur_us_array(&[0]); let lo = dur_us_array(&[1]); let hi = dur_us_array(&[1]); - let n = i32_array_all(1, 10); + let n = i64_array_all(1, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); assert!(downcast_i32(&out).is_null(0)); } @@ -615,7 +617,7 @@ mod tests { let v = ym_array(&[0, 5, 11, 12, 13]); let lo = ym_array(&[0; 5]); let hi = ym_array(&[12; 5]); - let n = i32_array_all(5, 12); + let n = i64_array_all(5, 12); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -627,7 +629,7 @@ mod tests { let v = ym_array(&[11, 12, 0, -1, 13]); let lo = ym_array(&[12; 5]); let hi = ym_array(&[0; 5]); - let n = i32_array_all(5, 12); + let n = i64_array_all(5, 12); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -641,7 +643,7 @@ mod tests { let v = mdn_array(&[(0, 0, 0), (5, 0, 0), (11, 0, 0), (12, 0, 0), (13, 0, 0)]); let lo = mdn_array(&[(0, 0, 0); 5]); let hi = mdn_array(&[(12, 0, 0); 5]); - let n = i32_array_all(5, 12); + let n = i64_array_all(5, 12); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -653,7 +655,7 @@ mod tests { let v = mdn_array(&[(11, 0, 0), (12, 0, 0), (0, 0, 0), (-1, 0, 0), (13, 0, 0)]); let lo = mdn_array(&[(12, 0, 0); 5]); let hi = mdn_array(&[(0, 0, 0); 5]); - let n = i32_array_all(5, 12); + let n = i64_array_all(5, 12); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -673,7 +675,7 @@ mod tests { ]); let lo = mdn_array(&[(0, 0, 0); 6]); let hi = mdn_array(&[(0, 10, 0); 6]); - let n = i32_array_all(6, 10); + let n = i64_array_all(6, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -686,7 +688,7 @@ mod tests { let v = mdn_array(&[(0, 9, 0), (0, 10, 0), (0, 0, 0), (0, -1, 0), (0, 11, 0)]); let lo = mdn_array(&[(0, 10, 0); 5]); let hi = mdn_array(&[(0, 0, 0); 5]); - let n = i32_array_all(5, 10); + let n = i64_array_all(5, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -698,7 +700,7 @@ mod tests { let v = mdn_array(&[(0, 9, 1), (0, 10, 0), (0, 0, 0), (0, -1, 0), (0, 11, 0)]); let lo = mdn_array(&[(0, 10, 0); 5]); let hi = mdn_array(&[(0, 0, 0); 5]); - let n = i32_array_all(5, 10); + let n = i64_array_all(5, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -711,7 +713,7 @@ mod tests { let v = mdn_array(&[(0, 1, 0)]); let lo = mdn_array(&[(0, 0, 0)]); let hi = mdn_array(&[(1, 1, 0)]); - let n = i32_array_all(1, 4); + let n = i64_array_all(1, 4); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -723,7 +725,7 @@ mod tests { let v = mdn_array(&[(0, 0, 0)]); let lo = mdn_array(&[(1, 2, 3)]); let hi = mdn_array(&[(1, 2, 3)]); // lo == hi - let n = i32_array_all(1, 10); + let n = i64_array_all(1, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); assert!(downcast_i32(&out).is_null(0)); @@ -734,7 +736,7 @@ mod tests { let v = mdn_array(&[(0, 0, 0)]); let lo = mdn_array(&[(0, 0, 0)]); let hi = mdn_array(&[(0, 10, 0)]); - let n = Arc::new(Int32Array::from(vec![0])); // n <= 0 + let n = Arc::new(Int64Array::from(vec![0])); // n <= 0 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); assert!(downcast_i32(&out).is_null(0)); @@ -748,7 +750,7 @@ mod tests { ])); let lo = mdn_array(&[(0, 0, 0), (0, 0, 0)]); let hi = mdn_array(&[(0, 10, 0), (0, 10, 0)]); - let n = i32_array_all(2, 10); + let n = i64_array_all(2, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -773,7 +775,7 @@ mod tests { let v: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); let lo = f64_array(&[0.0, 0.0, 0.0]); let hi = f64_array(&[10.0, 10.0, 10.0]); - let n = i32_array_all(3, 10); + let n = i64_array_all(3, 10); let err = width_bucket_kern(&[v, lo, hi, n]).unwrap_err(); let msg = format!("{err}"); diff --git a/datafusion/spark/src/function/mod.rs b/datafusion/spark/src/function/mod.rs index 3f4f94cfaaf8..d5dd60c3545a 100644 --- a/datafusion/spark/src/function/mod.rs +++ b/datafusion/spark/src/function/mod.rs @@ -33,6 +33,7 @@ pub mod lambda; pub mod map; pub mod math; pub mod misc; +mod null_utils; pub mod predicate; pub mod string; pub mod r#struct; diff --git a/datafusion/spark/src/function/null_utils.rs b/datafusion/spark/src/function/null_utils.rs new file mode 100644 index 000000000000..b25dc07d0e52 --- /dev/null +++ b/datafusion/spark/src/function/null_utils.rs @@ -0,0 +1,122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::Array; +use arrow::buffer::NullBuffer; +use arrow::datatypes::DataType; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::ColumnarValue; +use std::sync::Arc; + +pub(crate) enum NullMaskResolution { + /// Return NULL as the result (e.g., scalar inputs with at least one NULL) + ReturnNull, + /// No null mask needed (e.g., all scalar inputs are non-NULL) + NoMask, + /// Null mask to apply for arrays + Apply(NullBuffer), +} + +/// Compute NULL mask for the arguments using NullBuffer::union +pub(crate) fn compute_null_mask( + args: &[ColumnarValue], + number_rows: usize, +) -> Result { + // Check if all arguments are scalars + let all_scalars = args + .iter() + .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); + + if all_scalars { + // For scalars, check if any is NULL + for arg in args { + if let ColumnarValue::Scalar(scalar) = arg + && scalar.is_null() + { + return Ok(NullMaskResolution::ReturnNull); + } + } + // No NULLs in scalars + Ok(NullMaskResolution::NoMask) + } else { + // For arrays, compute NULL mask for each row using NullBuffer::union + let array_len = args + .iter() + .find_map(|arg| match arg { + ColumnarValue::Array(array) => Some(array.len()), + _ => None, + }) + .unwrap_or(number_rows); + + // Convert all scalars to arrays for uniform processing + let arrays: Result> = args + .iter() + .map(|arg| match arg { + ColumnarValue::Array(array) => Ok(Arc::clone(array)), + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(array_len), + }) + .collect(); + let arrays = arrays?; + + // Use NullBuffer::union to combine all null buffers + let combined_nulls = arrays + .iter() + .map(|arr| arr.nulls()) + .fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls)); + + match combined_nulls { + Some(nulls) => Ok(NullMaskResolution::Apply(nulls)), + None => Ok(NullMaskResolution::NoMask), + } + } +} + +/// Apply NULL mask to the result using NullBuffer::union +pub(crate) fn apply_null_mask( + result: ColumnarValue, + null_mask: NullMaskResolution, + return_type: &DataType, +) -> Result { + match (result, null_mask) { + // Scalar with ReturnNull mask means return NULL of the correct type + (ColumnarValue::Scalar(_), NullMaskResolution::ReturnNull) => { + Ok(ColumnarValue::Scalar(ScalarValue::try_from(return_type)?)) + } + // Scalar without mask, return as-is + (scalar @ ColumnarValue::Scalar(_), NullMaskResolution::NoMask) => Ok(scalar), + // Array with NULL mask - use NullBuffer::union to combine nulls + (ColumnarValue::Array(array), NullMaskResolution::Apply(null_mask)) => { + // Combine the result's existing nulls with our computed null mask + let combined_nulls = NullBuffer::union(array.nulls(), Some(&null_mask)); + + // Create new array with combined nulls + let new_array = array + .into_data() + .into_builder() + .nulls(combined_nulls) + .build()?; + + Ok(ColumnarValue::Array(Arc::new(arrow::array::make_array( + new_array, + )))) + } + // Array without NULL mask, return as-is + (array @ ColumnarValue::Array(_), NullMaskResolution::NoMask) => Ok(array), + // Edge cases that shouldn't happen in practice + (scalar, _) => Ok(scalar), + } +} diff --git a/datafusion/spark/src/function/string/base64.rs b/datafusion/spark/src/function/string/base64.rs new file mode 100644 index 000000000000..a171d4823b0f --- /dev/null +++ b/datafusion/spark/src/function/string/base64.rs @@ -0,0 +1,183 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::{Field, FieldRef}; +use datafusion_common::types::{NativeType, logical_string}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, exec_err, internal_err}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{Coercion, Expr, ReturnFieldArgs, TypeSignatureClass, lit}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions::expr_fn::{decode, encode}; + +/// Apache Spark base64 uses padded base64 encoding. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkBase64 { + signature: Signature, +} + +impl Default for SparkBase64 { + fn default() -> Self { + Self::new() + } +} + +impl SparkBase64 { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![Coercion::new_implicit( + TypeSignatureClass::Binary, + vec![TypeSignatureClass::Native(logical_string())], + NativeType::Binary, + )], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkBase64 { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "base64" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_type should not be called for {}", self.name()) + } + + fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result { + let [bin] = take_function_args(self.name(), args.arg_fields)?; + let return_type = match bin.data_type() { + DataType::LargeBinary => DataType::LargeUtf8, + _ => DataType::Utf8, + }; + Ok(Arc::new(Field::new( + self.name(), + return_type, + bin.is_nullable(), + ))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + exec_err!( + "invoke should not be called on a simplified {} function", + self.name() + ) + } + + fn simplify( + &self, + args: Vec, + _info: &SimplifyContext, + ) -> Result { + let [bin] = take_function_args(self.name(), args)?; + Ok(ExprSimplifyResult::Simplified(encode( + bin, + lit("base64pad"), + ))) + } +} + +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkUnBase64 { + signature: Signature, +} + +impl Default for SparkUnBase64 { + fn default() -> Self { + Self::new() + } +} + +impl SparkUnBase64 { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![Coercion::new_implicit( + TypeSignatureClass::Binary, + vec![TypeSignatureClass::Native(logical_string())], + NativeType::Binary, + )], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkUnBase64 { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "unbase64" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_type should not be called for {}", self.name()) + } + + fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result { + let [str] = take_function_args(self.name(), args.arg_fields)?; + let return_type = match str.data_type() { + DataType::LargeBinary => DataType::LargeBinary, + _ => DataType::Binary, + }; + Ok(Arc::new(Field::new( + self.name(), + return_type, + str.is_nullable(), + ))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + exec_err!("{} should have been simplified", self.name()) + } + + fn simplify( + &self, + args: Vec, + _info: &SimplifyContext, + ) -> Result { + let [bin] = take_function_args(self.name(), args)?; + Ok(ExprSimplifyResult::Simplified(decode( + bin, + lit("base64pad"), + ))) + } +} diff --git a/datafusion/spark/src/function/string/concat.rs b/datafusion/spark/src/function/string/concat.rs index 8e97e591fc35..b2073690fc44 100644 --- a/datafusion/spark/src/function/string/concat.rs +++ b/datafusion/spark/src/function/string/concat.rs @@ -15,20 +15,21 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::Array; -use arrow::buffer::NullBuffer; use arrow::datatypes::{DataType, Field}; use datafusion_common::arrow::datatypes::FieldRef; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::ReturnFieldArgs; use datafusion_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, - Volatility, + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; use datafusion_functions::string::concat::ConcatFunc; use std::any::Any; use std::sync::Arc; +use crate::function::null_utils::{ + NullMaskResolution, apply_null_mask, compute_null_mask, +}; + /// Spark-compatible `concat` expression /// /// @@ -52,10 +53,7 @@ impl Default for SparkConcat { impl SparkConcat { pub fn new() -> Self { Self { - signature: Signature::one_of( - vec![TypeSignature::UserDefined, TypeSignature::Nullary], - Volatility::Immutable, - ), + signature: Signature::user_defined(Volatility::Immutable), } } } @@ -87,21 +85,22 @@ impl ScalarUDFImpl for SparkConcat { ) } fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result { + use DataType::*; + // Spark semantics: concat returns NULL if ANY input is NULL let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); - Ok(Arc::new(Field::new("concat", DataType::Utf8, nullable))) - } -} + // Determine return type: Utf8View > LargeUtf8 > Utf8 + let mut dt = &Utf8; + for field in args.arg_fields { + let data_type = field.data_type(); + if data_type == &Utf8View || (data_type == &LargeUtf8 && dt != &Utf8View) { + dt = data_type; + } + } -/// Represents the null state for Spark concat -enum NullMaskResolution { - /// Return NULL as the result (e.g., scalar inputs with at least one NULL) - ReturnNull, - /// No null mask needed (e.g., all scalar inputs are non-NULL) - NoMask, - /// Null mask to apply for arrays - Apply(NullBuffer), + Ok(Arc::new(Field::new("concat", dt.clone(), nullable))) + } } /// Concatenates strings, returning NULL if any input is NULL @@ -118,9 +117,18 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { // Handle zero-argument case: return empty string if arg_values.is_empty() { - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8( - Some(String::new()), - ))); + let return_type = return_field.data_type(); + return match return_type { + DataType::Utf8View => Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + String::new(), + )))), + DataType::LargeUtf8 => Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8( + Some(String::new()), + ))), + _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8( + Some(String::new()), + ))), + }; } // Step 1: Check for NULL mask in incoming args @@ -128,11 +136,19 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { // If all scalars and any is NULL, return NULL immediately if matches!(null_mask, NullMaskResolution::ReturnNull) { - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + let return_type = return_field.data_type(); + return match return_type { + DataType::Utf8View => Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(None))), + DataType::LargeUtf8 => { + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(None))) + } + _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), + }; } // Step 2: Delegate to DataFusion's concat let concat_func = ConcatFunc::new(); + let return_type = return_field.data_type().clone(); let func_args = ScalarFunctionArgs { args: arg_values, arg_fields, @@ -143,103 +159,14 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { let result = concat_func.invoke_with_args(func_args)?; // Step 3: Apply NULL mask to result - apply_null_mask(result, null_mask) -} - -/// Compute NULL mask for the arguments using NullBuffer::union -fn compute_null_mask( - args: &[ColumnarValue], - number_rows: usize, -) -> Result { - // Check if all arguments are scalars - let all_scalars = args - .iter() - .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); - - if all_scalars { - // For scalars, check if any is NULL - for arg in args { - if let ColumnarValue::Scalar(scalar) = arg - && scalar.is_null() - { - return Ok(NullMaskResolution::ReturnNull); - } - } - // No NULLs in scalars - Ok(NullMaskResolution::NoMask) - } else { - // For arrays, compute NULL mask for each row using NullBuffer::union - let array_len = args - .iter() - .find_map(|arg| match arg { - ColumnarValue::Array(array) => Some(array.len()), - _ => None, - }) - .unwrap_or(number_rows); - - // Convert all scalars to arrays for uniform processing - let arrays: Result> = args - .iter() - .map(|arg| match arg { - ColumnarValue::Array(array) => Ok(Arc::clone(array)), - ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(array_len), - }) - .collect(); - let arrays = arrays?; - - // Use NullBuffer::union to combine all null buffers - let combined_nulls = arrays - .iter() - .map(|arr| arr.nulls()) - .fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls)); - - match combined_nulls { - Some(nulls) => Ok(NullMaskResolution::Apply(nulls)), - None => Ok(NullMaskResolution::NoMask), - } - } -} - -/// Apply NULL mask to the result using NullBuffer::union -fn apply_null_mask( - result: ColumnarValue, - null_mask: NullMaskResolution, -) -> Result { - match (result, null_mask) { - // Scalar with ReturnNull mask means return NULL - (ColumnarValue::Scalar(_), NullMaskResolution::ReturnNull) => { - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) - } - // Scalar without mask, return as-is - (scalar @ ColumnarValue::Scalar(_), NullMaskResolution::NoMask) => Ok(scalar), - // Array with NULL mask - use NullBuffer::union to combine nulls - (ColumnarValue::Array(array), NullMaskResolution::Apply(null_mask)) => { - // Combine the result's existing nulls with our computed null mask - let combined_nulls = NullBuffer::union(array.nulls(), Some(&null_mask)); - - // Create new array with combined nulls - let new_array = array - .into_data() - .into_builder() - .nulls(combined_nulls) - .build()?; - - Ok(ColumnarValue::Array(Arc::new(arrow::array::make_array( - new_array, - )))) - } - // Array without NULL mask, return as-is - (array @ ColumnarValue::Array(_), NullMaskResolution::NoMask) => Ok(array), - // Edge cases that shouldn't happen in practice - (scalar, _) => Ok(scalar), - } + apply_null_mask(result, null_mask, &return_type) } #[cfg(test)] mod tests { use super::*; use crate::function::utils::test::test_scalar_function; - use arrow::array::StringArray; + use arrow::array::{Array, StringArray}; use arrow::datatypes::{DataType, Field}; use datafusion_common::Result; use datafusion_expr::ReturnFieldArgs; @@ -277,6 +204,7 @@ mod tests { ); Ok(()) } + #[test] fn test_spark_concat_return_field_non_nullable() -> Result<()> { let func = SparkConcat::new(); diff --git a/datafusion/spark/src/function/string/format_string.rs b/datafusion/spark/src/function/string/format_string.rs index 73de985109b7..3adf50889594 100644 --- a/datafusion/spark/src/function/string/format_string.rs +++ b/datafusion/spark/src/function/string/format_string.rs @@ -598,7 +598,7 @@ impl ConversionType { pub fn validate(&self, arg_type: &DataType) -> Result<()> { match self { ConversionType::BooleanLower | ConversionType::BooleanUpper => { - if !matches!(arg_type, DataType::Boolean) { + if *arg_type != DataType::Boolean { return exec_err!( "Invalid argument type for boolean conversion: {:?}", arg_type @@ -1431,7 +1431,7 @@ impl ConversionSpecifier { let value = "null".to_string(); self.format_string(string, &value) } - _ => exec_err!("Invalid scalar value: {:?}", value), + _ => exec_err!("Invalid scalar value: {value}"), } } diff --git a/datafusion/spark/src/function/string/mod.rs b/datafusion/spark/src/function/string/mod.rs index 369d381a9c35..8859beca7799 100644 --- a/datafusion/spark/src/function/string/mod.rs +++ b/datafusion/spark/src/function/string/mod.rs @@ -16,6 +16,7 @@ // under the License. pub mod ascii; +pub mod base64; pub mod char; pub mod concat; pub mod elt; @@ -25,12 +26,14 @@ pub mod length; pub mod like; pub mod luhn_check; pub mod space; +pub mod substring; use datafusion_expr::ScalarUDF; use datafusion_functions::make_udf_function; use std::sync::Arc; make_udf_function!(ascii::SparkAscii, ascii); +make_udf_function!(base64::SparkBase64, base64); make_udf_function!(char::CharFunc, char); make_udf_function!(concat::SparkConcat, concat); make_udf_function!(ilike::SparkILike, ilike); @@ -40,6 +43,8 @@ make_udf_function!(like::SparkLike, like); make_udf_function!(luhn_check::SparkLuhnCheck, luhn_check); make_udf_function!(format_string::FormatStringFunc, format_string); make_udf_function!(space::SparkSpace, space); +make_udf_function!(substring::SparkSubstring, substring); +make_udf_function!(base64::SparkUnBase64, unbase64); pub mod expr_fn { use datafusion_functions::export_functions; @@ -49,6 +54,11 @@ pub mod expr_fn { "Returns the ASCII code point of the first character of string.", arg1 )); + export_functions!(( + base64, + "Encodes the input binary `bin` into a base64 string.", + bin + )); export_functions!(( char, "Returns the ASCII character having the binary equivalent to col. If col is larger than 256 the result is equivalent to char(col % 256).", @@ -90,11 +100,22 @@ pub mod expr_fn { strfmt args )); export_functions!((space, "Returns a string consisting of n spaces.", arg1)); + export_functions!(( + substring, + "Returns the substring from string `str` starting at position `pos` with length `length.", + str pos length + )); + export_functions!(( + unbase64, + "Decodes the input string `str` from a base64 string into binary data.", + str + )); } pub fn functions() -> Vec> { vec![ ascii(), + base64(), char(), concat(), elt(), @@ -104,5 +125,7 @@ pub fn functions() -> Vec> { luhn_check(), format_string(), space(), + substring(), + unbase64(), ] } diff --git a/datafusion/spark/src/function/string/substring.rs b/datafusion/spark/src/function/string/substring.rs new file mode 100644 index 000000000000..524262b12f19 --- /dev/null +++ b/datafusion/spark/src/function/string/substring.rs @@ -0,0 +1,258 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayBuilder, ArrayRef, AsArray, GenericStringBuilder, Int64Array, + OffsetSizeTrait, StringArrayType, StringViewBuilder, +}; +use arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::{Field, FieldRef}; +use datafusion_common::cast::as_int64_array; +use datafusion_common::types::{ + NativeType, logical_int32, logical_int64, logical_string, +}; +use datafusion_common::{Result, exec_err}; +use datafusion_expr::{Coercion, ReturnFieldArgs, TypeSignatureClass}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; +use datafusion_functions::unicode::substr::{enable_ascii_fast_path, get_true_start_end}; +use datafusion_functions::utils::make_scalar_function; +use std::any::Any; +use std::sync::Arc; + +/// Spark-compatible `substring` expression +/// +/// +/// Returns the substring from string starting at position pos with length len. +/// Position is 1-indexed. If pos is negative, it counts from the end of the string. +/// Returns NULL if any input is NULL. +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkSubstring { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkSubstring { + fn default() -> Self { + Self::new() + } +} + +impl SparkSubstring { + pub fn new() -> Self { + let string = Coercion::new_exact(TypeSignatureClass::Native(logical_string())); + let int64 = Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![TypeSignatureClass::Native(logical_int32())], + NativeType::Int64, + ); + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![string.clone(), int64.clone()]), + TypeSignature::Coercible(vec![ + string.clone(), + int64.clone(), + int64.clone(), + ]), + ], + Volatility::Immutable, + ) + .with_parameter_names(vec![ + "str".to_string(), + "pos".to_string(), + "length".to_string(), + ]) + .expect("valid parameter names"), + aliases: vec![String::from("substr")], + } + } +} + +impl ScalarUDFImpl for SparkSubstring { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "substring" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_substring, vec![])(&args.args) + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + datafusion_common::internal_err!( + "return_type should not be called for Spark substring" + ) + } + + fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result { + // Spark semantics: substring returns NULL if ANY input is NULL + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + + Ok(Arc::new(Field::new( + "substring", + args.arg_fields[0].data_type().clone(), + nullable, + ))) + } +} + +fn spark_substring(args: &[ArrayRef]) -> Result { + let start_array = as_int64_array(&args[1])?; + let length_array = if args.len() > 2 { + Some(as_int64_array(&args[2])?) + } else { + None + }; + + match args[0].data_type() { + DataType::Utf8 => spark_substring_impl( + &args[0].as_string::(), + start_array, + length_array, + GenericStringBuilder::::new(), + ), + DataType::LargeUtf8 => spark_substring_impl( + &args[0].as_string::(), + start_array, + length_array, + GenericStringBuilder::::new(), + ), + DataType::Utf8View => spark_substring_impl( + &args[0].as_string_view(), + start_array, + length_array, + StringViewBuilder::new(), + ), + other => exec_err!( + "Unsupported data type {other:?} for function spark_substring, expected Utf8View, Utf8 or LargeUtf8." + ), + } +} + +/// Convert Spark's start position to DataFusion's 1-based start position. +/// +/// Spark semantics: +/// - Positive start: 1-based index from beginning +/// - Zero start: treated as 1 +/// - Negative start: counts from end of string +/// +/// Returns the converted 1-based start position for use with `get_true_start_end`. +#[inline] +fn spark_start_to_datafusion_start(start: i64, len: usize) -> i64 { + if start >= 0 { + start.max(1) + } else { + let len_i64 = i64::try_from(len).unwrap_or(i64::MAX); + let start = start.saturating_add(len_i64).saturating_add(1); + start.max(1) + } +} + +trait StringArrayBuilder: ArrayBuilder { + fn append_value(&mut self, val: &str); + fn append_null(&mut self); +} + +impl StringArrayBuilder for GenericStringBuilder { + fn append_value(&mut self, val: &str) { + GenericStringBuilder::append_value(self, val); + } + fn append_null(&mut self) { + GenericStringBuilder::append_null(self); + } +} + +impl StringArrayBuilder for StringViewBuilder { + fn append_value(&mut self, val: &str) { + StringViewBuilder::append_value(self, val); + } + fn append_null(&mut self) { + StringViewBuilder::append_null(self); + } +} + +fn spark_substring_impl<'a, V, B>( + string_array: &V, + start_array: &Int64Array, + length_array: Option<&Int64Array>, + mut builder: B, +) -> Result +where + V: StringArrayType<'a>, + B: StringArrayBuilder, +{ + let is_ascii = enable_ascii_fast_path(string_array, start_array, length_array); + + for i in 0..string_array.len() { + if string_array.is_null(i) || start_array.is_null(i) { + builder.append_null(); + continue; + } + + if let Some(len_arr) = length_array + && len_arr.is_null(i) + { + builder.append_null(); + continue; + } + + let string = string_array.value(i); + let start = start_array.value(i); + let len_opt = length_array.map(|arr| arr.value(i)); + + // Spark: negative length returns empty string + if let Some(len) = len_opt + && len < 0 + { + builder.append_value(""); + continue; + } + + let string_len = if is_ascii { + string.len() + } else { + string.chars().count() + }; + + let adjusted_start = spark_start_to_datafusion_start(start, string_len); + + let (byte_start, byte_end) = get_true_start_end( + string, + adjusted_start, + len_opt.map(|l| l as u64), + is_ascii, + ); + let substr = &string[byte_start..byte_end]; + builder.append_value(substr); + } + + Ok(builder.finish()) +} diff --git a/datafusion/spark/src/function/url/parse_url.rs b/datafusion/spark/src/function/url/parse_url.rs index e82ef28045a3..7beb02f7750f 100644 --- a/datafusion/spark/src/function/url/parse_url.rs +++ b/datafusion/spark/src/function/url/parse_url.rs @@ -217,7 +217,12 @@ pub fn spark_handled_parse_url( handler_err, ) } - _ => exec_err!("{} expects STRING arguments, got {:?}", "`parse_url`", args), + _ => exec_err!( + "`parse_url` expects STRING arguments, got ({}, {}, {})", + url.data_type(), + part.data_type(), + key.data_type() + ), } } else { // The 'key' argument is omitted, assume all values are null @@ -253,7 +258,11 @@ pub fn spark_handled_parse_url( handler_err, ) } - _ => exec_err!("{} expects STRING arguments, got {:?}", "`parse_url`", args), + _ => exec_err!( + "`parse_url` expects STRING arguments, got ({}, {})", + url.data_type(), + part.data_type() + ), } } } diff --git a/datafusion/spark/src/lib.rs b/datafusion/spark/src/lib.rs index aad3ceed68ce..9575f560b8d0 100644 --- a/datafusion/spark/src/lib.rs +++ b/datafusion/spark/src/lib.rs @@ -22,7 +22,6 @@ #![cfg_attr(docsrs, feature(doc_cfg))] // Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] -#![deny(clippy::allow_attributes)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] //! Spark Expression packages for [DataFusion]. @@ -92,9 +91,49 @@ //! let expr = sha2(col("my_data"), lit(256)); //! ``` //! +//! # Example: using the Spark expression planner +//! +//! The [`planner::SparkFunctionPlanner`] provides Spark-compatible expression +//! planning, such as mapping SQL `EXTRACT` expressions to Spark's `date_part` +//! function. To use it, register it with your session context: +//! +//! ```ignore +//! use std::sync::Arc; +//! use datafusion::prelude::SessionContext; +//! use datafusion_spark::planner::SparkFunctionPlanner; +//! +//! let mut ctx = SessionContext::new(); +//! // Register the Spark expression planner +//! ctx.register_expr_planner(Arc::new(SparkFunctionPlanner))?; +//! // Now EXTRACT expressions will use Spark semantics +//! let df = ctx.sql("SELECT EXTRACT(YEAR FROM timestamp_col) FROM my_table").await?; +//! ``` +//! //![`Expr`]: datafusion_expr::Expr +//! +//! # Example: enabling Apache Spark features with SessionStateBuilder +//! +//! The recommended way to enable Apache Spark compatibility is to use the +//! `SessionStateBuilderSpark` extension trait. This registers all +//! Apache Spark functions (scalar, aggregate, window, and table) as well as the Apache Spark +//! expression planner. +//! +//! Enable the `core` feature in your `Cargo.toml`: +//! ```toml +//! datafusion-spark = { version = "X", features = ["core"] } +//! ``` +//! +//! Then use the extension trait - see [`SessionStateBuilderSpark::with_spark_features`] +//! for an example. pub mod function; +pub mod planner; + +#[cfg(feature = "core")] +mod session_state; + +#[cfg(feature = "core")] +pub use session_state::SessionStateBuilderSpark; use datafusion_catalog::TableFunction; use datafusion_common::Result; diff --git a/datafusion/spark/src/planner.rs b/datafusion/spark/src/planner.rs new file mode 100644 index 000000000000..2dafbb1f9a57 --- /dev/null +++ b/datafusion/spark/src/planner.rs @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::Expr; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::planner::{ExprPlanner, PlannerResult}; + +#[derive(Default, Debug)] +pub struct SparkFunctionPlanner; + +impl ExprPlanner for SparkFunctionPlanner { + fn plan_extract( + &self, + args: Vec, + ) -> datafusion_common::Result>> { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf(crate::function::datetime::date_part(), args), + ))) + } + + fn plan_substring( + &self, + args: Vec, + ) -> datafusion_common::Result>> { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf(crate::function::string::substring(), args), + ))) + } +} diff --git a/datafusion/spark/src/session_state.rs b/datafusion/spark/src/session_state.rs new file mode 100644 index 000000000000..e39de3a5888e --- /dev/null +++ b/datafusion/spark/src/session_state.rs @@ -0,0 +1,111 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::sync::Arc; + +use datafusion::execution::SessionStateBuilder; + +use crate::planner::SparkFunctionPlanner; +use crate::{ + all_default_aggregate_functions, all_default_scalar_functions, + all_default_table_functions, all_default_window_functions, +}; + +/// Extension trait for adding Apache Spark features to [`SessionStateBuilder`]. +/// +/// This trait provides a convenient way to register all Apache Spark-compatible +/// functions and planners with a DataFusion session. +/// +/// # Example +/// +/// ```rust +/// use datafusion::execution::SessionStateBuilder; +/// use datafusion_spark::SessionStateBuilderSpark; +/// +/// // Create a SessionState with Apache Spark features enabled +/// // note: the order matters here, `with_spark_features` should be +/// // called after `with_default_features` to overwrite any existing functions +/// let state = SessionStateBuilder::new() +/// .with_default_features() +/// .with_spark_features() +/// .build(); +/// ``` +pub trait SessionStateBuilderSpark { + /// Adds all expr_planners, scalar, aggregate, window and table functions + /// compatible with Apache Spark. + /// + /// Note: This overwrites any previously registered items with the same name. + fn with_spark_features(self) -> Self; +} + +impl SessionStateBuilderSpark for SessionStateBuilder { + fn with_spark_features(mut self) -> Self { + self.expr_planners() + .get_or_insert_with(Vec::new) + // planners are evaluated in order of insertion. Push Apache Spark function planner to the front + // to take precedence over others + .insert(0, Arc::new(SparkFunctionPlanner)); + + self.scalar_functions() + .get_or_insert_with(Vec::new) + .extend(all_default_scalar_functions()); + + self.aggregate_functions() + .get_or_insert_with(Vec::new) + .extend(all_default_aggregate_functions()); + + self.window_functions() + .get_or_insert_with(Vec::new) + .extend(all_default_window_functions()); + + self.table_functions() + .get_or_insert_with(HashMap::new) + .extend( + all_default_table_functions() + .into_iter() + .map(|f| (f.name().to_string(), f)), + ); + + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_session_state_with_spark_features() { + let state = SessionStateBuilder::new().with_spark_features().build(); + + assert!( + state.scalar_functions().contains_key("sha2"), + "Apache Spark scalar function 'sha2' should be registered" + ); + + assert!( + state.aggregate_functions().contains_key("try_sum"), + "Apache Spark aggregate function 'try_sum' should be registered" + ); + + assert!( + !state.expr_planners().is_empty(), + "Apache Spark expr planners should be registered" + ); + } +} diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index a814292a3d71..b7338cb764d7 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -56,6 +56,7 @@ bigdecimal = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true, features = ["sql"] } datafusion-expr = { workspace = true, features = ["sql"] } +datafusion-functions-nested = { workspace = true, features = ["sql"] } indexmap = { workspace = true } log = { workspace = true } recursive = { workspace = true, optional = true } diff --git a/datafusion/sql/src/expr/binary_op.rs b/datafusion/sql/src/expr/binary_op.rs index edad5bbc6daa..4e9025e02e0c 100644 --- a/datafusion/sql/src/expr/binary_op.rs +++ b/datafusion/sql/src/expr/binary_op.rs @@ -22,7 +22,7 @@ use sqlparser::ast::BinaryOperator; impl SqlToRel<'_, S> { pub(crate) fn parse_sql_binary_op(&self, op: &BinaryOperator) -> Result { - match *op { + match op { BinaryOperator::Gt => Ok(Operator::Gt), BinaryOperator::GtEq => Ok(Operator::GtEq), BinaryOperator::Lt => Ok(Operator::Lt), @@ -68,6 +68,7 @@ impl SqlToRel<'_, S> { BinaryOperator::Question => Ok(Operator::Question), BinaryOperator::QuestionAnd => Ok(Operator::QuestionAnd), BinaryOperator::QuestionPipe => Ok(Operator::QuestionPipe), + BinaryOperator::Custom(s) if s == ":" => Ok(Operator::Colon), _ => not_impl_err!("Unsupported binary operator: {:?}", op), } } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 641f3bb8dcad..c81575366fb3 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -122,7 +122,7 @@ impl FunctionArgs { null_treatment: null_treatment.map(|v| v.into()), distinct: false, within_group, - function_without_parentheses: matches!(args, FunctionArguments::None), + function_without_parentheses: args == FunctionArguments::None, }); }; diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 4c23c7a818be..cca09df0db02 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -37,7 +37,7 @@ impl SqlToRel<'_, S> { planner_context: &mut PlannerContext, ) -> Result { let id_span = id.span; - if id.value.starts_with('@') { + if id.value.starts_with('@') && id.quote_style.is_none() { // TODO: figure out if ScalarVariables should be insensitive. let var_names = vec![id.value]; let field = self @@ -76,15 +76,16 @@ impl SqlToRel<'_, S> { } // Check the outer query schema - if let Some(outer) = planner_context.outer_query_schema() - && let Ok((qualifier, field)) = + for outer in planner_context.outer_schemas_iter() { + if let Ok((qualifier, field)) = outer.qualified_field_with_unqualified_name(normalize_ident.as_str()) - { - // Found an exact match on a qualified name in the outer plan schema, so this is an outer reference column - return Ok(Expr::OuterReferenceColumn( - Arc::clone(field), - Column::from((qualifier, field)), - )); + { + // Found an exact match on a qualified name in the outer plan schema, so this is an outer reference column + return Ok(Expr::OuterReferenceColumn( + Arc::clone(field), + Column::from((qualifier, field)), + )); + } } // Default case @@ -111,7 +112,7 @@ impl SqlToRel<'_, S> { .filter_map(|id| Span::try_from_sqlparser_span(id.span)), ); - if ids[0].value.starts_with('@') { + if ids[0].value.starts_with('@') && ids[0].quote_style.is_none() { let var_names: Vec<_> = ids .into_iter() .map(|id| self.ident_normalizer.normalize(id)) @@ -172,14 +173,14 @@ impl SqlToRel<'_, S> { not_impl_err!("compound identifier: {ids:?}") } else { // Check the outer_query_schema and try to find a match - if let Some(outer) = planner_context.outer_query_schema() { + for outer in planner_context.outer_schemas_iter() { let search_result = search_dfschema(&ids, outer); - match search_result { + let result = match search_result { // Found matching field with spare identifier(s) for nested field(s) in structure Some((field, qualifier, nested_names)) if !nested_names.is_empty() => { - // TODO: remove when can support nested identifiers for OuterReferenceColumn + // TODO: remove this when we have support for nested identifiers for OuterReferenceColumn not_impl_err!( "Nested identifiers are not yet supported for OuterReferenceColumn {}", Column::from((qualifier, field)) @@ -195,26 +196,20 @@ impl SqlToRel<'_, S> { )) } // Found no matching field, will return a default - None => { - let s = &ids[0..ids.len()]; - // safe unwrap as s can never be empty or exceed the bounds - let (relation, column_name) = - form_identifier(s).unwrap(); - Ok(Expr::Column(Column::new(relation, column_name))) - } - } - } else { - let s = &ids[0..ids.len()]; - // Safe unwrap as s can never be empty or exceed the bounds - let (relation, column_name) = form_identifier(s).unwrap(); - let mut column = Column::new(relation, column_name); - if self.options.collect_spans - && let Some(span) = ids_span - { - column.spans_mut().add_span(span); - } - Ok(Expr::Column(column)) + None => continue, + }; + return result; + } + // Safe unwrap as column name can never be empty or exceed the bounds + let (relation, column_name) = + form_identifier(&ids[0..ids.len()]).unwrap(); + let mut column = Column::new(relation, column_name); + if self.options.collect_spans + && let Some(span) = ids_span + { + column.spans_mut().add_span(span); } + Ok(Expr::Column(column)) } } } diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index fcd7d6376d21..7902eed1e692 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -22,8 +22,8 @@ use datafusion_expr::planner::{ use sqlparser::ast::{ AccessExpr, BinaryOperator, CastFormat, CastKind, CeilFloorKind, DataType as SQLDataType, DateTimeField, DictionaryField, Expr as SQLExpr, - ExprWithAlias as SQLExprWithAlias, MapEntry, StructField, Subscript, TrimWhereField, - TypedString, Value, ValueWithSpan, + ExprWithAlias as SQLExprWithAlias, JsonPath, MapEntry, StructField, Subscript, + TrimWhereField, TypedString, Value, ValueWithSpan, }; use datafusion_common::{ @@ -32,6 +32,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::expr::SetQuantifier; use datafusion_expr::expr::{InList, WildcardOptions}; use datafusion_expr::{ Between, BinaryExpr, Cast, Expr, ExprSchemable, GetFieldAccess, Like, Literal, @@ -39,6 +40,7 @@ use datafusion_expr::{ }; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; +use datafusion_functions_nested::expr_fn::array_has; mod binary_op; mod function; @@ -265,11 +267,16 @@ impl SqlToRel<'_, S> { planner_context, ), + SQLExpr::Cast { array: true, .. } => { + not_impl_err!("`CAST(... AS type ARRAY`) not supported") + } + SQLExpr::Cast { kind: CastKind::Cast | CastKind::DoubleColon, expr, data_type, format, + array: false, } => { self.sql_cast_to_expr(*expr, &data_type, format, schema, planner_context) } @@ -279,6 +286,7 @@ impl SqlToRel<'_, S> { expr, data_type, format, + array: false, } => { if let Some(format) = format { return not_impl_err!("CAST with format is not supported: {format}"); @@ -594,32 +602,44 @@ impl SqlToRel<'_, S> { // ANY/SOME are equivalent, this field specifies which the user // specified but it doesn't affect the plan so ignore the field is_some: _, - } => { - let mut binary_expr = RawBinaryExpr { - op: compare_op, - left: self.sql_expr_to_logical_expr( - *left, - schema, - planner_context, - )?, - right: self.sql_expr_to_logical_expr( - *right, - schema, - planner_context, - )?, - }; - for planner in self.context_provider.get_expr_planners() { - match planner.plan_any(binary_expr)? { - PlannerResult::Planned(expr) => { - return Ok(expr); - } - PlannerResult::Original(expr) => { - binary_expr = expr; - } + } => match *right { + SQLExpr::Subquery(subquery) => self.parse_set_comparison_subquery( + *left, + *subquery, + &compare_op, + SetQuantifier::Any, + schema, + planner_context, + ), + _ => { + if compare_op != BinaryOperator::Eq { + plan_err!( + "Unsupported AnyOp: '{compare_op}', only '=' is supported" + ) + } else { + let left_expr = + self.sql_to_expr(*left, schema, planner_context)?; + let right_expr = + self.sql_to_expr(*right, schema, planner_context)?; + Ok(array_has(right_expr, left_expr)) } } - not_impl_err!("AnyOp not supported by ExprPlanner: {binary_expr:?}") - } + }, + SQLExpr::AllOp { + left, + compare_op, + right, + } => match *right { + SQLExpr::Subquery(subquery) => self.parse_set_comparison_subquery( + *left, + *subquery, + &compare_op, + SetQuantifier::All, + schema, + planner_context, + ), + _ => not_impl_err!("ALL only supports subquery comparison currently"), + }, #[expect(deprecated)] SQLExpr::Wildcard(_token) => Ok(Expr::Wildcard { qualifier: None, @@ -631,10 +651,36 @@ impl SqlToRel<'_, S> { options: Box::new(WildcardOptions::default()), }), SQLExpr::Tuple(values) => self.parse_tuple(schema, planner_context, values), + SQLExpr::JsonAccess { value, path } => { + self.parse_json_access(schema, planner_context, value, &path) + } _ => not_impl_err!("Unsupported ast node in sqltorel: {sql:?}"), } } + fn parse_json_access( + &self, + schema: &DFSchema, + planner_context: &mut PlannerContext, + value: Box, + path: &JsonPath, + ) -> Result { + let json_path = path.to_string(); + let json_path = if let Some(json_path) = json_path.strip_prefix(":") { + // sqlparser's JsonPath display adds an extra `:` at the beginning. + json_path.to_owned() + } else { + json_path + }; + self.build_logical_expr( + BinaryOperator::Custom(":".to_owned()), + self.sql_to_expr(*value, schema, planner_context)?, + // pass json path as a string literal, let the impl parse it when needed. + Expr::Literal(ScalarValue::Utf8(Some(json_path)), None), + schema, + ) + } + /// Parses a struct(..) expression and plans it creation fn parse_struct( &self, diff --git a/datafusion/sql/src/expr/subquery.rs b/datafusion/sql/src/expr/subquery.rs index ec34ff3d5342..662c44f6f262 100644 --- a/datafusion/sql/src/expr/subquery.rs +++ b/datafusion/sql/src/expr/subquery.rs @@ -17,10 +17,10 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{DFSchema, Diagnostic, Result, Span, Spans, plan_err}; -use datafusion_expr::expr::{Exists, InSubquery}; +use datafusion_expr::expr::{Exists, InSubquery, SetComparison, SetQuantifier}; use datafusion_expr::{Expr, LogicalPlan, Subquery}; use sqlparser::ast::Expr as SQLExpr; -use sqlparser::ast::{Query, SelectItem, SetExpr}; +use sqlparser::ast::{BinaryOperator, Query, SelectItem, SetExpr}; use std::sync::Arc; impl SqlToRel<'_, S> { @@ -31,11 +31,10 @@ impl SqlToRel<'_, S> { input_schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let old_outer_query_schema = - planner_context.set_outer_query_schema(Some(input_schema.clone().into())); + planner_context.append_outer_query_schema(input_schema.clone().into()); let sub_plan = self.query_to_plan(subquery, planner_context)?; let outer_ref_columns = sub_plan.all_out_ref_exprs(); - planner_context.set_outer_query_schema(old_outer_query_schema); + planner_context.pop_outer_query_schema(); Ok(Expr::Exists(Exists { subquery: Subquery { subquery: Arc::new(sub_plan), @@ -54,8 +53,7 @@ impl SqlToRel<'_, S> { input_schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let old_outer_query_schema = - planner_context.set_outer_query_schema(Some(input_schema.clone().into())); + planner_context.append_outer_query_schema(Arc::new(input_schema.clone())); let mut spans = Spans::new(); if let SetExpr::Select(select) = &subquery.body.as_ref() { @@ -70,7 +68,7 @@ impl SqlToRel<'_, S> { let sub_plan = self.query_to_plan(subquery, planner_context)?; let outer_ref_columns = sub_plan.all_out_ref_exprs(); - planner_context.set_outer_query_schema(old_outer_query_schema); + planner_context.pop_outer_query_schema(); self.validate_single_column( &sub_plan, @@ -98,8 +96,7 @@ impl SqlToRel<'_, S> { input_schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let old_outer_query_schema = - planner_context.set_outer_query_schema(Some(input_schema.clone().into())); + planner_context.append_outer_query_schema(Arc::new(input_schema.clone())); let mut spans = Spans::new(); if let SetExpr::Select(select) = subquery.body.as_ref() { for item in &select.projection { @@ -112,7 +109,7 @@ impl SqlToRel<'_, S> { } let sub_plan = self.query_to_plan(subquery, planner_context)?; let outer_ref_columns = sub_plan.all_out_ref_exprs(); - planner_context.set_outer_query_schema(old_outer_query_schema); + planner_context.pop_outer_query_schema(); self.validate_single_column( &sub_plan, @@ -162,4 +159,50 @@ impl SqlToRel<'_, S> { diagnostic.add_help(help_message, None); diagnostic } + + pub(super) fn parse_set_comparison_subquery( + &self, + left_expr: SQLExpr, + subquery: Query, + compare_op: &BinaryOperator, + quantifier: SetQuantifier, + input_schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + planner_context.append_outer_query_schema(Arc::new(input_schema.clone())); + + let mut spans = Spans::new(); + if let SetExpr::Select(select) = subquery.body.as_ref() { + for item in &select.projection { + if let SelectItem::ExprWithAlias { alias, .. } = item + && let Some(span) = Span::try_from_sqlparser_span(alias.span) + { + spans.add_span(span); + } + } + } + + let sub_plan = self.query_to_plan(subquery, planner_context)?; + let outer_ref_columns = sub_plan.all_out_ref_exprs(); + planner_context.pop_outer_query_schema(); + + self.validate_single_column( + &sub_plan, + &spans, + "Too many columns! The subquery should only return one column", + "Select only one column in the subquery", + )?; + + let expr_obj = self.sql_to_expr(left_expr, input_schema, planner_context)?; + Ok(Expr::SetComparison(SetComparison::new( + Box::new(expr_obj), + Subquery { + subquery: Arc::new(sub_plan), + outer_ref_columns, + spans, + }, + self.parse_sql_binary_op(compare_op)?, + quantifier, + ))) + } } diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs index b21eb52920ab..7fef670933f9 100644 --- a/datafusion/sql/src/lib.rs +++ b/datafusion/sql/src/lib.rs @@ -23,7 +23,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] -#![deny(clippy::allow_attributes)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] //! This crate provides: diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index 27db2b0f9757..1ecf90b7947c 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -363,28 +363,49 @@ const DEFAULT_DIALECT: GenericDialect = GenericDialect {}; /// # Ok(()) /// # } /// ``` -pub struct DFParserBuilder<'a> { - /// The SQL string to parse - sql: &'a str, +pub struct DFParserBuilder<'a, 'b> { + /// Parser input: either raw SQL or tokens + input: ParserInput<'a>, /// The Dialect to use (defaults to [`GenericDialect`] - dialect: &'a dyn Dialect, + dialect: &'b dyn Dialect, /// The recursion limit while parsing recursion_limit: usize, } -impl<'a> DFParserBuilder<'a> { +/// Describes a possible input for parser +pub enum ParserInput<'a> { + /// Raw SQL. Tokenization will be performed automatically as a + /// part of [`DFParserBuilder::build`] + Sql(&'a str), + /// Tokens + Tokens(Vec), +} + +impl<'a> From<&'a str> for ParserInput<'a> { + fn from(sql: &'a str) -> Self { + Self::Sql(sql) + } +} + +impl From> for ParserInput<'static> { + fn from(tokens: Vec) -> Self { + Self::Tokens(tokens) + } +} + +impl<'a, 'b> DFParserBuilder<'a, 'b> { /// Create a new parser builder for the specified tokens using the /// [`GenericDialect`]. - pub fn new(sql: &'a str) -> Self { + pub fn new(input: impl Into>) -> Self { Self { - sql, + input: input.into(), dialect: &DEFAULT_DIALECT, recursion_limit: DEFAULT_RECURSION_LIMIT, } } /// Adjust the parser builder's dialect. Defaults to [`GenericDialect`] - pub fn with_dialect(mut self, dialect: &'a dyn Dialect) -> Self { + pub fn with_dialect(mut self, dialect: &'b dyn Dialect) -> Self { self.dialect = dialect; self } @@ -395,12 +416,18 @@ impl<'a> DFParserBuilder<'a> { self } - pub fn build(self) -> Result, DataFusionError> { - let mut tokenizer = Tokenizer::new(self.dialect, self.sql); - // Convert TokenizerError -> ParserError - let tokens = tokenizer - .tokenize_with_location() - .map_err(ParserError::from)?; + /// Build resulting parser + pub fn build(self) -> Result, DataFusionError> { + let tokens = match self.input { + ParserInput::Tokens(tokens) => tokens, + ParserInput::Sql(sql) => { + let mut tokenizer = Tokenizer::new(self.dialect, sql); + // Convert TokenizerError -> ParserError + tokenizer + .tokenize_with_location() + .map_err(ParserError::from)? + } + }; Ok(DFParser { parser: Parser::new(self.dialect) @@ -658,7 +685,7 @@ impl<'a> DFParser<'a> { } } } else { - let token = self.parser.next_token(); + let token = self.parser.peek_token(); if token == Token::EOF || token == Token::SemiColon { break; } else { @@ -1079,7 +1106,7 @@ impl<'a> DFParser<'a> { } } } else { - let token = self.parser.next_token(); + let token = self.parser.peek_token(); if token == Token::EOF || token == Token::SemiColon { break; } else { @@ -1162,7 +1189,7 @@ mod tests { BinaryOperator, DataType, ExactNumberInfo, Expr, Ident, ValueWithSpan, }; use sqlparser::dialect::SnowflakeDialect; - use sqlparser::tokenizer::Span; + use sqlparser::tokenizer::{Location, Span, Whitespace}; fn expect_parse_ok(sql: &str, expected: Statement) -> Result<(), DataFusionError> { let statements = DFParser::parse_sql(sql)?; @@ -2026,6 +2053,78 @@ mod tests { ); } + #[test] + fn test_multistatement() { + let sql = "COPY foo TO bar STORED AS CSV; \ + CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv'; \ + RESET var;"; + let statements = DFParser::parse_sql(sql).unwrap(); + assert_eq!( + statements, + vec![ + Statement::CopyTo(CopyToStatement { + source: object_name("foo"), + target: "bar".to_string(), + partitioned_by: vec![], + stored_as: Some("CSV".to_owned()), + options: vec![], + }), + { + let name = ObjectName::from(vec![Ident::from("t")]); + let display = None; + Statement::CreateExternalTable(CreateExternalTable { + name: name.clone(), + columns: vec![make_column_def("c1", DataType::Int(display))], + file_type: "CSV".to_string(), + location: "foo.csv".into(), + table_partition_cols: vec![], + order_exprs: vec![], + if_not_exists: false, + or_replace: false, + temporary: false, + unbounded: false, + options: vec![], + constraints: vec![], + }) + }, + { + let name = ObjectName::from(vec![Ident::from("var")]); + Statement::Reset(ResetStatement::Variable(name)) + } + ] + ); + } + + #[test] + fn test_custom_tokens() { + // Span mock. + let span = Span { + start: Location { line: 0, column: 0 }, + end: Location { line: 0, column: 0 }, + }; + let tokens = vec![ + TokenWithSpan { + token: Token::make_keyword("SELECT"), + span, + }, + TokenWithSpan { + token: Token::Whitespace(Whitespace::Space), + span, + }, + TokenWithSpan { + token: Token::Placeholder("1".to_string()), + span, + }, + ]; + + let statements = DFParserBuilder::new(tokens) + .build() + .unwrap() + .parse_statements() + .unwrap(); + assert_eq!(statements.len(), 1); + } + fn expect_parse_expr_ok(sql: &str, expected: ExprWithAlias) { let expr = DFParser::parse_sql_into_expr(sql).unwrap(); assert_eq!(expr, expected, "actual:\n{expr:#?}"); diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index eb798b71e455..307f28e8ff9a 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -261,8 +261,10 @@ pub struct PlannerContext { /// Map of CTE name to logical plan of the WITH clause. /// Use `Arc` to allow cheap cloning ctes: HashMap>, - /// The query schema of the outer query plan, used to resolve the columns in subquery - outer_query_schema: Option, + + /// The queries schemas of outer query relations, used to resolve the outer referenced + /// columns in subquery (recursive aware) + outer_queries_schemas_stack: Vec, /// The joined schemas of all FROM clauses planned so far. When planning LATERAL /// FROM clauses, this should become a suffix of the `outer_query_schema`. outer_from_schema: Option, @@ -282,7 +284,7 @@ impl PlannerContext { Self { prepare_param_data_types: Arc::new(vec![]), ctes: HashMap::new(), - outer_query_schema: None, + outer_queries_schemas_stack: vec![], outer_from_schema: None, create_table_schema: None, } @@ -297,19 +299,42 @@ impl PlannerContext { self } - // Return a reference to the outer query's schema - pub fn outer_query_schema(&self) -> Option<&DFSchema> { - self.outer_query_schema.as_ref().map(|s| s.as_ref()) + /// Return the stack of outer relations' schemas, the outer most + /// relation are at the first entry + pub fn outer_queries_schemas(&self) -> &[DFSchemaRef] { + &self.outer_queries_schemas_stack + } + + /// Return an iterator of the subquery relations' schemas, innermost + /// relation is returned first. + /// + /// This order corresponds to the order of resolution when looking up column + /// references in subqueries, which start from the innermost relation and + /// then look up the outer relations one by one until a match is found or no + /// more outer relation exist. + /// + /// NOTE this is *REVERSED* order of [`Self::outer_queries_schemas`] + /// + /// This is useful to resolve the column reference in the subquery by + /// looking up the outer query schemas one by one. + pub fn outer_schemas_iter(&self) -> impl Iterator { + self.outer_queries_schemas_stack.iter().rev() } /// Sets the outer query schema, returning the existing one, if /// any - pub fn set_outer_query_schema( - &mut self, - mut schema: Option, - ) -> Option { - std::mem::swap(&mut self.outer_query_schema, &mut schema); - schema + pub fn append_outer_query_schema(&mut self, schema: DFSchemaRef) { + self.outer_queries_schemas_stack.push(schema); + } + + /// The schema of the adjacent outer relation + pub fn latest_outer_query_schema(&self) -> Option<&DFSchemaRef> { + self.outer_queries_schemas_stack.last() + } + + /// Remove the schema of the adjacent outer relation + pub fn pop_outer_query_schema(&mut self) -> Option { + self.outer_queries_schemas_stack.pop() } pub fn set_table_schema( @@ -688,8 +713,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLDataType::Timestamp(precision, tz_info) if precision.is_none() || [0, 3, 6, 9].contains(&precision.unwrap()) => { - let tz = if matches!(tz_info, TimezoneInfo::Tz) - || matches!(tz_info, TimezoneInfo::WithTimeZone) + let tz = if *tz_info == TimezoneInfo::Tz + || *tz_info == TimezoneInfo::WithTimeZone { // Timestamp With Time Zone // INPUT : [SQLDataType] TimestampTz + [Config] Time Zone @@ -710,8 +735,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } SQLDataType::Date => Ok(DataType::Date32), SQLDataType::Time(None, tz_info) => { - if matches!(tz_info, TimezoneInfo::None) - || matches!(tz_info, TimezoneInfo::WithoutTimeZone) + if *tz_info == TimezoneInfo::None + || *tz_info == TimezoneInfo::WithoutTimeZone { Ok(DataType::Time64(TimeUnit::Nanosecond)) } else { @@ -823,7 +848,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { | SQLDataType::HugeInt | SQLDataType::UHugeInt | SQLDataType::UBigInt - | SQLDataType::TimestampNtz + | SQLDataType::TimestampNtz{..} | SQLDataType::NamedTable { .. } | SQLDataType::TsVector | SQLDataType::TsQuery diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index eba48a2401c3..1b7bb856a592 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -170,6 +170,7 @@ impl SqlToRel<'_, S> { name: alias, // Apply to all fields columns: vec![], + explicit: true, }, ), PipeOperator::Union { diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index 3115d8dfffbd..6558763ca4e4 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -93,7 +93,7 @@ impl SqlToRel<'_, S> { match self.create_extension_relation(relation, planner_context)? { RelationPlanning::Planned(planned) => planned, RelationPlanning::Original(original) => { - self.create_default_relation(original, planner_context)? + Box::new(self.create_default_relation(*original, planner_context)?) } }; @@ -112,7 +112,7 @@ impl SqlToRel<'_, S> { ) -> Result { let planners = self.context_provider.get_relation_planners(); if planners.is_empty() { - return Ok(RelationPlanning::Original(relation)); + return Ok(RelationPlanning::Original(Box::new(relation))); } let mut current_relation = relation; @@ -127,12 +127,12 @@ impl SqlToRel<'_, S> { return Ok(RelationPlanning::Planned(planned)); } RelationPlanning::Original(original) => { - current_relation = original; + current_relation = *original; } } } - Ok(RelationPlanning::Original(current_relation)) + Ok(RelationPlanning::Original(Box::new(current_relation))) } fn create_default_relation( @@ -262,9 +262,10 @@ impl SqlToRel<'_, S> { } => { let tbl_func_ref = self.object_name_to_table_reference(name)?; let schema = planner_context - .outer_query_schema() + .outer_queries_schemas() + .last() .cloned() - .unwrap_or_else(DFSchema::empty); + .unwrap_or_else(|| Arc::new(DFSchema::empty())); let func_args = args .into_iter() .map(|arg| match arg { @@ -310,20 +311,24 @@ impl SqlToRel<'_, S> { let old_from_schema = planner_context .set_outer_from_schema(None) .unwrap_or_else(|| Arc::new(DFSchema::empty())); - let new_query_schema = match planner_context.outer_query_schema() { - Some(old_query_schema) => { + let outer_query_schema = planner_context.pop_outer_query_schema(); + let new_query_schema = match outer_query_schema { + Some(ref old_query_schema) => { let mut new_query_schema = old_from_schema.as_ref().clone(); - new_query_schema.merge(old_query_schema); - Some(Arc::new(new_query_schema)) + new_query_schema.merge(old_query_schema.as_ref()); + Arc::new(new_query_schema) } - None => Some(Arc::clone(&old_from_schema)), + None => Arc::clone(&old_from_schema), }; - let old_query_schema = planner_context.set_outer_query_schema(new_query_schema); + planner_context.append_outer_query_schema(new_query_schema); let plan = self.create_relation(subquery, planner_context)?; let outer_ref_columns = plan.all_out_ref_exprs(); - planner_context.set_outer_query_schema(old_query_schema); + planner_context.pop_outer_query_schema(); + if let Some(schema) = outer_query_schema { + planner_context.append_outer_query_schema(schema); + } planner_context.set_outer_from_schema(Some(old_from_schema)); // We can omit the subquery wrapper if there are no columns diff --git a/datafusion/sql/src/resolve.rs b/datafusion/sql/src/resolve.rs index 148e886161fc..955dbb86602a 100644 --- a/datafusion/sql/src/resolve.rs +++ b/datafusion/sql/src/resolve.rs @@ -15,10 +15,12 @@ // specific language governing permissions and limitations // under the License. -use crate::TableReference; use std::collections::BTreeSet; use std::ops::ControlFlow; +use datafusion_common::{DataFusionError, Result}; + +use crate::TableReference; use crate::parser::{CopyToSource, CopyToStatement, Statement as DFStatement}; use crate::planner::object_name_to_table_reference; use sqlparser::ast::*; @@ -45,27 +47,40 @@ const INFORMATION_SCHEMA_TABLES: &[&str] = &[ PARAMETERS, ]; +// Collect table/CTE references as `TableReference`s and normalize them during traversal. +// This avoids a second normalization/conversion pass after visiting the AST. struct RelationVisitor { - relations: BTreeSet, - all_ctes: BTreeSet, - ctes_in_scope: Vec, + relations: BTreeSet, + all_ctes: BTreeSet, + ctes_in_scope: Vec, + enable_ident_normalization: bool, } impl RelationVisitor { /// Record the reference to `relation`, if it's not a CTE reference. - fn insert_relation(&mut self, relation: &ObjectName) { - if !self.relations.contains(relation) && !self.ctes_in_scope.contains(relation) { - self.relations.insert(relation.clone()); + fn insert_relation(&mut self, relation: &ObjectName) -> ControlFlow { + match object_name_to_table_reference( + relation.clone(), + self.enable_ident_normalization, + ) { + Ok(relation) => { + if !self.relations.contains(&relation) + && !self.ctes_in_scope.contains(&relation) + { + self.relations.insert(relation); + } + ControlFlow::Continue(()) + } + Err(e) => ControlFlow::Break(e), } } } impl Visitor for RelationVisitor { - type Break = (); + type Break = DataFusionError; - fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<()> { - self.insert_relation(relation); - ControlFlow::Continue(()) + fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow { + self.insert_relation(relation) } fn pre_visit_query(&mut self, q: &Query) -> ControlFlow { @@ -78,10 +93,16 @@ impl Visitor for RelationVisitor { if !with.recursive { // This is a bit hackish as the CTE will be visited again as part of visiting `q`, // but thankfully `insert_relation` is idempotent. - let _ = cte.visit(self); + cte.visit(self)?; + } + let cte_name = ObjectName::from(vec![cte.alias.name.clone()]); + match object_name_to_table_reference( + cte_name, + self.enable_ident_normalization, + ) { + Ok(cte_ref) => self.ctes_in_scope.push(cte_ref), + Err(e) => return ControlFlow::Break(e), } - self.ctes_in_scope - .push(ObjectName::from(vec![cte.alias.name.clone()])); } } ControlFlow::Continue(()) @@ -97,13 +118,13 @@ impl Visitor for RelationVisitor { ControlFlow::Continue(()) } - fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<()> { + fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow { if let Statement::ShowCreate { obj_type: ShowCreateObject::Table | ShowCreateObject::View, obj_name, } = statement { - self.insert_relation(obj_name) + self.insert_relation(obj_name)?; } // SHOW statements will later be rewritten into a SELECT from the information_schema @@ -120,35 +141,53 @@ impl Visitor for RelationVisitor { ); if requires_information_schema { for s in INFORMATION_SCHEMA_TABLES { - self.relations.insert(ObjectName::from(vec![ + // Information schema references are synthesized here, so convert directly. + let obj = ObjectName::from(vec![ Ident::new(INFORMATION_SCHEMA), Ident::new(*s), - ])); + ]); + match object_name_to_table_reference(obj, self.enable_ident_normalization) + { + Ok(tbl_ref) => { + self.relations.insert(tbl_ref); + } + Err(e) => return ControlFlow::Break(e), + } } } ControlFlow::Continue(()) } } -fn visit_statement(statement: &DFStatement, visitor: &mut RelationVisitor) { +fn control_flow_to_result(flow: ControlFlow) -> Result<()> { + match flow { + ControlFlow::Continue(()) => Ok(()), + ControlFlow::Break(err) => Err(err), + } +} + +fn visit_statement(statement: &DFStatement, visitor: &mut RelationVisitor) -> Result<()> { match statement { DFStatement::Statement(s) => { - let _ = s.as_ref().visit(visitor); + control_flow_to_result(s.as_ref().visit(visitor))?; } DFStatement::CreateExternalTable(table) => { - visitor.relations.insert(table.name.clone()); + control_flow_to_result(visitor.insert_relation(&table.name))?; } DFStatement::CopyTo(CopyToStatement { source, .. }) => match source { CopyToSource::Relation(table_name) => { - visitor.insert_relation(table_name); + control_flow_to_result(visitor.insert_relation(table_name))?; } CopyToSource::Query(query) => { - let _ = query.visit(visitor); + control_flow_to_result(query.visit(visitor))?; } }, - DFStatement::Explain(explain) => visit_statement(&explain.statement, visitor), + DFStatement::Explain(explain) => { + visit_statement(&explain.statement, visitor)?; + } DFStatement::Reset(_) => {} } + Ok(()) } /// Collects all tables and views referenced in the SQL statement. CTEs are collected separately. @@ -188,26 +227,20 @@ fn visit_statement(statement: &DFStatement, visitor: &mut RelationVisitor) { pub fn resolve_table_references( statement: &crate::parser::Statement, enable_ident_normalization: bool, -) -> datafusion_common::Result<(Vec, Vec)> { +) -> Result<(Vec, Vec)> { let mut visitor = RelationVisitor { relations: BTreeSet::new(), all_ctes: BTreeSet::new(), ctes_in_scope: vec![], + enable_ident_normalization, }; - visit_statement(statement, &mut visitor); - - let table_refs = visitor - .relations - .into_iter() - .map(|x| object_name_to_table_reference(x, enable_ident_normalization)) - .collect::>()?; - let ctes = visitor - .all_ctes - .into_iter() - .map(|x| object_name_to_table_reference(x, enable_ident_normalization)) - .collect::>()?; - Ok((table_refs, ctes)) + visit_statement(statement, &mut visitor)?; + + Ok(( + visitor.relations.into_iter().collect(), + visitor.all_ctes.into_iter().collect(), + )) } #[cfg(test)] @@ -270,4 +303,57 @@ mod tests { assert_eq!(ctes.len(), 1); assert_eq!(ctes[0].to_string(), "nodes"); } + + #[test] + fn resolve_table_references_cte_with_quoted_reference() { + use crate::parser::DFParser; + + let query = r#"with barbaz as (select 1) select * from "barbaz""#; + let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap(); + let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap(); + assert_eq!(ctes.len(), 1); + assert_eq!(ctes[0].to_string(), "barbaz"); + // Quoted reference should still resolve to the CTE when normalization is on + assert_eq!(table_refs.len(), 0); + } + + #[test] + fn resolve_table_references_cte_with_quoted_reference_normalization_off() { + use crate::parser::DFParser; + + let query = r#"with barbaz as (select 1) select * from "barbaz""#; + let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap(); + let (table_refs, ctes) = resolve_table_references(&statement, false).unwrap(); + assert_eq!(ctes.len(), 1); + assert_eq!(ctes[0].to_string(), "barbaz"); + // Even with normalization off, quoted reference matches same-case CTE name + assert_eq!(table_refs.len(), 0); + } + + #[test] + fn resolve_table_references_cte_with_quoted_reference_uppercase_normalization_on() { + use crate::parser::DFParser; + + let query = r#"with FOObar as (select 1) select * from "FOObar""#; + let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap(); + let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap(); + // CTE name is normalized to lowercase, quoted reference preserves case, so they differ + assert_eq!(ctes.len(), 1); + assert_eq!(ctes[0].to_string(), "foobar"); + assert_eq!(table_refs.len(), 1); + assert_eq!(table_refs[0].to_string(), "FOObar"); + } + + #[test] + fn resolve_table_references_cte_with_quoted_reference_uppercase_normalization_off() { + use crate::parser::DFParser; + + let query = r#"with FOObar as (select 1) select * from "FOObar""#; + let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap(); + let (table_refs, ctes) = resolve_table_references(&statement, false).unwrap(); + // Without normalization, cases match exactly, so quoted reference resolves to the CTE + assert_eq!(ctes.len(), 1); + assert_eq!(ctes[0].to_string(), "FOObar"); + assert_eq!(table_refs.len(), 0); + } } diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 1d6ccde6be13..edf4b9ef79e8 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -29,7 +29,7 @@ use crate::utils::{ use datafusion_common::error::DataFusionErrorBuilder; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; -use datafusion_common::{Column, Result, not_impl_err, plan_err}; +use datafusion_common::{Column, DFSchema, Result, not_impl_err, plan_err}; use datafusion_common::{RecursionUnnestOption, UnnestOptions}; use datafusion_expr::expr::{Alias, PlannedReplaceSelectItem, WildcardOptions}; use datafusion_expr::expr_rewriter::{ @@ -361,6 +361,7 @@ impl SqlToRel<'_, S> { // Process distinct clause let plan = match select.distinct { None => Ok(plan), + Some(Distinct::All) => Ok(plan), Some(Distinct::Distinct) => { LogicalPlanBuilder::from(plan).distinct()?.build() } @@ -637,11 +638,6 @@ impl SqlToRel<'_, S> { match selection { Some(predicate_expr) => { let fallback_schemas = plan.fallback_normalize_schemas(); - let outer_query_schema = planner_context.outer_query_schema().cloned(); - let outer_query_schema_vec = outer_query_schema - .as_ref() - .map(|schema| vec![schema]) - .unwrap_or_else(Vec::new); let filter_expr = self.sql_to_expr(predicate_expr, plan.schema(), planner_context)?; @@ -657,9 +653,19 @@ impl SqlToRel<'_, S> { let mut using_columns = HashSet::new(); expr_to_columns(&filter_expr, &mut using_columns)?; + let mut schema_stack: Vec> = + vec![vec![plan.schema()], fallback_schemas]; + for sc in planner_context.outer_schemas_iter() { + schema_stack.push(vec![sc.as_ref()]); + } + let filter_expr = normalize_col_with_schemas_and_ambiguity_check( filter_expr, - &[&[plan.schema()], &fallback_schemas, &outer_query_schema_vec], + schema_stack + .iter() + .map(|sc| sc.as_slice()) + .collect::>() + .as_slice(), &[using_columns], )?; diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 1acbcc92dfe1..b91e38e53776 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -55,9 +55,10 @@ use datafusion_expr::{ TransactionIsolationLevel, TransactionStart, Volatility, WriteOp, cast, col, }; use sqlparser::ast::{ - self, BeginTransactionKind, IndexColumn, IndexType, NullsDistinctOption, OrderByExpr, - OrderByOptions, Set, ShowStatementIn, ShowStatementOptions, SqliteOnConflict, - TableObject, UpdateTableFromKind, ValueWithSpan, + self, BeginTransactionKind, CheckConstraint, ForeignKeyConstraint, IndexColumn, + IndexType, NullsDistinctOption, OrderByExpr, OrderByOptions, PrimaryKeyConstraint, + Set, ShowStatementIn, ShowStatementOptions, SqliteOnConflict, TableObject, + UniqueConstraint, Update, UpdateTableFromKind, ValueWithSpan, }; use sqlparser::ast::{ Assignment, AssignmentTarget, ColumnDef, CreateIndex, CreateTable, @@ -102,38 +103,24 @@ fn get_schema_name(schema_name: &SchemaName) -> String { /// Construct `TableConstraint`(s) for the given columns by iterating over /// `columns` and extracting individual inline constraint definitions. fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec { - let mut constraints = vec![]; + let mut constraints: Vec = vec![]; for column in columns { for ast::ColumnOptionDef { name, option } in &column.options { match option { - ast::ColumnOption::Unique { - is_primary: false, + ast::ColumnOption::Unique(UniqueConstraint { characteristics, - } => constraints.push(TableConstraint::Unique { + name, + index_name: _index_name, + index_type_display: _index_type_display, + index_type: _index_type, + columns: _column, + index_options: _index_options, + nulls_distinct: _nulls_distinct, + }) => constraints.push(TableConstraint::Unique(UniqueConstraint { name: name.clone(), - columns: vec![IndexColumn { - column: OrderByExpr { - expr: SQLExpr::Identifier(column.name.clone()), - options: OrderByOptions { - asc: None, - nulls_first: None, - }, - with_fill: None, - }, - operator_class: None, - }], - characteristics: *characteristics, index_name: None, index_type_display: ast::KeyOrIndexDisplay::None, index_type: None, - index_options: vec![], - nulls_distinct: NullsDistinctOption::None, - }), - ast::ColumnOption::Unique { - is_primary: true, - characteristics, - } => constraints.push(TableConstraint::PrimaryKey { - name: name.clone(), columns: vec![IndexColumn { column: OrderByExpr { expr: SQLExpr::Identifier(column.name.clone()), @@ -145,35 +132,69 @@ fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec { + constraints.push(TableConstraint::PrimaryKey(PrimaryKeyConstraint { + name: name.clone(), + index_name: None, + index_type: None, + columns: vec![IndexColumn { + column: OrderByExpr { + expr: SQLExpr::Identifier(column.name.clone()), + options: OrderByOptions { + asc: None, + nulls_first: None, + }, + with_fill: None, + }, + operator_class: None, + }], + index_options: vec![], + characteristics: *characteristics, + })) + } + ast::ColumnOption::ForeignKey(ForeignKeyConstraint { foreign_table, referred_columns, on_delete, on_update, characteristics, - } => constraints.push(TableConstraint::ForeignKey { - name: name.clone(), - columns: vec![], - foreign_table: foreign_table.clone(), - referred_columns: referred_columns.to_vec(), - on_delete: *on_delete, - on_update: *on_update, - characteristics: *characteristics, - index_name: None, - }), - ast::ColumnOption::Check(expr) => { - constraints.push(TableConstraint::Check { + name: _name, + index_name: _index_name, + columns: _columns, + match_kind: _match_kind, + }) => { + constraints.push(TableConstraint::ForeignKey(ForeignKeyConstraint { name: name.clone(), - expr: Box::new(expr.clone()), - enforced: None, - }) - } - // Other options are not constraint related. + index_name: None, + columns: vec![], + foreign_table: foreign_table.clone(), + referred_columns: referred_columns.clone(), + on_delete: *on_delete, + on_update: *on_update, + match_kind: None, + characteristics: *characteristics, + })) + } + ast::ColumnOption::Check(CheckConstraint { + name, + expr, + enforced: _enforced, + }) => constraints.push(TableConstraint::Check(CheckConstraint { + name: name.clone(), + expr: expr.clone(), + enforced: None, + })), ast::ColumnOption::Default(_) | ast::ColumnOption::Null | ast::ColumnOption::NotNull @@ -191,7 +212,8 @@ fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec {} + | ast::ColumnOption::Collation(_) + | ast::ColumnOption::Invisible => {} } } } @@ -320,152 +342,160 @@ impl SqlToRel<'_, S> { refresh_mode, initialize, require_user, + partition_of, + for_values, }) => { if temporary { - return not_impl_err!("Temporary tables not supported")?; + return not_impl_err!("Temporary tables not supported"); } if external { - return not_impl_err!("External tables not supported")?; + return not_impl_err!("External tables not supported"); } if global.is_some() { - return not_impl_err!("Global tables not supported")?; + return not_impl_err!("Global tables not supported"); } if transient { - return not_impl_err!("Transient tables not supported")?; + return not_impl_err!("Transient tables not supported"); } if volatile { - return not_impl_err!("Volatile tables not supported")?; + return not_impl_err!("Volatile tables not supported"); } if hive_distribution != ast::HiveDistributionStyle::NONE { return not_impl_err!( "Hive distribution not supported: {hive_distribution:?}" - )?; + ); } - if !matches!( - hive_formats, - Some(ast::HiveFormat { - row_format: None, - serde_properties: None, - storage: None, - location: None, - }) - ) { - return not_impl_err!( - "Hive formats not supported: {hive_formats:?}" - )?; + if hive_formats.is_some() + && !matches!( + hive_formats, + Some(ast::HiveFormat { + row_format: None, + serde_properties: None, + storage: None, + location: None, + }) + ) + { + return not_impl_err!("Hive formats not supported: {hive_formats:?}"); } if file_format.is_some() { - return not_impl_err!("File format not supported")?; + return not_impl_err!("File format not supported"); } if location.is_some() { - return not_impl_err!("Location not supported")?; + return not_impl_err!("Location not supported"); } if without_rowid { - return not_impl_err!("Without rowid not supported")?; + return not_impl_err!("Without rowid not supported"); } if like.is_some() { - return not_impl_err!("Like not supported")?; + return not_impl_err!("Like not supported"); } if clone.is_some() { - return not_impl_err!("Clone not supported")?; + return not_impl_err!("Clone not supported"); } if comment.is_some() { - return not_impl_err!("Comment not supported")?; + return not_impl_err!("Comment not supported"); } if on_commit.is_some() { - return not_impl_err!("On commit not supported")?; + return not_impl_err!("On commit not supported"); } if on_cluster.is_some() { - return not_impl_err!("On cluster not supported")?; + return not_impl_err!("On cluster not supported"); } if primary_key.is_some() { - return not_impl_err!("Primary key not supported")?; + return not_impl_err!("Primary key not supported"); } if order_by.is_some() { - return not_impl_err!("Order by not supported")?; + return not_impl_err!("Order by not supported"); } if partition_by.is_some() { - return not_impl_err!("Partition by not supported")?; + return not_impl_err!("Partition by not supported"); } if cluster_by.is_some() { - return not_impl_err!("Cluster by not supported")?; + return not_impl_err!("Cluster by not supported"); } if clustered_by.is_some() { - return not_impl_err!("Clustered by not supported")?; + return not_impl_err!("Clustered by not supported"); } if strict { - return not_impl_err!("Strict not supported")?; + return not_impl_err!("Strict not supported"); } if copy_grants { - return not_impl_err!("Copy grants not supported")?; + return not_impl_err!("Copy grants not supported"); } if enable_schema_evolution.is_some() { - return not_impl_err!("Enable schema evolution not supported")?; + return not_impl_err!("Enable schema evolution not supported"); } if change_tracking.is_some() { - return not_impl_err!("Change tracking not supported")?; + return not_impl_err!("Change tracking not supported"); } if data_retention_time_in_days.is_some() { - return not_impl_err!("Data retention time in days not supported")?; + return not_impl_err!("Data retention time in days not supported"); } if max_data_extension_time_in_days.is_some() { return not_impl_err!( "Max data extension time in days not supported" - )?; + ); } if default_ddl_collation.is_some() { - return not_impl_err!("Default DDL collation not supported")?; + return not_impl_err!("Default DDL collation not supported"); } if with_aggregation_policy.is_some() { - return not_impl_err!("With aggregation policy not supported")?; + return not_impl_err!("With aggregation policy not supported"); } if with_row_access_policy.is_some() { - return not_impl_err!("With row access policy not supported")?; + return not_impl_err!("With row access policy not supported"); } if with_tags.is_some() { - return not_impl_err!("With tags not supported")?; + return not_impl_err!("With tags not supported"); } if iceberg { - return not_impl_err!("Iceberg not supported")?; + return not_impl_err!("Iceberg not supported"); } if external_volume.is_some() { - return not_impl_err!("External volume not supported")?; + return not_impl_err!("External volume not supported"); } if base_location.is_some() { - return not_impl_err!("Base location not supported")?; + return not_impl_err!("Base location not supported"); } if catalog.is_some() { - return not_impl_err!("Catalog not supported")?; + return not_impl_err!("Catalog not supported"); } if catalog_sync.is_some() { - return not_impl_err!("Catalog sync not supported")?; + return not_impl_err!("Catalog sync not supported"); } if storage_serialization_policy.is_some() { - return not_impl_err!("Storage serialization policy not supported")?; + return not_impl_err!("Storage serialization policy not supported"); } if inherits.is_some() { - return not_impl_err!("Table inheritance not supported")?; + return not_impl_err!("Table inheritance not supported"); } if dynamic { - return not_impl_err!("Dynamic tables not supported")?; + return not_impl_err!("Dynamic tables not supported"); } if version.is_some() { - return not_impl_err!("Version not supported")?; + return not_impl_err!("Version not supported"); } if target_lag.is_some() { - return not_impl_err!("Target lag not supported")?; + return not_impl_err!("Target lag not supported"); } if warehouse.is_some() { - return not_impl_err!("Warehouse not supported")?; + return not_impl_err!("Warehouse not supported"); } if refresh_mode.is_some() { - return not_impl_err!("Refresh mode not supported")?; + return not_impl_err!("Refresh mode not supported"); } if initialize.is_some() { - return not_impl_err!("Initialize not supported")?; + return not_impl_err!("Initialize not supported"); } if require_user { - return not_impl_err!("Require user not supported")?; + return not_impl_err!("Require user not supported"); + } + if partition_of.is_some() { + return not_impl_err!("PARTITION OF not supported"); + } + if for_values.is_some() { + return not_impl_err!("PARTITION OF .. FOR VALUES .. not supported"); } // Merge inline constraints and existing constraints let mut all_constraints = constraints; @@ -557,7 +587,7 @@ impl SqlToRel<'_, S> { } } } - Statement::CreateView { + Statement::CreateView(ast::CreateView { or_replace, materialized, name, @@ -574,7 +604,7 @@ impl SqlToRel<'_, S> { or_alter, secure, name_before_not_exists, - } => { + }) => { if materialized { return not_impl_err!("Materialized views not supported")?; } @@ -596,7 +626,7 @@ impl SqlToRel<'_, S> { // put the statement back together temporarily to get the SQL // string representation - let stmt = Statement::CreateView { + let stmt = Statement::CreateView(ast::CreateView { or_replace, materialized, name, @@ -613,16 +643,16 @@ impl SqlToRel<'_, S> { or_alter, secure, name_before_not_exists, - }; + }); let sql = stmt.to_string(); - let Statement::CreateView { + let Statement::CreateView(ast::CreateView { name, columns, query, or_replace, temporary, .. - } = stmt + }) = stmt else { return internal_err!("Unreachable code in create view"); }; @@ -965,6 +995,8 @@ impl SqlToRel<'_, S> { has_table_keyword, settings, format_clause, + insert_token: _, // record the location the `INSERT` token + optimizer_hint, }) => { let table_name = match table { TableObject::TableName(table_name) => table_name, @@ -1020,12 +1052,15 @@ impl SqlToRel<'_, S> { if format_clause.is_some() { plan_err!("Inserts with format clause not supported")?; } + if optimizer_hint.is_some() { + plan_err!("Optimizer hints not supported")?; + } // optional keywords don't change behavior let _ = into; let _ = has_table_keyword; self.insert_to_plan(table_name, columns, source, overwrite, replace_into) } - Statement::Update { + Statement::Update(Update { table, assignments, from, @@ -1033,7 +1068,9 @@ impl SqlToRel<'_, S> { returning, or, limit, - } => { + update_token: _, + optimizer_hint, + }) => { let from_clauses = from.map(|update_table_from_kind| match update_table_from_kind { UpdateTableFromKind::BeforeSet(from_clauses) => from_clauses, @@ -1041,9 +1078,18 @@ impl SqlToRel<'_, S> { }); // TODO: support multiple tables in UPDATE SET FROM if from_clauses.as_ref().is_some_and(|f| f.len() > 1) { - plan_err!("Multiple tables in UPDATE SET FROM not yet supported")?; + not_impl_err!( + "Multiple tables in UPDATE SET FROM not yet supported" + )?; } let update_from = from_clauses.and_then(|mut f| f.pop()); + + // UPDATE ... FROM is currently not working + // TODO fix https://github.com/apache/datafusion/issues/19950 + if update_from.is_some() { + return not_impl_err!("UPDATE ... FROM is not supported"); + } + if returning.is_some() { plan_err!("Update-returning clause not yet supported")?; } @@ -1053,6 +1099,9 @@ impl SqlToRel<'_, S> { if limit.is_some() { return not_impl_err!("Update-limit clause not supported")?; } + if optimizer_hint.is_some() { + plan_err!("Optimizer hints not supported")?; + } self.update_to_plan(table, &assignments, update_from, selection) } @@ -1064,6 +1113,8 @@ impl SqlToRel<'_, S> { from, order_by, limit, + delete_token: _, + optimizer_hint, }) => { if !tables.is_empty() { plan_err!("DELETE not supported")?; @@ -1081,12 +1132,12 @@ impl SqlToRel<'_, S> { plan_err!("Delete-order-by clause not yet supported")?; } - if limit.is_some() { - plan_err!("Delete-limit clause not yet supported")?; + if optimizer_hint.is_some() { + plan_err!("Optimizer hints not supported")?; } let table_name = self.get_delete_target(from)?; - self.delete_to_plan(&table_name, selection) + self.delete_to_plan(&table_name, selection, limit) } Statement::StartTransaction { @@ -1295,7 +1346,8 @@ impl SqlToRel<'_, S> { let function_body = match function_body { Some(r) => Some(self.sql_to_expr( match r { - ast::CreateFunctionBody::AsBeforeOptions(expr) => expr, + // `link_symbol` indicates if the primary expression contains the name of shared library file. + ast::CreateFunctionBody::AsBeforeOptions{body: expr, link_symbol: _link_symbol} => expr, ast::CreateFunctionBody::AsAfterOptions(expr) => expr, ast::CreateFunctionBody::Return(expr) => expr, ast::CreateFunctionBody::AsBeginEnd(_) => { @@ -1338,11 +1390,11 @@ impl SqlToRel<'_, S> { Ok(LogicalPlan::Ddl(statement)) } - Statement::DropFunction { + Statement::DropFunction(ast::DropFunction { if_exists, func_desc, - .. - } => { + drop_behavior: _, + }) => { // According to postgresql documentation it can be only one function // specified in drop statement if let Some(desc) = func_desc.first() { @@ -1362,6 +1414,60 @@ impl SqlToRel<'_, S> { exec_err!("Function name not provided") } } + Statement::Truncate(ast::Truncate { + table_names, + partitions, + identity, + cascade, + on_cluster, + table, + if_exists, + }) => { + let _ = table; // Support TRUNCATE TABLE and TRUNCATE syntax + if table_names.len() != 1 { + return not_impl_err!( + "TRUNCATE with multiple tables is not supported" + ); + } + + let target = &table_names[0]; + if target.only { + return not_impl_err!("TRUNCATE with ONLY is not supported"); + } + if partitions.is_some() { + return not_impl_err!("TRUNCATE with PARTITION is not supported"); + } + if identity.is_some() { + return not_impl_err!( + "TRUNCATE with RESTART/CONTINUE IDENTITY is not supported" + ); + } + if cascade.is_some() { + return not_impl_err!( + "TRUNCATE with CASCADE/RESTRICT is not supported" + ); + } + if on_cluster.is_some() { + return not_impl_err!("TRUNCATE with ON CLUSTER is not supported"); + } + if if_exists { + return not_impl_err!("TRUNCATE .. with IF EXISTS is not supported"); + } + let table = self.object_name_to_table_reference(target.name.clone())?; + let source = self.context_provider.get_table_source(table.clone())?; + + // TRUNCATE does not operate on input rows. The EmptyRelation is a logical placeholder + // since the real operation is executed directly by the TableProvider's truncate() hook. + Ok(LogicalPlan::Dml(DmlStatement::new( + table.clone(), + source, + WriteOp::Truncate, + Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: DFSchemaRef::new(DFSchema::empty()), + })), + ))) + } Statement::CreateIndex(CreateIndex { name, table_name, @@ -1716,8 +1822,17 @@ impl SqlToRel<'_, S> { let constraints = constraints .iter() .map(|c: &TableConstraint| match c { - TableConstraint::Unique { name, columns, .. } => { - let constraint_name = match name { + TableConstraint::Unique(UniqueConstraint { + name, + index_name: _, + index_type_display: _, + index_type: _, + columns, + index_options: _, + characteristics: _, + nulls_distinct: _, + }) => { + let constraint_name = match &name { Some(name) => &format!("unique constraint with name '{name}'"), None => "unique constraint", }; @@ -1729,7 +1844,14 @@ impl SqlToRel<'_, S> { )?; Ok(Constraint::Unique(indices)) } - TableConstraint::PrimaryKey { columns, .. } => { + TableConstraint::PrimaryKey(PrimaryKeyConstraint { + name: _, + index_name: _, + index_type: _, + columns, + index_options: _, + characteristics: _, + }) => { // Get primary key indices in the schema let indices = self.get_constraint_column_indices( df_schema, @@ -1978,6 +2100,7 @@ impl SqlToRel<'_, S> { &self, table_name: &ObjectName, predicate_expr: Option, + limit: Option, ) -> Result { // Do a table lookup to verify the table exists let table_ref = self.object_name_to_table_reference(table_name.clone())?; @@ -1991,7 +2114,7 @@ impl SqlToRel<'_, S> { .build()?; let mut planner_context = PlannerContext::new(); - let source = match predicate_expr { + let mut source = match predicate_expr { None => scan, Some(predicate_expr) => { let filter_expr = @@ -2008,6 +2131,14 @@ impl SqlToRel<'_, S> { } }; + if let Some(limit) = limit { + let empty_schema = DFSchema::empty(); + let limit = self.sql_to_expr(limit, &empty_schema, &mut planner_context)?; + source = LogicalPlanBuilder::from(source) + .limit_by_expr(None, Some(limit))? + .build()? + } + let plan = LogicalPlan::Dml(DmlStatement::new( table_ref, table_source, diff --git a/datafusion/sql/src/unparser/ast.rs b/datafusion/sql/src/unparser/ast.rs index ec78a42d6534..8446a44b07e3 100644 --- a/datafusion/sql/src/unparser/ast.rs +++ b/datafusion/sql/src/unparser/ast.rs @@ -315,7 +315,9 @@ impl SelectBuilder { } pub fn build(&self) -> Result { Ok(ast::Select { + optimizer_hint: None, distinct: self.distinct.clone(), + select_modifiers: None, top_before_distinct: false, top: self.top.clone(), projection: self.projection.clone().unwrap_or_default(), @@ -340,12 +342,12 @@ impl SelectBuilder { named_window: self.named_window.clone(), qualify: self.qualify.clone(), value_table_mode: self.value_table_mode, - connect_by: None, + connect_by: Vec::new(), window_before_qualify: false, prewhere: None, select_token: AttachedToken::empty(), flavor: match self.flavor { - Some(ref value) => value.clone(), + Some(ref value) => *value, None => return Err(Into::into(UninitializedFieldError::from("flavor"))), }, exclude: None, @@ -608,6 +610,7 @@ impl DerivedRelationBuilder { } }, alias: self.alias.clone(), + sample: None, }) } fn create_empty() -> Self { diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index 1a3e1a06db5f..31d2662cc4cc 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -372,6 +372,7 @@ impl PostgreSqlDialect { kind: ast::CastKind::Cast, expr: Box::new(expr.clone()), data_type: ast::DataType::Numeric(ast::ExactNumberInfo::None), + array: false, format: None, }; } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 5746a568e712..b82ab24adef7 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -45,7 +45,7 @@ use datafusion_common::{ }; use datafusion_expr::{ Between, BinaryExpr, Case, Cast, Expr, GroupingSet, Like, Operator, TryCast, - expr::{Alias, Exists, InList, ScalarFunction, Sort, WindowFunction}, + expr::{Alias, Exists, InList, ScalarFunction, SetQuantifier, Sort, WindowFunction}, }; use sqlparser::ast::helpers::attached_token::AttachedToken; use sqlparser::tokenizer::Span; @@ -393,6 +393,33 @@ impl Unparser<'_> { negated: insubq.negated, }) } + Expr::SetComparison(set_cmp) => { + let left = Box::new(self.expr_to_sql_inner(set_cmp.expr.as_ref())?); + let sub_statement = + self.plan_to_sql(set_cmp.subquery.subquery.as_ref())?; + let sub_query = if let ast::Statement::Query(inner_query) = sub_statement + { + inner_query + } else { + return plan_err!( + "Subquery must be a Query, but found {sub_statement:?}" + ); + }; + let compare_op = self.op_to_sql(&set_cmp.op)?; + match set_cmp.quantifier { + SetQuantifier::Any => Ok(ast::Expr::AnyOp { + left, + compare_op, + right: Box::new(ast::Expr::Subquery(sub_query)), + is_some: false, + }), + SetQuantifier::All => Ok(ast::Expr::AllOp { + left, + compare_op, + right: Box::new(ast::Expr::Subquery(sub_query)), + }), + } + } Expr::Exists(Exists { subquery, negated }) => { let sub_statement = self.plan_to_sql(subquery.subquery.as_ref())?; let sub_query = if let ast::Statement::Query(inner_query) = sub_statement @@ -467,6 +494,7 @@ impl Unparser<'_> { kind: ast::CastKind::TryCast, expr: Box::new(inner_expr), data_type: self.arrow_dtype_to_ast_dtype(data_type)?, + array: false, format: None, }) } @@ -1066,6 +1094,7 @@ impl Unparser<'_> { Operator::Question => Ok(BinaryOperator::Question), Operator::QuestionAnd => Ok(BinaryOperator::QuestionAnd), Operator::QuestionPipe => Ok(BinaryOperator::QuestionPipe), + Operator::Colon => Ok(BinaryOperator::Custom(":".to_owned())), } } @@ -1118,6 +1147,7 @@ impl Unparser<'_> { kind: ast::CastKind::Cast, expr: Box::new(ast::Expr::value(SingleQuotedString(ts))), data_type: self.dialect.timestamp_cast_dtype(&time_unit, &None), + array: false, format: None, }) } @@ -1140,6 +1170,7 @@ impl Unparser<'_> { kind: ast::CastKind::Cast, expr: Box::new(ast::Expr::value(SingleQuotedString(time))), data_type: ast::DataType::Time(None, TimezoneInfo::None), + array: false, format: None, }) } @@ -1157,6 +1188,7 @@ impl Unparser<'_> { kind: ast::CastKind::Cast, expr: Box::new(inner_expr), data_type: self.arrow_dtype_to_ast_dtype(data_type)?, + array: false, format: None, }), }, @@ -1164,6 +1196,7 @@ impl Unparser<'_> { kind: ast::CastKind::Cast, expr: Box::new(inner_expr), data_type: self.arrow_dtype_to_ast_dtype(data_type)?, + array: false, format: None, }), } @@ -1305,6 +1338,7 @@ impl Unparser<'_> { date.to_string(), ))), data_type: ast::DataType::Date, + array: false, format: None, }) } @@ -1328,6 +1362,7 @@ impl Unparser<'_> { datetime.to_string(), ))), data_type: self.ast_type_for_date64_in_cast(), + array: false, format: None, }) } @@ -1414,6 +1449,7 @@ impl Unparser<'_> { ScalarValue::Map(_) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Union(..) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Dictionary(_k, v) => self.scalar_to_sql(v), + ScalarValue::RunEndEncoded(_, _, v) => self.scalar_to_sql(v), } } @@ -1763,6 +1799,9 @@ impl Unparser<'_> { not_impl_err!("Unsupported DataType: conversion: {data_type}") } DataType::Dictionary(_, val) => self.arrow_dtype_to_ast_dtype(val), + DataType::RunEndEncoded(_, val) => { + self.arrow_dtype_to_ast_dtype(val.data_type()) + } DataType::Decimal32(precision, scale) | DataType::Decimal64(precision, scale) | DataType::Decimal128(precision, scale) @@ -1784,9 +1823,6 @@ impl Unparser<'_> { DataType::Map(_, _) => { not_impl_err!("Unsupported DataType: conversion: {data_type}") } - DataType::RunEndEncoded(_, _) => { - not_impl_err!("Unsupported DataType: conversion: {data_type}") - } } } } @@ -2289,6 +2325,17 @@ mod tests { ), "'foo'", ), + ( + Expr::Literal( + ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int32, false).into(), + Field::new("values", DataType::Utf8, true).into(), + Box::new(ScalarValue::Utf8(Some("foo".into()))), + ), + None, + ), + "'foo'", + ), ( Expr::Literal( ScalarValue::List(Arc::new(ListArray::from_iter_primitive::< @@ -3158,6 +3205,22 @@ mod tests { Ok(()) } + #[test] + fn test_run_end_encoded_to_sql() -> Result<()> { + let dialect = CustomDialectBuilder::new().build(); + + let unparser = Unparser::new(&dialect); + + let ast_dtype = unparser.arrow_dtype_to_ast_dtype(&DataType::RunEndEncoded( + Field::new("run_ends", DataType::Int32, false).into(), + Field::new("values", DataType::Utf8, true).into(), + ))?; + + assert_eq!(ast_dtype, ast::DataType::Varchar(None)); + + Ok(()) + } + #[test] fn test_utf8_view_to_sql() -> Result<()> { let dialect = CustomDialectBuilder::new() diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 56bf887dbde4..ca8dfa431b4f 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -499,6 +499,17 @@ impl Unparser<'_> { ) } LogicalPlan::Sort(sort) => { + // Sort can be top-level plan for derived table + if select.already_projected() { + return self.derive_with_dialect_alias( + "derived_sort", + plan, + relation, + false, + vec![], + ); + } + let Some(query_ref) = query else { return internal_err!( "Sort operator only valid in a statement context." @@ -1395,6 +1406,7 @@ impl Unparser<'_> { ast::TableAlias { name: self.new_ident_quoted_if_needs(alias), columns, + explicit: true, } } diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs index ec1b17cd28a9..e3b644f33f3b 100644 --- a/datafusion/sql/src/unparser/rewrite.rs +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -223,7 +223,15 @@ pub(super) fn rewrite_plan_for_sort_on_non_projected_fields( let mut collects = p.expr.clone(); for sort in &sort.expr { - collects.push(sort.expr.clone()); + // Strip aliases from sort expressions so the comparison matches + // the inner Projection's raw expressions. The optimizer may add + // sort expressions to the inner Projection without aliases, while + // the Sort node's expressions carry aliases from the original plan. + let mut expr = sort.expr.clone(); + while let Expr::Alias(alias) = expr { + expr = *alias.expr; + } + collects.push(expr); } // Compare outer collects Expr::to_string with inner collected transformed values diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index af2e1c79427c..16ac353d4ba9 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -331,6 +331,8 @@ pub(crate) fn value_to_string(value: &Value) -> Option { Value::Number(_, _) | Value::Boolean(_) => Some(value.to_string()), Value::UnicodeStringLiteral(s) => Some(s.to_string()), Value::EscapedStringLiteral(s) => Some(s.to_string()), + Value::QuoteDelimitedStringLiteral(s) + | Value::NationalQuoteDelimitedStringLiteral(s) => Some(s.value.to_string()), Value::DoubleQuotedString(_) | Value::NationalStringLiteral(_) | Value::SingleQuotedByteStringLiteral(_) @@ -374,7 +376,7 @@ pub(crate) fn rewrite_recursive_unnests_bottom_up( pub const UNNEST_PLACEHOLDER: &str = "__unnest_placeholder"; /* -This is only usedful when used with transform down up +This is only useful when used with transform down up A full example of how the transformation works: */ struct RecursiveUnnestRewriter<'a> { @@ -406,6 +408,24 @@ impl RecursiveUnnestRewriter<'_> { .collect() } + /// Check if the current expression is at the root level for struct unnest purposes. + /// This is true if: + /// 1. The expression IS the root expression, OR + /// 2. The root expression is an Alias wrapping this expression + /// + /// This allows `unnest(struct_col) AS alias` to work, where the alias is simply + /// ignored for struct unnest (matching DuckDB behavior). + fn is_at_struct_allowed_root(&self, expr: &Expr) -> bool { + if expr == self.root_expr { + return true; + } + // Allow struct unnest when root is an alias wrapping the unnest + if let Expr::Alias(Alias { expr: inner, .. }) = self.root_expr { + return inner.as_ref() == expr; + } + false + } + fn transform( &mut self, level: usize, @@ -478,7 +498,7 @@ impl TreeNodeRewriter for RecursiveUnnestRewriter<'_> { type Node = Expr; /// This downward traversal needs to keep track of: - /// - Whether or not some unnest expr has been visited from the top util the current node + /// - Whether or not some unnest expr has been visited from the top until the current node /// - If some unnest expr has been visited, maintain a stack of such information, this /// is used to detect if some recursive unnest expr exists (e.g **unnest(unnest(unnest(3d column))))** fn f_down(&mut self, expr: Expr) -> Result> { @@ -566,7 +586,8 @@ impl TreeNodeRewriter for RecursiveUnnestRewriter<'_> { // instead of unnest(struct_arr_col, depth = 2) let unnest_recursion = unnest_stack.len(); - let struct_allowed = (&expr == self.root_expr) && unnest_recursion == 1; + let struct_allowed = + self.is_at_struct_allowed_root(&expr) && unnest_recursion == 1; let mut transformed_exprs = self.transform( unnest_recursion, @@ -574,7 +595,9 @@ impl TreeNodeRewriter for RecursiveUnnestRewriter<'_> { inner_expr, struct_allowed, )?; - if struct_allowed { + // Only set transformed_root_exprs for struct unnest (which returns multiple expressions). + // For list unnest (single expression), we let the normal rewrite handle the alias. + if struct_allowed && transformed_exprs.len() > 1 { self.transformed_root_exprs = Some(transformed_exprs.clone()); } return Ok(Transformed::new( diff --git a/datafusion/sql/src/values.rs b/datafusion/sql/src/values.rs index dd8957c95470..c8cdf1254f33 100644 --- a/datafusion/sql/src/values.rs +++ b/datafusion/sql/src/values.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::{DFSchema, Result}; +use datafusion_common::{DFSchema, Result, not_impl_err}; use datafusion_expr::{LogicalPlan, LogicalPlanBuilder}; use sqlparser::ast::Values as SQLValues; @@ -31,7 +31,13 @@ impl SqlToRel<'_, S> { let SQLValues { explicit_row: _, rows, + value_keyword, } = values; + if value_keyword { + return not_impl_err!( + "`VALUE` keyword not supported. Did you mean `VALUES`?" + )?; + } let empty_schema = Arc::new(DFSchema::empty()); let values = rows diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 46a42ae534af..4c8ea3609068 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -286,7 +286,7 @@ fn roundtrip_crossjoin() -> Result<()> { plan_roundtrip, @r" Projection: j1.j1_id, j2.j2_string - Cross Join: + Cross Join: TableScan: j1 TableScan: j2 " @@ -1740,6 +1740,42 @@ fn test_sort_with_push_down_fetch() -> Result<()> { Ok(()) } +#[test] +fn test_sort_with_scalar_fn_and_push_down_fetch() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("search_phrase", DataType::Utf8, false), + Field::new("event_time", DataType::Utf8, false), + ]); + + let substr_udf = unicode::substr(); + + // Build a plan that mimics the DF52 optimizer output: + // Projection(search_phrase) → Sort(substr(event_time), fetch=10) + // → Projection(search_phrase, event_time) → Filter → TableScan + // This triggers a subquery because the outer projection differs from the inner one. + // The ORDER BY scalar function must not reference the inner table qualifier. + let plan = table_scan(Some("t1"), &schema, None)? + .filter(col("search_phrase").not_eq(lit("")))? + .project(vec![col("search_phrase"), col("event_time")])? + .sort_with_limit( + vec![ + substr_udf + .call(vec![col("event_time"), lit(1), lit(5)]) + .sort(true, true), + ], + Some(10), + )? + .project(vec![col("search_phrase")])? + .build()?; + + let sql = plan_to_sql(&plan)?; + assert_snapshot!( + sql, + @"SELECT t1.search_phrase FROM (SELECT t1.search_phrase, t1.event_time FROM t1 WHERE (t1.search_phrase <> '') ORDER BY substr(t1.event_time, 1, 5) ASC NULLS FIRST LIMIT 10)" + ); + Ok(()) +} + #[test] fn test_join_with_table_scan_filters() -> Result<()> { let schema_left = Schema::new(vec![ @@ -1984,7 +2020,7 @@ fn test_complex_order_by_with_grouping() -> Result<()> { }, { assert_snapshot!( sql, - @r#"SELECT j1.j1_id, j1.j1_string, lochierarchy FROM (SELECT j1.j1_id, j1.j1_string, (grouping(j1.j1_id) + grouping(j1.j1_string)) AS lochierarchy, grouping(j1.j1_string), grouping(j1.j1_id) FROM j1 GROUP BY ROLLUP (j1.j1_id, j1.j1_string)) ORDER BY lochierarchy DESC NULLS FIRST, CASE WHEN (("grouping(j1.j1_id)" + "grouping(j1.j1_string)") = 0) THEN j1.j1_id END ASC NULLS LAST LIMIT 100"# + @"SELECT j1.j1_id, j1.j1_string, lochierarchy FROM (SELECT j1.j1_id, j1.j1_string, (grouping(j1.j1_id) + grouping(j1.j1_string)) AS lochierarchy, grouping(j1.j1_string), grouping(j1.j1_id) FROM j1 GROUP BY ROLLUP (j1.j1_id, j1.j1_string) ORDER BY lochierarchy DESC NULLS FIRST, CASE WHEN ((grouping(j1.j1_id) + grouping(j1.j1_string)) = 0) THEN j1.j1_id END ASC NULLS LAST) LIMIT 100" ); }); @@ -2821,3 +2857,39 @@ fn test_struct_expr3() { @r#"SELECT test.c1."metadata".product."name" FROM (SELECT {"metadata": {product: {"name": 'Product Name'}}} AS c1) AS test"# ); } + +#[test] +fn test_json_access_1() { + let statement = generate_round_trip_statement( + GenericDialect {}, + r#"SELECT j1_string:field FROM j1"#, + ); + assert_snapshot!( + statement, + @r#"SELECT (j1.j1_string : 'field') FROM j1"# + ); +} + +#[test] +fn test_json_access_2() { + let statement = generate_round_trip_statement( + GenericDialect {}, + r#"SELECT j1_string:field[0] FROM j1"#, + ); + assert_snapshot!( + statement, + @r#"SELECT (j1.j1_string : 'field[0]') FROM j1"# + ); +} + +#[test] +fn test_json_access_3() { + let statement = generate_round_trip_statement( + GenericDialect {}, + r#"SELECT j1_string:field.inner1['inner2'] FROM j1"#, + ); + assert_snapshot!( + statement, + @r#"SELECT (j1.j1_string : 'field.inner1[''inner2'']') FROM j1"# + ); +} diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index 44dd7cec89cb..4b8667c3c0cb 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -161,12 +161,26 @@ impl ContextProvider for MockContextProvider { ])), "orders" => Ok(Schema::new(vec![ Field::new("order_id", DataType::UInt32, false), + Field::new("o_orderkey", DataType::UInt32, false), + Field::new("o_custkey", DataType::UInt32, false), + Field::new("o_orderstatus", DataType::Utf8, false), Field::new("customer_id", DataType::UInt32, false), + Field::new("o_totalprice", DataType::Decimal128(15, 2), false), Field::new("o_item_id", DataType::Utf8, false), Field::new("qty", DataType::Int32, false), Field::new("price", DataType::Float64, false), Field::new("delivered", DataType::Boolean, false), ])), + "customer" => Ok(Schema::new(vec![ + Field::new("c_custkey", DataType::UInt32, false), + Field::new("c_name", DataType::Utf8, false), + Field::new("c_address", DataType::Utf8, false), + Field::new("c_nationkey", DataType::UInt32, false), + Field::new("c_phone", DataType::Utf8, false), + Field::new("c_acctbal", DataType::Float64, false), + Field::new("c_mktsegment", DataType::Utf8, false), + Field::new("c_comment", DataType::Utf8, false), + ])), "array" => Ok(Schema::new(vec![ Field::new( "left", @@ -186,8 +200,10 @@ impl ContextProvider for MockContextProvider { ), ])), "lineitem" => Ok(Schema::new(vec![ + Field::new("l_orderkey", DataType::UInt32, false), Field::new("l_item_id", DataType::UInt32, false), Field::new("l_description", DataType::Utf8, false), + Field::new("l_extendedprice", DataType::Decimal128(15, 2), false), Field::new("price", DataType::Float64, false), ])), "aggregate_test_100" => Ok(Schema::new(vec![ @@ -227,6 +243,11 @@ impl ContextProvider for MockContextProvider { false, ), ])), + "@quoted_identifier_names_table" => Ok(Schema::new(vec![Field::new( + "@column", + DataType::UInt32, + false, + )])), _ => plan_err!("No table named: {} found", name.table()), }; @@ -244,8 +265,11 @@ impl ContextProvider for MockContextProvider { self.state.aggregate_functions.get(name).cloned() } - fn get_variable_type(&self, _: &[String]) -> Option { - unimplemented!() + fn get_variable_type(&self, variable_names: &[String]) -> Option { + match variable_names { + [var] if var == "@variable" => Some(DataType::Date32), + _ => unimplemented!(), + } } fn get_window_meta(&self, name: &str) -> Option> { diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 969d56afdae0..444bdae73ac2 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -995,15 +995,15 @@ fn select_nested_with_filters() { #[test] fn table_with_column_alias() { - let sql = "SELECT a, b, c - FROM lineitem l (a, b, c)"; + let sql = "SELECT a, b, c, d, e + FROM lineitem l (a, b, c, d, e)"; let plan = logical_plan(sql).unwrap(); assert_snapshot!( plan, @r" - Projection: l.a, l.b, l.c + Projection: l.a, l.b, l.c, l.d, l.e SubqueryAlias: l - Projection: lineitem.l_item_id AS a, lineitem.l_description AS b, lineitem.price AS c + Projection: lineitem.l_orderkey AS a, lineitem.l_item_id AS b, lineitem.l_description AS c, lineitem.l_extendedprice AS d, lineitem.price AS e TableScan: lineitem " ); @@ -1017,7 +1017,7 @@ fn table_with_column_alias_number_cols() { assert_snapshot!( err.strip_backtrace(), - @"Error during planning: Source table contains 3 columns but only 2 names given as column alias" + @"Error during planning: Source table contains 5 columns but only 2 names given as column alias" ); } @@ -1058,7 +1058,7 @@ fn natural_left_join() { plan, @r" Projection: a.l_item_id - Left Join: Using a.l_item_id = b.l_item_id, a.l_description = b.l_description, a.price = b.price + Left Join: Using a.l_orderkey = b.l_orderkey, a.l_item_id = b.l_item_id, a.l_description = b.l_description, a.l_extendedprice = b.l_extendedprice, a.price = b.price SubqueryAlias: a TableScan: lineitem SubqueryAlias: b @@ -1075,7 +1075,7 @@ fn natural_right_join() { plan, @r" Projection: a.l_item_id - Right Join: Using a.l_item_id = b.l_item_id, a.l_description = b.l_description, a.price = b.price + Right Join: Using a.l_orderkey = b.l_orderkey, a.l_item_id = b.l_item_id, a.l_description = b.l_description, a.l_extendedprice = b.l_extendedprice, a.price = b.price SubqueryAlias: a TableScan: lineitem SubqueryAlias: b @@ -3395,8 +3395,8 @@ fn cross_join_not_to_inner_join() { @r" Projection: person.id Filter: person.id = person.age - Cross Join: - Cross Join: + Cross Join: + Cross Join: TableScan: person TableScan: orders TableScan: lineitem @@ -3530,11 +3530,11 @@ fn exists_subquery_schema_outer_schema_overlap() { Subquery: Projection: person.first_name Filter: person.id = p2.id AND person.last_name = outer_ref(p.last_name) AND person.state = outer_ref(p.state) - Cross Join: + Cross Join: TableScan: person SubqueryAlias: p2 TableScan: person - Cross Join: + Cross Join: TableScan: person SubqueryAlias: p TableScan: person @@ -3619,10 +3619,10 @@ fn scalar_subquery_reference_outer_field() { Projection: count(*) Aggregate: groupBy=[[]], aggr=[[count(*)]] Filter: outer_ref(j2.j2_id) = j1.j1_id AND j1.j1_id = j3.j3_id - Cross Join: + Cross Join: TableScan: j1 TableScan: j3 - Cross Join: + Cross Join: TableScan: j1 TableScan: j2 " @@ -4522,6 +4522,43 @@ fn test_parse_escaped_string_literal_value() { ); } +#[test] +fn test_parse_quoted_column_name_with_at_sign() { + let sql = r"SELECT `@column` FROM `@quoted_identifier_names_table`"; + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: @quoted_identifier_names_table.@column + TableScan: @quoted_identifier_names_table + "# + ); + + let sql = r"SELECT `@quoted_identifier_names_table`.`@column` FROM `@quoted_identifier_names_table`"; + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: @quoted_identifier_names_table.@column + TableScan: @quoted_identifier_names_table + "# + ); +} + +#[test] +fn test_variable_identifier() { + let sql = r"SELECT t_date32 FROM test WHERE t_date32 = @variable"; + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: test.t_date32 + Filter: test.t_date32 = @variable + TableScan: test + "# + ); +} + #[test] fn plan_create_index() { let sql = @@ -4764,7 +4801,11 @@ fn test_using_join_wildcard_schema() { // Only columns from one join side should be present let expected_fields = vec![ "o1.order_id".to_string(), + "o1.o_orderkey".to_string(), + "o1.o_custkey".to_string(), + "o1.o_orderstatus".to_string(), "o1.customer_id".to_string(), + "o1.o_totalprice".to_string(), "o1.o_item_id".to_string(), "o1.qty".to_string(), "o1.price".to_string(), @@ -4818,3 +4859,70 @@ fn test_using_join_wildcard_schema() { ] ); } + +#[test] +fn test_2_nested_lateral_join_with_the_deepest_join_referencing_the_outer_most_relation() +{ + let sql = "SELECT * FROM j1 j1_outer, LATERAL ( + SELECT * FROM j1 j1_inner, LATERAL ( + SELECT * FROM j2 WHERE j1_inner.j1_id = j2_id and j1_outer.j1_id=j2_id + ) as j2 +) as j2"; + + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: j1_outer.j1_id, j1_outer.j1_string, j2.j1_id, j2.j1_string, j2.j2_id, j2.j2_string + Cross Join: + SubqueryAlias: j1_outer + TableScan: j1 + SubqueryAlias: j2 + Subquery: + Projection: j1_inner.j1_id, j1_inner.j1_string, j2.j2_id, j2.j2_string + Cross Join: + SubqueryAlias: j1_inner + TableScan: j1 + SubqueryAlias: j2 + Subquery: + Projection: j2.j2_id, j2.j2_string + Filter: outer_ref(j1_inner.j1_id) = j2.j2_id AND outer_ref(j1_outer.j1_id) = j2.j2_id + TableScan: j2 +"# + ); +} + +#[test] +fn test_correlated_recursive_scalar_subquery_with_level_3_scalar_subquery_referencing_level1_relation() + { + let sql = "select c_custkey from customer + where c_acctbal < ( + select sum(o_totalprice) from orders + where o_custkey = c_custkey + and o_totalprice < ( + select sum(l_extendedprice) as price from lineitem where l_orderkey = o_orderkey + and l_extendedprice < c_acctbal + ) + ) order by c_custkey"; + + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Sort: customer.c_custkey ASC NULLS LAST + Projection: customer.c_custkey + Filter: customer.c_acctbal < () + Subquery: + Projection: sum(orders.o_totalprice) + Aggregate: groupBy=[[]], aggr=[[sum(orders.o_totalprice)]] + Filter: orders.o_custkey = outer_ref(customer.c_custkey) AND orders.o_totalprice < () + Subquery: + Projection: sum(lineitem.l_extendedprice) AS price + Aggregate: groupBy=[[]], aggr=[[sum(lineitem.l_extendedprice)]] + Filter: lineitem.l_orderkey = outer_ref(orders.o_orderkey) AND lineitem.l_extendedprice < outer_ref(customer.c_acctbal) + TableScan: lineitem + TableScan: orders + TableScan: customer +"# + ); +} diff --git a/datafusion/sqllogictest/Cargo.toml b/datafusion/sqllogictest/Cargo.toml index a26a1d44225f..b00fbe466728 100644 --- a/datafusion/sqllogictest/Cargo.toml +++ b/datafusion/sqllogictest/Cargo.toml @@ -45,9 +45,9 @@ async-trait = { workspace = true } bigdecimal = { workspace = true } bytes = { workspace = true, optional = true } chrono = { workspace = true, optional = true } -clap = { version = "4.5.53", features = ["derive", "env"] } +clap = { version = "4.5.60", features = ["derive", "env"] } datafusion = { workspace = true, default-features = true, features = ["avro"] } -datafusion-spark = { workspace = true, default-features = true } +datafusion-spark = { workspace = true, features = ["core"] } datafusion-substrait = { workspace = true, default-features = true } futures = { workspace = true } half = { workspace = true, default-features = true } @@ -55,18 +55,16 @@ indicatif = "0.18" itertools = { workspace = true } log = { workspace = true } object_store = { workspace = true } -postgres-protocol = { version = "0.6.7", optional = true } -postgres-types = { version = "0.2.11", features = ["derive", "with-chrono-0_4"], optional = true } -rust_decimal = { version = "1.38.0", features = ["tokio-pg"] } +postgres-types = { version = "0.2.12", features = ["derive", "with-chrono-0_4"], optional = true } # When updating the following dependency verify that sqlite test file regeneration works correctly # by running the regenerate_sqlite_files.sh script. -sqllogictest = "0.28.4" +sqllogictest = "0.29.1" sqlparser = { workspace = true } tempfile = { workspace = true } testcontainers-modules = { workspace = true, features = ["postgres"], optional = true } -thiserror = "2.0.17" +thiserror = "2.0.18" tokio = { workspace = true } -tokio-postgres = { version = "0.7.14", optional = true } +tokio-postgres = { version = "0.7.16", optional = true } [features] avro = ["datafusion/avro"] @@ -75,7 +73,6 @@ postgres = [ "bytes", "chrono", "postgres-types", - "postgres-protocol", "testcontainers-modules", "tokio-postgres", ] diff --git a/datafusion/sqllogictest/README.md b/datafusion/sqllogictest/README.md index 8768deee3d87..7d84ad23d590 100644 --- a/datafusion/sqllogictest/README.md +++ b/datafusion/sqllogictest/README.md @@ -70,6 +70,36 @@ cargo test --test sqllogictests -- ddl --complete RUST_LOG=debug cargo test --test sqllogictests -- ddl ``` +### Per-file timing summary + +The sqllogictest runner can emit deterministic per-file elapsed timings to help +identify slow test files. + +By default (`--timing-summary auto`), timing summary output is disabled in local +TTY runs and shows a top-slowest summary in non-TTY/CI runs. + +`--timing-top-n` / `SLT_TIMING_TOP_N` must be a positive integer (`>= 1`). + +```shell +# Show top 10 slowest files (good for CI) +cargo test --test sqllogictests -- --timing-summary top --timing-top-n 10 +``` + +```shell +# Show full per-file timing table +cargo test --test sqllogictests -- --timing-summary full +``` + +```shell +# Same controls via environment variables +SLT_TIMING_SUMMARY=top SLT_TIMING_TOP_N=15 cargo test --test sqllogictests +``` + +```shell +# Optional debug logging for per-task slow files (>30s), disabled by default +SLT_TIMING_DEBUG_SLOW_FILES=1 cargo test --test sqllogictests +``` + ## Cookbook: Adding Tests 1. Add queries diff --git a/datafusion/sqllogictest/bin/sqllogictests.rs b/datafusion/sqllogictest/bin/sqllogictests.rs index 8037532c09ac..e067f2488d81 100644 --- a/datafusion/sqllogictest/bin/sqllogictests.rs +++ b/datafusion/sqllogictest/bin/sqllogictests.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use clap::Parser; +use clap::{ColorChoice, Parser, ValueEnum}; use datafusion::common::instant::Instant; use datafusion::common::utils::get_available_parallelism; use datafusion::common::{DataFusionError, Result, exec_datafusion_err, exec_err}; @@ -44,7 +44,12 @@ use datafusion::common::runtime::SpawnedTask; use futures::FutureExt; use std::ffi::OsStr; use std::fs; +use std::io::{IsTerminal, stderr, stdout}; use std::path::{Path, PathBuf}; +use std::str::FromStr; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::time::Duration; #[cfg(feature = "postgres")] mod postgres_container; @@ -54,6 +59,21 @@ const DATAFUSION_TESTING_TEST_DIRECTORY: &str = "../../datafusion-testing/data/" const PG_COMPAT_FILE_PREFIX: &str = "pg_compat_"; const SQLITE_PREFIX: &str = "sqlite"; const ERRS_PER_FILE_LIMIT: usize = 10; +const TIMING_DEBUG_SLOW_FILES_ENV: &str = "SLT_TIMING_DEBUG_SLOW_FILES"; + +#[derive(Clone, Copy, Debug, Eq, PartialEq, ValueEnum)] +enum TimingSummaryMode { + Auto, + Off, + Top, + Full, +} + +#[derive(Debug)] +struct FileTiming { + relative_path: PathBuf, + elapsed: Duration, +} pub fn main() -> Result<()> { tokio::runtime::Builder::new_multi_thread() @@ -96,6 +116,7 @@ async fn run_tests() -> Result<()> { env_logger::init(); let options: Options = Parser::parse(); + let timing_debug_slow_files = is_env_truthy(TIMING_DEBUG_SLOW_FILES_ENV); if options.list { // nextest parses stdout, so print messages to stderr eprintln!("NOTICE: --list option unsupported, quitting"); @@ -108,6 +129,13 @@ async fn run_tests() -> Result<()> { options.warn_on_ignored(); + // Print parallelism info for debugging CI performance + eprintln!( + "Running with {} test threads (available parallelism: {})", + options.test_threads, + get_available_parallelism() + ); + #[cfg(feature = "postgres")] initialize_postgres_container(&options).await?; @@ -123,6 +151,8 @@ async fn run_tests() -> Result<()> { .unwrap() .progress_chars("##-"); + let colored_output = options.is_colored(); + let start = Instant::now(); let test_files = read_test_files(&options)?; @@ -143,7 +173,11 @@ async fn run_tests() -> Result<()> { } let num_tests = test_files.len(); - let errors: Vec<_> = futures::stream::iter(test_files) + // For CI environments without TTY, print progress periodically + let is_ci = !stderr().is_terminal(); + let completed_count = Arc::new(AtomicUsize::new(0)); + + let file_results: Vec<_> = futures::stream::iter(test_files) .map(|test_file| { let validator = if options.include_sqlite && test_file.relative_path.starts_with(SQLITE_PREFIX) @@ -158,12 +192,14 @@ async fn run_tests() -> Result<()> { let filters = options.filters.clone(); let relative_path = test_file.relative_path.clone(); + let relative_path_for_timing = test_file.relative_path.clone(); let currently_running_sql_tracker = CurrentlyExecutingSqlTracker::new(); let currently_running_sql_tracker_clone = currently_running_sql_tracker.clone(); + let file_start = Instant::now(); SpawnedTask::spawn(async move { - match ( + let result = match ( options.postgres_runner, options.complete, options.substrait_round_trip, @@ -176,8 +212,9 @@ async fn run_tests() -> Result<()> { m_style_clone, filters.as_ref(), currently_running_sql_tracker_clone, + colored_output, ) - .await? + .await } (false, false, _) => { run_test_file( @@ -187,8 +224,9 @@ async fn run_tests() -> Result<()> { m_style_clone, filters.as_ref(), currently_running_sql_tracker_clone, + colored_output, ) - .await? + .await } (false, true, _) => { run_complete_file( @@ -198,7 +236,7 @@ async fn run_tests() -> Result<()> { m_style_clone, currently_running_sql_tracker_clone, ) - .await? + .await } (true, false, _) => { run_test_file_with_postgres( @@ -209,7 +247,7 @@ async fn run_tests() -> Result<()> { filters.as_ref(), currently_running_sql_tracker_clone, ) - .await? + .await } (true, true, _) => { run_complete_file_with_postgres( @@ -219,20 +257,77 @@ async fn run_tests() -> Result<()> { m_style_clone, currently_running_sql_tracker_clone, ) - .await? + .await } + }; + + let elapsed = file_start.elapsed(); + if timing_debug_slow_files && elapsed.as_secs() > 30 { + eprintln!( + "Slow file: {} took {:.1}s", + relative_path_for_timing.display(), + elapsed.as_secs_f64() + ); } - Ok(()) as Result<()> + + (result, elapsed) }) .join() - .map(move |result| (result, relative_path, currently_running_sql_tracker)) + .map(move |result| { + let elapsed = match &result { + Ok((_, elapsed)) => *elapsed, + Err(_) => Duration::ZERO, + }; + + ( + result.map(|(thread_result, _)| thread_result), + relative_path, + currently_running_sql_tracker, + elapsed, + ) + }) }) // run up to num_cpus streams in parallel .buffer_unordered(options.test_threads) - .flat_map(|(result, test_file_path, current_sql)| { + .inspect({ + let completed_count = Arc::clone(&completed_count); + move |_| { + let completed = completed_count.fetch_add(1, Ordering::Relaxed) + 1; + // In CI (no TTY), print progress every 10% or every 50 files + if is_ci && (completed.is_multiple_of(50) || completed == num_tests) { + eprintln!( + "Progress: {}/{} files completed ({:.0}%)", + completed, + num_tests, + (completed as f64 / num_tests as f64) * 100.0 + ); + } + } + }) + .collect() + .await; + + let mut file_timings: Vec = file_results + .iter() + .map(|(_, path, _, elapsed)| FileTiming { + relative_path: path.clone(), + elapsed: *elapsed, + }) + .collect(); + + file_timings.sort_by(|a, b| { + b.elapsed + .cmp(&a.elapsed) + .then_with(|| a.relative_path.cmp(&b.relative_path)) + }); + + print_timing_summary(&options, &m, is_ci, &file_timings)?; + + let errors: Vec<_> = file_results + .into_iter() + .filter_map(|(result, test_file_path, current_sql, _)| { // Filter out any Ok() leaving only the DataFusionErrors - futures::stream::iter(match result { - // Tokio panic error + match result { Err(e) => { let error = DataFusionError::External(Box::new(e)); let current_sql = current_sql.get_currently_running_sqls(); @@ -262,10 +357,9 @@ async fn run_tests() -> Result<()> { } } Ok(thread_result) => thread_result.err(), - }) + } }) - .collect() - .await; + .collect(); m.println(format!( "Completed {} test files in {}", @@ -287,6 +381,69 @@ async fn run_tests() -> Result<()> { } } +fn print_timing_summary( + options: &Options, + progress: &MultiProgress, + is_ci: bool, + file_timings: &[FileTiming], +) -> Result<()> { + let mode = options.timing_summary_mode(is_ci); + if mode == TimingSummaryMode::Off || file_timings.is_empty() { + return Ok(()); + } + + let top_n = options.timing_top_n; + debug_assert!(matches!( + mode, + TimingSummaryMode::Top | TimingSummaryMode::Full + )); + let count = if mode == TimingSummaryMode::Full { + file_timings.len() + } else { + top_n + }; + + progress.println("Per-file elapsed summary (deterministic):")?; + for (idx, timing) in file_timings.iter().take(count).enumerate() { + progress.println(format!( + "{:>3}. {:>8.3}s {}", + idx + 1, + timing.elapsed.as_secs_f64(), + timing.relative_path.display() + ))?; + } + + if mode != TimingSummaryMode::Full && file_timings.len() > count { + progress.println(format!( + "... {} more files omitted (use --timing-summary full to show all)", + file_timings.len() - count + ))?; + } + + Ok(()) +} + +fn is_env_truthy(name: &str) -> bool { + std::env::var_os(name) + .and_then(|value| value.into_string().ok()) + .is_some_and(|value| { + matches!( + value.trim().to_ascii_lowercase().as_str(), + "1" | "true" | "yes" | "on" + ) + }) +} + +fn parse_timing_top_n(arg: &str) -> std::result::Result { + let parsed = arg + .parse::() + .map_err(|error| format!("invalid value '{arg}': {error}"))?; + if parsed == 0 { + return Err("must be >= 1".to_string()); + } + Ok(parsed) +} + async fn run_test_file_substrait_round_trip( test_file: TestFile, validator: Validator, @@ -294,6 +451,7 @@ async fn run_test_file_substrait_round_trip( mp_style: ProgressStyle, filters: &[Filter], currently_executing_sql_tracker: CurrentlyExecutingSqlTracker, + colored_output: bool, ) -> Result<()> { let TestFile { path, @@ -323,7 +481,7 @@ async fn run_test_file_substrait_round_trip( runner.with_column_validator(strict_column_validator); runner.with_normalizer(value_normalizer); runner.with_validator(validator); - let res = run_file_in_runner(path, runner, filters).await; + let res = run_file_in_runner(path, runner, filters, colored_output).await; pb.finish_and_clear(); res } @@ -335,6 +493,7 @@ async fn run_test_file( mp_style: ProgressStyle, filters: &[Filter], currently_executing_sql_tracker: CurrentlyExecutingSqlTracker, + colored_output: bool, ) -> Result<()> { let TestFile { path, @@ -364,7 +523,7 @@ async fn run_test_file( runner.with_column_validator(strict_column_validator); runner.with_normalizer(value_normalizer); runner.with_validator(validator); - let result = run_file_in_runner(path, runner, filters).await; + let result = run_file_in_runner(path, runner, filters, colored_output).await; pb.finish_and_clear(); result } @@ -373,6 +532,7 @@ async fn run_file_in_runner>( path: PathBuf, mut runner: sqllogictest::Runner, filters: &[Filter], + colored_output: bool, ) -> Result<()> { let path = path.canonicalize()?; let records = @@ -386,7 +546,11 @@ async fn run_file_in_runner>( continue; } if let Err(err) = runner.run_async(record).await { - errs.push(format!("{err}")); + if colored_output { + errs.push(format!("{}", err.display(true))); + } else { + errs.push(format!("{err}")); + } } } @@ -479,7 +643,7 @@ async fn run_test_file_with_postgres( runner.with_column_validator(strict_column_validator); runner.with_normalizer(value_normalizer); runner.with_validator(validator); - let result = run_file_in_runner(path, runner, filters).await; + let result = run_file_in_runner(path, runner, filters, false).await; pb.finish_and_clear(); result } @@ -772,9 +936,48 @@ struct Options { default_value_t = get_available_parallelism() )] test_threads: usize, + + #[clap( + long, + env = "SLT_TIMING_SUMMARY", + value_enum, + default_value_t = TimingSummaryMode::Auto, + help = "Per-file timing summary mode: auto|off|top|full" + )] + timing_summary: TimingSummaryMode, + + #[clap( + long, + env = "SLT_TIMING_TOP_N", + default_value_t = 10, + value_parser = parse_timing_top_n, + help = "Number of files to show when timing summary mode is auto/top (must be >= 1)" + )] + timing_top_n: usize, + + #[clap( + long, + value_name = "MODE", + help = "Control colored output", + default_value_t = ColorChoice::Auto + )] + color: ColorChoice, } impl Options { + fn timing_summary_mode(&self, is_ci: bool) -> TimingSummaryMode { + match self.timing_summary { + TimingSummaryMode::Auto => { + if is_ci { + TimingSummaryMode::Top + } else { + TimingSummaryMode::Off + } + } + mode => mode, + } + } + /// Because this test can be run as a cargo test, commands like /// /// ```shell @@ -813,6 +1016,37 @@ impl Options { eprintln!("WARNING: Ignoring `--show-output` compatibility option"); } } + + /// Determine if colour output should be enabled, respecting --color, NO_COLOR, CARGO_TERM_COLOR, and terminal detection + fn is_colored(&self) -> bool { + // NO_COLOR takes precedence + if std::env::var_os("NO_COLOR").is_some() { + return false; + } + + match self.color { + ColorChoice::Always => true, + ColorChoice::Never => false, + ColorChoice::Auto => { + // CARGO_TERM_COLOR takes precedence over auto-detection + let cargo_term_color = ::from_str( + &std::env::var("CARGO_TERM_COLOR") + .unwrap_or_else(|_| "auto".to_string()), + ) + .unwrap_or(ColorChoice::Auto); + match cargo_term_color { + ColorChoice::Always => true, + ColorChoice::Never => false, + ColorChoice::Auto => { + // Auto for both CLI argument and CARGO_TERM_COLOR, + // then use colors by default for non-dumb terminals + stdout().is_terminal() + && std::env::var("TERM").unwrap_or_default() != "dumb" + } + } + } + } + } } /// Performs scratch file check for all test files. diff --git a/datafusion/sqllogictest/src/engines/conversion.rs b/datafusion/sqllogictest/src/engines/conversion.rs index 633029a2def2..3e519042f4ee 100644 --- a/datafusion/sqllogictest/src/engines/conversion.rs +++ b/datafusion/sqllogictest/src/engines/conversion.rs @@ -18,7 +18,7 @@ use arrow::datatypes::{Decimal128Type, Decimal256Type, DecimalType, i256}; use bigdecimal::BigDecimal; use half::f16; -use rust_decimal::prelude::*; +use std::str::FromStr; /// Represents a constant for NULL string in your database. pub const NULL_STR: &str = "NULL"; @@ -115,8 +115,8 @@ pub(crate) fn decimal_256_to_str(value: i256, scale: i8) -> String { } #[cfg(feature = "postgres")] -pub(crate) fn decimal_to_str(value: Decimal) -> String { - big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap(), None) +pub(crate) fn decimal_to_str(value: BigDecimal) -> String { + big_decimal_to_str(value, None) } /// Converts a `BigDecimal` to its plain string representation, optionally rounding to a specified number of decimal places. diff --git a/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs index cb6410d857a8..bad9a1dd3fc4 100644 --- a/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs @@ -186,7 +186,7 @@ macro_rules! get_row_value { /// /// Floating numbers are rounded to have a consistent representation with the Postgres runner. pub fn cell_to_string(col: &ArrayRef, row: usize, is_spark_path: bool) -> Result { - if !col.is_valid(row) { + if col.is_null(row) { // represent any null value with the string "NULL" Ok(NULL_STR.to_string()) } else { diff --git a/datafusion/sqllogictest/src/engines/postgres_engine/mod.rs b/datafusion/sqllogictest/src/engines/postgres_engine/mod.rs index b14886fedd61..c3f266dcd1b6 100644 --- a/datafusion/sqllogictest/src/engines/postgres_engine/mod.rs +++ b/datafusion/sqllogictest/src/engines/postgres_engine/mod.rs @@ -16,6 +16,7 @@ // under the License. use async_trait::async_trait; +use bigdecimal::BigDecimal; use bytes::Bytes; use datafusion::common::runtime::SpawnedTask; use futures::{SinkExt, StreamExt}; @@ -32,12 +33,8 @@ use crate::engines::output::{DFColumnType, DFOutput}; use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; use indicatif::ProgressBar; use postgres_types::Type; -use rust_decimal::Decimal; use tokio::time::Instant; -use tokio_postgres::{Column, Row}; -use types::PgRegtype; - -mod types; +use tokio_postgres::{SimpleQueryMessage, SimpleQueryRow}; // default connect string, can be overridden by the `PG_URL` environment variable const PG_URI: &str = "postgresql://postgres@127.0.0.1/test"; @@ -299,8 +296,20 @@ impl sqllogictest::AsyncDB for Postgres { self.pb.inc(1); return Ok(DBOutput::StatementComplete(0)); } + // Use a prepared statement to get the output column types + let statement = self.get_client().prepare(sql).await?; + let types: Vec = statement + .columns() + .iter() + .map(|c| c.type_().clone()) + .collect(); + + // Run the actual query using the "simple query" protocol that returns all + // rows as text. Doing this avoids having to convert values from the binary + // format to strings, which is somewhat tricky for numeric types. + // See https://github.com/apache/datafusion/pull/19666#discussion_r2668090587 let start = Instant::now(); - let rows = self.get_client().query(sql, &[]).await?; + let messages = self.get_client().simple_query(sql).await?; let duration = start.elapsed(); if duration.gt(&Duration::from_millis(500)) { @@ -309,30 +318,16 @@ impl sqllogictest::AsyncDB for Postgres { self.pb.inc(1); - let types: Vec = if rows.is_empty() { - self.get_client() - .prepare(sql) - .await? - .columns() - .iter() - .map(|c| c.type_().clone()) - .collect() - } else { - rows[0] - .columns() - .iter() - .map(|c| c.type_().clone()) - .collect() - }; - self.currently_executing_sql_tracker.remove_sql(tracked_sql); + let rows = convert_rows(&types, &messages); + if rows.is_empty() && types.is_empty() { Ok(DBOutput::StatementComplete(0)) } else { Ok(DBOutput::Rows { types: convert_types(types), - rows: convert_rows(&rows), + rows, }) } } @@ -351,58 +346,68 @@ impl sqllogictest::AsyncDB for Postgres { } } -fn convert_rows(rows: &[Row]) -> Vec> { - rows.iter() +fn convert_rows(types: &[Type], messages: &[SimpleQueryMessage]) -> Vec> { + messages + .iter() + .filter_map(|message| match message { + SimpleQueryMessage::Row(row) => Some(row), + _ => None, + }) .map(|row| { - row.columns() + types .iter() .enumerate() - .map(|(idx, column)| cell_to_string(row, column, idx)) + .map(|(idx, column_type)| cell_to_string(row, column_type, idx)) .collect::>() }) .collect::>() } -macro_rules! make_string { - ($row:ident, $idx:ident, $t:ty) => {{ - let value: Option<$t> = $row.get($idx); - match value { - Some(value) => value.to_string(), - None => NULL_STR.to_string(), +fn cell_to_string(row: &SimpleQueryRow, column_type: &Type, idx: usize) -> String { + // simple_query returns text values, so we parse by Postgres type to keep + // normalization aligned with the DataFusion engine output. + let value = row.get(idx); + match (column_type, value) { + (_, None) => NULL_STR.to_string(), + (&Type::CHAR, Some(value)) => value + .as_bytes() + .first() + .map(|byte| (*byte as i8).to_string()) + .unwrap_or_else(|| NULL_STR.to_string()), + (&Type::INT2, Some(value)) => value.parse::().unwrap().to_string(), + (&Type::INT4, Some(value)) => value.parse::().unwrap().to_string(), + (&Type::INT8, Some(value)) => value.parse::().unwrap().to_string(), + (&Type::NUMERIC, Some(value)) => { + decimal_to_str(BigDecimal::from_str(value).unwrap()) } - }}; - ($row:ident, $idx:ident, $t:ty, $convert:ident) => {{ - let value: Option<$t> = $row.get($idx); - match value { - Some(value) => $convert(value).to_string(), - None => NULL_STR.to_string(), + // Parse date/time strings explicitly to avoid locale-specific formatting. + (&Type::DATE, Some(value)) => NaiveDate::parse_from_str(value, "%Y-%m-%d") + .unwrap() + .to_string(), + (&Type::TIME, Some(value)) => NaiveTime::parse_from_str(value, "%H:%M:%S%.f") + .unwrap() + .to_string(), + (&Type::TIMESTAMP, Some(value)) => { + let parsed = NaiveDateTime::parse_from_str(value, "%Y-%m-%d %H:%M:%S%.f") + .or_else(|_| NaiveDateTime::parse_from_str(value, "%Y-%m-%dT%H:%M:%S%.f")) + .unwrap(); + format!("{parsed:?}") } - }}; -} - -fn cell_to_string(row: &Row, column: &Column, idx: usize) -> String { - match column.type_().clone() { - Type::CHAR => make_string!(row, idx, i8), - Type::INT2 => make_string!(row, idx, i16), - Type::INT4 => make_string!(row, idx, i32), - Type::INT8 => make_string!(row, idx, i64), - Type::NUMERIC => make_string!(row, idx, Decimal, decimal_to_str), - Type::DATE => make_string!(row, idx, NaiveDate), - Type::TIME => make_string!(row, idx, NaiveTime), - Type::TIMESTAMP => { - let value: Option = row.get(idx); - value - .map(|d| format!("{d:?}")) - .unwrap_or_else(|| "NULL".to_string()) + (&Type::BOOL, Some(value)) => { + let parsed = match value { + "t" | "true" | "TRUE" => true, + "f" | "false" | "FALSE" => false, + _ => panic!("Unsupported boolean value: {value}"), + }; + bool_to_str(parsed) } - Type::BOOL => make_string!(row, idx, bool, bool_to_str), - Type::BPCHAR | Type::VARCHAR | Type::TEXT => { - make_string!(row, idx, &str, varchar_to_str) + (&Type::BPCHAR | &Type::VARCHAR | &Type::TEXT, Some(value)) => { + varchar_to_str(value) } - Type::FLOAT4 => make_string!(row, idx, f32, f32_to_str), - Type::FLOAT8 => make_string!(row, idx, f64, f64_to_str), - Type::REGTYPE => make_string!(row, idx, PgRegtype), - _ => unimplemented!("Unsupported type: {}", column.type_().name()), + (&Type::FLOAT4, Some(value)) => f32_to_str(value.parse::().unwrap()), + (&Type::FLOAT8, Some(value)) => f64_to_str(value.parse::().unwrap()), + (&Type::REGTYPE, Some(value)) => value.to_string(), + _ => unimplemented!("Unsupported type: {}", column_type.name()), } } diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index 9ec085b41eec..8bd0cabcb05b 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -21,6 +21,7 @@ use std::fs::File; use std::io::Write; use std::path::Path; use std::sync::Arc; +use std::vec; use arrow::array::{ Array, ArrayRef, BinaryArray, Float64Array, Int32Array, LargeBinaryArray, @@ -30,7 +31,7 @@ use arrow::buffer::ScalarBuffer; use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit, UnionFields}; use arrow::record_batch::RecordBatch; use datafusion::catalog::{ - CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, Session, + CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider, Session, }; use datafusion::common::{DataFusionError, Result, not_impl_err}; use datafusion::functions::math::abs; @@ -45,6 +46,7 @@ use datafusion::{ datasource::{MemTable, TableProvider, TableType}, prelude::{CsvReadOptions, SessionContext}, }; +use datafusion_spark::SessionStateBuilderSpark; use crate::is_spark_path; use async_trait::async_trait; @@ -80,22 +82,26 @@ impl TestContext { // hardcode target partitions so plans are deterministic .with_target_partitions(4); let runtime = Arc::new(RuntimeEnv::default()); - let mut state = SessionStateBuilder::new() + + let mut state_builder = SessionStateBuilder::new() .with_config(config) .with_runtime_env(runtime) - .with_default_features() - .build(); + .with_default_features(); if is_spark_path(relative_path) { - info!("Registering Spark functions"); - datafusion_spark::register_all(&mut state) - .expect("Can not register Spark functions"); + state_builder = state_builder.with_spark_features(); } + let state = state_builder.build(); + let mut test_ctx = TestContext::new(SessionContext::new_with_state(state)); let file_name = relative_path.file_name().unwrap().to_str().unwrap(); match file_name { + "cte_quoted_reference.slt" => { + info!("Registering strict catalog provider for CTE tests"); + register_strict_orders_catalog(test_ctx.session_ctx()); + } "information_schema_table_types.slt" => { info!("Registering local temporary table"); register_temp_table(test_ctx.session_ctx()).await; @@ -171,6 +177,104 @@ impl TestContext { } } +// ============================================================================== +// Strict Catalog / Schema Provider (sqllogictest-only) +// ============================================================================== +// +// The goal of `cte_quoted_reference.slt` is to exercise end-to-end query planning +// while detecting *unexpected* catalog lookups. +// +// Specifically, if DataFusion incorrectly treats a CTE reference (e.g. `"barbaz"`) +// as a real table reference, the planner will attempt to resolve it through the +// schema provider. The types below deliberately `panic!` on any lookup other than +// the one table we expect (`orders`). +// +// This makes the "extra provider lookup" bug observable in an end-to-end test, +// rather than being silently ignored by default providers that return `Ok(None)` +// for unknown tables. + +#[derive(Debug)] +struct StrictOrdersCatalog { + schema: Arc, +} + +impl CatalogProvider for StrictOrdersCatalog { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema_names(&self) -> Vec { + vec!["public".to_string()] + } + + fn schema(&self, name: &str) -> Option> { + (name == "public").then(|| Arc::clone(&self.schema)) + } +} + +#[derive(Debug)] +struct StrictOrdersSchema { + orders: Arc, +} + +#[async_trait] +impl SchemaProvider for StrictOrdersSchema { + fn as_any(&self) -> &dyn Any { + self + } + + fn table_names(&self) -> Vec { + vec!["orders".to_string()] + } + + async fn table( + &self, + name: &str, + ) -> Result>, DataFusionError> { + match name { + "orders" => Ok(Some(Arc::clone(&self.orders))), + other => panic!( + "unexpected table lookup: {other}. This maybe indicates a CTE reference was \ + incorrectly treated as a catalog table reference." + ), + } + } + + fn table_exist(&self, name: &str) -> bool { + name == "orders" + } +} + +fn register_strict_orders_catalog(ctx: &SessionContext) { + let schema = Arc::new(Schema::new(vec![Field::new( + "order_id", + DataType::Int32, + false, + )])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![1, 2]))], + ) + .expect("record batch should be valid"); + + let orders = + MemTable::try_new(schema, vec![vec![batch]]).expect("memtable should be valid"); + + let schema_provider: Arc = Arc::new(StrictOrdersSchema { + orders: Arc::new(orders), + }); + + // Override the default "datafusion" catalog for this test file so that any + // unexpected lookup is caught immediately. + ctx.register_catalog( + "datafusion", + Arc::new(StrictOrdersCatalog { + schema: schema_provider, + }), + ); +} + #[cfg(feature = "avro")] pub async fn register_avro_tables(ctx: &mut TestContext) { use datafusion::prelude::AvroReadOptions; @@ -436,14 +540,15 @@ fn create_example_udf() -> ScalarUDF { fn register_union_table(ctx: &SessionContext) { let union = UnionArray::try_new( - UnionFields::new( + UnionFields::try_new( // typeids: 3 for int, 1 for string vec![3, 1], vec![ Field::new("int", DataType::Int32, false), Field::new("string", DataType::Utf8, false), ], - ), + ) + .unwrap(), ScalarBuffer::from(vec![3, 1, 3]), None, vec![ diff --git a/datafusion/sqllogictest/src/util.rs b/datafusion/sqllogictest/src/util.rs index 6a3d3944e4e8..b0cf32266ea3 100644 --- a/datafusion/sqllogictest/src/util.rs +++ b/datafusion/sqllogictest/src/util.rs @@ -44,7 +44,7 @@ pub fn setup_scratch_dir(name: &Path) -> Result<()> { /// Trailing whitespace from lines in SLT will typically be removed, but do not fail if it is not /// If particular test wants to cover trailing whitespace on a value, /// it should project additional non-whitespace column on the right. -#[allow(clippy::ptr_arg)] +#[expect(clippy::ptr_arg)] pub fn value_normalizer(s: &String) -> String { s.trim_end().to_string() } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index f6ce68917e03..517467110fe6 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -379,6 +379,59 @@ select array_sort(c1), array_sort(c2) from ( statement ok drop table array_agg_distinct_list_table; +# Test array_agg with DISTINCT and IGNORE NULLS (regression test for issue #19735) +query ? +SELECT array_sort(ARRAY_AGG(DISTINCT x IGNORE NULLS)) as result +FROM (VALUES (1), (2), (NULL), (2), (NULL), (1)) AS t(x); +---- +[1, 2] + +# Test that non-DISTINCT aggregates also preserve IGNORE NULLS when mixed with DISTINCT +# This tests the two-phase aggregation rewrite in SingleDistinctToGroupBy +query I? +SELECT + COUNT(DISTINCT x) as distinct_count, + array_sort(ARRAY_AGG(y IGNORE NULLS)) as y_agg +FROM (VALUES + (1, 10), + (1, 20), + (2, 30), + (3, NULL), + (3, 40), + (NULL, 50) +) AS t(x, y) +---- +3 [10, 20, 30, 40, 50] + +# Test that FILTER clause is preserved in two-phase aggregation rewrite +query II +SELECT + COUNT(DISTINCT x) as distinct_count, + SUM(y) FILTER (WHERE y > 15) as filtered_sum +FROM (VALUES + (1, 10), + (1, 20), + (2, 5), + (2, 30), + (3, 25) +) AS t(x, y) +---- +3 75 + +# Test that ORDER BY is preserved in two-phase aggregation rewrite +query I? +SELECT + COUNT(DISTINCT x) as distinct_count, + ARRAY_AGG(y ORDER BY y DESC) as ordered_agg +FROM (VALUES + (1, 10), + (1, 30), + (2, 20), + (2, 40) +) AS t(x, y) +---- +2 [40, 30, 20, 10] + statement error This feature is not implemented: Calling array_agg: LIMIT not supported in function arguments: 1 SELECT array_agg(c13 LIMIT 1) FROM aggregate_test_100 @@ -518,6 +571,16 @@ SELECT covar(c2, c12) FROM aggregate_test_100 ---- -0.079969012479 +query R +SELECT covar_pop(arrow_cast(c2, 'Float16'), arrow_cast(c12, 'Float16')) FROM aggregate_test_100 +---- +-0.079163311005 + +query R +SELECT covar(arrow_cast(c2, 'Float16'), arrow_cast(c12, 'Float16')) FROM aggregate_test_100 +---- +-0.079962940409 + # single_row_query_covar_1 query R select covar_samp(sq.column1, sq.column2) from (values (1.1, 2.2)) as sq @@ -700,8 +763,10 @@ SELECT var(distinct c2) FROM aggregate_test_100 ---- 2.5 -statement error DataFusion error: This feature is not implemented: VAR\(DISTINCT\) aggregations are not available +query RR SELECT var(c2), var(distinct c2) FROM aggregate_test_100 +---- +1.886363636364 2.5 # csv_query_distinct_variance_population query R @@ -709,8 +774,10 @@ SELECT var_pop(distinct c2) FROM aggregate_test_100 ---- 2 -statement error DataFusion error: This feature is not implemented: VAR_POP\(DISTINCT\) aggregations are not available +query RR SELECT var_pop(c2), var_pop(distinct c2) FROM aggregate_test_100 +---- +1.8675 2 # csv_query_variance_5 query R @@ -1126,6 +1193,128 @@ ORDER BY tags, timestamp; 4 tag2 90 75 80 95 5 tag2 100 80 80 100 +########### +# Issue #19612: Test that percentile_cont produces correct results +# in window frame queries. Previously percentile_cont consumed its internal state +# during evaluate(), causing incorrect results when called multiple times. +########### + +# Test percentile_cont sliding window (same as median) +query ITRR +SELECT + timestamp, + tags, + value, + percentile_cont(value, 0.5) OVER ( + PARTITION BY tags + ORDER BY timestamp + ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING + ) AS value_percentile_50 +FROM median_window_test +ORDER BY tags, timestamp; +---- +1 tag1 10 15 +2 tag1 20 20 +3 tag1 30 30 +4 tag1 40 40 +5 tag1 50 45 +1 tag2 60 65 +2 tag2 70 70 +3 tag2 80 80 +4 tag2 90 90 +5 tag2 100 95 + +# Test percentile_cont non-sliding window +query ITRRRR +SELECT + timestamp, + tags, + value, + percentile_cont(value, 0.5) OVER ( + PARTITION BY tags + ORDER BY timestamp + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + ) AS value_percentile_unbounded_preceding, + percentile_cont(value, 0.5) OVER ( + PARTITION BY tags + ORDER BY timestamp + ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + ) AS value_percentile_unbounded_both, + percentile_cont(value, 0.5) OVER ( + PARTITION BY tags + ORDER BY timestamp + ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING + ) AS value_percentile_unbounded_following +FROM median_window_test +ORDER BY tags, timestamp; +---- +1 tag1 10 10 30 30 +2 tag1 20 15 30 35 +3 tag1 30 20 30 40 +4 tag1 40 25 30 45 +5 tag1 50 30 30 50 +1 tag2 60 60 80 80 +2 tag2 70 65 80 85 +3 tag2 80 70 80 90 +4 tag2 90 75 80 95 +5 tag2 100 80 80 100 + +# Test percentile_cont with different percentile values +query ITRRR +SELECT + timestamp, + tags, + value, + percentile_cont(value, 0.25) OVER ( + PARTITION BY tags + ORDER BY timestamp + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + ) AS p25, + percentile_cont(value, 0.75) OVER ( + PARTITION BY tags + ORDER BY timestamp + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + ) AS p75 +FROM median_window_test +ORDER BY tags, timestamp; +---- +1 tag1 10 10 10 +2 tag1 20 12.5 17.5 +3 tag1 30 15 25 +4 tag1 40 17.5 32.5 +5 tag1 50 20 40 +1 tag2 60 60 60 +2 tag2 70 62.5 67.5 +3 tag2 80 65 75 +4 tag2 90 67.5 82.5 +5 tag2 100 70 90 + + +# Test distinct median non-sliding window +query ITRR +SELECT + timestamp, + tags, + value, + median(DISTINCT value) OVER ( + PARTITION BY tags + ORDER BY timestamp + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + ) AS distinct_median +FROM median_window_test +ORDER BY tags, timestamp; +---- +1 tag1 10 10 +2 tag1 20 15 +3 tag1 30 20 +4 tag1 40 25 +5 tag1 50 30 +1 tag2 60 60 +2 tag2 70 65 +3 tag2 80 70 +4 tag2 90 75 +5 tag2 100 80 + statement ok DROP TABLE median_window_test; @@ -1134,6 +1323,24 @@ select approx_median(arrow_cast(col_f32, 'Float16')), arrow_typeof(approx_median ---- 2.75 Float16 +# This shouldn't be NaN, see: +# https://github.com/apache/datafusion/issues/18945 +query RT +select + percentile_cont(0.5) within group (order by arrow_cast(col_f32, 'Float16')), + arrow_typeof(percentile_cont(0.5) within group (order by arrow_cast(col_f32, 'Float16'))) +from median_table; +---- +2.75 Float16 + +query RT +select + approx_percentile_cont(0.5) within group (order by arrow_cast(col_f32, 'Float16')), + arrow_typeof(approx_percentile_cont(0.5) within group (order by arrow_cast(col_f32, 'Float16'))) +from median_table; +---- +2.75 Float16 + query ?T select approx_median(NULL), arrow_typeof(approx_median(NULL)) from median_table; ---- @@ -1850,11 +2057,12 @@ statement ok INSERT INTO t1 VALUES (TRUE); # ISSUE: https://github.com/apache/datafusion/issues/12716 -# This test verifies that approx_percentile_cont_with_weight does not panic when given 'NaN' and returns 'inf' +# This test verifies that approx_percentile_cont_with_weight does not panic when given 'NaN' +# With weight=0, the data point does not contribute, so result is NULL query R SELECT approx_percentile_cont_with_weight(0, 0) WITHIN GROUP (ORDER BY 'NaN'::DOUBLE) FROM t1 WHERE t1.v1; ---- -Infinity +NULL statement ok DROP TABLE t1; @@ -2173,21 +2381,21 @@ e 115 query TI SELECT c1, approx_percentile_cont_with_weight(c2, 0.95) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- -a 74 +a 65 b 68 -c 123 -d 124 -e 115 +c 122 +d 123 +e 110 # approx_percentile_cont_with_weight with centroids query TI SELECT c1, approx_percentile_cont_with_weight(c2, 0.95, 200) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- -a 74 +a 65 b 68 -c 123 -d 124 -e 115 +c 122 +d 123 +e 110 # csv_query_sum_crossjoin query TTI @@ -5322,10 +5530,10 @@ as values statement ok create table t as select - arrow_cast(column1, 'Timestamp(Nanosecond, None)') as nanos, - arrow_cast(column1, 'Timestamp(Microsecond, None)') as micros, - arrow_cast(column1, 'Timestamp(Millisecond, None)') as millis, - arrow_cast(column1, 'Timestamp(Second, None)') as secs, + arrow_cast(column1, 'Timestamp(ns)') as nanos, + arrow_cast(column1, 'Timestamp(µs)') as micros, + arrow_cast(column1, 'Timestamp(ms)') as millis, + arrow_cast(column1, 'Timestamp(s)') as secs, arrow_cast(column1, 'Timestamp(Nanosecond, Some("UTC"))') as nanos_utc, arrow_cast(column1, 'Timestamp(Microsecond, Some("UTC"))') as micros_utc, arrow_cast(column1, 'Timestamp(Millisecond, Some("UTC"))') as millis_utc, @@ -5408,7 +5616,7 @@ SELECT tag, avg(nanos), avg(micros), avg(millis), avg(secs) FROM t GROUP BY tag # aggregate_duration_array_agg query T? -SELECT tag, array_agg(millis - arrow_cast(secs, 'Timestamp(Millisecond, None)')) FROM t GROUP BY tag ORDER BY tag; +SELECT tag, array_agg(millis - arrow_cast(secs, 'Timestamp(ms)')) FROM t GROUP BY tag ORDER BY tag; ---- X [0 days 0 hours 0 mins 0.011 secs, 0 days 0 hours 0 mins 0.123 secs] Y [NULL, 0 days 0 hours 0 mins 0.432 secs] @@ -6539,7 +6747,12 @@ from aggregate_test_100; ---- 0.051534002628 0.48427355347 100 0.001929150558 0.479274948239 0.508972509913 6.707779292571 9.234223721582 0.345678715695 - +query R +select + regr_slope(arrow_cast(c12, 'Float16'), arrow_cast(c11, 'Float16')) +from aggregate_test_100; +---- +0.051477733249 # regr_*() functions ignore NULLs query RRIRRRRRR @@ -7772,8 +7985,9 @@ logical_plan 02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 03)----TableScan: t projection=[] physical_plan -01)ProjectionExec: expr=[2 as count(Int64(1)), 2 as count()] -02)--PlaceholderRowExec +01)ProjectionExec: expr=[count(Int64(1))@0 as count(Int64(1)), count(Int64(1))@0 as count()] +02)--ProjectionExec: expr=[2 as count(Int64(1))] +03)----PlaceholderRowExec query II select count(1), count(*) from t; @@ -7788,8 +8002,9 @@ logical_plan 02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 03)----TableScan: t projection=[] physical_plan -01)ProjectionExec: expr=[2 as count(Int64(1)), 2 as count(*)] -02)--PlaceholderRowExec +01)ProjectionExec: expr=[count(Int64(1))@0 as count(Int64(1)), count(Int64(1))@0 as count(*)] +02)--ProjectionExec: expr=[2 as count(Int64(1))] +03)----PlaceholderRowExec query II select count(), count(*) from t; @@ -7804,8 +8019,9 @@ logical_plan 02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 03)----TableScan: t projection=[] physical_plan -01)ProjectionExec: expr=[2 as count(), 2 as count(*)] -02)--PlaceholderRowExec +01)ProjectionExec: expr=[count(Int64(1))@0 as count(), count(Int64(1))@0 as count(*)] +02)--ProjectionExec: expr=[2 as count(Int64(1))] +03)----PlaceholderRowExec query TT explain select count(1) * count(2) from t; @@ -8246,3 +8462,325 @@ query R select percentile_cont(null, 0.5); ---- NULL + +# Test string_agg window frame behavior (fix for issue #19612) +statement ok +CREATE TABLE string_agg_window_test ( + id INT, + grp VARCHAR, + val VARCHAR +); + +statement ok +INSERT INTO string_agg_window_test (id, grp, val) VALUES +(1, 'A', 'a'), +(2, 'A', 'b'), +(3, 'A', 'c'), +(1, 'B', 'x'), +(2, 'B', 'y'), +(3, 'B', 'z'); + +# Test string_agg with ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +# The function should maintain state correctly across multiple evaluate() calls +query ITT +SELECT + id, + grp, + string_agg(val, ',') OVER ( + PARTITION BY grp + ORDER BY id + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + ) AS cumulative_string +FROM string_agg_window_test +ORDER BY grp, id; +---- +1 A a +2 A a,b +3 A a,b,c +1 B x +2 B x,y +3 B x,y,z + +statement ok +DROP TABLE string_agg_window_test; + +# Enable streaming aggregation by limiting partitions and ensuring sorted input +statement ok +set datafusion.execution.target_partitions = 1; + +# Setup data +statement ok +CREATE TABLE stream_test ( + g INT, + x DOUBLE, + y DOUBLE, + i INT, + b BOOLEAN, + s VARCHAR +) AS VALUES +(1, 1.0, 1.0, 1, true, 'a'), (1, 2.0, 2.0, 2, true, 'b'), +(2, 1.0, 5.0, 3, false, 'c'), (2, 2.0, 5.0, 4, true, 'd'), +(3, 1.0, 1.0, 7, false, 'e'), (3, 2.0, 2.0, 8, false, 'f'); + +# Test comprehensive aggregates with streaming +# This verifies that CORR and other aggregates work together in a streaming plan (ordering_mode=Sorted) + +# Basic Aggregates +query TT +EXPLAIN SELECT + g, + COUNT(*), + SUM(x), + AVG(x), + MEAN(x), + MIN(x), + MAX(y), + BIT_AND(i), + BIT_OR(i), + BIT_XOR(i), + BOOL_AND(b), + BOOL_OR(b), + MEDIAN(x), + GROUPING(g), + VAR(x), + VAR_SAMP(x), + VAR_POP(x), + VAR_SAMPLE(x), + VAR_POPULATION(x), + STDDEV(x), + STDDEV_SAMP(x), + STDDEV_POP(x) +FROM (SELECT * FROM stream_test ORDER BY g LIMIT 10000) +GROUP BY g +ORDER BY g; +---- +logical_plan +01)Sort: stream_test.g ASC NULLS LAST +02)--Projection: stream_test.g, count(Int64(1)) AS count(*), sum(stream_test.x), avg(stream_test.x), avg(stream_test.x) AS mean(stream_test.x), min(stream_test.x), max(stream_test.y), bit_and(stream_test.i), bit_or(stream_test.i), bit_xor(stream_test.i), bool_and(stream_test.b), bool_or(stream_test.b), median(stream_test.x), Int32(0) AS grouping(stream_test.g), var(stream_test.x), var(stream_test.x) AS var_samp(stream_test.x), var_pop(stream_test.x), var(stream_test.x) AS var_sample(stream_test.x), var_pop(stream_test.x) AS var_population(stream_test.x), stddev(stream_test.x), stddev(stream_test.x) AS stddev_samp(stream_test.x), stddev_pop(stream_test.x) +03)----Aggregate: groupBy=[[stream_test.g]], aggr=[[count(Int64(1)), sum(stream_test.x), avg(stream_test.x), min(stream_test.x), max(stream_test.y), bit_and(stream_test.i), bit_or(stream_test.i), bit_xor(stream_test.i), bool_and(stream_test.b), bool_or(stream_test.b), median(stream_test.x), var(stream_test.x), var_pop(stream_test.x), stddev(stream_test.x), stddev_pop(stream_test.x)]] +04)------Sort: stream_test.g ASC NULLS LAST, fetch=10000 +05)--------TableScan: stream_test projection=[g, x, y, i, b] +physical_plan +01)ProjectionExec: expr=[g@0 as g, count(Int64(1))@1 as count(*), sum(stream_test.x)@2 as sum(stream_test.x), avg(stream_test.x)@3 as avg(stream_test.x), avg(stream_test.x)@3 as mean(stream_test.x), min(stream_test.x)@4 as min(stream_test.x), max(stream_test.y)@5 as max(stream_test.y), bit_and(stream_test.i)@6 as bit_and(stream_test.i), bit_or(stream_test.i)@7 as bit_or(stream_test.i), bit_xor(stream_test.i)@8 as bit_xor(stream_test.i), bool_and(stream_test.b)@9 as bool_and(stream_test.b), bool_or(stream_test.b)@10 as bool_or(stream_test.b), median(stream_test.x)@11 as median(stream_test.x), 0 as grouping(stream_test.g), var(stream_test.x)@12 as var(stream_test.x), var(stream_test.x)@12 as var_samp(stream_test.x), var_pop(stream_test.x)@13 as var_pop(stream_test.x), var(stream_test.x)@12 as var_sample(stream_test.x), var_pop(stream_test.x)@13 as var_population(stream_test.x), stddev(stream_test.x)@14 as stddev(stream_test.x), stddev(stream_test.x)@14 as stddev_samp(stream_test.x), stddev_pop(stream_test.x)@15 as stddev_pop(stream_test.x)] +02)--AggregateExec: mode=Single, gby=[g@0 as g], aggr=[count(Int64(1)), sum(stream_test.x), avg(stream_test.x), min(stream_test.x), max(stream_test.y), bit_and(stream_test.i), bit_or(stream_test.i), bit_xor(stream_test.i), bool_and(stream_test.b), bool_or(stream_test.b), median(stream_test.x), var(stream_test.x), var_pop(stream_test.x), stddev(stream_test.x), stddev_pop(stream_test.x)], ordering_mode=Sorted +03)----SortExec: TopK(fetch=10000), expr=[g@0 ASC NULLS LAST], preserve_partitioning=[false] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +query IIRRRRRIIIBBRIRRRRRRRR +SELECT + g, + COUNT(*), + SUM(x), + AVG(x), + MEAN(x), + MIN(x), + MAX(y), + BIT_AND(i), + BIT_OR(i), + BIT_XOR(i), + BOOL_AND(b), + BOOL_OR(b), + MEDIAN(x), + GROUPING(g), + VAR(x), + VAR_SAMP(x), + VAR_POP(x), + VAR_SAMPLE(x), + VAR_POPULATION(x), + STDDEV(x), + STDDEV_SAMP(x), + STDDEV_POP(x) +FROM (SELECT * FROM stream_test ORDER BY g LIMIT 10000) +GROUP BY g +ORDER BY g; +---- +1 2 3 1.5 1.5 1 2 0 3 3 true true 1.5 0 0.5 0.5 0.25 0.5 0.25 0.707106781187 0.707106781187 0.5 +2 2 3 1.5 1.5 1 5 0 7 7 false true 1.5 0 0.5 0.5 0.25 0.5 0.25 0.707106781187 0.707106781187 0.5 +3 2 3 1.5 1.5 1 2 0 15 15 false false 1.5 0 0.5 0.5 0.25 0.5 0.25 0.707106781187 0.707106781187 0.5 + +# Ordered Aggregates (by x) +query TT +EXPLAIN SELECT + g, + ARRAY_AGG(x ORDER BY x), + ARRAY_AGG(DISTINCT x ORDER BY x), + FIRST_VALUE(x ORDER BY x), + LAST_VALUE(x ORDER BY x), + NTH_VALUE(x, 1 ORDER BY x) +FROM (SELECT * FROM stream_test ORDER BY g LIMIT 10000) +GROUP BY g +ORDER BY g; +---- +logical_plan +01)Sort: stream_test.g ASC NULLS LAST +02)--Aggregate: groupBy=[[stream_test.g]], aggr=[[array_agg(stream_test.x) ORDER BY [stream_test.x ASC NULLS LAST], array_agg(DISTINCT stream_test.x) ORDER BY [stream_test.x ASC NULLS LAST], first_value(stream_test.x) ORDER BY [stream_test.x ASC NULLS LAST], last_value(stream_test.x) ORDER BY [stream_test.x ASC NULLS LAST], nth_value(stream_test.x, Int64(1)) ORDER BY [stream_test.x ASC NULLS LAST]]] +03)----Sort: stream_test.g ASC NULLS LAST, fetch=10000 +04)------TableScan: stream_test projection=[g, x] +physical_plan +01)AggregateExec: mode=Single, gby=[g@0 as g], aggr=[array_agg(stream_test.x) ORDER BY [stream_test.x ASC NULLS LAST], array_agg(DISTINCT stream_test.x) ORDER BY [stream_test.x ASC NULLS LAST], first_value(stream_test.x) ORDER BY [stream_test.x ASC NULLS LAST], last_value(stream_test.x) ORDER BY [stream_test.x ASC NULLS LAST], nth_value(stream_test.x,Int64(1)) ORDER BY [stream_test.x ASC NULLS LAST]], ordering_mode=Sorted +02)--SortExec: TopK(fetch=10000), expr=[g@0 ASC NULLS LAST, x@1 ASC NULLS LAST], preserve_partitioning=[false] +03)----DataSourceExec: partitions=1, partition_sizes=[1] + +query I??RRR +SELECT + g, + ARRAY_AGG(x ORDER BY x), + ARRAY_AGG(DISTINCT x ORDER BY x), + FIRST_VALUE(x ORDER BY x), + LAST_VALUE(x ORDER BY x), + NTH_VALUE(x, 1 ORDER BY x) +FROM (SELECT * FROM stream_test ORDER BY g LIMIT 10000) +GROUP BY g +ORDER BY g; +---- +1 [1.0, 2.0] [1.0, 2.0] 1 2 1 +2 [1.0, 2.0] [1.0, 2.0] 1 2 1 +3 [1.0, 2.0] [1.0, 2.0] 1 2 1 + +# Ordered Aggregates (by s) +query TT +EXPLAIN SELECT + g, + ARRAY_AGG(s ORDER BY s), + STRING_AGG(s, '|' ORDER BY s), + STRING_AGG(DISTINCT s, '|' ORDER BY s) +FROM (SELECT * FROM stream_test ORDER BY g LIMIT 10000) +GROUP BY g +ORDER BY g; +---- +logical_plan +01)Sort: stream_test.g ASC NULLS LAST +02)--Aggregate: groupBy=[[stream_test.g]], aggr=[[array_agg(stream_test.s) ORDER BY [stream_test.s ASC NULLS LAST], string_agg(stream_test.s, Utf8("|")) ORDER BY [stream_test.s ASC NULLS LAST], string_agg(DISTINCT stream_test.s, Utf8("|")) ORDER BY [stream_test.s ASC NULLS LAST]]] +03)----Sort: stream_test.g ASC NULLS LAST, fetch=10000 +04)------TableScan: stream_test projection=[g, s] +physical_plan +01)AggregateExec: mode=Single, gby=[g@0 as g], aggr=[array_agg(stream_test.s) ORDER BY [stream_test.s ASC NULLS LAST], string_agg(stream_test.s,Utf8("|")) ORDER BY [stream_test.s ASC NULLS LAST], string_agg(DISTINCT stream_test.s,Utf8("|")) ORDER BY [stream_test.s ASC NULLS LAST]], ordering_mode=Sorted +02)--SortExec: TopK(fetch=10000), expr=[g@0 ASC NULLS LAST, s@1 ASC NULLS LAST], preserve_partitioning=[false] +03)----DataSourceExec: partitions=1, partition_sizes=[1] + +query I?TT +SELECT + g, + ARRAY_AGG(s ORDER BY s), + STRING_AGG(s, '|' ORDER BY s), + STRING_AGG(DISTINCT s, '|' ORDER BY s) +FROM (SELECT * FROM stream_test ORDER BY g LIMIT 10000) +GROUP BY g +ORDER BY g; +---- +1 [a, b] a|b a|b +2 [c, d] c|d c|d +3 [e, f] e|f e|f + +# Statistical & Regression Aggregates +query TT +EXPLAIN SELECT + g, + CORR(x, y), + COVAR(x, y), + COVAR_SAMP(x, y), + COVAR_POP(x, y), + REGR_SXX(x, y), + REGR_SXY(x, y), + REGR_SYY(x, y), + REGR_AVGX(x, y), + REGR_AVGY(x, y), + REGR_COUNT(x, y), + REGR_SLOPE(x, y), + REGR_INTERCEPT(x, y), + REGR_R2(x, y) +FROM (SELECT * FROM stream_test ORDER BY g LIMIT 10000) +GROUP BY g +ORDER BY g; +---- +logical_plan +01)Sort: stream_test.g ASC NULLS LAST +02)--Projection: stream_test.g, corr(stream_test.x,stream_test.y), covar_samp(stream_test.x,stream_test.y) AS covar(stream_test.x,stream_test.y), covar_samp(stream_test.x,stream_test.y), covar_pop(stream_test.x,stream_test.y), regr_sxx(stream_test.x,stream_test.y), regr_sxy(stream_test.x,stream_test.y), regr_syy(stream_test.x,stream_test.y), regr_avgx(stream_test.x,stream_test.y), regr_avgy(stream_test.x,stream_test.y), regr_count(stream_test.x,stream_test.y), regr_slope(stream_test.x,stream_test.y), regr_intercept(stream_test.x,stream_test.y), regr_r2(stream_test.x,stream_test.y) +03)----Aggregate: groupBy=[[stream_test.g]], aggr=[[corr(stream_test.x, stream_test.y), covar_samp(stream_test.x, stream_test.y), covar_pop(stream_test.x, stream_test.y), regr_sxx(stream_test.x, stream_test.y), regr_sxy(stream_test.x, stream_test.y), regr_syy(stream_test.x, stream_test.y), regr_avgx(stream_test.x, stream_test.y), regr_avgy(stream_test.x, stream_test.y), regr_count(stream_test.x, stream_test.y), regr_slope(stream_test.x, stream_test.y), regr_intercept(stream_test.x, stream_test.y), regr_r2(stream_test.x, stream_test.y)]] +04)------Sort: stream_test.g ASC NULLS LAST, fetch=10000 +05)--------TableScan: stream_test projection=[g, x, y] +physical_plan +01)ProjectionExec: expr=[g@0 as g, corr(stream_test.x,stream_test.y)@1 as corr(stream_test.x,stream_test.y), covar_samp(stream_test.x,stream_test.y)@2 as covar(stream_test.x,stream_test.y), covar_samp(stream_test.x,stream_test.y)@2 as covar_samp(stream_test.x,stream_test.y), covar_pop(stream_test.x,stream_test.y)@3 as covar_pop(stream_test.x,stream_test.y), regr_sxx(stream_test.x,stream_test.y)@4 as regr_sxx(stream_test.x,stream_test.y), regr_sxy(stream_test.x,stream_test.y)@5 as regr_sxy(stream_test.x,stream_test.y), regr_syy(stream_test.x,stream_test.y)@6 as regr_syy(stream_test.x,stream_test.y), regr_avgx(stream_test.x,stream_test.y)@7 as regr_avgx(stream_test.x,stream_test.y), regr_avgy(stream_test.x,stream_test.y)@8 as regr_avgy(stream_test.x,stream_test.y), regr_count(stream_test.x,stream_test.y)@9 as regr_count(stream_test.x,stream_test.y), regr_slope(stream_test.x,stream_test.y)@10 as regr_slope(stream_test.x,stream_test.y), regr_intercept(stream_test.x,stream_test.y)@11 as regr_intercept(stream_test.x,stream_test.y), regr_r2(stream_test.x,stream_test.y)@12 as regr_r2(stream_test.x,stream_test.y)] +02)--AggregateExec: mode=Single, gby=[g@0 as g], aggr=[corr(stream_test.x,stream_test.y), covar_samp(stream_test.x,stream_test.y), covar_pop(stream_test.x,stream_test.y), regr_sxx(stream_test.x,stream_test.y), regr_sxy(stream_test.x,stream_test.y), regr_syy(stream_test.x,stream_test.y), regr_avgx(stream_test.x,stream_test.y), regr_avgy(stream_test.x,stream_test.y), regr_count(stream_test.x,stream_test.y), regr_slope(stream_test.x,stream_test.y), regr_intercept(stream_test.x,stream_test.y), regr_r2(stream_test.x,stream_test.y)], ordering_mode=Sorted +03)----SortExec: TopK(fetch=10000), expr=[g@0 ASC NULLS LAST], preserve_partitioning=[false] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +query IRRRRRRRRRIRRR +SELECT + g, + CORR(x, y), + COVAR(x, y), + COVAR_SAMP(x, y), + COVAR_POP(x, y), + REGR_SXX(x, y), + REGR_SXY(x, y), + REGR_SYY(x, y), + REGR_AVGX(x, y), + REGR_AVGY(x, y), + REGR_COUNT(x, y), + REGR_SLOPE(x, y), + REGR_INTERCEPT(x, y), + REGR_R2(x, y) +FROM (SELECT * FROM stream_test ORDER BY g LIMIT 10000) +GROUP BY g +ORDER BY g; +---- +1 1 0.5 0.5 0.25 0.5 0.5 0.5 1.5 1.5 2 1 0 1 +2 NULL 0 0 0 0 0 0.5 5 1.5 2 NULL NULL NULL +3 1 0.5 0.5 0.25 0.5 0.5 0.5 1.5 1.5 2 1 0 1 + +# Approximate and Ordered-Set Aggregates +query TT +EXPLAIN SELECT + g, + APPROX_DISTINCT(i), + APPROX_MEDIAN(x), + PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x), + QUANTILE_CONT(0.5) WITHIN GROUP (ORDER BY x), + APPROX_PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x), + APPROX_PERCENTILE_CONT_WITH_WEIGHT(1.0, 0.5) WITHIN GROUP (ORDER BY x), + PERCENTILE_CONT(x, 0.5), + APPROX_PERCENTILE_CONT(x, 0.5), + APPROX_PERCENTILE_CONT_WITH_WEIGHT(x, 1.0, 0.5) +FROM (SELECT * FROM stream_test ORDER BY g LIMIT 10000) +GROUP BY g +ORDER BY g; +---- +logical_plan +01)Sort: stream_test.g ASC NULLS LAST +02)--Projection: stream_test.g, approx_distinct(stream_test.i), approx_median(stream_test.x), percentile_cont(Float64(0.5)) WITHIN GROUP [stream_test.x ASC NULLS LAST], percentile_cont(Float64(0.5)) WITHIN GROUP [stream_test.x ASC NULLS LAST] AS quantile_cont(stream_test.x,Float64(0.5)), approx_percentile_cont(Float64(0.5)) WITHIN GROUP [stream_test.x ASC NULLS LAST], approx_percentile_cont_with_weight(Float64(1),Float64(0.5)) WITHIN GROUP [stream_test.x ASC NULLS LAST], percentile_cont(stream_test.x,Float64(0.5)), approx_percentile_cont(stream_test.x,Float64(0.5)), approx_percentile_cont_with_weight(stream_test.x,Float64(1),Float64(0.5)) +03)----Aggregate: groupBy=[[stream_test.g]], aggr=[[approx_distinct(stream_test.i), approx_median(stream_test.x), percentile_cont(stream_test.x, Float64(0.5)) ORDER BY [stream_test.x ASC NULLS LAST], approx_percentile_cont(stream_test.x, Float64(0.5)) ORDER BY [stream_test.x ASC NULLS LAST], approx_percentile_cont_with_weight(stream_test.x, Float64(1), Float64(0.5)) ORDER BY [stream_test.x ASC NULLS LAST], percentile_cont(stream_test.x, Float64(0.5)), approx_percentile_cont(stream_test.x, Float64(0.5)), approx_percentile_cont_with_weight(stream_test.x, Float64(1), Float64(0.5))]] +04)------Sort: stream_test.g ASC NULLS LAST, fetch=10000 +05)--------TableScan: stream_test projection=[g, x, i] +physical_plan +01)ProjectionExec: expr=[g@0 as g, approx_distinct(stream_test.i)@1 as approx_distinct(stream_test.i), approx_median(stream_test.x)@2 as approx_median(stream_test.x), percentile_cont(Float64(0.5)) WITHIN GROUP [stream_test.x ASC NULLS LAST]@3 as percentile_cont(Float64(0.5)) WITHIN GROUP [stream_test.x ASC NULLS LAST], percentile_cont(Float64(0.5)) WITHIN GROUP [stream_test.x ASC NULLS LAST]@3 as quantile_cont(stream_test.x,Float64(0.5)), approx_percentile_cont(Float64(0.5)) WITHIN GROUP [stream_test.x ASC NULLS LAST]@4 as approx_percentile_cont(Float64(0.5)) WITHIN GROUP [stream_test.x ASC NULLS LAST], approx_percentile_cont_with_weight(Float64(1),Float64(0.5)) WITHIN GROUP [stream_test.x ASC NULLS LAST]@5 as approx_percentile_cont_with_weight(Float64(1),Float64(0.5)) WITHIN GROUP [stream_test.x ASC NULLS LAST], percentile_cont(stream_test.x,Float64(0.5))@6 as percentile_cont(stream_test.x,Float64(0.5)), approx_percentile_cont(stream_test.x,Float64(0.5))@7 as approx_percentile_cont(stream_test.x,Float64(0.5)), approx_percentile_cont_with_weight(stream_test.x,Float64(1),Float64(0.5))@8 as approx_percentile_cont_with_weight(stream_test.x,Float64(1),Float64(0.5))] +02)--AggregateExec: mode=Single, gby=[g@0 as g], aggr=[approx_distinct(stream_test.i), approx_median(stream_test.x), percentile_cont(Float64(0.5)) WITHIN GROUP [stream_test.x ASC NULLS LAST], approx_percentile_cont(Float64(0.5)) WITHIN GROUP [stream_test.x ASC NULLS LAST], approx_percentile_cont_with_weight(Float64(1),Float64(0.5)) WITHIN GROUP [stream_test.x ASC NULLS LAST], percentile_cont(stream_test.x,Float64(0.5)), approx_percentile_cont(stream_test.x,Float64(0.5)), approx_percentile_cont_with_weight(stream_test.x,Float64(1),Float64(0.5))], ordering_mode=Sorted +03)----SortExec: TopK(fetch=10000), expr=[g@0 ASC NULLS LAST], preserve_partitioning=[false] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +query IIRRRRRRRR +SELECT + g, + APPROX_DISTINCT(i), + APPROX_MEDIAN(x), + PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x), + QUANTILE_CONT(0.5) WITHIN GROUP (ORDER BY x), + APPROX_PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x), + APPROX_PERCENTILE_CONT_WITH_WEIGHT(1.0, 0.5) WITHIN GROUP (ORDER BY x), + PERCENTILE_CONT(x, 0.5), + APPROX_PERCENTILE_CONT(x, 0.5), + APPROX_PERCENTILE_CONT_WITH_WEIGHT(x, 1.0, 0.5) +FROM (SELECT * FROM stream_test ORDER BY g LIMIT 10000) +GROUP BY g +ORDER BY g; +---- +1 2 1.5 1.5 1.5 1.5 1.5 1.5 1.5 1.5 +2 2 1.5 1.5 1.5 1.5 1.5 1.5 1.5 1.5 +3 2 1.5 1.5 1.5 1.5 1.5 1.5 1.5 1.5 + +statement ok +DROP TABLE stream_test; + +# Restore default target partitions +statement ok +set datafusion.execution.target_partitions = 4; diff --git a/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt b/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt index 0885a6a7d663..c16a6f442427 100644 --- a/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt +++ b/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt @@ -175,6 +175,21 @@ GROUP BY 1, 2 ORDER BY 1 LIMIT 5; -2117946883 d 1 0 0 0 -2098805236 c 1 0 0 0 +query IT???? +SELECT c5, c1, + ARRAY_AGG(c3), + ARRAY_AGG(CASE WHEN c1 = 'a' THEN c3 ELSE NULL END), + ARRAY_AGG(c3) FILTER (WHERE c1 = 'b'), + ARRAY_AGG(CASE WHEN c1 = 'a' THEN c3 ELSE NULL END) FILTER (WHERE c1 = 'b') +FROM aggregate_test_100 +GROUP BY 1, 2 ORDER BY 1 LIMIT 5; +---- +-2141999138 c [-2] [NULL] NULL NULL +-2141451704 a [-72] [-72] NULL NULL +-2138770630 b [63] [NULL] [63] [NULL] +-2117946883 d [-59] [NULL] NULL NULL +-2098805236 c [22] [NULL] NULL NULL + # Regression test for https://github.com/apache/datafusion/issues/11846 query TBBBB rowsort select v1, bool_or(v2), bool_and(v2), bool_or(v3), bool_and(v3) @@ -244,6 +259,19 @@ SELECT c2, count(c1), count(c5), count(c11) FROM aggregate_test_100 GROUP BY c2 4 23 23 23 5 14 14 14 +# Test array_agg; we sort the output to ensure deterministic results +query I?? +SELECT c2, + array_sort(array_agg(c5)), + array_sort(array_agg(c3) FILTER (WHERE c3 > 0)) +FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 [-1991133944, -1882293856, -1448995523, -1383162419, -1339586153, -1331533190, -1176490478, -1143802338, -928766616, -644225469, -335410409, 383352709, 431378678, 794623392, 994303988, 1171968280, 1188089983, 1213926989, 1325868318, 1413111008, 2106705285, 2143473091] [12, 29, 36, 38, 41, 54, 57, 70, 71, 83, 103, 120, 125] +2 [-2138770630, -1927628110, -1908480893, -1899175111, -1808210365, -1660426473, -1222533990, -1090239422, -1011669561, -800561771, -587831330, -537142430, -168758331, -108973366, 49866617, 370975815, 439738328, 715235348, 1354539333, 1593800404, 2033001162, 2053379412] [1, 29, 31, 45, 49, 52, 52, 63, 68, 93, 97, 113, 122] +3 [-2141999138, -2141451704, -2098805236, -1302295658, -903316089, -421042466, -382483011, -346989627, 141218956, 240273900, 397430452, 670497898, 912707948, 1299719633, 1337043149, 1436496767, 1489733240, 1738331255, 2030965207] [13, 13, 14, 17, 17, 22, 71, 73, 77, 97, 104, 112, 123] +4 [-1885422396, -1813935549, -1009656194, -673237643, -237425046, -4229382, 61035129, 427197269, 434021400, 659422734, 702611616, 762932956, 852509237, 1282464673, 1423957796, 1544188174, 1579876740, 1902023838, 1991172974, 1993193190, 2047637360, 2051224722, 2064155045] [3, 5, 17, 30, 47, 55, 65, 73, 74, 96, 97, 102, 123] +5 [-2117946883, -842693467, -629486480, -467659022, -134213907, 41423756, 586844478, 623103518, 706441268, 1188285940, 1689098844, 1824882165, 1955646088, 2025611582] [36, 62, 64, 68, 118] + # Test min / max for int / float query IIIRR SELECT c2, min(c5), max(c5), min(c11), max(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; @@ -389,19 +417,6 @@ c 2.666666666667 0.425241138254 d 2.444444444444 0.541519476308 e 3 0.505440263521 -# FIXME: add bool_and(v3) column when issue fixed -# ISSUE https://github.com/apache/datafusion/issues/11846 -query TBBB rowsort -select v1, bool_or(v2), bool_and(v2), bool_or(v3) -from aggregate_test_100_bool -group by v1 ----- -a true false true -b true false true -c true false false -d true false false -e true false NULL - query TBBB rowsort select v1, bool_or(v2) FILTER (WHERE v1 = 'a' OR v1 = 'c' OR v1 = 'e'), diff --git a/datafusion/sqllogictest/test_files/aggregates_topk.slt b/datafusion/sqllogictest/test_files/aggregates_topk.slt index 58abecfacfa8..19ead8965ed0 100644 --- a/datafusion/sqllogictest/test_files/aggregates_topk.slt +++ b/datafusion/sqllogictest/test_files/aggregates_topk.slt @@ -195,6 +195,70 @@ a -1 -1 NULL 0 0 a 1 1 +statement ok +CREATE TABLE string_topk(category varchar, val varchar) AS VALUES +('x', 'apple'), +('x', 'zebra'), +('y', 'banana'), +('y', 'apricot'), +('z', 'mango'); + +statement ok +CREATE VIEW string_topk_view AS +SELECT + arrow_cast(category, 'Utf8View') AS category, + arrow_cast(val, 'Utf8View') AS val +FROM + string_topk; + +query TT +select category, max(val) from string_topk group by category order by max(val) desc limit 2; +---- +x zebra +z mango + +query TT +explain select category, max(val) max_val from string_topk group by category order by max_val desc limit 2; +---- +logical_plan +01)Sort: max_val DESC NULLS FIRST, fetch=2 +02)--Projection: string_topk.category, max(string_topk.val) AS max_val +03)----Aggregate: groupBy=[[string_topk.category]], aggr=[[max(string_topk.val)]] +04)------TableScan: string_topk projection=[category, val] +physical_plan +01)SortPreservingMergeExec: [max_val@1 DESC], fetch=2 +02)--SortExec: TopK(fetch=2), expr=[max_val@1 DESC], preserve_partitioning=[true] +03)----ProjectionExec: expr=[category@0 as category, max(string_topk.val)@1 as max_val] +04)------AggregateExec: mode=FinalPartitioned, gby=[category@0 as category], aggr=[max(string_topk.val)], lim=[2] +05)--------RepartitionExec: partitioning=Hash([category@0], 4), input_partitions=1 +06)----------AggregateExec: mode=Partial, gby=[category@0 as category], aggr=[max(string_topk.val)], lim=[2] +07)------------DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +select category, max(val) from string_topk_view group by category order by max(val) desc limit 2; +---- +x zebra +z mango + +query TT +explain select category, max(val) max_val from string_topk_view group by category order by max_val desc limit 2; +---- +logical_plan +01)Sort: max_val DESC NULLS FIRST, fetch=2 +02)--Projection: string_topk_view.category, max(string_topk_view.val) AS max_val +03)----Aggregate: groupBy=[[string_topk_view.category]], aggr=[[max(string_topk_view.val)]] +04)------SubqueryAlias: string_topk_view +05)--------Projection: string_topk.category AS category, string_topk.val AS val +06)----------TableScan: string_topk projection=[category, val] +physical_plan +01)SortPreservingMergeExec: [max_val@1 DESC], fetch=2 +02)--SortExec: TopK(fetch=2), expr=[max_val@1 DESC], preserve_partitioning=[true] +03)----ProjectionExec: expr=[category@0 as category, max(string_topk_view.val)@1 as max_val] +04)------AggregateExec: mode=FinalPartitioned, gby=[category@0 as category], aggr=[max(string_topk_view.val)], lim=[2] +05)--------RepartitionExec: partitioning=Hash([category@0], 4), input_partitions=1 +06)----------AggregateExec: mode=Partial, gby=[category@0 as category], aggr=[max(string_topk_view.val)], lim=[2] +07)------------DataSourceExec: partitions=1, partition_sizes=[1] + query TII select trace_id, min(other), MIN(timestamp) from traces group by trace_id order by MIN(timestamp), MIN(other) limit 4; ---- @@ -203,6 +267,30 @@ a -1 -1 NULL 0 0 c 1 2 +# Regression tests for string max with ORDER BY ... LIMIT to ensure schema stability +query TT +select trace_id, max(trace_id) as max_trace from traces group by trace_id order by max_trace desc limit 2; +---- +c c +b b + +query TT +explain select trace_id, max(trace_id) as max_trace from traces group by trace_id order by max_trace desc limit 2; +---- +logical_plan +01)Sort: max_trace DESC NULLS FIRST, fetch=2 +02)--Projection: traces.trace_id, max(traces.trace_id) AS max_trace +03)----Aggregate: groupBy=[[traces.trace_id]], aggr=[[max(traces.trace_id)]] +04)------TableScan: traces projection=[trace_id] +physical_plan +01)SortPreservingMergeExec: [max_trace@1 DESC], fetch=2 +02)--SortExec: TopK(fetch=2), expr=[max_trace@1 DESC], preserve_partitioning=[true] +03)----ProjectionExec: expr=[trace_id@0 as trace_id, max(traces.trace_id)@1 as max_trace] +04)------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[max(traces.trace_id)], lim=[2] +05)--------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=1 +06)----------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[max(traces.trace_id)], lim=[2] +07)------------DataSourceExec: partitions=1, partition_sizes=[1] + # Setting to map varchar to utf8view, to test PR https://github.com/apache/datafusion/pull/15152 # Before the PR, the test case would not work because the Utf8View will not be supported by the TopK aggregation @@ -256,5 +344,123 @@ physical_plan 06)----------DataSourceExec: partitions=1, partition_sizes=[1] +## Test GROUP BY with ORDER BY on the same column (no aggregate functions) +statement ok +CREATE TABLE ids(id int, value int) AS VALUES +(1, 10), +(2, 20), +(3, 30), +(4, 40), +(1, 50), +(2, 60), +(5, 70); + +query TT +explain select id from ids group by id order by id desc limit 3; +---- +logical_plan +01)Sort: ids.id DESC NULLS FIRST, fetch=3 +02)--Aggregate: groupBy=[[ids.id]], aggr=[[]] +03)----TableScan: ids projection=[id] +physical_plan +01)SortPreservingMergeExec: [id@0 DESC], fetch=3 +02)--SortExec: TopK(fetch=3), expr=[id@0 DESC], preserve_partitioning=[true] +03)----AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[], lim=[3] +04)------RepartitionExec: partitioning=Hash([id@0], 4), input_partitions=1 +05)--------AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[], lim=[3] +06)----------DataSourceExec: partitions=1, partition_sizes=[1] + +query I +select id from ids group by id order by id desc limit 3; +---- +5 +4 +3 + +query TT +explain select id from ids group by id order by id asc limit 2; +---- +logical_plan +01)Sort: ids.id ASC NULLS LAST, fetch=2 +02)--Aggregate: groupBy=[[ids.id]], aggr=[[]] +03)----TableScan: ids projection=[id] +physical_plan +01)SortPreservingMergeExec: [id@0 ASC NULLS LAST], fetch=2 +02)--SortExec: TopK(fetch=2), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[], lim=[2] +04)------RepartitionExec: partitioning=Hash([id@0], 4), input_partitions=1 +05)--------AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[], lim=[2] +06)----------DataSourceExec: partitions=1, partition_sizes=[1] + +query I +select id from ids group by id order by id asc limit 2; +---- +1 +2 + +# Test with larger limit than distinct values +query I +select id from ids group by id order by id desc limit 100; +---- +5 +4 +3 +2 +1 + +# Test with bigint group by +statement ok +CREATE TABLE values_table (value INT, category BIGINT) AS VALUES +(10, 100), +(20, 200), +(30, 300), +(40, 400), +(50, 500), +(20, 200), +(10, 100), +(40, 400); + +query TT +explain select category from values_table group by category order by category desc limit 3; +---- +logical_plan +01)Sort: values_table.category DESC NULLS FIRST, fetch=3 +02)--Aggregate: groupBy=[[values_table.category]], aggr=[[]] +03)----TableScan: values_table projection=[category] +physical_plan +01)SortPreservingMergeExec: [category@0 DESC], fetch=3 +02)--SortExec: TopK(fetch=3), expr=[category@0 DESC], preserve_partitioning=[true] +03)----AggregateExec: mode=FinalPartitioned, gby=[category@0 as category], aggr=[], lim=[3] +04)------RepartitionExec: partitioning=Hash([category@0], 4), input_partitions=1 +05)--------AggregateExec: mode=Partial, gby=[category@0 as category], aggr=[], lim=[3] +06)----------DataSourceExec: partitions=1, partition_sizes=[1] + +query I +select category from values_table group by category order by category desc limit 3; +---- +500 +400 +300 + +# Test with integer group by +query I +select value from values_table group by value order by value asc limit 3; +---- +10 +20 +30 + +# Test DISTINCT semantics are preserved +query I +select count(*) from (select category from values_table group by category order by category desc limit 3); +---- +3 + +statement ok +drop table values_table; + +statement ok +drop table ids; + statement ok drop table traces; diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index c31f3d070235..45c6dd48996a 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -2577,6 +2577,31 @@ NULL NULL NULL NULL NULL NULL +# maintains inner nullability +query ?T +select array_sort(column1), arrow_typeof(array_sort(column1)) +from values + (arrow_cast([], 'List(non-null Int32)')), + (arrow_cast(NULL, 'List(non-null Int32)')), + (arrow_cast([1, 3, 5, -5], 'List(non-null Int32)')) +; +---- +[] List(non-null Int32) +NULL List(non-null Int32) +[-5, 1, 3, 5] List(non-null Int32) + +query ?T +select column1, arrow_typeof(column1) +from values (array_sort(arrow_cast([1, 3, 5, -5], 'LargeList(non-null Int32)'))); +---- +[-5, 1, 3, 5] LargeList(non-null Int32) + +query ?T +select column1, arrow_typeof(column1) +from values (array_sort(arrow_cast([1, 3, 5, -5], 'FixedSizeList(4 x non-null Int32)'))); +---- +[-5, 1, 3, 5] List(non-null Int32) + query ? select array_sort([struct('foo', 3), struct('foo', 1), struct('bar', 1)]) ---- @@ -3231,6 +3256,99 @@ drop table array_repeat_table; statement ok drop table large_array_repeat_table; +# array_repeat: arrays with NULL counts +statement ok +create table array_repeat_null_count_table +as values +(1, 2), +(2, null), +(3, 1), +(4, -1), +(null, null); + +query I? +select column1, array_repeat(column1, column2) from array_repeat_null_count_table; +---- +1 [1, 1] +2 NULL +3 [3] +4 [] +NULL NULL + +statement ok +drop table array_repeat_null_count_table + +# array_repeat: nested arrays with NULL counts +statement ok +create table array_repeat_nested_null_count_table +as values +([[1, 2], [3, 4]], 2), +([[5, 6], [7, 8]], null), +([[null, null], [9, 10]], 1), +(null, 3), +([[11, 12]], -1); + +query ?? +select column1, array_repeat(column1, column2) from array_repeat_nested_null_count_table; +---- +[[1, 2], [3, 4]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]] +[[5, 6], [7, 8]] NULL +[[NULL, NULL], [9, 10]] [[[NULL, NULL], [9, 10]]] +NULL [NULL, NULL, NULL] +[[11, 12]] [] + +statement ok +drop table array_repeat_nested_null_count_table + +# array_repeat edge cases: empty arrays +query ??? +select array_repeat([], 3), array_repeat([], 0), array_repeat([], null); +---- +[[], [], []] [] NULL + +query ?? +select array_repeat(null::int, 0), array_repeat(null::int, null); +---- +[] NULL + +# array_repeat LargeList with NULL count +statement ok +create table array_repeat_large_list_null_table +as values +(arrow_cast([1, 2, 3], 'LargeList(Int64)'), 2), +(arrow_cast([4, 5], 'LargeList(Int64)'), null), +(arrow_cast(null, 'LargeList(Int64)'), 3); + +query ?? +select column1, array_repeat(column1, column2) from array_repeat_large_list_null_table; +---- +[1, 2, 3] [[1, 2, 3], [1, 2, 3]] +[4, 5] NULL +NULL [NULL, NULL, NULL] + +statement ok +drop table array_repeat_large_list_null_table + +# array_repeat edge cases: LargeList nested with NULL count +statement ok +create table array_repeat_large_nested_null_table +as values +(arrow_cast([[1, 2], [3, 4]], 'LargeList(List(Int64))'), 2), +(arrow_cast([[5, 6], [7, 8]], 'LargeList(List(Int64))'), null), +(arrow_cast([[null, null]], 'LargeList(List(Int64))'), 1), +(null, 3); + +query ?? +select column1, array_repeat(column1, column2) from array_repeat_large_nested_null_table; +---- +[[1, 2], [3, 4]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]] +[[5, 6], [7, 8]] NULL +[[NULL, NULL]] [[[NULL, NULL]]] +NULL [NULL, NULL, NULL] + +statement ok +drop table array_repeat_large_nested_null_table + ## array_concat (aliases: `array_cat`, `list_concat`, `list_cat`) # test with empty array @@ -3762,6 +3880,111 @@ select array_position(arrow_cast(make_array([1, 2, 3], [4, 5, 6], [11, 12, 13]), NULL 6 4 NULL 1 NULL +# array_position with NULL element in haystack array (NULL = NULL semantics) +query III +select array_position([1, NULL, 3], arrow_cast(NULL, 'Int64')), array_position([NULL, 2, 3], arrow_cast(NULL, 'Int64')), array_position([1, 2, NULL], arrow_cast(NULL, 'Int64')); +---- +2 1 3 + +query I +select array_position(arrow_cast([1, NULL, 3], 'LargeList(Int64)'), arrow_cast(NULL, 'Int64')); +---- +2 + +# array_position with NULL element in array and start_from +query II +select array_position([NULL, 1, NULL, 2], arrow_cast(NULL, 'Int64'), 2), array_position([NULL, 1, NULL, 2], arrow_cast(NULL, 'Int64'), 1); +---- +3 1 + +# array_position with column array and scalar element +query IIII +select array_position(column1, 3), array_position(column1, 10), array_position(column1, 20), array_position(column1, 999) from arrays_values_without_nulls; +---- +3 10 NULL NULL +NULL NULL 10 NULL +NULL NULL NULL NULL +NULL NULL NULL NULL + +query II +select array_position(column1, 3), array_position(column1, 20) from large_arrays_values_without_nulls; +---- +3 NULL +NULL 10 +NULL NULL +NULL NULL + +query II +select array_position(column1, 3), array_position(column1, 20) from fixed_size_arrays_values_without_nulls; +---- +3 NULL +NULL 10 +NULL NULL +NULL NULL + +# array_position with column array, scalar element, and scalar start_from +query II +select array_position(column1, 3, 1), array_position(column1, 3, 4) from arrays_values_without_nulls; +---- +3 NULL +NULL NULL +NULL NULL +NULL NULL + +query II +select array_position(column1, 3, 1), array_position(column1, 3, 4) from large_arrays_values_without_nulls; +---- +3 NULL +NULL NULL +NULL NULL +NULL NULL + +# array_position with column array, scalar element, and column start_from +query I +select array_position(column1, 3, column3) from arrays_values_without_nulls; +---- +3 +NULL +NULL +NULL + +# array_position with scalar haystack, scalar element, and column start_from +query I +select array_position([1, 2, 1, 2], 2, column3) from arrays_values_without_nulls; +---- +2 +2 +4 +4 + +# array_position start_from boundary cases +query IIII +select array_position([1, 2, 3], 3, 3), array_position([1, 2, 3], 1, 2), array_position([1, 2, 3], 1, 1), array_position([1, 2, 3], 3, 4); +---- +3 NULL 1 NULL + +query II +select array_position([1, 2, 3], 3, 4), array_position([1], 1, 2); +---- +NULL NULL + +# array_position with empty array in various contexts +query II +select array_position(arrow_cast(make_array(), 'List(Int64)'), 1), array_position(arrow_cast(make_array(), 'LargeList(Int64)'), 1); +---- +NULL NULL + +# FixedSizeList with start_from +query II +select array_position(arrow_cast([1, 2, 3, 1, 2], 'FixedSizeList(5, Int64)'), 1, 2), array_position(arrow_cast([1, 2, 3, 1, 2], 'FixedSizeList(5, Int64)'), 2, 4); +---- +4 5 + +query I +select array_position(arrow_cast(['a', 'b', 'c', 'b'], 'FixedSizeList(4, Utf8)'), 'b', 3); +---- +4 + ## array_positions (aliases: `list_positions`) query ? @@ -4747,10 +4970,11 @@ select array_union(arrow_cast([], 'LargeList(Int64)'), arrow_cast([], 'LargeList [] # array_union scalar function #7 -query ? -select array_union([[null]], []); ----- -[[]] +# re-enable when https://github.com/apache/arrow-rs/issues/9227 is fixed +# query ? +# select array_union([[null]], []); +# ---- +# [[]] query error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'array_union' function: select array_union(arrow_cast([[null]], 'LargeList(List(Int64))'), arrow_cast([], 'LargeList(Int64)')); @@ -4770,12 +4994,12 @@ select array_union(arrow_cast([[null]], 'LargeList(List(Int64))'), arrow_cast([[ query ? select array_union(null, []); ---- -[] +NULL query ? select array_union(null, arrow_cast([], 'LargeList(Int64)')); ---- -[] +NULL # array_union scalar function #10 query ? @@ -4787,23 +5011,23 @@ NULL query ? select array_union([1, 1, 2, 2, 3, 3], null); ---- -[1, 2, 3] +NULL query ? select array_union(arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)'), null); ---- -[1, 2, 3] +NULL # array_union scalar function #12 query ? select array_union(null, [1, 1, 2, 2, 3, 3]); ---- -[1, 2, 3] +NULL query ? select array_union(null, arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)')); ---- -[1, 2, 3] +NULL # array_union scalar function #13 query ? @@ -4838,6 +5062,36 @@ NULL NULL NULL +query ? +select array_union(arrow_cast(null, 'List(Int64)'), [1, 2]); +---- +NULL + +query ? +select array_union([1, 2], arrow_cast(null, 'List(Int64)')); +---- +NULL + +query ? +select array_intersect(arrow_cast(null, 'List(Int64)'), [1, 2]); +---- +NULL + +query ? +select array_intersect([1, 2], arrow_cast(null, 'List(Int64)')); +---- +NULL + +query ? +select array_except(arrow_cast(null, 'List(Int64)'), [1, 2]); +---- +NULL + +query ? +select array_except([1, 2], arrow_cast(null, 'List(Int64)')); +---- +NULL + # list_to_string scalar function #4 (function alias `array_to_string`) query TTT select list_to_string(['h', 'e', 'l', 'l', 'o'], ','), list_to_string([1, 2, 3, 4, 5], '-'), list_to_string([1.0, 2.0, 3.0], '|'); @@ -4903,6 +5157,33 @@ select array_to_string(arrow_cast(make_array('h', NULL, NULL, NULL, 'o'), 'Fixed ---- h,-,-,-,o nil-2-nil-4-5 1|0|3 +# array_to_string float formatting: special values and longer decimals +query TTT +select + array_to_string(make_array(CAST('NaN' AS DOUBLE), CAST('Infinity' AS DOUBLE), CAST('-Infinity' AS DOUBLE), CAST('0.30000000000000004' AS DOUBLE), CAST('1.2345678901234567' AS DOUBLE)), '|'), + array_to_string(arrow_cast(make_array(CAST('NaN' AS DOUBLE), CAST('Infinity' AS DOUBLE), CAST('-Infinity' AS DOUBLE), CAST('0.30000000000000004' AS DOUBLE), CAST('1.2345678901234567' AS DOUBLE)), 'LargeList(Float64)'), '|'), + array_to_string(arrow_cast(make_array(CAST('NaN' AS DOUBLE), CAST('Infinity' AS DOUBLE), CAST('-Infinity' AS DOUBLE), CAST('0.30000000000000004' AS DOUBLE), CAST('1.2345678901234567' AS DOUBLE)), 'FixedSizeList(5, Float64)'), '|'); +---- +NaN|inf|-inf|0.30000000000000004|1.2345678901234567 NaN|inf|-inf|0.30000000000000004|1.2345678901234567 NaN|inf|-inf|0.30000000000000004|1.2345678901234567 + +# array_to_string float formatting: scientific-notation inputs +query T +select array_to_string( + make_array( + CAST('1E20' AS DOUBLE), + CAST('-1e+20' AS DOUBLE), + CAST('6.02214076e23' AS DOUBLE), + CAST('1.2345e6' AS DOUBLE), + CAST('1e-5' AS DOUBLE), + CAST('-1e-5' AS DOUBLE), + CAST('9.1093837015e-31' AS DOUBLE), + CAST('-2.5e-4' AS DOUBLE) + ), + '|' +); +---- +100000000000000000000|-100000000000000000000|602214076000000000000000|1234500|0.00001|-0.00001|0.00000000000000000000000000000091093837015|-0.00025 + query T select array_to_string(arrow_cast([arrow_cast([NULL, 'a'], 'FixedSizeList(2, Utf8)'), NULL], 'FixedSizeList(2, FixedSizeList(2, Utf8))'), ',', '-'); ---- @@ -4994,6 +5275,87 @@ NULL 1.2.3 51_52_*_54_55_56_57_58_59_60 1.2.3 61_62_63_64_65_66_67_68_69_70 1.2.3 +# array_to_string with per-row null_string column +statement ok +CREATE TABLE test_null_str_col AS VALUES + (make_array(1, NULL, 3), ',', 'N/A'), + (make_array(NULL, 5, NULL), ',', 'MISSING'), + (make_array(10, NULL, 12), '-', 'X'), + (make_array(20, NULL, 21), '-', NULL); + +query T +SELECT array_to_string(column1, column2, column3) FROM test_null_str_col; +---- +1,N/A,3 +MISSING,5,MISSING +10-X-12 +20-21 + +statement ok +DROP TABLE test_null_str_col; + +# array_to_string with decimal values +query T +select array_to_string(arrow_cast(make_array(1.5, NULL, 3.14), 'List(Decimal128(10, 2))'), ',', 'N'); +---- +1.50,N,3.14 + +# array_to_string with date values +query T +select array_to_string(arrow_cast(make_array('2024-01-15', '2024-06-30', '2024-12-25'), 'List(Date32)'), ','); +---- +2024-01-15,2024-06-30,2024-12-25 + +query T +select array_to_string(arrow_cast(make_array('2024-01-15', NULL, '2024-12-25'), 'List(Date32)'), ',', 'N'); +---- +2024-01-15,N,2024-12-25 + +# array_to_string with timestamp values +query T +select array_to_string(make_array(arrow_cast('2024-01-15T10:30:00', 'Timestamp(Second, None)'), arrow_cast('2024-06-30T15:45:00', 'Timestamp(Second, None)')), '|'); +---- +2024-01-15T10:30:00|2024-06-30T15:45:00 + +query T +select array_to_string(make_array(arrow_cast('2024-01-15T10:30:00', 'Timestamp(Millisecond, None)'), arrow_cast('2024-06-30T15:45:00', 'Timestamp(Millisecond, None)')), '|'); +---- +2024-01-15T10:30:00|2024-06-30T15:45:00 + +query T +select array_to_string(make_array(arrow_cast('2024-01-15T10:30:00', 'Timestamp(Microsecond, None)'), arrow_cast('2024-06-30T15:45:00', 'Timestamp(Microsecond, None)')), '|'); +---- +2024-01-15T10:30:00|2024-06-30T15:45:00 + +query T +select array_to_string(make_array(arrow_cast('2024-01-15T10:30:00', 'Timestamp(Nanosecond, None)'), arrow_cast('2024-06-30T15:45:00', 'Timestamp(Nanosecond, None)')), '|'); +---- +2024-01-15T10:30:00|2024-06-30T15:45:00 + +# array_to_string with time values +query T +select array_to_string(make_array(arrow_cast('10:30:00', 'Time32(Second)'), arrow_cast('15:45:00', 'Time32(Second)')), ','); +---- +10:30:00,15:45:00 + +query T +select array_to_string(make_array(arrow_cast('10:30:00', 'Time64(Microsecond)'), arrow_cast('15:45:00', 'Time64(Microsecond)')), ','); +---- +10:30:00,15:45:00 + +# array_to_string with interval values +query T +select array_to_string(make_array(interval '1 year 2 months', interval '3 days 4 hours'), ','); +---- +14 mons,3 days 4 hours + +# array_to_string with duration values +query T +select array_to_string(make_array(arrow_cast(1000, 'Duration(Millisecond)'), arrow_cast(2000, 'Duration(Millisecond)')), ','); +---- +PT1S,PT2S + + ## cardinality # cardinality scalar function @@ -5032,12 +5394,17 @@ select cardinality(arrow_cast([[1, 2], [3, 4], [5, 6]], 'FixedSizeList(3, List(I query II select cardinality(make_array()), cardinality(make_array(make_array())) ---- -NULL 0 +0 0 + +query II +select cardinality([]), cardinality([]::int[]) as with_cast +---- +0 0 query II select cardinality(arrow_cast(make_array(), 'LargeList(Int64)')), cardinality(arrow_cast(make_array(make_array()), 'LargeList(List(Int64))')) ---- -NULL 0 +0 0 #TODO #https://github.com/apache/datafusion/issues/9158 @@ -5046,6 +5413,12 @@ NULL 0 #---- #NULL 0 +# cardinality of NULL arrays should return NULL +query II +select cardinality(NULL), cardinality(arrow_cast(NULL, 'LargeList(Int64)')) +---- +NULL NULL + # cardinality with columns query III select cardinality(column1), cardinality(column2), cardinality(column3) from arrays; @@ -5139,21 +5512,47 @@ select array_remove(make_array(1, null, 2), null), array_remove(make_array(1, null, 2, null), null); ---- -[1, 2] [1, 2, NULL] +NULL NULL query ?? select array_remove(arrow_cast(make_array(1, null, 2), 'LargeList(Int64)'), null), array_remove(arrow_cast(make_array(1, null, 2, null), 'LargeList(Int64)'), null); ---- -[1, 2] [1, 2, NULL] +NULL NULL query ?? select array_remove(arrow_cast(make_array(1, null, 2), 'FixedSizeList(3, Int64)'), null), array_remove(arrow_cast(make_array(1, null, 2, null), 'FixedSizeList(4, Int64)'), null); ---- -[1, 2] [1, 2, NULL] +NULL NULL + +# array_remove with null element from column +query ? +select array_remove(column1, column2) from (values + (make_array(1, 2, 3), 2), + (make_array(4, 5, 6), null), + (make_array(7, 8, 9), 8), + (null, 1) +) as t(column1, column2); +---- +[1, 3] +NULL +[7, 9] +NULL + +# array_remove with null element from column (LargeList) +query ? +select array_remove(column1, column2) from (values + (arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 2), + (arrow_cast(make_array(4, 5, 6), 'LargeList(Int64)'), null), + (arrow_cast(make_array(7, 8, 9), 'LargeList(Int64)'), 8) +) as t(column1, column2); +---- +[1, 3] +NULL +[7, 9] # array_remove scalar function #2 (element is list) query ?? @@ -5296,6 +5695,46 @@ select array_remove(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [1 ## array_remove_n (aliases: `list_remove_n`) +# array_remove_n with null element scalar +query ?? +select array_remove_n(make_array(1, 2, 2, 1, 1), NULL, 2), + array_remove_n(make_array(1, 2, 2, 1, 1), 2, 2); +---- +NULL [1, 1, 1] + +# array_remove_n with null element scalar (LargeList) +query ?? +select array_remove_n(arrow_cast(make_array(1, 2, 2, 1, 1), 'LargeList(Int64)'), NULL, 2), + array_remove_n(arrow_cast(make_array(1, 2, 2, 1, 1), 'LargeList(Int64)'), 2, 2); +---- +NULL [1, 1, 1] + +# array_remove_n with null element from column +query ? +select array_remove_n(column1, column2, column3) from (values + (make_array(1, 2, 2, 1, 1), 2, 2), + (make_array(3, 4, 4, 3, 3), null, 2), + (make_array(5, 6, 6, 5, 5), 6, 1), + (null, 1, 1) +) as t(column1, column2, column3); +---- +[1, 1, 1] +NULL +[5, 6, 5, 5] +NULL + +# array_remove_n with null element from column (LargeList) +query ? +select array_remove_n(column1, column2, column3) from (values + (arrow_cast(make_array(1, 2, 2, 1, 1), 'LargeList(Int64)'), 2, 2), + (arrow_cast(make_array(3, 4, 4, 3, 3), 'LargeList(Int64)'), null, 2), + (arrow_cast(make_array(5, 6, 6, 5, 5), 'LargeList(Int64)'), 6, 1) +) as t(column1, column2, column3); +---- +[1, 1, 1] +NULL +[5, 6, 5, 5] + # array_remove_n scalar function #1 query ??? select array_remove_n(make_array(1, 2, 2, 1, 1), 2, 2), array_remove_n(make_array(1.0, 2.0, 2.0, 1.0, 1.0), 1.0, 2), array_remove_n(make_array('h', 'e', 'l', 'l', 'o'), 'l', 3); @@ -5388,7 +5827,33 @@ select array_remove_n(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], query ? select array_remove_all(make_array(1, 2, 2, 1, 1), NULL); ---- -[1, 2, 2, 1, 1] +NULL + +# array_remove_all with null element from column +query ? +select array_remove_all(column1, column2) from (values + (make_array(1, 2, 2, 1, 1), 2), + (make_array(3, 4, 4, 3, 3), null), + (make_array(5, 6, 6, 5, 5), 6), + (null, 1) +) as t(column1, column2); +---- +[1, 1, 1] +NULL +[5, 5, 5] +NULL + +# array_remove_all with null element from column (LargeList) +query ? +select array_remove_all(column1, column2) from (values + (arrow_cast(make_array(1, 2, 2, 1, 1), 'LargeList(Int64)'), 2), + (arrow_cast(make_array(3, 4, 4, 3, 3), 'LargeList(Int64)'), null), + (arrow_cast(make_array(5, 6, 6, 5, 5), 'LargeList(Int64)'), 6) +) as t(column1, column2); +---- +[1, 1, 1] +NULL +[5, 5, 5] # array_remove_all scalar function #1 query ??? @@ -6457,10 +6922,9 @@ physical_plan 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] -05)--------ProjectionExec: expr=[] -06)----------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) -07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -08)--------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] +05)--------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]), projection=[] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] query I with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) @@ -6485,10 +6949,9 @@ physical_plan 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] -05)--------ProjectionExec: expr=[] -06)----------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) -07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -08)--------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] +05)--------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]), projection=[] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] query I with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) @@ -6513,10 +6976,9 @@ physical_plan 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] -05)--------ProjectionExec: expr=[] -06)----------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) -07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -08)--------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] +05)--------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]), projection=[] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] query I with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) @@ -6541,10 +7003,9 @@ physical_plan 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] -05)--------ProjectionExec: expr=[] -06)----------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) -07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -08)--------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] +05)--------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]), projection=[] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] query I with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) @@ -6569,10 +7030,9 @@ physical_plan 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] -05)--------ProjectionExec: expr=[] -06)----------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) -07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -08)--------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] +05)--------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]), projection=[] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] query I with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) @@ -6599,10 +7059,9 @@ physical_plan 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] -05)--------ProjectionExec: expr=[] -06)----------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IS NOT NULL OR NULL -07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -08)--------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] +05)--------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IS NOT NULL OR NULL, projection=[] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] # any operator query ? @@ -6689,7 +7148,7 @@ from array_distinct_table_2D; ---- [[1, 2], [3, 4], [5, 6]] [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] -[NULL, [5, 6]] +[[5, 6], NULL] query ? select array_distinct(column1) @@ -6721,7 +7180,207 @@ from array_distinct_table_2D_fixed; ---- [[1, 2], [3, 4], [5, 6]] [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] -[NULL, [5, 6]] +[[5, 6], NULL] + +## arrays_zip (aliases: `list_zip`) + +# Spark example: arrays_zip(array(1, 2, 3), array(2, 3, 4)) +query ? +select arrays_zip([1, 2, 3], [2, 3, 4]); +---- +[{c0: 1, c1: 2}, {c0: 2, c1: 3}, {c0: 3, c1: 4}] + +# Spark example: arrays_zip(array(1, 2), array(2, 3), array(3, 4)) +query ? +select arrays_zip([1, 2], [2, 3], [3, 4]); +---- +[{c0: 1, c1: 2, c2: 3}, {c0: 2, c1: 3, c2: 4}] + +# basic: two integer arrays of equal length +query ? +select arrays_zip([1, 2, 3], [10, 20, 30]); +---- +[{c0: 1, c1: 10}, {c0: 2, c1: 20}, {c0: 3, c1: 30}] + +# basic: two arrays with different element types (int + string) +query ? +select arrays_zip([1, 2, 3], ['a', 'b', 'c']); +---- +[{c0: 1, c1: a}, {c0: 2, c1: b}, {c0: 3, c1: c}] + +# three arrays of equal length +query ? +select arrays_zip([1, 2, 3], [10, 20, 30], [100, 200, 300]); +---- +[{c0: 1, c1: 10, c2: 100}, {c0: 2, c1: 20, c2: 200}, {c0: 3, c1: 30, c2: 300}] + +# four arrays of equal length +query ? +select arrays_zip([1], [2], [3], [4]); +---- +[{c0: 1, c1: 2, c2: 3, c3: 4}] + +# mixed element types: float + boolean +query ? +select arrays_zip([1.5, 2.5], [true, false]); +---- +[{c0: 1.5, c1: true}, {c0: 2.5, c1: false}] + +# different length arrays: shorter array padded with NULLs +query ? +select arrays_zip([1, 2], [3, 4, 5]); +---- +[{c0: 1, c1: 3}, {c0: 2, c1: 4}, {c0: NULL, c1: 5}] + +# different length arrays: first longer +query ? +select arrays_zip([1, 2, 3], [10]); +---- +[{c0: 1, c1: 10}, {c0: 2, c1: NULL}, {c0: 3, c1: NULL}] + +# different length: one single element, other three elements +query ? +select arrays_zip([1], ['a', 'b', 'c']); +---- +[{c0: 1, c1: a}, {c0: NULL, c1: b}, {c0: NULL, c1: c}] + +# empty arrays +query ? +select arrays_zip([], []); +---- +[] + +# one empty, one non-empty +query ? +select arrays_zip([], [1, 2, 3]); +---- +[{c0: NULL, c1: 1}, {c0: NULL, c1: 2}, {c0: NULL, c1: 3}] + +# NULL elements inside arrays +query ? +select arrays_zip([1, NULL, 3], ['a', 'b', 'c']); +---- +[{c0: 1, c1: a}, {c0: NULL, c1: b}, {c0: 3, c1: c}] + +# all NULL elements +query ? +select arrays_zip([NULL::int, NULL, NULL], [NULL::text, NULL, NULL]); +---- +[{c0: NULL, c1: NULL}, {c0: NULL, c1: NULL}, {c0: NULL, c1: NULL}] + +# both args are NULL (entire list null) +query ? +select arrays_zip(NULL::int[], NULL::int[]); +---- +NULL + +# one arg is NULL list, other is real array +query ? +select arrays_zip(NULL::int[], [1, 2, 3]); +---- +[{c0: NULL, c1: 1}, {c0: NULL, c1: 2}, {c0: NULL, c1: 3}] + +# real array + NULL list +query ? +select arrays_zip([1, 2], NULL::text[]); +---- +[{c0: 1, c1: NULL}, {c0: 2, c1: NULL}] + +# column-level test with multiple rows +query ? +select arrays_zip(a, b) from (values ([1, 2], [10, 20]), ([3, 4, 5], [30]), ([6], [60, 70])) as t(a, b); +---- +[{c0: 1, c1: 10}, {c0: 2, c1: 20}] +[{c0: 3, c1: 30}, {c0: 4, c1: NULL}, {c0: 5, c1: NULL}] +[{c0: 6, c1: 60}, {c0: NULL, c1: 70}] + +# column-level test with NULL rows +query ? +select arrays_zip(a, b) from (values ([1, 2], [10, 20]), (null, [30, 40]), ([5, 6], null)) as t(a, b); +---- +[{c0: 1, c1: 10}, {c0: 2, c1: 20}] +[{c0: NULL, c1: 30}, {c0: NULL, c1: 40}] +[{c0: 5, c1: NULL}, {c0: 6, c1: NULL}] + +# alias: list_zip +query ? +select list_zip([1, 2], [3, 4]); +---- +[{c0: 1, c1: 3}, {c0: 2, c1: 4}] + +# column test: total values equal (3 each) but per-row lengths differ +# a: [1] b: [10, 20] → row 0: a has 1, b has 2 +# a: [2, 3] b: [30] → row 1: a has 2, b has 1 +# total a values = 3, total b values = 3 (same!) but rows are misaligned +query ? +select arrays_zip(a, b) from (values ([1], [10, 20]), ([2, 3], [30])) as t(a, b); +---- +[{c0: 1, c1: 10}, {c0: NULL, c1: 20}] +[{c0: 2, c1: 30}, {c0: 3, c1: NULL}] + +# single element arrays +query ? +select arrays_zip([42], ['hello']); +---- +[{c0: 42, c1: hello}] + +# error: too few arguments +statement error +select arrays_zip([1, 2, 3]); + +# arrays_zip with LargeList inputs +query ? +select arrays_zip( + arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), + arrow_cast(make_array(10, 20, 30), 'LargeList(Int64)') +); +---- +[{c0: 1, c1: 10}, {c0: 2, c1: 20}, {c0: 3, c1: 30}] + +# arrays_zip with LargeList different lengths (padding) +query ? +select arrays_zip( + arrow_cast(make_array(1, 2), 'LargeList(Int64)'), + arrow_cast(make_array(10, 20, 30), 'LargeList(Int64)') +); +---- +[{c0: 1, c1: 10}, {c0: 2, c1: 20}, {c0: NULL, c1: 30}] + +# arrays_zip with FixedSizeList inputs +query ? +select arrays_zip( + arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)'), + arrow_cast(make_array(10, 20, 30), 'FixedSizeList(3, Int64)') +); +---- +[{c0: 1, c1: 10}, {c0: 2, c1: 20}, {c0: 3, c1: 30}] + +# arrays_zip mixing List and LargeList +query ? +select arrays_zip( + [1, 2, 3], + arrow_cast(make_array(10, 20, 30), 'LargeList(Int64)') +); +---- +[{c0: 1, c1: 10}, {c0: 2, c1: 20}, {c0: 3, c1: 30}] + +# arrays_zip mixing List and FixedSizeList with different lengths (padding) +query ? +select arrays_zip( + [1, 2, 3], + arrow_cast(make_array(10, 20), 'FixedSizeList(2, Int64)') +); +---- +[{c0: 1, c1: 10}, {c0: 2, c1: 20}, {c0: 3, c1: NULL}] + +# arrays_zip with LargeList and FixedSizeList mixed types +query ? +select arrays_zip( + arrow_cast(make_array(1, 2), 'LargeList(Int64)'), + arrow_cast(make_array('a', 'b'), 'FixedSizeList(2, Utf8)') +); +---- +[{c0: 1, c1: a}, {c0: 2, c1: b}] query ??? select array_intersect(column1, column2), @@ -6756,7 +7415,7 @@ select array_intersect(column1, column2), array_intersect(column5, column6) from array_intersect_table_1D_Boolean; ---- -[] [false, true] [false] +[] [true, false] [false] [false] [true] [true] query ??? @@ -6765,7 +7424,7 @@ select array_intersect(column1, column2), array_intersect(column5, column6) from large_array_intersect_table_1D_Boolean; ---- -[] [false, true] [false] +[] [true, false] [false] [false] [true] [true] query ??? @@ -6774,8 +7433,8 @@ select array_intersect(column1, column2), array_intersect(column5, column6) from array_intersect_table_1D_UTF8; ---- -[bc] [arrow, rust] [] -[] [arrow, datafusion, rust] [arrow, rust] +[bc] [rust, arrow] [] +[] [datafusion, rust, arrow] [rust, arrow] query ??? select array_intersect(column1, column2), @@ -6783,8 +7442,8 @@ select array_intersect(column1, column2), array_intersect(column5, column6) from large_array_intersect_table_1D_UTF8; ---- -[bc] [arrow, rust] [] -[] [arrow, datafusion, rust] [arrow, rust] +[bc] [rust, arrow] [] +[] [datafusion, rust, arrow] [rust, arrow] query ? select array_intersect(column1, column2) @@ -6888,27 +7547,27 @@ select array_intersect(arrow_cast([], 'LargeList(Int64)'), arrow_cast([], 'Large query ? select array_intersect([1, 1, 2, 2, 3, 3], null); ---- -[] +NULL query ? select array_intersect(arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)'), null); ---- -[] +NULL query ? select array_intersect(null, [1, 1, 2, 2, 3, 3]); ---- -[] +NULL query ? select array_intersect(null, arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)')); ---- -[] +NULL query ? select array_intersect([], null); ---- -[] +NULL query ? select array_intersect([[1,2,3]], [[]]); @@ -6923,17 +7582,17 @@ select array_intersect([[null]], [[]]); query ? select array_intersect(arrow_cast([], 'LargeList(Int64)'), null); ---- -[] +NULL query ? select array_intersect(null, []); ---- -[] +NULL query ? select array_intersect(null, arrow_cast([], 'LargeList(Int64)')); ---- -[] +NULL query ? select array_intersect(null, null); @@ -7189,12 +7848,12 @@ select generate_series('2021-01-01'::timestamp, '2021-01-01T15:00:00'::timestamp # Other timestamp types are coerced to nanosecond query ? -select generate_series(arrow_cast('2021-01-01'::timestamp, 'Timestamp(Second, None)'), '2021-01-01T15:00:00'::timestamp, INTERVAL '1' HOUR); +select generate_series(arrow_cast('2021-01-01'::timestamp, 'Timestamp(s)'), '2021-01-01T15:00:00'::timestamp, INTERVAL '1' HOUR); ---- [2021-01-01T00:00:00, 2021-01-01T01:00:00, 2021-01-01T02:00:00, 2021-01-01T03:00:00, 2021-01-01T04:00:00, 2021-01-01T05:00:00, 2021-01-01T06:00:00, 2021-01-01T07:00:00, 2021-01-01T08:00:00, 2021-01-01T09:00:00, 2021-01-01T10:00:00, 2021-01-01T11:00:00, 2021-01-01T12:00:00, 2021-01-01T13:00:00, 2021-01-01T14:00:00, 2021-01-01T15:00:00] query ? -select generate_series('2021-01-01'::timestamp, arrow_cast('2021-01-01T15:00:00'::timestamp, 'Timestamp(Microsecond, None)'), INTERVAL '1' HOUR); +select generate_series('2021-01-01'::timestamp, arrow_cast('2021-01-01T15:00:00'::timestamp, 'Timestamp(µs)'), INTERVAL '1' HOUR); ---- [2021-01-01T00:00:00, 2021-01-01T01:00:00, 2021-01-01T02:00:00, 2021-01-01T03:00:00, 2021-01-01T04:00:00, 2021-01-01T05:00:00, 2021-01-01T06:00:00, 2021-01-01T07:00:00, 2021-01-01T08:00:00, 2021-01-01T09:00:00, 2021-01-01T10:00:00, 2021-01-01T11:00:00, 2021-01-01T12:00:00, 2021-01-01T13:00:00, 2021-01-01T14:00:00, 2021-01-01T15:00:00] @@ -7476,7 +8135,7 @@ select array_except(column1, column2) from array_except_table; [2] [] NULL -[1, 2] +NULL NULL statement ok @@ -7497,7 +8156,7 @@ select array_except(column1, column2) from array_except_nested_list_table; ---- [[1, 2]] [[3]] -[[1, 2], [3]] +NULL NULL [] @@ -7536,7 +8195,7 @@ select array_except(column1, column2) from array_except_table_ut8; ---- [b, c] [a, bc] -[a, bc, def] +NULL NULL statement ok @@ -7558,7 +8217,7 @@ select array_except(column1, column2) from array_except_table_bool; [true] [true] [false] -[true, false] +NULL NULL statement ok @@ -7567,7 +8226,7 @@ drop table array_except_table_bool; query ? select array_except([], null); ---- -[] +NULL query ? select array_except([], []); diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index ee1f204664a1..0c69e8591c3a 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -95,7 +95,7 @@ SELECT arrow_cast('1', 'Int16') query error SELECT arrow_cast('1') -query error Expect TypeSignatureClass::Native\(LogicalType\(Native\(String\), String\)\) but received NativeType::Int64, DataType: Int64 +query error DataFusion error: Error during planning: Function 'arrow_cast' requires TypeSignatureClass::Native\(LogicalType\(Native\(String\), String\)\), but received Int64 \(DataType: Int64\) SELECT arrow_cast('1', 43) query error DataFusion error: Execution error: arrow_cast requires its second argument to be a non\-empty constant string @@ -123,10 +123,10 @@ SELECT arrow_typeof(arrow_cast('foo', 'Utf8View')) as col_utf8_view, arrow_typeof(arrow_cast('foo', 'Binary')) as col_binary, arrow_typeof(arrow_cast('foo', 'LargeBinary')) as col_large_binary, - arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Second, None)')) as col_ts_s, - arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Millisecond, None)')) as col_ts_ms, - arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Microsecond, None)')) as col_ts_us, - arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Nanosecond, None)')) as col_ts_ns, + arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(s)')) as col_ts_s, + arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(ms)')) as col_ts_ms, + arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(µs)')) as col_ts_us, + arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(ns)')) as col_ts_ns, arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Second, Some("+08:00"))')) as col_tstz_s, arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Millisecond, Some("+08:00"))')) as col_tstz_ms, arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Microsecond, Some("+08:00"))')) as col_tstz_us, @@ -242,10 +242,10 @@ drop table foo statement ok create table foo as select - arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Second, None)') as col_ts_s, - arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Millisecond, None)') as col_ts_ms, - arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Microsecond, None)') as col_ts_us, - arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Nanosecond, None)') as col_ts_ns + arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(s)') as col_ts_s, + arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(ms)') as col_ts_ms, + arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(µs)') as col_ts_us, + arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(ns)') as col_ts_ns ; ## Ensure each column in the table has the expected type diff --git a/datafusion/sqllogictest/test_files/async_udf.slt b/datafusion/sqllogictest/test_files/async_udf.slt index 31ca87c4354a..0708b59e519a 100644 --- a/datafusion/sqllogictest/test_files/async_udf.slt +++ b/datafusion/sqllogictest/test_files/async_udf.slt @@ -37,8 +37,7 @@ physical_plan 03)----AggregateExec: mode=Partial, gby=[], aggr=[min(async_abs(data.x))] 04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 05)--------AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_abs(x@0))] -06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------DataSourceExec: partitions=1, partition_sizes=[1] +06)----------DataSourceExec: partitions=1, partition_sizes=[1] # Async udf can be used in aggregation with group by query I rowsort @@ -63,8 +62,7 @@ physical_plan 05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 06)----------ProjectionExec: expr=[__async_fn_0@1 as __common_expr_1] 07)------------AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_abs(x@0))] -08)--------------CoalesceBatchesExec: target_batch_size=8192 -09)----------------DataSourceExec: partitions=1, partition_sizes=[1] +08)--------------DataSourceExec: partitions=1, partition_sizes=[1] # Async udf can be used in filter query I @@ -82,8 +80,7 @@ physical_plan 01)FilterExec: __async_fn_0@1 < 5, projection=[x@0] 02)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 03)----AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_abs(x@0))] -04)------CoalesceBatchesExec: target_batch_size=8192 -05)--------DataSourceExec: partitions=1, partition_sizes=[1] +04)------DataSourceExec: partitions=1, partition_sizes=[1] # Async udf can be used in projection query I rowsort @@ -101,5 +98,4 @@ logical_plan physical_plan 01)ProjectionExec: expr=[__async_fn_0@1 as async_abs(data.x)] 02)--AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_abs(x@0))] -03)----CoalesceBatchesExec: target_batch_size=8192 -04)------DataSourceExec: partitions=1, partition_sizes=[1] +03)----DataSourceExec: partitions=1, partition_sizes=[1] diff --git a/datafusion/sqllogictest/test_files/binary.slt b/datafusion/sqllogictest/test_files/binary.slt index 1077c32e46f3..c4a21deeff26 100644 --- a/datafusion/sqllogictest/test_files/binary.slt +++ b/datafusion/sqllogictest/test_files/binary.slt @@ -311,3 +311,13 @@ Foo foo Foo foo NULL NULL NULL NULL Bar Bar Bar Bar FooBar fooBar FooBar fooBar + +# show helpful error msg when Binary type is used with string functions +query error DataFusion error: Error during planning: Function 'split_part' requires TypeSignatureClass::Native\(LogicalType\(Native\(String\), String\)\), but received Binary \(DataType: Binary\)\.\n\nHint: Binary types are not automatically coerced to String\. Use CAST\(column AS VARCHAR\) to convert Binary data to String\. +SELECT split_part(binary, '~', 2) FROM t WHERE binary IS NOT NULL LIMIT 1; + +# ensure the suggested CAST workaround works +query T +SELECT split_part(CAST(binary AS VARCHAR), 'o', 2) FROM t WHERE binary = X'466f6f'; +---- +(empty) diff --git a/datafusion/sqllogictest/test_files/case.slt b/datafusion/sqllogictest/test_files/case.slt index 074d216ac752..3953878ceb66 100644 --- a/datafusion/sqllogictest/test_files/case.slt +++ b/datafusion/sqllogictest/test_files/case.slt @@ -383,9 +383,10 @@ SELECT column2, column3, column4 FROM t; ---- {foo: a, xxx: b} {xxx: c, foo: d} {xxx: e} -# coerce structs with different field orders, -# (note the *value*s are from column2 but the field name is 'xxx', as the coerced -# type takes the field name from the last argument (column3) +# coerce structs with different field orders +# With name-based struct coercion, matching fields by name: +# column2={foo:a, xxx:b} unified with column3={xxx:c, foo:d} +# Result uses the THEN branch's field order (when executed): {xxx: b, foo: a} query ? SELECT case @@ -394,9 +395,10 @@ SELECT end FROM t; ---- -{xxx: a, foo: b} +{xxx: b, foo: a} # coerce structs with different field orders +# When ELSE branch executes, uses its field order: {xxx: c, foo: d} query ? SELECT case @@ -407,8 +409,9 @@ FROM t; ---- {xxx: c, foo: d} -# coerce structs with subset of fields -query error Failed to coerce then +# coerce structs with subset of fields - field count mismatch causes type coercion failure +# column3 has 2 fields but column4 has only 1 field +query error DataFusion error: type_coercion\ncaused by\nError during planning: Failed to coerce then .* and else .* to common types in CASE WHEN expression SELECT case when column1 > 0 then column3 @@ -618,6 +621,59 @@ a b c +query I +SELECT CASE WHEN d != 0 THEN n / d ELSE NULL END FROM (VALUES (1, 1), (1, 0), (1, -1)) t(n,d) +---- +1 +NULL +-1 + +query I +SELECT CASE WHEN d > 0 THEN n / d ELSE NULL END FROM (VALUES (1, 1), (1, 0), (1, -1)) t(n,d) +---- +1 +NULL +NULL + +query I +SELECT CASE WHEN d < 0 THEN n / d ELSE NULL END FROM (VALUES (1, 1), (1, 0), (1, -1)) t(n,d) +---- +NULL +NULL +-1 + +# single WHEN, no ELSE (absent) +query I +SELECT CASE WHEN a > 0 THEN b END +FROM (VALUES (1, 10), (0, 20)) AS t(a, b); +---- +10 +NULL + +# single WHEN, explicit ELSE NULL +query I +SELECT CASE WHEN a > 0 THEN b ELSE NULL END +FROM (VALUES (1, 10), (0, 20)) AS t(a, b); +---- +10 +NULL + +# fallible THEN expression should only be evaluated on true rows +query I +SELECT CASE WHEN a > 0 THEN 10 / a END +FROM (VALUES (1), (0)) AS t(a); +---- +10 +NULL + +# all-false path returns typed NULLs +query I +SELECT CASE WHEN a < 0 THEN b END +FROM (VALUES (1, 10), (2, 20)) AS t(a, b); +---- +NULL +NULL + # EvalMethod::WithExpression using subset of all selected columns in case expression query III SELECT CASE a1 WHEN 1 THEN a1 WHEN 2 THEN a2 WHEN 3 THEN b END, b, c diff --git a/datafusion/sqllogictest/test_files/clickbench.slt b/datafusion/sqllogictest/test_files/clickbench.slt index 4c60a4365ee2..42b7cfafdaa6 100644 --- a/datafusion/sqllogictest/test_files/clickbench.slt +++ b/datafusion/sqllogictest/test_files/clickbench.slt @@ -26,10 +26,28 @@ # COPY (SELECT * FROM 'hits.parquet' LIMIT 10) TO 'clickbench_hits_10.parquet' (FORMAT PARQUET); statement ok -CREATE EXTERNAL TABLE hits +CREATE EXTERNAL TABLE hits_raw STORED AS PARQUET LOCATION '../core/tests/data/clickbench_hits_10.parquet'; +# ClickBench encodes EventDate as UInt16 days since epoch. +statement ok +CREATE VIEW hits AS +SELECT * EXCEPT ("EventDate"), + CAST(CAST("EventDate" AS INTEGER) AS DATE) AS "EventDate" +FROM hits_raw; + +# Verify EventDate transformation from UInt16 to DATE +query D +SELECT "EventDate" FROM hits LIMIT 1; +---- +2013-07-15 + +# Verify the raw value is still UInt16 in hits_raw +query I +SELECT "EventDate" FROM hits_raw LIMIT 1; +---- +15901 # queries.sql came from # https://github.com/ClickHouse/ClickBench/blob/8b9e3aa05ea18afa427f14909ddc678b8ef0d5e6/datafusion/queries.sql @@ -64,10 +82,10 @@ SELECT COUNT(DISTINCT "SearchPhrase") FROM hits; ---- 1 -query II +query DD SELECT MIN("EventDate"), MAX("EventDate") FROM hits; ---- -15901 15901 +2013-07-15 2013-07-15 query II SELECT "AdvEngineID", COUNT(*) FROM hits WHERE "AdvEngineID" <> 0 GROUP BY "AdvEngineID" ORDER BY COUNT(*) DESC; @@ -167,7 +185,8 @@ query TTTII SELECT "SearchPhrase", MIN("URL"), MIN("Title"), COUNT(*) AS c, COUNT(DISTINCT "UserID") FROM hits WHERE "Title" LIKE '%Google%' AND "URL" NOT LIKE '%.google.%' AND "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; ---- -query IITIIIIIIIIIITTIIIIIIIIIITIIITIIIITTIIITIIIIIIIIIITIIIIITIIIIIITIIIIIIIIIITTTTIIIIIIIITITTITTTTTTTTTTIIII +query IITIIIIIIIIITTIIIIIIIIIITIIITIIIITTIIITIIIIIIIIIITIIIIITIIIIIITIIIIIIIIIITTTTIIIIIIIITITTITTTTTTTTTTIIIID + SELECT * FROM hits WHERE "URL" LIKE '%google%' ORDER BY "EventTime" LIMIT 10; ---- @@ -262,7 +281,7 @@ query IIITTI SELECT "TraficSourceID", "SearchEngineID", "AdvEngineID", CASE WHEN ("SearchEngineID" = 0 AND "AdvEngineID" = 0) THEN "Referer" ELSE '' END AS Src, "URL" AS Dst, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "IsRefresh" = 0 GROUP BY "TraficSourceID", "SearchEngineID", "AdvEngineID", Src, Dst ORDER BY PageViews DESC LIMIT 10 OFFSET 1000; ---- -query III +query IDI SELECT "URLHash", "EventDate", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "IsRefresh" = 0 AND "TraficSourceID" IN (-1, 6) AND "RefererHash" = 3594120000172545465 GROUP BY "URLHash", "EventDate" ORDER BY PageViews DESC LIMIT 10 OFFSET 100; ---- @@ -293,4 +312,7 @@ SELECT "BrowserCountry", COUNT(DISTINCT "SocialNetwork"), COUNT(DISTINCT "HitCo statement ok -drop table hits; +drop view hits; + +statement ok +drop table hits_raw; diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index 3dac92938772..4fd77be045c1 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -42,6 +42,63 @@ physical_plan statement error DataFusion error: Error during planning: WITH query name "a" specified more than once WITH a AS (SELECT 1), a AS (SELECT 2) SELECT * FROM a; +statement ok +CREATE TABLE orders AS VALUES (1), (2); + +########## +## CTE Reference Resolution +########## + +# These tests exercise CTE reference resolution with and without identifier +# normalization. The session is configured with a strict catalog/schema provider +# (see `datafusion/sqllogictest/src/test_context.rs`) that only provides the +# `orders` table and panics on any unexpected table lookup. +# +# This makes it observable if DataFusion incorrectly treats a CTE reference as a +# catalog lookup. +# +# Refs: https://github.com/apache/datafusion/issues/18932 +# +# NOTE: This test relies on a strict catalog/schema provider registered in +# `datafusion/sqllogictest/src/test_context.rs` that provides only the `orders` +# table and panics on unexpected lookups. + +statement ok +set datafusion.sql_parser.enable_ident_normalization = true; + +query I +with barbaz as (select * from orders) select * from "barbaz"; +---- +1 +2 + +query I +with BarBaz as (select * from orders) select * from "barbaz"; +---- +1 +2 + +query I +with barbaz as (select * from orders) select * from barbaz; +---- +1 +2 + +statement ok +set datafusion.sql_parser.enable_ident_normalization = false; + +query I +with barbaz as (select * from orders) select * from "barbaz"; +---- +1 +2 + +query I +with barbaz as (select * from orders) select * from barbaz; +---- +1 +2 + # Test disabling recursive CTE statement ok set datafusion.execution.enable_recursive_ctes = false; @@ -996,7 +1053,7 @@ query TT explain WITH RECURSIVE numbers AS ( select 1 as n UNION ALL - select n + 1 FROM numbers WHERE N < 10 + select n + 1 FROM numbers WHERE n < 10 ) select * from numbers; ---- logical_plan @@ -1021,7 +1078,7 @@ query TT explain WITH RECURSIVE numbers AS ( select 1 as n UNION ALL - select n + 1 FROM numbers WHERE N < 10 + select n + 1 FROM numbers WHERE n < 10 ) select * from numbers; ---- logical_plan @@ -1160,5 +1217,5 @@ query error DataFusion error: This feature is not implemented: Recursive CTEs ar explain WITH RECURSIVE numbers AS ( select 1 as n UNION ALL - select n + 1 FROM numbers WHERE N < 10 + select n + 1 FROM numbers WHERE n < 10 ) select * from numbers; diff --git a/datafusion/sqllogictest/test_files/date_bin_errors.slt b/datafusion/sqllogictest/test_files/date_bin_errors.slt new file mode 100644 index 000000000000..b6cda471d7af --- /dev/null +++ b/datafusion/sqllogictest/test_files/date_bin_errors.slt @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tests for DATE_BIN error handling with out-of-range values + +# Test case from issue #20219 - should return NULL instead of panicking +query P +select date_bin(interval '1637426858 months', to_timestamp_millis(1040292460), timestamp '1984-01-07 00:00:00'); +---- +NULL + +# Negative timestamp with month interval - should return NULL instead of panicking +query P +select date_bin(interval '1 month', to_timestamp_millis(-1040292460), timestamp '1984-01-07 00:00:00'); +---- +NULL + +# Large stride causing overflow - should return NULL +query P +select date_bin( + interval '1637426858 months', + timestamp '1969-12-31 00:00:00', + timestamp '1984-01-07 00:00:00' +); +---- +NULL + +# Another large stride test +query P +select date_bin( + interval '1637426858 months', + to_timestamp_millis(-1040292000), + timestamp '1984-01-07 00:00:00' +) as b; +---- +NULL + +# Test with 1900-01-01 timestamp +query P +select date_bin( + interval '1637426858 months', + to_timestamp_millis(-2208988800000), + timestamp '1984-01-07 00:00:00' +) as b; +---- +NULL \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/datetime/arith_date_date.slt b/datafusion/sqllogictest/test_files/datetime/arith_date_date.slt index f6e4aad78b27..8eb5cc176f36 100644 --- a/datafusion/sqllogictest/test_files/datetime/arith_date_date.slt +++ b/datafusion/sqllogictest/test_files/datetime/arith_date_date.slt @@ -1,16 +1,15 @@ # date - date → integer # Subtract dates, producing the number of days elapsed # date '2001-10-01' - date '2001-09-28' → 3 +# This aligns with PostgreSQL, DuckDB, and MySQL behavior +# Resolved by: https://github.com/apache/datafusion/issues/19528 -# note that datafusion returns Duration whereas postgres returns an int -# Tracking issue: https://github.com/apache/datafusion/issues/19528 - -query ? +query I SELECT '2001-10-01'::date - '2001-09-28'::date ---- -3 days 0 hours 0 mins 0 secs +3 query T SELECT arrow_typeof('2001-10-01'::date - '2001-09-28'::date) ---- -Duration(s) +Int64 diff --git a/datafusion/sqllogictest/test_files/datetime/arith_date_time.slt b/datafusion/sqllogictest/test_files/datetime/arith_date_time.slt index bc796a51ff5a..8e85c8f90580 100644 --- a/datafusion/sqllogictest/test_files/datetime/arith_date_time.slt +++ b/datafusion/sqllogictest/test_files/datetime/arith_date_time.slt @@ -113,4 +113,3 @@ SELECT '2001-09-28'::date / '03:00'::time query error Invalid timestamp arithmetic operation SELECT '2001-09-28'::date % '03:00'::time - diff --git a/datafusion/sqllogictest/test_files/datetime/arith_timestamp_duration.slt b/datafusion/sqllogictest/test_files/datetime/arith_timestamp_duration.slt index 10381346f835..aeeebe73db70 100644 --- a/datafusion/sqllogictest/test_files/datetime/arith_timestamp_duration.slt +++ b/datafusion/sqllogictest/test_files/datetime/arith_timestamp_duration.slt @@ -144,4 +144,4 @@ query error Invalid timestamp arithmetic operation SELECT '2001-09-28T01:00:00'::timestamp % arrow_cast(12345, 'Duration(Second)'); query error Invalid timestamp arithmetic operation -SELECT '2001-09-28T01:00:00'::timestamp / arrow_cast(12345, 'Duration(Second)'); \ No newline at end of file +SELECT '2001-09-28T01:00:00'::timestamp / arrow_cast(12345, 'Duration(Second)'); diff --git a/datafusion/sqllogictest/test_files/datetime/date_part.slt b/datafusion/sqllogictest/test_files/datetime/date_part.slt index bee8602d80bd..79d6d8ac0509 100644 --- a/datafusion/sqllogictest/test_files/datetime/date_part.slt +++ b/datafusion/sqllogictest/test_files/datetime/date_part.slt @@ -19,7 +19,7 @@ # for the same function). -## Begin tests fo rdate_part with columns and timestamp's with timezones +## Begin tests for date_part with columns and timestamp's with timezones # Source data table has # timestamps with millisecond (very common timestamp precision) and nanosecond (maximum precision) timestamps @@ -40,30 +40,32 @@ with t as (values ) SELECT -- nanoseconds, with no, utc, and local timezone - arrow_cast(column1, 'Timestamp(Nanosecond, None)') as ts_nano_no_tz, + arrow_cast(column1, 'Timestamp(ns)') as ts_nano_no_tz, + arrow_cast(column1, 'Timestamp(Nanosecond, None)') as ts_nano_no_tz_old_format, arrow_cast(column1, 'Timestamp(Nanosecond, Some("UTC"))') as ts_nano_utc, arrow_cast(column1, 'Timestamp(Nanosecond, Some("America/New_York"))') as ts_nano_eastern, -- milliseconds, with no, utc, and local timezone - arrow_cast(column1, 'Timestamp(Millisecond, None)') as ts_milli_no_tz, + arrow_cast(column1, 'Timestamp(ms)') as ts_milli_no_tz, + arrow_cast(column1, 'Timestamp(Millisecond, None)') as ts_milli_no_tz_old_format, arrow_cast(column1, 'Timestamp(Millisecond, Some("UTC"))') as ts_milli_utc, arrow_cast(column1, 'Timestamp(Millisecond, Some("America/New_York"))') as ts_milli_eastern FROM t; -query PPPPPP +query PPPPPPPP SELECT * FROM source_ts; ---- -2020-01-01T00:00:00 2020-01-01T00:00:00Z 2019-12-31T19:00:00-05:00 2020-01-01T00:00:00 2020-01-01T00:00:00Z 2019-12-31T19:00:00-05:00 -2021-01-01T00:00:00 2021-01-01T00:00:00Z 2020-12-31T19:00:00-05:00 2021-01-01T00:00:00 2021-01-01T00:00:00Z 2020-12-31T19:00:00-05:00 -2020-09-01T00:00:00 2020-09-01T00:00:00Z 2020-08-31T20:00:00-04:00 2020-09-01T00:00:00 2020-09-01T00:00:00Z 2020-08-31T20:00:00-04:00 -2020-01-25T00:00:00 2020-01-25T00:00:00Z 2020-01-24T19:00:00-05:00 2020-01-25T00:00:00 2020-01-25T00:00:00Z 2020-01-24T19:00:00-05:00 -2020-01-24T00:00:00 2020-01-24T00:00:00Z 2020-01-23T19:00:00-05:00 2020-01-24T00:00:00 2020-01-24T00:00:00Z 2020-01-23T19:00:00-05:00 -2020-01-01T12:00:00 2020-01-01T12:00:00Z 2020-01-01T07:00:00-05:00 2020-01-01T12:00:00 2020-01-01T12:00:00Z 2020-01-01T07:00:00-05:00 -2020-01-01T00:30:00 2020-01-01T00:30:00Z 2019-12-31T19:30:00-05:00 2020-01-01T00:30:00 2020-01-01T00:30:00Z 2019-12-31T19:30:00-05:00 -2020-01-01T00:00:30 2020-01-01T00:00:30Z 2019-12-31T19:00:30-05:00 2020-01-01T00:00:30 2020-01-01T00:00:30Z 2019-12-31T19:00:30-05:00 -2020-01-01T00:00:00.123 2020-01-01T00:00:00.123Z 2019-12-31T19:00:00.123-05:00 2020-01-01T00:00:00.123 2020-01-01T00:00:00.123Z 2019-12-31T19:00:00.123-05:00 -2020-01-01T00:00:00.123456 2020-01-01T00:00:00.123456Z 2019-12-31T19:00:00.123456-05:00 2020-01-01T00:00:00.123 2020-01-01T00:00:00.123Z 2019-12-31T19:00:00.123-05:00 -2020-01-01T00:00:00.123456789 2020-01-01T00:00:00.123456789Z 2019-12-31T19:00:00.123456789-05:00 2020-01-01T00:00:00.123 2020-01-01T00:00:00.123Z 2019-12-31T19:00:00.123-05:00 +2020-01-01T00:00:00 2020-01-01T00:00:00 2020-01-01T00:00:00Z 2019-12-31T19:00:00-05:00 2020-01-01T00:00:00 2020-01-01T00:00:00 2020-01-01T00:00:00Z 2019-12-31T19:00:00-05:00 +2021-01-01T00:00:00 2021-01-01T00:00:00 2021-01-01T00:00:00Z 2020-12-31T19:00:00-05:00 2021-01-01T00:00:00 2021-01-01T00:00:00 2021-01-01T00:00:00Z 2020-12-31T19:00:00-05:00 +2020-09-01T00:00:00 2020-09-01T00:00:00 2020-09-01T00:00:00Z 2020-08-31T20:00:00-04:00 2020-09-01T00:00:00 2020-09-01T00:00:00 2020-09-01T00:00:00Z 2020-08-31T20:00:00-04:00 +2020-01-25T00:00:00 2020-01-25T00:00:00 2020-01-25T00:00:00Z 2020-01-24T19:00:00-05:00 2020-01-25T00:00:00 2020-01-25T00:00:00 2020-01-25T00:00:00Z 2020-01-24T19:00:00-05:00 +2020-01-24T00:00:00 2020-01-24T00:00:00 2020-01-24T00:00:00Z 2020-01-23T19:00:00-05:00 2020-01-24T00:00:00 2020-01-24T00:00:00 2020-01-24T00:00:00Z 2020-01-23T19:00:00-05:00 +2020-01-01T12:00:00 2020-01-01T12:00:00 2020-01-01T12:00:00Z 2020-01-01T07:00:00-05:00 2020-01-01T12:00:00 2020-01-01T12:00:00 2020-01-01T12:00:00Z 2020-01-01T07:00:00-05:00 +2020-01-01T00:30:00 2020-01-01T00:30:00 2020-01-01T00:30:00Z 2019-12-31T19:30:00-05:00 2020-01-01T00:30:00 2020-01-01T00:30:00 2020-01-01T00:30:00Z 2019-12-31T19:30:00-05:00 +2020-01-01T00:00:30 2020-01-01T00:00:30 2020-01-01T00:00:30Z 2019-12-31T19:00:30-05:00 2020-01-01T00:00:30 2020-01-01T00:00:30 2020-01-01T00:00:30Z 2019-12-31T19:00:30-05:00 +2020-01-01T00:00:00.123 2020-01-01T00:00:00.123 2020-01-01T00:00:00.123Z 2019-12-31T19:00:00.123-05:00 2020-01-01T00:00:00.123 2020-01-01T00:00:00.123 2020-01-01T00:00:00.123Z 2019-12-31T19:00:00.123-05:00 +2020-01-01T00:00:00.123456 2020-01-01T00:00:00.123456 2020-01-01T00:00:00.123456Z 2019-12-31T19:00:00.123456-05:00 2020-01-01T00:00:00.123 2020-01-01T00:00:00.123 2020-01-01T00:00:00.123Z 2019-12-31T19:00:00.123-05:00 +2020-01-01T00:00:00.123456789 2020-01-01T00:00:00.123456789 2020-01-01T00:00:00.123456789Z 2019-12-31T19:00:00.123456789-05:00 2020-01-01T00:00:00.123 2020-01-01T00:00:00.123 2020-01-01T00:00:00.123Z 2019-12-31T19:00:00.123-05:00 # date_part (year) with columns and explicit timestamp query IIIIII @@ -81,6 +83,23 @@ SELECT date_part('year', ts_nano_no_tz), date_part('year', ts_nano_utc), date_pa 2020 2020 2019 2020 2020 2019 2020 2020 2019 2020 2020 2019 +# date_part (isoyear) with columns and explicit timestamp +query IIIIII +SELECT date_part('isoyear', ts_nano_no_tz), date_part('isoyear', ts_nano_utc), date_part('isoyear', ts_nano_eastern), date_part('isoyear', ts_milli_no_tz), date_part('isoyear', ts_milli_utc), date_part('isoyear', ts_milli_eastern) FROM source_ts; +---- +2020 2020 2020 2020 2020 2020 +2020 2020 2020 2020 2020 2020 +2020 2020 2020 2020 2020 2020 +2020 2020 2020 2020 2020 2020 +2020 2020 2020 2020 2020 2020 +2020 2020 2020 2020 2020 2020 +2020 2020 2020 2020 2020 2020 +2020 2020 2020 2020 2020 2020 +2020 2020 2020 2020 2020 2020 +2020 2020 2020 2020 2020 2020 +2020 2020 2020 2020 2020 2020 + + # date_part (month) query IIIIII SELECT date_part('month', ts_nano_no_tz), date_part('month', ts_nano_utc), date_part('month', ts_nano_eastern), date_part('month', ts_milli_no_tz), date_part('month', ts_milli_utc), date_part('month', ts_milli_eastern) FROM source_ts; @@ -228,6 +247,26 @@ SELECT EXTRACT('year' FROM timestamp '2020-09-08T12:00:00+00:00') ---- 2020 +query I +SELECT date_part('ISOYEAR', CAST('2000-01-01' AS DATE)) +---- +1999 + +query I +SELECT EXTRACT(isoyear FROM timestamp '2020-09-08T12:00:00+00:00') +---- +2020 + +query I +SELECT EXTRACT("isoyear" FROM timestamp '2020-09-08T12:00:00+00:00') +---- +2020 + +query I +SELECT EXTRACT('isoyear' FROM timestamp '2020-09-08T12:00:00+00:00') +---- +2020 + query I SELECT date_part('QUARTER', CAST('2000-01-01' AS DATE)) ---- @@ -865,9 +904,15 @@ SELECT extract(month from arrow_cast('20 months', 'Interval(YearMonth)')) ---- 8 +query error DataFusion error: Arrow error: Compute error: YearISO does not support: Interval\(YearMonth\) +SELECT extract(isoyear from arrow_cast('10 years', 'Interval(YearMonth)')) + query error DataFusion error: Arrow error: Compute error: Year does not support: Interval\(DayTime\) SELECT extract(year from arrow_cast('10 days', 'Interval(DayTime)')) +query error DataFusion error: Arrow error: Compute error: YearISO does not support: Interval\(DayTime\) +SELECT extract(isoyear from arrow_cast('10 days', 'Interval(DayTime)')) + query error DataFusion error: Arrow error: Compute error: Month does not support: Interval\(DayTime\) SELECT extract(month from arrow_cast('10 days', 'Interval(DayTime)')) @@ -936,6 +981,57 @@ SELECT extract(second from arrow_cast(NULL, 'Interval(MonthDayNano)')) ---- NULL +# extract epoch from intervals +query R +SELECT extract(epoch from interval '15 minutes') +---- +900 + +query R +SELECT extract(epoch from interval '1 hour') +---- +3600 + +query R +SELECT extract(epoch from interval '1 day') +---- +86400 + +query R +SELECT extract(epoch from interval '1 month') +---- +2592000 + +query R +SELECT extract(epoch from arrow_cast('3 days', 'Interval(DayTime)')) +---- +259200 + +query R +SELECT extract(epoch from arrow_cast('100 milliseconds', 'Interval(MonthDayNano)')) +---- +0.1 + +query R +SELECT extract(epoch from arrow_cast('500 microseconds', 'Interval(MonthDayNano)')) +---- +0.0005 + +query R +SELECT extract(epoch from arrow_cast('2500 nanoseconds', 'Interval(MonthDayNano)')) +---- +0.0000025 + +query R +SELECT extract(epoch from arrow_cast('1 month 2 days 500 milliseconds', 'Interval(MonthDayNano)')) +---- +2764800.5 + +query R +SELECT extract(epoch from arrow_cast('2 months', 'Interval(YearMonth)')) +---- +5184000 + statement ok create table t (id int, i interval) as values (0, interval '5 months 1 day 10 nanoseconds'), @@ -1011,6 +1107,9 @@ SELECT extract(month from arrow_cast(864000, 'Duration(Second)')) query error DataFusion error: Arrow error: Compute error: Year does not support: Duration\(s\) SELECT extract(year from arrow_cast(864000, 'Duration(Second)')) +query error DataFusion error: Arrow error: Compute error: YearISO does not support: Duration\(s\) +SELECT extract(isoyear from arrow_cast(864000, 'Duration(Second)')) + query I SELECT extract(day from arrow_cast(NULL, 'Duration(Second)')) ---- @@ -1023,6 +1122,11 @@ SELECT (date_part('year', now()) = EXTRACT(year FROM now())) ---- true +query B +SELECT (date_part('isoyear', now()) = EXTRACT(isoyear FROM now())) +---- +true + query B SELECT (date_part('quarter', now()) = EXTRACT(quarter FROM now())) ---- @@ -1090,3 +1194,563 @@ query I SELECT EXTRACT('isodow' FROM to_timestamp('2020-09-08T12:00:00+00:00')) ---- 1 + +## Preimage tests + +statement ok +create table t1(c DATE) as VALUES (NULL), ('1990-01-01'), ('2024-01-01'), ('2030-01-01'); + +# Simple optimizations, col on LHS + +query D +select c from t1 where extract(year from c) = 2024; +---- +2024-01-01 + +query D +select c from t1 where extract(year from c) <> 2024; +---- +1990-01-01 +2030-01-01 + +query D +select c from t1 where extract(year from c) > 2024; +---- +2030-01-01 + +query D +select c from t1 where extract(year from c) < 2024; +---- +1990-01-01 + +query D +select c from t1 where extract(year from c) >= 2024; +---- +2024-01-01 +2030-01-01 + +query D +select c from t1 where extract(year from c) <= 2024; +---- +1990-01-01 +2024-01-01 + +query D +select c from t1 where extract(year from c) is not distinct from 2024 +---- +2024-01-01 + +query D +select c from t1 where extract(year from c) is distinct from 2024 +---- +NULL +1990-01-01 +2030-01-01 + +# IN list optimization +query D +select c from t1 where extract(year from c) in (1990, 2024); +---- +1990-01-01 +2024-01-01 + +# NOT IN list optimization (NULL does not satisfy NOT IN) +query D +select c from t1 where extract(year from c) not in (1990, 2024); +---- +2030-01-01 + +# Check that date_part is not in the explain statements + +query TT +explain select c from t1 where extract (year from c) = 2024 +---- +logical_plan +01)Filter: t1.c >= Date32("2024-01-01") AND t1.c < Date32("2025-01-01") +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: c@0 >= 2024-01-01 AND c@0 < 2025-01-01 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (year from c) <> 2024 +---- +logical_plan +01)Filter: t1.c < Date32("2024-01-01") OR t1.c >= Date32("2025-01-01") +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: c@0 < 2024-01-01 OR c@0 >= 2025-01-01 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (year from c) > 2024 +---- +logical_plan +01)Filter: t1.c >= Date32("2025-01-01") +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: c@0 >= 2025-01-01 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (year from c) < 2024 +---- +logical_plan +01)Filter: t1.c < Date32("2024-01-01") +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: c@0 < 2024-01-01 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (year from c) >= 2024 +---- +logical_plan +01)Filter: t1.c >= Date32("2024-01-01") +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: c@0 >= 2024-01-01 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (year from c) <= 2024 +---- +logical_plan +01)Filter: t1.c < Date32("2025-01-01") +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: c@0 < 2025-01-01 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (year from c) is not distinct from 2024 +---- +logical_plan +01)Filter: t1.c IS NOT NULL AND t1.c >= Date32("2024-01-01") AND t1.c < Date32("2025-01-01") +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: c@0 IS NOT NULL AND c@0 >= 2024-01-01 AND c@0 < 2025-01-01 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (year from c) is distinct from 2024 +---- +logical_plan +01)Filter: t1.c < Date32("2024-01-01") OR t1.c >= Date32("2025-01-01") OR t1.c IS NULL +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: c@0 < 2024-01-01 OR c@0 >= 2025-01-01 OR c@0 IS NULL +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (year from c) in (1990, 2024) +---- +logical_plan +01)Filter: t1.c >= Date32("1990-01-01") AND t1.c < Date32("1991-01-01") OR t1.c >= Date32("2024-01-01") AND t1.c < Date32("2025-01-01") +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: c@0 >= 1990-01-01 AND c@0 < 1991-01-01 OR c@0 >= 2024-01-01 AND c@0 < 2025-01-01 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +# Simple optimizations, column on RHS + +query D +select c from t1 where 2024 = extract(year from c); +---- +2024-01-01 + +query D +select c from t1 where 2024 <> extract(year from c); +---- +1990-01-01 +2030-01-01 + +query D +select c from t1 where 2024 < extract(year from c); +---- +2030-01-01 + +query D +select c from t1 where 2024 > extract(year from c); +---- +1990-01-01 + +query D +select c from t1 where 2024 <= extract(year from c); +---- +2024-01-01 +2030-01-01 + +query D +select c from t1 where 2024 >= extract(year from c); +---- +1990-01-01 +2024-01-01 + +query D +select c from t1 where 2024 is not distinct from extract(year from c); +---- +2024-01-01 + +query D +select c from t1 where 2024 is distinct from extract(year from c); +---- +NULL +1990-01-01 +2030-01-01 + +# Check explain statements for optimizations for other interval types + +query TT +explain select c from t1 where extract (quarter from c) = 2024 +---- +logical_plan +01)Filter: date_part(Utf8("QUARTER"), t1.c) = Int32(2024) +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: date_part(QUARTER, c@0) = 2024 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (month from c) = 2024 +---- +logical_plan +01)Filter: date_part(Utf8("MONTH"), t1.c) = Int32(2024) +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: date_part(MONTH, c@0) = 2024 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (week from c) = 2024 +---- +logical_plan +01)Filter: date_part(Utf8("WEEK"), t1.c) = Int32(2024) +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: date_part(WEEK, c@0) = 2024 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (day from c) = 2024 +---- +logical_plan +01)Filter: date_part(Utf8("DAY"), t1.c) = Int32(2024) +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: date_part(DAY, c@0) = 2024 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (hour from c) = 2024 +---- +logical_plan +01)Filter: date_part(Utf8("HOUR"), t1.c) = Int32(2024) +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: date_part(HOUR, c@0) = 2024 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (minute from c) = 2024 +---- +logical_plan +01)Filter: date_part(Utf8("MINUTE"), t1.c) = Int32(2024) +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: date_part(MINUTE, c@0) = 2024 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (second from c) = 2024 +---- +logical_plan +01)Filter: date_part(Utf8("SECOND"), t1.c) = Int32(2024) +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: date_part(SECOND, c@0) = 2024 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (millisecond from c) = 2024 +---- +logical_plan +01)Filter: date_part(Utf8("MILLISECOND"), t1.c) = Int32(2024) +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: date_part(MILLISECOND, c@0) = 2024 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (microsecond from c) = 2024 +---- +logical_plan +01)Filter: date_part(Utf8("MICROSECOND"), t1.c) = Int32(2024) +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: date_part(MICROSECOND, c@0) = 2024 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (nanosecond from c) = 2024 +---- +logical_plan +01)Filter: date_part(Utf8("NANOSECOND"), t1.c) = Int32(2024) +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: date_part(NANOSECOND, c@0) = 2024 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (dow from c) = 2024 +---- +logical_plan +01)Filter: date_part(Utf8("DOW"), t1.c) = Int32(2024) +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: date_part(DOW, c@0) = 2024 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (doy from c) = 2024 +---- +logical_plan +01)Filter: date_part(Utf8("DOY"), t1.c) = Int32(2024) +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: date_part(DOY, c@0) = 2024 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (epoch from c) = 2024 +---- +logical_plan +01)Filter: date_part(Utf8("EPOCH"), t1.c) = Float64(2024) +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: date_part(EPOCH, c@0) = 2024 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (isodow from c) = 2024 +---- +logical_plan +01)Filter: date_part(Utf8("ISODOW"), t1.c) = Int32(2024) +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: date_part(ISODOW, c@0) = 2024 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +# Simple optimize different datatypes + +statement ok +create table t2( + c1_date32 DATE, + c2_ts_sec timestamp, + c3_ts_mili timestamp, + c4_ts_micro timestamp, + c5_ts_nano timestamp +) as VALUES + (NULL, + NULL, + NULL, + NULL, + NULL), + ('1990-05-20', + '1990-05-20T00:00:10'::timestamp, + '1990-05-20T00:00:10.987'::timestamp, + '1990-05-20T00:00:10.987654'::timestamp, + '1990-05-20T00:00:10.987654321'::timestamp), + ('2024-01-01', + '2024-01-01T00:00:00'::timestamp, + '2024-01-01T00:00:00.123'::timestamp, + '2024-01-01T00:00:00.123456'::timestamp, + '2024-01-01T00:00:00.123456789'::timestamp), + ('2030-12-31', + '2030-12-31T23:59:59'::timestamp, + '2030-12-31T23:59:59.001'::timestamp, + '2030-12-31T23:59:59.001234'::timestamp, + '2030-12-31T23:59:59.001234567'::timestamp) +; + +query D +select c1_date32 from t2 where extract(year from c1_date32) = 2024; +---- +2024-01-01 + +query D +select c1_date32 from t2 where extract(year from c1_date32) <> 2024; +---- +1990-05-20 +2030-12-31 + +query P +select c2_ts_sec from t2 where extract(year from c2_ts_sec) > 2024; +---- +2030-12-31T23:59:59 + +query P +select c3_ts_mili from t2 where extract(year from c3_ts_mili) < 2024; +---- +1990-05-20T00:00:10.987 + +query P +select c4_ts_micro from t2 where extract(year from c4_ts_micro) >= 2024; +---- +2024-01-01T00:00:00.123456 +2030-12-31T23:59:59.001234 + +query P +select c5_ts_nano from t2 where extract(year from c5_ts_nano) <= 2024; +---- +1990-05-20T00:00:10.987654321 +2024-01-01T00:00:00.123456789 + +query D +select c1_date32 from t2 where extract(year from c1_date32) is not distinct from 2024 +---- +2024-01-01 + +query D +select c1_date32 from t2 where extract(year from c1_date32) is distinct from 2024 +---- +NULL +1990-05-20 +2030-12-31 + +# Check that date_part is not in the explain statements for other datatypes + +query TT +explain select c1_date32 from t2 where extract (year from c1_date32) = 2024 +---- +logical_plan +01)Filter: t2.c1_date32 >= Date32("2024-01-01") AND t2.c1_date32 < Date32("2025-01-01") +02)--TableScan: t2 projection=[c1_date32] +physical_plan +01)FilterExec: c1_date32@0 >= 2024-01-01 AND c1_date32@0 < 2025-01-01 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c1_date32 from t2 where extract (year from c1_date32) <> 2024 +---- +logical_plan +01)Filter: t2.c1_date32 < Date32("2024-01-01") OR t2.c1_date32 >= Date32("2025-01-01") +02)--TableScan: t2 projection=[c1_date32] +physical_plan +01)FilterExec: c1_date32@0 < 2024-01-01 OR c1_date32@0 >= 2025-01-01 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c2_ts_sec from t2 where extract (year from c2_ts_sec) > 2024 +---- +logical_plan +01)Filter: t2.c2_ts_sec >= TimestampNanosecond(1735689600000000000, None) +02)--TableScan: t2 projection=[c2_ts_sec] +physical_plan +01)FilterExec: c2_ts_sec@0 >= 1735689600000000000 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c3_ts_mili from t2 where extract (year from c3_ts_mili) < 2024 +---- +logical_plan +01)Filter: t2.c3_ts_mili < TimestampNanosecond(1704067200000000000, None) +02)--TableScan: t2 projection=[c3_ts_mili] +physical_plan +01)FilterExec: c3_ts_mili@0 < 1704067200000000000 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c4_ts_micro from t2 where extract (year from c4_ts_micro) >= 2024 +---- +logical_plan +01)Filter: t2.c4_ts_micro >= TimestampNanosecond(1704067200000000000, None) +02)--TableScan: t2 projection=[c4_ts_micro] +physical_plan +01)FilterExec: c4_ts_micro@0 >= 1704067200000000000 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c5_ts_nano from t2 where extract (year from c5_ts_nano) <= 2024 +---- +logical_plan +01)Filter: t2.c5_ts_nano < TimestampNanosecond(1735689600000000000, None) +02)--TableScan: t2 projection=[c5_ts_nano] +physical_plan +01)FilterExec: c5_ts_nano@0 < 1735689600000000000 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c1_date32 from t2 where extract (year from c1_date32) is not distinct from 2024 +---- +logical_plan +01)Filter: t2.c1_date32 IS NOT NULL AND t2.c1_date32 >= Date32("2024-01-01") AND t2.c1_date32 < Date32("2025-01-01") +02)--TableScan: t2 projection=[c1_date32] +physical_plan +01)FilterExec: c1_date32@0 IS NOT NULL AND c1_date32@0 >= 2024-01-01 AND c1_date32@0 < 2025-01-01 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c1_date32 from t2 where extract (year from c1_date32) is distinct from 2024 +---- +logical_plan +01)Filter: t2.c1_date32 < Date32("2024-01-01") OR t2.c1_date32 >= Date32("2025-01-01") OR t2.c1_date32 IS NULL +02)--TableScan: t2 projection=[c1_date32] +physical_plan +01)FilterExec: c1_date32@0 < 2024-01-01 OR c1_date32@0 >= 2025-01-01 OR c1_date32@0 IS NULL +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +# Preimage with timestamp with America/New_York timezone + +statement ok +SET datafusion.execution.time_zone = 'America/New_York'; + +statement ok +create table t3( + c1_ts_tz timestamptz +) as VALUES + (NULL), + ('2024-01-01T04:59:59Z'::timestamptz), -- local 2023-12-31 23:59:59 -05 + ('2024-01-01T05:00:00Z'::timestamptz), -- local 2024-01-01 00:00:00 -05 + ('2025-01-01T04:59:59Z'::timestamptz), -- local 2024-12-31 23:59:59 -05 + ('2025-01-01T05:00:00Z'::timestamptz) -- local 2025-01-01 00:00:00 -05 +; + +query P +select c1_ts_tz +from t3 +where extract(year from c1_ts_tz) = 2024 +order by c1_ts_tz +---- +2024-01-01T00:00:00-05:00 +2024-12-31T23:59:59-05:00 + +query TT +explain select c1_ts_tz from t3 where extract(year from c1_ts_tz) = 2024 +---- +logical_plan +01)Filter: t3.c1_ts_tz >= TimestampNanosecond(1704085200000000000, Some("America/New_York")) AND t3.c1_ts_tz < TimestampNanosecond(1735707600000000000, Some("America/New_York")) +02)--TableScan: t3 projection=[c1_ts_tz] +physical_plan +01)FilterExec: c1_ts_tz@0 >= 1704085200000000000 AND c1_ts_tz@0 < 1735707600000000000 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +statement ok +RESET datafusion.execution.time_zone; + +# Test non-Int32 rhs argument + +query D +select c from t1 where extract(year from c) = cast(2024 as bigint); +---- +2024-01-01 + +query TT +explain select c from t1 where extract (year from c) = cast(2024 as bigint) +---- +logical_plan +01)Filter: t1.c >= Date32("2024-01-01") AND t1.c < Date32("2025-01-01") +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: c@0 >= 2024-01-01 AND c@0 < 2025-01-01 +02)--DataSourceExec: partitions=1, partition_sizes=[1] diff --git a/datafusion/sqllogictest/test_files/datetime/dates.slt b/datafusion/sqllogictest/test_files/datetime/dates.slt index 6ba34cfcac03..d2a7360b120c 100644 --- a/datafusion/sqllogictest/test_files/datetime/dates.slt +++ b/datafusion/sqllogictest/test_files/datetime/dates.slt @@ -94,13 +94,6 @@ caused by Error during planning: Cannot coerce arithmetic expression Timestamp(ns) + Utf8 to valid types -# DATE minus DATE -# https://github.com/apache/arrow-rs/issues/4383 -query ? -SELECT DATE '2023-04-09' - DATE '2023-04-02'; ----- -7 days 0 hours 0 mins 0 secs - # DATE minus Timestamp query ? SELECT DATE '2023-04-09' - '2000-01-01T00:00:00'::timestamp; @@ -113,17 +106,18 @@ SELECT '2023-01-01T00:00:00'::timestamp - DATE '2021-01-01'; ---- 730 days 0 hours 0 mins 0.000000000 secs -# NULL with DATE arithmetic should yield NULL -query ? +# NULL with DATE arithmetic should yield NULL (but Int64 type) +query I SELECT NULL - DATE '1984-02-28'; ---- NULL -query ? +query I SELECT DATE '1984-02-28' - NULL ---- NULL + # to_date_test statement ok create table to_date_t1(ts bigint) as VALUES diff --git a/datafusion/sqllogictest/test_files/datetime/timestamps.slt b/datafusion/sqllogictest/test_files/datetime/timestamps.slt index dbb924ef7aa6..c3d36b247b5a 100644 --- a/datafusion/sqllogictest/test_files/datetime/timestamps.slt +++ b/datafusion/sqllogictest/test_files/datetime/timestamps.slt @@ -19,10 +19,10 @@ ## Common timestamp data # # ts_data: Int64 nanoseconds -# ts_data_nanos: Timestamp(Nanosecond, None) -# ts_data_micros: Timestamp(Microsecond, None) -# ts_data_millis: Timestamp(Millisecond, None) -# ts_data_secs: Timestamp(Second, None) +# ts_data_nanos: Timestamp(ns) +# ts_data_micros: Timestamp(µs) +# ts_data_millis: Timestamp(ms) +# ts_data_secs: Timestamp(s) ########## # Create timestamp tables with different precisions but the same logical values @@ -34,16 +34,16 @@ create table ts_data(ts bigint, value int) as values (1599565349190855123, 3); statement ok -create table ts_data_nanos as select arrow_cast(ts, 'Timestamp(Nanosecond, None)') as ts, value from ts_data; +create table ts_data_nanos as select arrow_cast(ts, 'Timestamp(ns)') as ts, value from ts_data; statement ok -create table ts_data_micros as select arrow_cast(ts / 1000, 'Timestamp(Microsecond, None)') as ts, value from ts_data; +create table ts_data_micros as select arrow_cast(ts / 1000, 'Timestamp(µs)') as ts, value from ts_data; statement ok -create table ts_data_millis as select arrow_cast(ts / 1000000, 'Timestamp(Millisecond, None)') as ts, value from ts_data; +create table ts_data_millis as select arrow_cast(ts / 1000000, 'Timestamp(ms)') as ts, value from ts_data; statement ok -create table ts_data_secs as select arrow_cast(ts / 1000000000, 'Timestamp(Second, None)') as ts, value from ts_data; +create table ts_data_secs as select arrow_cast(ts / 1000000000, 'Timestamp(s)') as ts, value from ts_data; statement ok create table ts_data_micros_kolkata as select arrow_cast(ts / 1000, 'Timestamp(Microsecond, Some("Asia/Kolkata"))') as ts, value from ts_data; @@ -771,6 +771,18 @@ select to_timestamp_seconds(cast (1 as int)); ## test date_bin function ########## +# NULL stride should return NULL, not a planning error +query P +SELECT date_bin(NULL, TIMESTAMP '2023-01-01 12:30:00', TIMESTAMP '2023-01-01 12:00:00') +---- +NULL + +# NULL stride should return NULL, not a planning error +query P +SELECT date_bin(NULL, TIMESTAMP '2023-01-01 12:30:00') +---- +NULL + # invalid second arg type query error SELECT DATE_BIN(INTERVAL '0 second', 25, TIMESTAMP '1970-01-01T00:00:00Z') @@ -1579,13 +1591,13 @@ second 2020-09-08T13:42:29 # test date trunc on different timestamp scalar types and ensure they are consistent query P rowsort -SELECT DATE_TRUNC('second', arrow_cast(TIMESTAMP '2023-08-03 14:38:50Z', 'Timestamp(Second, None)')) as ts +SELECT DATE_TRUNC('second', arrow_cast(TIMESTAMP '2023-08-03 14:38:50Z', 'Timestamp(s)')) as ts UNION ALL -SELECT DATE_TRUNC('second', arrow_cast(TIMESTAMP '2023-08-03 14:38:50Z', 'Timestamp(Nanosecond, None)')) as ts +SELECT DATE_TRUNC('second', arrow_cast(TIMESTAMP '2023-08-03 14:38:50Z', 'Timestamp(ns)')) as ts UNION ALL -SELECT DATE_TRUNC('day', arrow_cast(TIMESTAMP '2023-08-03 14:38:50Z', 'Timestamp(Microsecond, None)')) as ts +SELECT DATE_TRUNC('day', arrow_cast(TIMESTAMP '2023-08-03 14:38:50Z', 'Timestamp(µs)')) as ts UNION ALL -SELECT DATE_TRUNC('day', arrow_cast(TIMESTAMP '2023-08-03 14:38:50Z', 'Timestamp(Millisecond, None)')) as ts +SELECT DATE_TRUNC('day', arrow_cast(TIMESTAMP '2023-08-03 14:38:50Z', 'Timestamp(ms)')) as ts ---- 2023-08-03T00:00:00 2023-08-03T00:00:00 @@ -2376,6 +2388,59 @@ select arrow_typeof(date_trunc('microsecond', to_timestamp(61))) ---- Timestamp(ns) +########## +## date_trunc with Time types +########## + +# Truncate time to hour +query D +SELECT date_trunc('hour', TIME '14:30:45'); +---- +14:00:00 + +# Truncate time to minute +query D +SELECT date_trunc('minute', TIME '14:30:45'); +---- +14:30:00 + +# Truncate time to second (removes fractional seconds) +query D +SELECT date_trunc('second', TIME '14:30:45.123456789'); +---- +14:30:45 + +# Truncate time to millisecond +query D +SELECT date_trunc('millisecond', TIME '14:30:45.123456789'); +---- +14:30:45.123 + +# Truncate time to microsecond +query D +SELECT date_trunc('microsecond', TIME '14:30:45.123456789'); +---- +14:30:45.123456 + +# Return type should be Time64(ns) +query T +SELECT arrow_typeof(date_trunc('hour', TIME '14:30:45')); +---- +Time64(ns) + +# Error for granularities not valid for Time types +query error date_trunc does not support 'day' granularity for Time types +SELECT date_trunc('day', TIME '14:30:45'); + +query error date_trunc does not support 'week' granularity for Time types +SELECT date_trunc('week', TIME '14:30:45'); + +query error date_trunc does not support 'month' granularity for Time types +SELECT date_trunc('month', TIME '14:30:45'); + +query error date_trunc does not support 'year' granularity for Time types +SELECT date_trunc('year', TIME '14:30:45'); + # check date_bin query P SELECT date_bin(INTERVAL '1 day', time, '1970-01-01T00:00:00+05:00') FROM foo @@ -2653,7 +2718,7 @@ drop table ts_utf8_data ########## query B -select arrow_cast(now(), 'Date64') < arrow_cast('2022-02-02 02:02:02', 'Timestamp(Nanosecond, None)'); +select arrow_cast(now(), 'Date64') < arrow_cast('2022-02-02 02:02:02', 'Timestamp(ns)'); ---- false @@ -3011,7 +3076,7 @@ NULL query error DataFusion error: Error during planning: Function 'make_date' expects 3 arguments but received 1 select make_date(1); -query error Expect TypeSignatureClass::Native\(LogicalType\(Native\(Int32\), Int32\)\) but received NativeType::Interval\(MonthDayNano\), DataType: Interval\(MonthDayNano\) +query error DataFusion error: Error during planning: Function 'make_date' requires TypeSignatureClass::Native\(LogicalType\(Native\(Int32\), Int32\)\), but received Interval\(MonthDayNano\) \(DataType: Interval\(MonthDayNano\)\) select make_date(interval '1 day', '2001-05-21'::timestamp, '2001-05-21'::timestamp); ########## @@ -3284,7 +3349,7 @@ select make_time(22, '', 27); query error Cannot cast string '' to value of Int32 type select make_time(22, 1, ''); -query error Expect TypeSignatureClass::Native\(LogicalType\(Native\(Int32\), Int32\)\) but received NativeType::Float64, DataType: Float64 +query error DataFusion error: Error during planning: Function 'make_time' requires TypeSignatureClass::Native\(LogicalType\(Native\(Int32\), Int32\)\), but received Float64 \(DataType: Float64\) select make_time(arrow_cast(22, 'Float64'), 1, ''); ########## @@ -3587,7 +3652,7 @@ select to_char(arrow_cast(12344567890000, 'Time64(Nanosecond)'), '%H-%M-%S %f') 03-25-44 567890000 query T -select to_char(arrow_cast(TIMESTAMP '2023-08-03 14:38:50Z', 'Timestamp(Second, None)'), '%d-%m-%Y %H-%M-%S') +select to_char(arrow_cast(TIMESTAMP '2023-08-03 14:38:50Z', 'Timestamp(s)'), '%d-%m-%Y %H-%M-%S') ---- 03-08-2023 14-38-50 @@ -3611,10 +3676,10 @@ select to_char(arrow_cast(123456, 'Duration(Second)'), null); ---- NULL -query error DataFusion error: Execution error: Cast error: Format error +query error DataFusion error: Arrow error: Cast error: Format error SELECT to_char(timestamps, '%X%K') from formats; -query error DataFusion error: Execution error: Cast error: Format error +query error DataFusion error: Arrow error: Cast error: Format error SELECT to_char('2000-02-03'::date, '%X%K'); query T @@ -3661,6 +3726,21 @@ select to_char('2020-01-01 00:10:20.123'::timestamp at time zone 'America/New_Yo ---- 2020-01-01 00:10:20.123 +# Null values with array format +query T +SELECT to_char(column1, column2) +FROM (VALUES + (DATE '2020-09-01', '%Y-%m-%d'), + (NULL, '%Y-%m-%d'), + (DATE '2020-09-02', NULL), + (NULL, NULL) +); +---- +2020-09-01 +NULL +NULL +NULL + statement ok drop table formats; @@ -3679,7 +3759,7 @@ select to_unixtime(arrow_cast(to_timestamp('2023-01-14T01:01:30'), 'Timestamp(Se 1673638290 query I -select to_unixtime(arrow_cast(to_timestamp('2023-01-14T01:01:30'), 'Timestamp(Millisecond, None)')); +select to_unixtime(arrow_cast(to_timestamp('2023-01-14T01:01:30'), 'Timestamp(ms)')); ---- 1673658090 @@ -3899,7 +3979,7 @@ statement error select to_local_time('2024-04-01T00:00:20Z'::timestamp, 'some string'); # invalid argument data type -statement error DataFusion error: Error during planning: Internal error: Expect TypeSignatureClass::Timestamp but received NativeType::String, DataType: Utf8 +statement error DataFusion error: Error during planning: Function 'to_local_time' requires TypeSignatureClass::Timestamp, but received String \(DataType: Utf8\) select to_local_time('2024-04-01T00:00:20Z'); # invalid timezone @@ -4254,58 +4334,58 @@ SELECT CAST(CAST(one AS decimal(17,2)) AS timestamp(3)) AS a FROM (VALUES (1)) t 1970-01-01T00:00:00.001 query P -SELECT arrow_cast(CAST(1 AS decimal(17,2)), 'Timestamp(Nanosecond, None)') AS a UNION ALL -SELECT arrow_cast(CAST(one AS decimal(17,2)), 'Timestamp(Nanosecond, None)') AS a FROM (VALUES (1)) t(one); +SELECT arrow_cast(CAST(1 AS decimal(17,2)), 'Timestamp(ns)') AS a UNION ALL +SELECT arrow_cast(CAST(one AS decimal(17,2)), 'Timestamp(ns)') AS a FROM (VALUES (1)) t(one); ---- 1970-01-01T00:00:00.000000001 1970-01-01T00:00:00.000000001 query P -SELECT arrow_cast(CAST(1 AS decimal(17,2)), 'Timestamp(Microsecond, None)') AS a UNION ALL -SELECT arrow_cast(CAST(one AS decimal(17,2)), 'Timestamp(Microsecond, None)') AS a FROM (VALUES (1)) t(one); +SELECT arrow_cast(CAST(1 AS decimal(17,2)), 'Timestamp(µs)') AS a UNION ALL +SELECT arrow_cast(CAST(one AS decimal(17,2)), 'Timestamp(µs)') AS a FROM (VALUES (1)) t(one); ---- 1970-01-01T00:00:00.000001 1970-01-01T00:00:00.000001 query P -SELECT arrow_cast(CAST(1 AS decimal(17,2)), 'Timestamp(Millisecond, None)') AS a UNION ALL -SELECT arrow_cast(CAST(one AS decimal(17,2)), 'Timestamp(Millisecond, None)') AS a FROM (VALUES (1)) t(one); +SELECT arrow_cast(CAST(1 AS decimal(17,2)), 'Timestamp(ms)') AS a UNION ALL +SELECT arrow_cast(CAST(one AS decimal(17,2)), 'Timestamp(ms)') AS a FROM (VALUES (1)) t(one); ---- 1970-01-01T00:00:00.001 1970-01-01T00:00:00.001 query P -SELECT arrow_cast(CAST(1 AS decimal(17,2)), 'Timestamp(Second, None)') AS a UNION ALL -SELECT arrow_cast(CAST(one AS decimal(17,2)), 'Timestamp(Second, None)') AS a FROM (VALUES (1)) t(one); +SELECT arrow_cast(CAST(1 AS decimal(17,2)), 'Timestamp(s)') AS a UNION ALL +SELECT arrow_cast(CAST(one AS decimal(17,2)), 'Timestamp(s)') AS a FROM (VALUES (1)) t(one); ---- 1970-01-01T00:00:01 1970-01-01T00:00:01 query P -SELECT arrow_cast(CAST(1.123 AS decimal(17,3)), 'Timestamp(Nanosecond, None)') AS a UNION ALL -SELECT arrow_cast(CAST(one AS decimal(17,3)), 'Timestamp(Nanosecond, None)') AS a FROM (VALUES (1.123)) t(one); +SELECT arrow_cast(CAST(1.123 AS decimal(17,3)), 'Timestamp(ns)') AS a UNION ALL +SELECT arrow_cast(CAST(one AS decimal(17,3)), 'Timestamp(ns)') AS a FROM (VALUES (1.123)) t(one); ---- 1970-01-01T00:00:00.000000001 1970-01-01T00:00:00.000000001 query P -SELECT arrow_cast(CAST(1.123 AS decimal(17,3)), 'Timestamp(Microsecond, None)') AS a UNION ALL -SELECT arrow_cast(CAST(one AS decimal(17,3)), 'Timestamp(Microsecond, None)') AS a FROM (VALUES (1.123)) t(one); +SELECT arrow_cast(CAST(1.123 AS decimal(17,3)), 'Timestamp(µs)') AS a UNION ALL +SELECT arrow_cast(CAST(one AS decimal(17,3)), 'Timestamp(µs)') AS a FROM (VALUES (1.123)) t(one); ---- 1970-01-01T00:00:00.000001 1970-01-01T00:00:00.000001 query P -SELECT arrow_cast(CAST(1.123 AS decimal(17,3)), 'Timestamp(Millisecond, None)') AS a UNION ALL -SELECT arrow_cast(CAST(one AS decimal(17,3)), 'Timestamp(Millisecond, None)') AS a FROM (VALUES (1.123)) t(one); +SELECT arrow_cast(CAST(1.123 AS decimal(17,3)), 'Timestamp(ms)') AS a UNION ALL +SELECT arrow_cast(CAST(one AS decimal(17,3)), 'Timestamp(ms)') AS a FROM (VALUES (1.123)) t(one); ---- 1970-01-01T00:00:00.001 1970-01-01T00:00:00.001 query P -SELECT arrow_cast(CAST(1.123 AS decimal(17,3)), 'Timestamp(Second, None)') AS a UNION ALL -SELECT arrow_cast(CAST(one AS decimal(17,3)), 'Timestamp(Second, None)') AS a FROM (VALUES (1.123)) t(one); +SELECT arrow_cast(CAST(1.123 AS decimal(17,3)), 'Timestamp(s)') AS a UNION ALL +SELECT arrow_cast(CAST(one AS decimal(17,3)), 'Timestamp(s)') AS a FROM (VALUES (1.123)) t(one); ---- 1970-01-01T00:00:01 1970-01-01T00:00:01 @@ -4357,7 +4437,7 @@ FROM ts_data_micros_kolkata ## Casting between timestamp with and without timezone ########## -# Test casting from Timestamp(Nanosecond, Some("UTC")) to Timestamp(Nanosecond, None) +# Test casting from Timestamp(Nanosecond, Some("UTC")) to Timestamp(ns) # Verifies that the underlying nanosecond values are preserved when removing timezone # Verify input type @@ -4368,13 +4448,13 @@ Timestamp(ns, "UTC") # Verify output type after casting query T -SELECT arrow_typeof(arrow_cast(arrow_cast(1, 'Timestamp(Nanosecond, Some("UTC"))'), 'Timestamp(Nanosecond, None)')); +SELECT arrow_typeof(arrow_cast(arrow_cast(1, 'Timestamp(Nanosecond, Some("UTC"))'), 'Timestamp(ns)')); ---- Timestamp(ns) # Verify values are preserved when casting from timestamp with timezone to timestamp without timezone query P rowsort -SELECT arrow_cast(column1, 'Timestamp(Nanosecond, None)') +SELECT arrow_cast(column1, 'Timestamp(ns)') FROM (VALUES (arrow_cast(1, 'Timestamp(Nanosecond, Some("UTC"))')), (arrow_cast(2, 'Timestamp(Nanosecond, Some("UTC"))')), @@ -4389,18 +4469,18 @@ FROM (VALUES 1970-01-01T00:00:00.000000004 1970-01-01T00:00:00.000000005 -# Test casting from Timestamp(Nanosecond, None) to Timestamp(Nanosecond, Some("UTC")) +# Test casting from Timestamp(ns) to Timestamp(Nanosecond, Some("UTC")) # Verifies that the underlying nanosecond values are preserved when adding timezone # Verify input type query T -SELECT arrow_typeof(arrow_cast(1, 'Timestamp(Nanosecond, None)')); +SELECT arrow_typeof(arrow_cast(1, 'Timestamp(ns)')); ---- Timestamp(ns) # Verify output type after casting query T -SELECT arrow_typeof(arrow_cast(arrow_cast(1, 'Timestamp(Nanosecond, None)'), 'Timestamp(Nanosecond, Some("UTC"))')); +SELECT arrow_typeof(arrow_cast(arrow_cast(1, 'Timestamp(ns)'), 'Timestamp(Nanosecond, Some("UTC"))')); ---- Timestamp(ns, "UTC") @@ -4408,11 +4488,11 @@ Timestamp(ns, "UTC") query P rowsort SELECT arrow_cast(column1, 'Timestamp(Nanosecond, Some("UTC"))') FROM (VALUES - (arrow_cast(1, 'Timestamp(Nanosecond, None)')), - (arrow_cast(2, 'Timestamp(Nanosecond, None)')), - (arrow_cast(3, 'Timestamp(Nanosecond, None)')), - (arrow_cast(4, 'Timestamp(Nanosecond, None)')), - (arrow_cast(5, 'Timestamp(Nanosecond, None)')) + (arrow_cast(1, 'Timestamp(ns)')), + (arrow_cast(2, 'Timestamp(ns)')), + (arrow_cast(3, 'Timestamp(ns)')), + (arrow_cast(4, 'Timestamp(ns)')), + (arrow_cast(5, 'Timestamp(ns)')) ) t; ---- 1970-01-01T00:00:00.000000001Z @@ -4423,23 +4503,885 @@ FROM (VALUES ########## -## Common timestamp data +## to_timestamp functions with all numeric types ########## -statement ok -drop table ts_data +# Test to_timestamp with all integer types +# Int8 +query P +SELECT to_timestamp(arrow_cast(0, 'Int8')); +---- +1970-01-01T00:00:00 -statement ok -drop table ts_data_nanos +query P +SELECT to_timestamp(arrow_cast(100, 'Int8')); +---- +1970-01-01T00:01:40 -statement ok -drop table ts_data_micros +# Int16 +query P +SELECT to_timestamp(arrow_cast(0, 'Int16')); +---- +1970-01-01T00:00:00 -statement ok -drop table ts_data_millis +query P +SELECT to_timestamp(arrow_cast(1000, 'Int16')); +---- +1970-01-01T00:16:40 -statement ok -drop table ts_data_secs +# Int32 +query P +SELECT to_timestamp(arrow_cast(0, 'Int32')); +---- +1970-01-01T00:00:00 -statement ok -drop table ts_data_micros_kolkata +query P +SELECT to_timestamp(arrow_cast(86400, 'Int32')); +---- +1970-01-02T00:00:00 + +# Int64 +query P +SELECT to_timestamp(arrow_cast(0, 'Int64')); +---- +1970-01-01T00:00:00 + +query P +SELECT to_timestamp(arrow_cast(86400, 'Int64')); +---- +1970-01-02T00:00:00 + +# UInt8 +query P +SELECT to_timestamp(arrow_cast(0, 'UInt8')); +---- +1970-01-01T00:00:00 + +query P +SELECT to_timestamp(arrow_cast(100, 'UInt8')); +---- +1970-01-01T00:01:40 + +# UInt16 +query P +SELECT to_timestamp(arrow_cast(0, 'UInt16')); +---- +1970-01-01T00:00:00 + +query P +SELECT to_timestamp(arrow_cast(1000, 'UInt16')); +---- +1970-01-01T00:16:40 + +# UInt32 +query P +SELECT to_timestamp(arrow_cast(0, 'UInt32')); +---- +1970-01-01T00:00:00 + +query P +SELECT to_timestamp(arrow_cast(86400, 'UInt32')); +---- +1970-01-02T00:00:00 + +# UInt64 +query P +SELECT to_timestamp(arrow_cast(0, 'UInt64')); +---- +1970-01-01T00:00:00 + +query P +SELECT to_timestamp(arrow_cast(86400, 'UInt64')); +---- +1970-01-02T00:00:00 + +# Float16 +query P +SELECT to_timestamp(arrow_cast(0.0, 'Float16')); +---- +1970-01-01T00:00:00 + +query P +SELECT to_timestamp(arrow_cast(1.5, 'Float16')); +---- +1970-01-01T00:00:01.500 + +# Float32 +query P +SELECT to_timestamp(arrow_cast(0.0, 'Float32')); +---- +1970-01-01T00:00:00 + +query P +SELECT to_timestamp(arrow_cast(1.5, 'Float32')); +---- +1970-01-01T00:00:01.500 + +# Float64 +query P +SELECT to_timestamp(arrow_cast(0.0, 'Float64')); +---- +1970-01-01T00:00:00 + +query P +SELECT to_timestamp(arrow_cast(1.5, 'Float64')); +---- +1970-01-01T00:00:01.500 + +# Test to_timestamp_seconds with all integer types +# Int8 +query P +SELECT to_timestamp_seconds(arrow_cast(0, 'Int8')); +---- +1970-01-01T00:00:00 + +query P +SELECT to_timestamp_seconds(arrow_cast(100, 'Int8')); +---- +1970-01-01T00:01:40 + +# Int16 +query P +SELECT to_timestamp_seconds(arrow_cast(1000, 'Int16')); +---- +1970-01-01T00:16:40 + +# Int32 +query P +SELECT to_timestamp_seconds(arrow_cast(86400, 'Int32')); +---- +1970-01-02T00:00:00 + +# Int64 +query P +SELECT to_timestamp_seconds(arrow_cast(86400, 'Int64')); +---- +1970-01-02T00:00:00 + +# UInt8 +query P +SELECT to_timestamp_seconds(arrow_cast(100, 'UInt8')); +---- +1970-01-01T00:01:40 + +# UInt16 +query P +SELECT to_timestamp_seconds(arrow_cast(1000, 'UInt16')); +---- +1970-01-01T00:16:40 + +# UInt32 +query P +SELECT to_timestamp_seconds(arrow_cast(86400, 'UInt32')); +---- +1970-01-02T00:00:00 + +# UInt64 +query P +SELECT to_timestamp_seconds(arrow_cast(86400, 'UInt64')); +---- +1970-01-02T00:00:00 + +# Float16 +query P +SELECT to_timestamp_seconds(arrow_cast(1.9, 'Float16')); +---- +1970-01-01T00:00:01 + +# Float32 +query P +SELECT to_timestamp_seconds(arrow_cast(1.9, 'Float32')); +---- +1970-01-01T00:00:01 + +# Float64 +query P +SELECT to_timestamp_seconds(arrow_cast(1.9, 'Float64')); +---- +1970-01-01T00:00:01 + +# Test to_timestamp_millis with all integer types +# Int8 +query P +SELECT to_timestamp_millis(arrow_cast(0, 'Int8')); +---- +1970-01-01T00:00:00 + +query P +SELECT to_timestamp_millis(arrow_cast(100, 'Int8')); +---- +1970-01-01T00:00:00.100 + +# Int16 +query P +SELECT to_timestamp_millis(arrow_cast(1000, 'Int16')); +---- +1970-01-01T00:00:01 + +# Int32 +query P +SELECT to_timestamp_millis(arrow_cast(86400000, 'Int32')); +---- +1970-01-02T00:00:00 + +# Int64 +query P +SELECT to_timestamp_millis(arrow_cast(86400000, 'Int64')); +---- +1970-01-02T00:00:00 + +# UInt8 +query P +SELECT to_timestamp_millis(arrow_cast(100, 'UInt8')); +---- +1970-01-01T00:00:00.100 + +# UInt16 +query P +SELECT to_timestamp_millis(arrow_cast(1000, 'UInt16')); +---- +1970-01-01T00:00:01 + +# UInt32 +query P +SELECT to_timestamp_millis(arrow_cast(86400000, 'UInt32')); +---- +1970-01-02T00:00:00 + +# UInt64 +query P +SELECT to_timestamp_millis(arrow_cast(86400000, 'UInt64')); +---- +1970-01-02T00:00:00 + +# Float16 +query P +SELECT to_timestamp_millis(arrow_cast(1000, 'Float16')); +---- +1970-01-01T00:00:01 + +# Float32 +query P +SELECT to_timestamp_millis(arrow_cast(1000.9, 'Float32')); +---- +1970-01-01T00:00:01 + +# Float64 +query P +SELECT to_timestamp_millis(arrow_cast(1000.9, 'Float64')); +---- +1970-01-01T00:00:01 + +# Test to_timestamp_micros with all integer types +# Int8 +query P +SELECT to_timestamp_micros(arrow_cast(0, 'Int8')); +---- +1970-01-01T00:00:00 + +query P +SELECT to_timestamp_micros(arrow_cast(100, 'Int8')); +---- +1970-01-01T00:00:00.000100 + +# Int16 +query P +SELECT to_timestamp_micros(arrow_cast(1000, 'Int16')); +---- +1970-01-01T00:00:00.001 + +# Int32 +query P +SELECT to_timestamp_micros(arrow_cast(1000000, 'Int32')); +---- +1970-01-01T00:00:01 + +# Int64 +query P +SELECT to_timestamp_micros(arrow_cast(86400000000, 'Int64')); +---- +1970-01-02T00:00:00 + +# UInt8 +query P +SELECT to_timestamp_micros(arrow_cast(100, 'UInt8')); +---- +1970-01-01T00:00:00.000100 + +# UInt16 +query P +SELECT to_timestamp_micros(arrow_cast(1000, 'UInt16')); +---- +1970-01-01T00:00:00.001 + +# UInt32 +query P +SELECT to_timestamp_micros(arrow_cast(1000000, 'UInt32')); +---- +1970-01-01T00:00:01 + +# UInt64 +query P +SELECT to_timestamp_micros(arrow_cast(1000000, 'UInt64')); +---- +1970-01-01T00:00:01 + +# Float16 +query P +SELECT to_timestamp_micros(arrow_cast(1000, 'Float16')); +---- +1970-01-01T00:00:00.001 + +# Float32 +query P +SELECT to_timestamp_micros(arrow_cast(1000000.9, 'Float32')); +---- +1970-01-01T00:00:01 + +# Float64 +query P +SELECT to_timestamp_micros(arrow_cast(1000000.9, 'Float64')); +---- +1970-01-01T00:00:01 + +# Test to_timestamp_nanos with all integer types +# Int8 +query P +SELECT to_timestamp_nanos(arrow_cast(0, 'Int8')); +---- +1970-01-01T00:00:00 + +query P +SELECT to_timestamp_nanos(arrow_cast(100, 'Int8')); +---- +1970-01-01T00:00:00.000000100 + +# Int16 +query P +SELECT to_timestamp_nanos(arrow_cast(1000, 'Int16')); +---- +1970-01-01T00:00:00.000001 + +# Int32 +query P +SELECT to_timestamp_nanos(arrow_cast(1000000000, 'Int32')); +---- +1970-01-01T00:00:01 + +# Int64 +query P +SELECT to_timestamp_nanos(arrow_cast(86400000000000, 'Int64')); +---- +1970-01-02T00:00:00 + +# UInt8 +query P +SELECT to_timestamp_nanos(arrow_cast(100, 'UInt8')); +---- +1970-01-01T00:00:00.000000100 + +# UInt16 +query P +SELECT to_timestamp_nanos(arrow_cast(1000, 'UInt16')); +---- +1970-01-01T00:00:00.000001 + +# UInt32 +query P +SELECT to_timestamp_nanos(arrow_cast(1000000000, 'UInt32')); +---- +1970-01-01T00:00:01 + +# UInt64 +query P +SELECT to_timestamp_nanos(arrow_cast(1000000000, 'UInt64')); +---- +1970-01-01T00:00:01 + +# Float16 +query P +SELECT to_timestamp_nanos(arrow_cast(1000, 'Float16')); +---- +1970-01-01T00:00:00.000001 + +# Float32 +query P +SELECT to_timestamp_nanos(arrow_cast(1000000000.9, 'Float32')); +---- +1970-01-01T00:00:01 + +# Float64 +query P +SELECT to_timestamp_nanos(arrow_cast(1000000000.9, 'Float64')); +---- +1970-01-01T00:00:01 + +# Verify arrow_typeof for all to_timestamp functions with various input types +query T +SELECT arrow_typeof(to_timestamp(arrow_cast(0, 'Int8'))); +---- +Timestamp(ns) + +query T +SELECT arrow_typeof(to_timestamp(arrow_cast(0, 'UInt64'))); +---- +Timestamp(ns) + +query T +SELECT arrow_typeof(to_timestamp(arrow_cast(0.0, 'Float32'))); +---- +Timestamp(ns) + +query T +SELECT arrow_typeof(to_timestamp_seconds(arrow_cast(0, 'Int8'))); +---- +Timestamp(s) + +query T +SELECT arrow_typeof(to_timestamp_seconds(arrow_cast(0, 'UInt64'))); +---- +Timestamp(s) + +query T +SELECT arrow_typeof(to_timestamp_seconds(arrow_cast(0.0, 'Float32'))); +---- +Timestamp(s) + +query T +SELECT arrow_typeof(to_timestamp_millis(arrow_cast(0, 'Int8'))); +---- +Timestamp(ms) + +query T +SELECT arrow_typeof(to_timestamp_millis(arrow_cast(0, 'UInt64'))); +---- +Timestamp(ms) + +query T +SELECT arrow_typeof(to_timestamp_millis(arrow_cast(0.0, 'Float32'))); +---- +Timestamp(ms) + +query T +SELECT arrow_typeof(to_timestamp_micros(arrow_cast(0, 'Int8'))); +---- +Timestamp(µs) + +query T +SELECT arrow_typeof(to_timestamp_micros(arrow_cast(0, 'UInt64'))); +---- +Timestamp(µs) + +query T +SELECT arrow_typeof(to_timestamp_micros(arrow_cast(0.0, 'Float32'))); +---- +Timestamp(µs) + +query T +SELECT arrow_typeof(to_timestamp_nanos(arrow_cast(0, 'Int8'))); +---- +Timestamp(ns) + +query T +SELECT arrow_typeof(to_timestamp_nanos(arrow_cast(0, 'UInt64'))); +---- +Timestamp(ns) + +query T +SELECT arrow_typeof(to_timestamp_nanos(arrow_cast(0.0, 'Float32'))); +---- +Timestamp(ns) + +# Test decimal type support for all to_timestamp functions +# Decimal32 +query P +SELECT to_timestamp(arrow_cast(1.5, 'Decimal32(5,1)')); +---- +1970-01-01T00:00:01.500 + +query P +SELECT to_timestamp_seconds(arrow_cast(86400, 'Decimal32(9,0)')); +---- +1970-01-02T00:00:00 + +query P +SELECT to_timestamp_millis(arrow_cast(1000, 'Decimal32(9,0)')); +---- +1970-01-01T00:00:01 + +query P +SELECT to_timestamp_micros(arrow_cast(1000000, 'Decimal32(9,0)')); +---- +1970-01-01T00:00:01 + +query P +SELECT to_timestamp_nanos(arrow_cast(1000000, 'Decimal32(9,0)')); +---- +1970-01-01T00:00:00.001 + +# Decimal64 +query P +SELECT to_timestamp(arrow_cast(1.5, 'Decimal64(10,1)')); +---- +1970-01-01T00:00:01.500 + +query P +SELECT to_timestamp_seconds(arrow_cast(86400, 'Decimal64(18,0)')); +---- +1970-01-02T00:00:00 + +query P +SELECT to_timestamp_millis(arrow_cast(86400000, 'Decimal64(18,0)')); +---- +1970-01-02T00:00:00 + +query P +SELECT to_timestamp_micros(arrow_cast(86400000000, 'Decimal64(18,0)')); +---- +1970-01-02T00:00:00 + +query P +SELECT to_timestamp_nanos(arrow_cast(86400000000000, 'Decimal64(18,0)')); +---- +1970-01-02T00:00:00 + +# Decimal128 +query P +SELECT to_timestamp(arrow_cast(1.5, 'Decimal128(10,1)')); +---- +1970-01-01T00:00:01.500 + +query P +SELECT to_timestamp_seconds(arrow_cast(86400, 'Decimal128(10,0)')); +---- +1970-01-02T00:00:00 + +query P +SELECT to_timestamp_millis(arrow_cast(86400000, 'Decimal128(15,0)')); +---- +1970-01-02T00:00:00 + +query P +SELECT to_timestamp_micros(arrow_cast(86400000000, 'Decimal128(15,0)')); +---- +1970-01-02T00:00:00 + +query P +SELECT to_timestamp_nanos(arrow_cast(86400000000000, 'Decimal128(20,0)')); +---- +1970-01-02T00:00:00 + +# Decimal256 +query P +SELECT to_timestamp(arrow_cast(1.5, 'Decimal256(10,1)')); +---- +1970-01-01T00:00:01.500 + +query P +SELECT to_timestamp_seconds(arrow_cast(86400, 'Decimal256(38,0)')); +---- +1970-01-02T00:00:00 + +query P +SELECT to_timestamp_millis(arrow_cast(86400000, 'Decimal256(38,0)')); +---- +1970-01-02T00:00:00 + +query P +SELECT to_timestamp_micros(arrow_cast(86400000000, 'Decimal256(38,0)')); +---- +1970-01-02T00:00:00 + +query P +SELECT to_timestamp_nanos(arrow_cast(86400000000000, 'Decimal256(38,0)')); +---- +1970-01-02T00:00:00 + +# Verify arrow_typeof for decimal inputs +query T +SELECT arrow_typeof(to_timestamp(arrow_cast(0, 'Decimal128(10,0)'))); +---- +Timestamp(ns) + +query T +SELECT arrow_typeof(to_timestamp_seconds(arrow_cast(0, 'Decimal128(10,0)'))); +---- +Timestamp(s) + +query T +SELECT arrow_typeof(to_timestamp_millis(arrow_cast(0, 'Decimal128(10,0)'))); +---- +Timestamp(ms) + +query T +SELECT arrow_typeof(to_timestamp_micros(arrow_cast(0, 'Decimal128(10,0)'))); +---- +Timestamp(µs) + +query T +SELECT arrow_typeof(to_timestamp_nanos(arrow_cast(0, 'Decimal128(10,0)'))); +---- +Timestamp(ns) + +# Test decimal array inputs for to_timestamp +statement ok +CREATE TABLE test_decimal_timestamps ( + d128 DECIMAL(20, 9), + d256 DECIMAL(40, 9) +) AS VALUES + (1.5, 1.5), + (86400.123456789, 86400.123456789), + (0.0, 0.0), + (NULL, NULL); + +query P +SELECT to_timestamp(d128) FROM test_decimal_timestamps ORDER BY d128 NULLS LAST; +---- +1970-01-01T00:00:00 +1970-01-01T00:00:01.500 +1970-01-02T00:00:00.123456789 +NULL + +query P +SELECT to_timestamp(d256) FROM test_decimal_timestamps ORDER BY d256 NULLS LAST; +---- +1970-01-01T00:00:00 +1970-01-01T00:00:01.500 +1970-01-02T00:00:00.123456789 +NULL + +statement ok +DROP TABLE test_decimal_timestamps; + +# Test negative values +# to_timestamp with negative seconds +# Int8 +query P +SELECT to_timestamp(arrow_cast(-1, 'Int8')); +---- +1969-12-31T23:59:59 + +# Int16 +query P +SELECT to_timestamp(arrow_cast(-1, 'Int16')); +---- +1969-12-31T23:59:59 + +# Int32 +query P +SELECT to_timestamp(arrow_cast(-86400, 'Int32')); +---- +1969-12-31T00:00:00 + +# Int64 +query P +SELECT to_timestamp(arrow_cast(-1, 'Int64')); +---- +1969-12-31T23:59:59 + +# Float64 +query P +SELECT to_timestamp(arrow_cast(-0.5, 'Float64')); +---- +1969-12-31T23:59:59.500 + +# to_timestamp_seconds with negative values +# Int8 +query P +SELECT to_timestamp_seconds(arrow_cast(-1, 'Int8')); +---- +1969-12-31T23:59:59 + +# Int16 +query P +SELECT to_timestamp_seconds(arrow_cast(-1, 'Int16')); +---- +1969-12-31T23:59:59 + +# Int32 +query P +SELECT to_timestamp_seconds(arrow_cast(-86400, 'Int32')); +---- +1969-12-31T00:00:00 + +# Int64 +query P +SELECT to_timestamp_seconds(arrow_cast(-1, 'Int64')); +---- +1969-12-31T23:59:59 + +# to_timestamp_millis with negative values +# Int8 +query P +SELECT to_timestamp_millis(arrow_cast(-1, 'Int8')); +---- +1969-12-31T23:59:59.999 + +# Int16 +query P +SELECT to_timestamp_millis(arrow_cast(-1, 'Int16')); +---- +1969-12-31T23:59:59.999 + +# Int32 +query P +SELECT to_timestamp_millis(arrow_cast(-1000, 'Int32')); +---- +1969-12-31T23:59:59 + +# Int64 +query P +SELECT to_timestamp_millis(arrow_cast(-1, 'Int64')); +---- +1969-12-31T23:59:59.999 + +# to_timestamp_micros with negative values +# Int8 +query P +SELECT to_timestamp_micros(arrow_cast(-1, 'Int8')); +---- +1969-12-31T23:59:59.999999 + +# Int16 +query P +SELECT to_timestamp_micros(arrow_cast(-1, 'Int16')); +---- +1969-12-31T23:59:59.999999 + +# Int32 +query P +SELECT to_timestamp_micros(arrow_cast(-1000000, 'Int32')); +---- +1969-12-31T23:59:59 + +# Int64 +query P +SELECT to_timestamp_micros(arrow_cast(-1, 'Int64')); +---- +1969-12-31T23:59:59.999999 + +# to_timestamp_nanos with negative values +# Int8 +query P +SELECT to_timestamp_nanos(arrow_cast(-1, 'Int8')); +---- +1969-12-31T23:59:59.999999999 + +# Int16 +query P +SELECT to_timestamp_nanos(arrow_cast(-1, 'Int16')); +---- +1969-12-31T23:59:59.999999999 + +# Int32 +query P +SELECT to_timestamp_nanos(arrow_cast(-1000000000, 'Int32')); +---- +1969-12-31T23:59:59 + +# Int64 +query P +SELECT to_timestamp_nanos(arrow_cast(-1000000000, 'Int64')); +---- +1969-12-31T23:59:59 + +query P +SELECT to_timestamp_nanos(arrow_cast(-1, 'Int64')); +---- +1969-12-31T23:59:59.999999999 + +# Test large unsigned values +query P +SELECT to_timestamp_seconds(arrow_cast(4294967295, 'UInt64')); +---- +2106-02-07T06:28:15 + +# Large UInt64 value for milliseconds +query P +SELECT to_timestamp_millis(arrow_cast(4294967295000, 'UInt64')); +---- +2106-02-07T06:28:15 + +# Test UInt64 value larger than i64::MAX (9223372036854775808 = i64::MAX + 1) +query error Cast error: Can't cast value 9223372036854775808 to type Int64 +SELECT to_timestamp_nanos(arrow_cast(9223372036854775808, 'UInt64')); + +# Test boundary values for to_timestamp +query P +SELECT to_timestamp(arrow_cast(9223372036, 'Int64')); +---- +2262-04-11T23:47:16 + +# Minimum value for to_timestamp +query P +SELECT to_timestamp(arrow_cast(-9223372036, 'Int64')); +---- +1677-09-21T00:12:44 + +# Overflow error when value exceeds valid range +query error Arithmetic overflow +SELECT to_timestamp(arrow_cast(9223372037, 'Int64')); + +# Float truncation behavior +query P +SELECT to_timestamp_seconds(arrow_cast(-1.9, 'Float64')); +---- +1969-12-31T23:59:59 + +query P +SELECT to_timestamp_millis(arrow_cast(-1.9, 'Float64')); +---- +1969-12-31T23:59:59.999 + + +########## +## Common timestamp data +########## + +statement ok +drop table ts_data + +statement ok +drop table ts_data_nanos + +statement ok +drop table ts_data_micros + +statement ok +drop table ts_data_millis + +statement ok +drop table ts_data_secs + +statement ok +drop table ts_data_micros_kolkata + +########## +## Test to_timestamp with scalar float inputs +########## + +statement ok +create table test_to_timestamp_scalar(id int, name varchar) as values + (1, 'foo'), + (2, 'bar'); + +query P +SELECT to_timestamp(123.5, name) FROM test_to_timestamp_scalar ORDER BY id; +---- +1970-01-01T00:02:03.500 +1970-01-01T00:02:03.500 + +query P +SELECT to_timestamp(456.789::float, name) FROM test_to_timestamp_scalar ORDER BY id; +---- +1970-01-01T00:07:36.789001464 +1970-01-01T00:07:36.789001464 + +query P +SELECT to_timestamp(arrow_cast(100.5, 'Float16'), name) FROM test_to_timestamp_scalar ORDER BY id; +---- +1970-01-01T00:01:40.500 +1970-01-01T00:01:40.500 + +statement ok +drop table test_to_timestamp_scalar diff --git a/datafusion/sqllogictest/test_files/decimal.slt b/datafusion/sqllogictest/test_files/decimal.slt index 9dd31427dcb4..eca2c88bb5f8 100644 --- a/datafusion/sqllogictest/test_files/decimal.slt +++ b/datafusion/sqllogictest/test_files/decimal.slt @@ -782,7 +782,7 @@ query TR select arrow_typeof(round(173975140545.855, 2)), round(173975140545.855, 2); ---- -Decimal128(15, 3) 173975140545.86 +Decimal128(15, 2) 173975140545.86 # smoke test for decimal parsing query RT @@ -868,9 +868,11 @@ select log(100000000000000000000000000000000000::decimal(76,0)); ---- 35 -# log(10^50) for decimal256 for a value larger than i128 -query error Arrow error: Not yet implemented: Log of Decimal256 larger than Decimal128 is not yet supported +# log(10^50) for decimal256 for a value larger than i128 (uses f64 fallback) +query R select log(100000000000000000000000000000000000000000000000000::decimal(76,0)); +---- +50 # log(10^35) for decimal128 with explicit base query R @@ -904,6 +906,12 @@ select log(2.0, 100000000000000000000000000000000000::decimal(38,0)); ---- 116 +# log with non-integer base (fallback to f64) +query R +select log(2.5, 100::decimal(38,0)); +---- +5.025883189464 + # null cases query R select log(null, 100); @@ -1087,8 +1095,17 @@ SELECT power(2, 100000000000) ---- Infinity -query error Arrow error: Arithmetic overflow: Unsupported exp value -SELECT power(2::decimal(38, 0), -5) +# Negative exponent now works (fallback to f64) +query RT +SELECT power(2::decimal(38, 0), -5), arrow_typeof(power(2::decimal(38, 0), -5)); +---- +0 Decimal128(38, 0) + +# Negative exponent with scale preserves decimal places +query RT +SELECT power(4::decimal(38, 5), -1), arrow_typeof(power(4::decimal(38, 5), -1)); +---- +0.25 Decimal128(38, 5) # Expected to have `16 Decimal128(38, 0)` # Due to type coericion, it becomes Float -> Float -> Float @@ -1108,20 +1125,23 @@ SELECT power(2.5, 4.0), arrow_typeof(power(2.5, 4.0)); ---- 39 Decimal128(2, 1) -query error Compute error: Cannot use non-integer exp +# Non-integer exponent now works (fallback to f64) +query RT SELECT power(2.5, 4.2), arrow_typeof(power(2.5, 4.2)); +---- +46.9 Decimal128(2, 1) -query error Compute error: Cannot use non-integer exp: NaN +query error Compute error: Cannot use non-finite exp: NaN SELECT power(2::decimal(38, 0), arrow_cast('NaN','Float64')) -query error Compute error: Cannot use non-integer exp: inf +query error Compute error: Cannot use non-finite exp: inf SELECT power(2::decimal(38, 0), arrow_cast('INF','Float64')) -# Floating above u32::max -query error Compute error: Cannot use non-integer exp +# Floating above u32::max now works (fallback to f64, returns infinity which is an error) +query error Arrow error: Arithmetic overflow: Result of 2\^5000000000.1 is not finite SELECT power(2::decimal(38, 0), 5000000000.1) -# Integer Above u32::max +# Integer Above u32::max - still goes through integer path which fails query error Arrow error: Arithmetic overflow: Unsupported exp value SELECT power(2::decimal(38, 0), 5000000000) diff --git a/datafusion/sqllogictest/test_files/delete.slt b/datafusion/sqllogictest/test_files/delete.slt index e86343b6bf5f..b01eb6f5e9ec 100644 --- a/datafusion/sqllogictest/test_files/delete.slt +++ b/datafusion/sqllogictest/test_files/delete.slt @@ -113,3 +113,30 @@ logical_plan 05)--------TableScan: t2 06)----TableScan: t1 physical_plan_error This feature is not implemented: Physical plan does not support logical expression InSubquery(InSubquery { expr: Column(Column { relation: Some(Bare { table: "t1" }), name: "a" }), subquery: , negated: false }) + + +# Delete with limit + +query TT +explain delete from t1 limit 10 +---- +logical_plan +01)Dml: op=[Delete] table=[t1] +02)--Limit: skip=0, fetch=10 +03)----TableScan: t1 +physical_plan +01)CooperativeExec +02)--DmlResultExec: rows_affected=0 + + +query TT +explain delete from t1 where a = 1 and b = '2' limit 10 +---- +logical_plan +01)Dml: op=[Delete] table=[t1] +02)--Limit: skip=0, fetch=10 +03)----Filter: CAST(t1.a AS Int64) = Int64(1) AND t1.b = CAST(Utf8("2") AS Utf8View) +04)------TableScan: t1 +physical_plan +01)CooperativeExec +02)--DmlResultExec: rows_affected=0 diff --git a/datafusion/sqllogictest/test_files/dictionary.slt b/datafusion/sqllogictest/test_files/dictionary.slt index b6098758a9e6..511061cf82f0 100644 --- a/datafusion/sqllogictest/test_files/dictionary.slt +++ b/datafusion/sqllogictest/test_files/dictionary.slt @@ -36,7 +36,7 @@ SELECT arrow_cast(column3, 'Utf8') as f2, arrow_cast(column4, 'Utf8') as f3, arrow_cast(column5, 'Float64') as f4, - arrow_cast(column6, 'Timestamp(Nanosecond, None)') as time + arrow_cast(column6, 'Timestamp(ns)') as time FROM ( VALUES -- equivalent to the following line protocol data @@ -111,7 +111,7 @@ SELECT arrow_cast(column1, 'Dictionary(Int32, Utf8)') as type, arrow_cast(column2, 'Dictionary(Int32, Utf8)') as tag_id, arrow_cast(column3, 'Float64') as f5, - arrow_cast(column4, 'Timestamp(Nanosecond, None)') as time + arrow_cast(column4, 'Timestamp(ns)') as time FROM ( VALUES -- equivalent to the following line protocol data diff --git a/datafusion/sqllogictest/test_files/dynamic_filter_pushdown_config.slt b/datafusion/sqllogictest/test_files/dynamic_filter_pushdown_config.slt index 3e403171e071..cbf9f81e425f 100644 --- a/datafusion/sqllogictest/test_files/dynamic_filter_pushdown_config.slt +++ b/datafusion/sqllogictest/test_files/dynamic_filter_pushdown_config.slt @@ -92,6 +92,30 @@ physical_plan 01)SortExec: TopK(fetch=3), expr=[value@1 DESC], preserve_partitioning=[false] 02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/test_data.parquet]]}, projection=[id, value, name], file_type=parquet, predicate=DynamicFilter [ empty ] +statement ok +set datafusion.explain.analyze_level = summary; + +query TT +EXPLAIN ANALYZE SELECT id, value AS v, value + id as name FROM test_parquet where value > 3 ORDER BY v DESC LIMIT 3; +---- +Plan with Metrics +01)SortPreservingMergeExec: [v@1 DESC], fetch=3, metrics=[output_rows=3, ] +02)--SortExec: TopK(fetch=3), expr=[v@1 DESC], preserve_partitioning=[true], filter=[v@1 IS NULL OR v@1 > 800], metrics=[output_rows=3, ] +03)----ProjectionExec: expr=[id@0 as id, value@1 as v, value@1 + id@0 as name], metrics=[output_rows=10, ] +04)------FilterExec: value@1 > 3, metrics=[output_rows=10, , selectivity=100% (10/10)] +05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, metrics=[output_rows=10, ] +06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/test_data.parquet]]}, projection=[id, value], file_type=parquet, predicate=value@1 > 3 AND DynamicFilter [ value@1 IS NULL OR value@1 > 800 ], pruning_predicate=value_null_count@1 != row_count@2 AND value_max@0 > 3 AND (value_null_count@1 > 0 OR value_null_count@1 != row_count@2 AND value_max@0 > 800), required_guarantees=[], metrics=[output_rows=10, elapsed_compute=1ns, output_bytes=80.0 B, files_ranges_pruned_statistics=1 total → 1 matched, row_groups_pruned_statistics=1 total → 1 matched -> 1 fully matched, row_groups_pruned_bloom_filter=1 total → 1 matched, page_index_pages_pruned=1 total → 1 matched, limit_pruned_row_groups=0 total → 0 matched, bytes_scanned=210, metadata_load_time=, scan_efficiency_ratio=18% (210/1.15 K)] + +statement ok +set datafusion.explain.analyze_level = dev; + +query III +SELECT id, value AS v, value + id as name FROM test_parquet where value > 3 ORDER BY v DESC LIMIT 3; +---- +10 1000 1010 +9 900 909 +8 800 808 + # Disable TopK dynamic filter pushdown statement ok SET datafusion.optimizer.enable_topk_dynamic_filter_pushdown = false; @@ -106,6 +130,13 @@ physical_plan 01)SortExec: TopK(fetch=3), expr=[value@1 DESC], preserve_partitioning=[false] 02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/test_data.parquet]]}, projection=[id, value, name], file_type=parquet +query IIT +SELECT id, value AS v, name FROM (SELECT * FROM test_parquet UNION ALL SELECT * FROM test_parquet) ORDER BY v DESC LIMIT 3; +---- +10 1000 j +10 1000 j +9 900 i + # Re-enable for next tests statement ok SET datafusion.optimizer.enable_topk_dynamic_filter_pushdown = true; @@ -156,6 +187,197 @@ physical_plan statement ok SET datafusion.optimizer.enable_join_dynamic_filter_pushdown = true; +# Test 2b: Dynamic filter pushdown for non-inner join types +# LEFT JOIN: optimizer swaps to physical Right join (build=right_parquet, probe=left_parquet). +# Dynamic filter is NOT pushed because Right join needs all probe rows in output. +query TT +EXPLAIN SELECT l.*, r.info +FROM left_parquet l +LEFT JOIN right_parquet r ON l.id = r.id; +---- +logical_plan +01)Projection: l.id, l.data, r.info +02)--Left Join: l.id = r.id +03)----SubqueryAlias: l +04)------TableScan: left_parquet projection=[id, data] +05)----SubqueryAlias: r +06)------TableScan: right_parquet projection=[id, info] +physical_plan +01)ProjectionExec: expr=[id@1 as id, data@2 as data, info@0 as info] +02)--HashJoinExec: mode=CollectLeft, join_type=Right, on=[(id@0, id@0)], projection=[info@1, id@2, data@3] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_right.parquet]]}, projection=[id, info], file_type=parquet +04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_left.parquet]]}, projection=[id, data], file_type=parquet + +# LEFT JOIN correctness: all left rows appear, unmatched right rows produce NULLs +query ITT +SELECT l.id, l.data, r.info +FROM left_parquet l +LEFT JOIN right_parquet r ON l.id = r.id +ORDER BY l.id; +---- +1 left1 right1 +2 left2 NULL +3 left3 right3 +4 left4 NULL +5 left5 right5 + +# RIGHT JOIN: optimizer swaps to physical Left join (build=right_parquet, probe=left_parquet). +# No self-generated dynamic filter (only Inner joins get that), but parent filters +# on the preserved (build) side can still push down. +query TT +EXPLAIN SELECT l.*, r.info +FROM left_parquet l +RIGHT JOIN right_parquet r ON l.id = r.id; +---- +logical_plan +01)Projection: l.id, l.data, r.info +02)--Right Join: l.id = r.id +03)----SubqueryAlias: l +04)------TableScan: left_parquet projection=[id, data] +05)----SubqueryAlias: r +06)------TableScan: right_parquet projection=[id, info] +physical_plan +01)ProjectionExec: expr=[id@1 as id, data@2 as data, info@0 as info] +02)--HashJoinExec: mode=CollectLeft, join_type=Left, on=[(id@0, id@0)], projection=[info@1, id@2, data@3] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_right.parquet]]}, projection=[id, info], file_type=parquet +04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_left.parquet]]}, projection=[id, data], file_type=parquet + +# RIGHT JOIN correctness: all right rows appear, unmatched left rows produce NULLs +query ITT +SELECT l.id, l.data, r.info +FROM left_parquet l +RIGHT JOIN right_parquet r ON l.id = r.id +ORDER BY r.id; +---- +1 left1 right1 +3 left3 right3 +5 left5 right5 + +# FULL JOIN: dynamic filter should NOT be pushed (both sides must preserve all rows) +query TT +EXPLAIN SELECT l.id, r.id as rid, l.data, r.info +FROM left_parquet l +FULL JOIN right_parquet r ON l.id = r.id; +---- +logical_plan +01)Projection: l.id, r.id AS rid, l.data, r.info +02)--Full Join: l.id = r.id +03)----SubqueryAlias: l +04)------TableScan: left_parquet projection=[id, data] +05)----SubqueryAlias: r +06)------TableScan: right_parquet projection=[id, info] +physical_plan +01)ProjectionExec: expr=[id@2 as id, id@0 as rid, data@3 as data, info@1 as info] +02)--HashJoinExec: mode=CollectLeft, join_type=Full, on=[(id@0, id@0)] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_right.parquet]]}, projection=[id, info], file_type=parquet +04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_left.parquet]]}, projection=[id, data], file_type=parquet + +# LEFT SEMI JOIN: optimizer swaps to RightSemi (build=right_parquet, probe=left_parquet). +# No self-generated dynamic filter (only Inner joins), but parent filters on +# the preserved (probe) side can push down. +query TT +EXPLAIN SELECT l.* +FROM left_parquet l +WHERE l.id IN (SELECT r.id FROM right_parquet r); +---- +logical_plan +01)LeftSemi Join: l.id = __correlated_sq_1.id +02)--SubqueryAlias: l +03)----TableScan: left_parquet projection=[id, data] +04)--SubqueryAlias: __correlated_sq_1 +05)----SubqueryAlias: r +06)------TableScan: right_parquet projection=[id] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(id@0, id@0)] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_right.parquet]]}, projection=[id], file_type=parquet +03)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_left.parquet]]}, projection=[id, data], file_type=parquet + +# LEFT ANTI JOIN: no self-generated dynamic filter, but parent filters can push +# to the preserved (left/build) side. +query TT +EXPLAIN SELECT l.* +FROM left_parquet l +WHERE l.id NOT IN (SELECT r.id FROM right_parquet r); +---- +logical_plan +01)LeftAnti Join: l.id = __correlated_sq_1.id +02)--SubqueryAlias: l +03)----TableScan: left_parquet projection=[id, data] +04)--SubqueryAlias: __correlated_sq_1 +05)----SubqueryAlias: r +06)------TableScan: right_parquet projection=[id] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=LeftAnti, on=[(id@0, id@0)] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_left.parquet]]}, projection=[id, data], file_type=parquet +03)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_right.parquet]]}, projection=[id], file_type=parquet + +# Test 2c: Parent dynamic filter (from TopK) pushed through semi/anti joins +# Sort on the join key (id) so the TopK dynamic filter pushes to BOTH sides. + +# SEMI JOIN with TopK parent: TopK generates a dynamic filter on `id` (join key) +# that pushes through the RightSemi join to both the build and probe sides. +query TT +EXPLAIN SELECT l.* +FROM left_parquet l +WHERE l.id IN (SELECT r.id FROM right_parquet r) +ORDER BY l.id LIMIT 2; +---- +logical_plan +01)Sort: l.id ASC NULLS LAST, fetch=2 +02)--LeftSemi Join: l.id = __correlated_sq_1.id +03)----SubqueryAlias: l +04)------TableScan: left_parquet projection=[id, data] +05)----SubqueryAlias: __correlated_sq_1 +06)------SubqueryAlias: r +07)--------TableScan: right_parquet projection=[id] +physical_plan +01)SortExec: TopK(fetch=2), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(id@0, id@0)] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_right.parquet]]}, projection=[id], file_type=parquet, predicate=DynamicFilter [ empty ] +04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_left.parquet]]}, projection=[id, data], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Correctness check +query IT +SELECT l.* +FROM left_parquet l +WHERE l.id IN (SELECT r.id FROM right_parquet r) +ORDER BY l.id LIMIT 2; +---- +1 left1 +3 left3 + +# ANTI JOIN with TopK parent: TopK generates a dynamic filter on `id` (join key) +# that pushes through the LeftAnti join to both the preserved and non-preserved sides. +query TT +EXPLAIN SELECT l.* +FROM left_parquet l +WHERE l.id NOT IN (SELECT r.id FROM right_parquet r) +ORDER BY l.id LIMIT 2; +---- +logical_plan +01)Sort: l.id ASC NULLS LAST, fetch=2 +02)--LeftAnti Join: l.id = __correlated_sq_1.id +03)----SubqueryAlias: l +04)------TableScan: left_parquet projection=[id, data] +05)----SubqueryAlias: __correlated_sq_1 +06)------SubqueryAlias: r +07)--------TableScan: right_parquet projection=[id] +physical_plan +01)SortExec: TopK(fetch=2), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--HashJoinExec: mode=CollectLeft, join_type=LeftAnti, on=[(id@0, id@0)] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_left.parquet]]}, projection=[id, data], file_type=parquet, predicate=DynamicFilter [ empty ] +04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_right.parquet]]}, projection=[id], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Correctness check +query IT +SELECT l.* +FROM left_parquet l +WHERE l.id NOT IN (SELECT r.id FROM right_parquet r) +ORDER BY l.id LIMIT 2; +---- +2 left2 +4 left4 + # Test 3: Test independent control # Disable TopK, keep Join enabled @@ -257,6 +479,25 @@ physical_plan 04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/agg_data.parquet]]}, projection=[score], file_type=parquet, predicate=category@0 = alpha AND DynamicFilter [ empty ], pruning_predicate=category_null_count@2 != row_count@3 AND category_min@0 <= alpha AND alpha <= category_max@1, required_guarantees=[category in (alpha)] +# Test 4b: COUNT + MAX — DynamicFilter should NOT appear here in mixed aggregates + +query TT +EXPLAIN SELECT COUNT(*), MAX(score) FROM agg_parquet WHERE category = 'alpha'; +---- +logical_plan +01)Projection: count(Int64(1)) AS count(*), max(agg_parquet.score) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1)), max(agg_parquet.score)]] +03)----Projection: agg_parquet.score +04)------Filter: agg_parquet.category = Utf8View("alpha") +05)--------TableScan: agg_parquet projection=[category, score], partial_filters=[agg_parquet.category = Utf8View("alpha")] +physical_plan +01)ProjectionExec: expr=[count(Int64(1))@0 as count(*), max(agg_parquet.score)@1 as max(agg_parquet.score)] +02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1)), max(agg_parquet.score)] +03)----CoalescePartitionsExec +04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1)), max(agg_parquet.score)] +05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/agg_data.parquet]]}, projection=[score], file_type=parquet, predicate=category@0 = alpha, pruning_predicate=category_null_count@2 != row_count@3 AND category_min@0 <= alpha AND alpha <= category_max@1, required_guarantees=[category in (alpha)] + # Disable aggregate dynamic filters only statement ok SET datafusion.optimizer.enable_aggregate_dynamic_filter_pushdown = false; @@ -388,6 +629,97 @@ physical_plan 03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_right.parquet]]}, projection=[id, info], file_type=parquet 04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_left.parquet]]}, projection=[id, data], file_type=parquet, predicate=DynamicFilter [ empty ] +# Test 6: Regression test for issue #20213 - dynamic filter applied to wrong table +# when subquery join has same column names on both sides. +# +# The bug: when an outer join pushes a DynamicFilter for column "k" through an +# inner join where both sides have a column named "k", the name-based routing +# incorrectly pushed the filter to BOTH sides instead of only the correct one. +# This caused wrong results (0 rows instead of expected). + +# Create tables with same column names (k, v) on both sides +statement ok +CREATE TABLE issue_20213_t1(k INT, v INT) AS +SELECT i as k, i as v FROM generate_series(1, 1000) t(i); + +statement ok +CREATE TABLE issue_20213_t2(k INT, v INT) AS +SELECT i + 100 as k, i as v FROM generate_series(1, 100) t(i); + +# Use small row groups to make statistics-based pruning more likely to manifest the bug +statement ok +SET datafusion.execution.parquet.max_row_group_size = 10; + +query I +COPY issue_20213_t1 TO 'test_files/scratch/dynamic_filter_pushdown_config/issue_20213_t1.parquet' STORED AS PARQUET; +---- +1000 + +query I +COPY issue_20213_t2 TO 'test_files/scratch/dynamic_filter_pushdown_config/issue_20213_t2.parquet' STORED AS PARQUET; +---- +100 + +# Reset row group size +statement ok +SET datafusion.execution.parquet.max_row_group_size = 1000000; + +statement ok +CREATE EXTERNAL TABLE t1_20213(k INT, v INT) +STORED AS PARQUET +LOCATION 'test_files/scratch/dynamic_filter_pushdown_config/issue_20213_t1.parquet'; + +statement ok +CREATE EXTERNAL TABLE t2_20213(k INT, v INT) +STORED AS PARQUET +LOCATION 'test_files/scratch/dynamic_filter_pushdown_config/issue_20213_t2.parquet'; + +# The query from issue #20213: subquery joins t1 and t2 on v, then outer +# join uses t2's k column. The dynamic filter on k from the outer join +# must only apply to t2 (k range 101-200), NOT to t1 (k range 1-1000). +query I +SELECT count(*) FROM ( + SELECT t2_20213.k as k, t1_20213.k as k2 + FROM t1_20213 + JOIN t2_20213 ON t1_20213.v = t2_20213.v +) a +JOIN t2_20213 b ON a.k = b.k +WHERE b.v < 10; +---- +9 + +# Also verify with SELECT * to catch row-level correctness +query IIII rowsort +SELECT * FROM ( + SELECT t2_20213.k as k, t1_20213.k as k2 + FROM t1_20213 + JOIN t2_20213 ON t1_20213.v = t2_20213.v +) a +JOIN t2_20213 b ON a.k = b.k +WHERE b.v < 10; +---- +101 1 101 1 +102 2 102 2 +103 3 103 3 +104 4 104 4 +105 5 105 5 +106 6 106 6 +107 7 107 7 +108 8 108 8 +109 9 109 9 + +statement ok +DROP TABLE issue_20213_t1; + +statement ok +DROP TABLE issue_20213_t2; + +statement ok +DROP TABLE t1_20213; + +statement ok +DROP TABLE t2_20213; + # Cleanup statement ok diff --git a/datafusion/sqllogictest/test_files/encoding.slt b/datafusion/sqllogictest/test_files/encoding.slt index ef91eade01e5..b04d5061825b 100644 --- a/datafusion/sqllogictest/test_files/encoding.slt +++ b/datafusion/sqllogictest/test_files/encoding.slt @@ -20,21 +20,41 @@ SELECT encode(arrow_cast('tom', 'Utf8View'),'base64'); ---- dG9t +query T +SELECT encode(arrow_cast('tommy', 'Utf8View'),'base64pad'); +---- +dG9tbXk= + query T SELECT arrow_cast(decode(arrow_cast('dG9t', 'Utf8View'),'base64'), 'Utf8'); ---- tom +query T +SELECT arrow_cast(decode(arrow_cast('dG9tbXk=', 'Utf8View'),'base64pad'), 'Utf8'); +---- +tommy + query T SELECT encode(arrow_cast('tom', 'BinaryView'),'base64'); ---- dG9t +query T +SELECT encode(arrow_cast('tommy', 'BinaryView'),'base64pad'); +---- +dG9tbXk= + query T SELECT arrow_cast(decode(arrow_cast('dG9t', 'BinaryView'),'base64'), 'Utf8'); ---- tom +query T +SELECT arrow_cast(decode(arrow_cast('dG9tbXk=', 'BinaryView'),'base64pad'), 'Utf8'); +---- +tommy + # test for hex digest query T select encode(digest('hello', 'sha256'), 'hex'); @@ -55,16 +75,16 @@ CREATE TABLE test( ; # errors -query error DataFusion error: Error during planning: Internal error: Expect TypeSignatureClass::Binary but received NativeType::Int64, DataType: Int64 +query error DataFusion error: Error during planning: Function 'encode' requires TypeSignatureClass::Binary, but received Int64 \(DataType: Int64\) select encode(12, 'hex'); -query error DataFusion error: Error during planning: Internal error: Expect TypeSignatureClass::Binary but received NativeType::Int64, DataType: Int64 +query error DataFusion error: Error during planning: Function 'decode' requires TypeSignatureClass::Binary, but received Int64 \(DataType: Int64\) select decode(12, 'hex'); -query error DataFusion error: Error during planning: There is no built\-in encoding named 'non_encoding', currently supported encodings are: base64, hex +query error DataFusion error: Error during planning: There is no built\-in encoding named 'non_encoding', currently supported encodings are: base64, base64pad, hex select encode('', 'non_encoding'); -query error DataFusion error: Error during planning: There is no built\-in encoding named 'non_encoding', currently supported encodings are: base64, hex +query error DataFusion error: Error during planning: There is no built\-in encoding named 'non_encoding', currently supported encodings are: base64, base64pad, hex select decode('', 'non_encoding'); query error DataFusion error: Execution error: Encoding must be a non-null string @@ -73,7 +93,7 @@ select decode('', null) from test; query error DataFusion error: This feature is not implemented: Encoding must be a scalar; array specified encoding is not yet supported select decode('', hex_field) from test; -query error DataFusion error: Error during planning: Internal error: Expect TypeSignatureClass::Integer but received NativeType::String, DataType: Utf8View +query error DataFusion error: Error during planning: Function 'to_hex' requires TypeSignatureClass::Integer, but received String \(DataType: Utf8View\) select to_hex(hex_field) from test; query error DataFusion error: Execution error: Failed to decode value using base64 @@ -124,11 +144,21 @@ select encode(bin_field, 'base64') FROM test WHERE num = 3; ---- j1DT9g6uNw3b+FyGIZxVEIo1AWU +query T +select encode(bin_field, 'base64pad') FROM test WHERE num = 3; +---- +j1DT9g6uNw3b+FyGIZxVEIo1AWU= + query B select decode(encode(bin_field, 'base64'), 'base64') = X'8f50d3f60eae370ddbf85c86219c55108a350165' FROM test WHERE num = 3; ---- true +query B +select decode(encode(bin_field, 'base64pad'), 'base64pad') = X'8f50d3f60eae370ddbf85c86219c55108a350165' FROM test WHERE num = 3; +---- +true + statement ok drop table test @@ -144,18 +174,20 @@ FROM VALUES ('Raphael', 'R'), (NULL, 'R'); -query TTTT +query TTTTTT SELECT encode(column1_utf8view, 'base64') AS column1_base64, + encode(column1_utf8view, 'base64pad') AS column1_base64pad, encode(column1_utf8view, 'hex') AS column1_hex, encode(column2_utf8view, 'base64') AS column2_base64, + encode(column2_utf8view, 'base64pad') AS column2_base64pad, encode(column2_utf8view, 'hex') AS column2_hex FROM test_utf8view; ---- -QW5kcmV3 416e64726577 WA 58 -WGlhbmdwZW5n 5869616e6770656e67 WGlhbmdwZW5n 5869616e6770656e67 -UmFwaGFlbA 5261706861656c Ug 52 -NULL NULL Ug 52 +QW5kcmV3 QW5kcmV3 416e64726577 WA WA== 58 +WGlhbmdwZW5n WGlhbmdwZW5n 5869616e6770656e67 WGlhbmdwZW5n WGlhbmdwZW5n 5869616e6770656e67 +UmFwaGFlbA UmFwaGFlbA== 5261706861656c Ug Ug== 52 +NULL NULL NULL Ug Ug== 52 query TTTTTT SELECT @@ -172,6 +204,22 @@ WGlhbmdwZW5n WGlhbmdwZW5n WGlhbmdwZW5n WGlhbmdwZW5n WGlhbmdwZW5n WGlhbmdwZW5n UmFwaGFlbA UmFwaGFlbA UmFwaGFlbA UmFwaGFlbA UmFwaGFlbA UmFwaGFlbA NULL NULL NULL NULL NULL NULL + +query TTTTTT +SELECT + encode(arrow_cast(column1_utf8view, 'Utf8'), 'base64pad'), + encode(arrow_cast(column1_utf8view, 'LargeUtf8'), 'base64pad'), + encode(arrow_cast(column1_utf8view, 'Utf8View'), 'base64pad'), + encode(arrow_cast(column1_utf8view, 'Binary'), 'base64pad'), + encode(arrow_cast(column1_utf8view, 'LargeBinary'), 'base64pad'), + encode(arrow_cast(column1_utf8view, 'BinaryView'), 'base64pad') +FROM test_utf8view; +---- +QW5kcmV3 QW5kcmV3 QW5kcmV3 QW5kcmV3 QW5kcmV3 QW5kcmV3 +WGlhbmdwZW5n WGlhbmdwZW5n WGlhbmdwZW5n WGlhbmdwZW5n WGlhbmdwZW5n WGlhbmdwZW5n +UmFwaGFlbA== UmFwaGFlbA== UmFwaGFlbA== UmFwaGFlbA== UmFwaGFlbA== UmFwaGFlbA== +NULL NULL NULL NULL NULL NULL + statement ok drop table test_utf8view @@ -180,26 +228,31 @@ statement ok CREATE TABLE test_fsb AS SELECT arrow_cast(X'0123456789ABCDEF', 'FixedSizeBinary(8)') as fsb_col; -query ?? +query ??? SELECT decode(encode(arrow_cast(X'0123456789abcdef', 'FixedSizeBinary(8)'), 'base64'), 'base64'), + decode(encode(arrow_cast(X'0123456789abcdef', 'FixedSizeBinary(8)'), 'base64pad'), 'base64pad'), decode(encode(arrow_cast(X'0123456789abcdef', 'FixedSizeBinary(8)'), 'hex'), 'hex'); ---- -0123456789abcdef 0123456789abcdef +0123456789abcdef 0123456789abcdef 0123456789abcdef -query ?? +query ??? SELECT decode(encode(column1, 'base64'), 'base64'), + decode(encode(column1, 'base64pad'), 'base64pad'), decode(encode(column1, 'hex'), 'hex') FROM values (arrow_cast(X'0123456789abcdef', 'FixedSizeBinary(8)')), (arrow_cast(X'ffffffffffffffff', 'FixedSizeBinary(8)')); ---- -0123456789abcdef 0123456789abcdef -ffffffffffffffff ffffffffffffffff +0123456789abcdef 0123456789abcdef 0123456789abcdef +ffffffffffffffff ffffffffffffffff ffffffffffffffff query error DataFusion error: Execution error: Failed to decode value using base64 select decode('invalid', 'base64'); +query error DataFusion error: Execution error: Failed to decode value using base64pad +select decode('invalid', 'base64pad'); + query error DataFusion error: Execution error: Failed to decode value using hex select decode('invalid', 'hex'); diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 9087aee56d97..c5907d497500 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -176,6 +176,7 @@ initial_logical_plan logical_plan after resolve_grouping_function SAME TEXT AS ABOVE logical_plan after type_coercion SAME TEXT AS ABOVE analyzed_logical_plan SAME TEXT AS ABOVE +logical_plan after rewrite_set_comparison SAME TEXT AS ABOVE logical_plan after optimize_unions SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE @@ -196,7 +197,10 @@ logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE +logical_plan after extract_leaf_expressions SAME TEXT AS ABOVE +logical_plan after push_down_leaf_projections SAME TEXT AS ABOVE logical_plan after optimize_projections TableScan: simple_explain_test projection=[a, b, c] +logical_plan after rewrite_set_comparison SAME TEXT AS ABOVE logical_plan after optimize_unions SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE @@ -217,6 +221,8 @@ logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE +logical_plan after extract_leaf_expressions SAME TEXT AS ABOVE +logical_plan after push_down_leaf_projections SAME TEXT AS ABOVE logical_plan after optimize_projections SAME TEXT AS ABOVE logical_plan TableScan: simple_explain_test projection=[a, b, c] initial_physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], file_type=csv, has_header=true @@ -234,7 +240,6 @@ physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE physical_plan after EnforceSorting SAME TEXT AS ABOVE physical_plan after OptimizeAggregateOrder SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE -physical_plan after coalesce_batches SAME TEXT AS ABOVE physical_plan after OutputRequirements DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], file_type=csv, has_header=true physical_plan after LimitAggregation SAME TEXT AS ABOVE physical_plan after LimitPushPastWindows SAME TEXT AS ABOVE @@ -298,8 +303,8 @@ initial_physical_plan 01)GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]: ScanBytes=Exact(32)),(Col[1]: ScanBytes=Inexact(24)),(Col[2]: ScanBytes=Exact(32)),(Col[3]: ScanBytes=Exact(32)),(Col[4]: ScanBytes=Exact(32)),(Col[5]: ScanBytes=Exact(64)),(Col[6]: ScanBytes=Exact(32)),(Col[7]: ScanBytes=Exact(64)),(Col[8]: ScanBytes=Inexact(88)),(Col[9]: ScanBytes=Inexact(49)),(Col[10]: ScanBytes=Exact(64))]] 02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]: ScanBytes=Exact(32)),(Col[1]: ScanBytes=Inexact(24)),(Col[2]: ScanBytes=Exact(32)),(Col[3]: ScanBytes=Exact(32)),(Col[4]: ScanBytes=Exact(32)),(Col[5]: ScanBytes=Exact(64)),(Col[6]: ScanBytes=Exact(32)),(Col[7]: ScanBytes=Exact(64)),(Col[8]: ScanBytes=Inexact(88)),(Col[9]: ScanBytes=Inexact(49)),(Col[10]: ScanBytes=Exact(64))]] initial_physical_plan_with_schema -01)GlobalLimitExec: skip=0, fetch=10, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(Nanosecond, None);N] -02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(Nanosecond, None);N] +01)GlobalLimitExec: skip=0, fetch=10, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(ns);N] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(ns);N] physical_plan after OutputRequirements 01)OutputRequirementExec: order_by=[], dist_by=Unspecified, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]: ScanBytes=Exact(32)),(Col[1]: ScanBytes=Inexact(24)),(Col[2]: ScanBytes=Exact(32)),(Col[3]: ScanBytes=Exact(32)),(Col[4]: ScanBytes=Exact(32)),(Col[5]: ScanBytes=Exact(64)),(Col[6]: ScanBytes=Exact(32)),(Col[7]: ScanBytes=Exact(64)),(Col[8]: ScanBytes=Inexact(88)),(Col[9]: ScanBytes=Inexact(49)),(Col[10]: ScanBytes=Exact(64))]] 02)--GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]: ScanBytes=Exact(32)),(Col[1]: ScanBytes=Inexact(24)),(Col[2]: ScanBytes=Exact(32)),(Col[3]: ScanBytes=Exact(32)),(Col[4]: ScanBytes=Exact(32)),(Col[5]: ScanBytes=Exact(64)),(Col[6]: ScanBytes=Exact(32)),(Col[7]: ScanBytes=Exact(64)),(Col[8]: ScanBytes=Inexact(88)),(Col[9]: ScanBytes=Inexact(49)),(Col[10]: ScanBytes=Exact(64))]] @@ -313,7 +318,6 @@ physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE physical_plan after EnforceSorting SAME TEXT AS ABOVE physical_plan after OptimizeAggregateOrder SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE -physical_plan after coalesce_batches SAME TEXT AS ABOVE physical_plan after OutputRequirements 01)GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]: ScanBytes=Exact(32)),(Col[1]: ScanBytes=Inexact(24)),(Col[2]: ScanBytes=Exact(32)),(Col[3]: ScanBytes=Exact(32)),(Col[4]: ScanBytes=Exact(32)),(Col[5]: ScanBytes=Exact(64)),(Col[6]: ScanBytes=Exact(32)),(Col[7]: ScanBytes=Exact(64)),(Col[8]: ScanBytes=Inexact(88)),(Col[9]: ScanBytes=Inexact(49)),(Col[10]: ScanBytes=Exact(64))]] 02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]: ScanBytes=Exact(32)),(Col[1]: ScanBytes=Inexact(24)),(Col[2]: ScanBytes=Exact(32)),(Col[3]: ScanBytes=Exact(32)),(Col[4]: ScanBytes=Exact(32)),(Col[5]: ScanBytes=Exact(64)),(Col[6]: ScanBytes=Exact(32)),(Col[7]: ScanBytes=Exact(64)),(Col[8]: ScanBytes=Inexact(88)),(Col[9]: ScanBytes=Inexact(49)),(Col[10]: ScanBytes=Exact(64))]] @@ -326,7 +330,7 @@ physical_plan after EnsureCooperative SAME TEXT AS ABOVE physical_plan after FilterPushdown(Post) SAME TEXT AS ABOVE physical_plan after SanityCheckPlan SAME TEXT AS ABOVE physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]: ScanBytes=Exact(32)),(Col[1]: ScanBytes=Inexact(24)),(Col[2]: ScanBytes=Exact(32)),(Col[3]: ScanBytes=Exact(32)),(Col[4]: ScanBytes=Exact(32)),(Col[5]: ScanBytes=Exact(64)),(Col[6]: ScanBytes=Exact(32)),(Col[7]: ScanBytes=Exact(64)),(Col[8]: ScanBytes=Inexact(88)),(Col[9]: ScanBytes=Inexact(49)),(Col[10]: ScanBytes=Exact(64))]] -physical_plan_with_schema DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(Nanosecond, None);N] +physical_plan_with_schema DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(ns);N] statement ok @@ -343,8 +347,8 @@ initial_physical_plan_with_stats 01)GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]: ScanBytes=Exact(32)),(Col[1]: ScanBytes=Inexact(24)),(Col[2]: ScanBytes=Exact(32)),(Col[3]: ScanBytes=Exact(32)),(Col[4]: ScanBytes=Exact(32)),(Col[5]: ScanBytes=Exact(64)),(Col[6]: ScanBytes=Exact(32)),(Col[7]: ScanBytes=Exact(64)),(Col[8]: ScanBytes=Inexact(88)),(Col[9]: ScanBytes=Inexact(49)),(Col[10]: ScanBytes=Exact(64))]] 02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]: ScanBytes=Exact(32)),(Col[1]: ScanBytes=Inexact(24)),(Col[2]: ScanBytes=Exact(32)),(Col[3]: ScanBytes=Exact(32)),(Col[4]: ScanBytes=Exact(32)),(Col[5]: ScanBytes=Exact(64)),(Col[6]: ScanBytes=Exact(32)),(Col[7]: ScanBytes=Exact(64)),(Col[8]: ScanBytes=Inexact(88)),(Col[9]: ScanBytes=Inexact(49)),(Col[10]: ScanBytes=Exact(64))]] initial_physical_plan_with_schema -01)GlobalLimitExec: skip=0, fetch=10, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(Nanosecond, None);N] -02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(Nanosecond, None);N] +01)GlobalLimitExec: skip=0, fetch=10, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(ns);N] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(ns);N] physical_plan after OutputRequirements 01)OutputRequirementExec: order_by=[], dist_by=Unspecified 02)--GlobalLimitExec: skip=0, fetch=10 @@ -358,7 +362,6 @@ physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE physical_plan after EnforceSorting SAME TEXT AS ABOVE physical_plan after OptimizeAggregateOrder SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE -physical_plan after coalesce_batches SAME TEXT AS ABOVE physical_plan after OutputRequirements 01)GlobalLimitExec: skip=0, fetch=10 02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet @@ -372,7 +375,7 @@ physical_plan after FilterPushdown(Post) SAME TEXT AS ABOVE physical_plan after SanityCheckPlan SAME TEXT AS ABOVE physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet physical_plan_with_stats DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]: ScanBytes=Exact(32)),(Col[1]: ScanBytes=Inexact(24)),(Col[2]: ScanBytes=Exact(32)),(Col[3]: ScanBytes=Exact(32)),(Col[4]: ScanBytes=Exact(32)),(Col[5]: ScanBytes=Exact(64)),(Col[6]: ScanBytes=Exact(32)),(Col[7]: ScanBytes=Exact(64)),(Col[8]: ScanBytes=Inexact(88)),(Col[9]: ScanBytes=Inexact(49)),(Col[10]: ScanBytes=Exact(64))]] -physical_plan_with_schema DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(Nanosecond, None);N] +physical_plan_with_schema DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(ns);N] statement ok @@ -538,6 +541,7 @@ initial_logical_plan logical_plan after resolve_grouping_function SAME TEXT AS ABOVE logical_plan after type_coercion SAME TEXT AS ABOVE analyzed_logical_plan SAME TEXT AS ABOVE +logical_plan after rewrite_set_comparison SAME TEXT AS ABOVE logical_plan after optimize_unions SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE @@ -558,7 +562,10 @@ logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE +logical_plan after extract_leaf_expressions SAME TEXT AS ABOVE +logical_plan after push_down_leaf_projections SAME TEXT AS ABOVE logical_plan after optimize_projections TableScan: simple_explain_test projection=[a, b, c] +logical_plan after rewrite_set_comparison SAME TEXT AS ABOVE logical_plan after optimize_unions SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE @@ -579,6 +586,8 @@ logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE +logical_plan after extract_leaf_expressions SAME TEXT AS ABOVE +logical_plan after push_down_leaf_projections SAME TEXT AS ABOVE logical_plan after optimize_projections SAME TEXT AS ABOVE logical_plan TableScan: simple_explain_test projection=[a, b, c] initial_physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], file_type=csv, has_header=true @@ -596,7 +605,6 @@ physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE physical_plan after EnforceSorting SAME TEXT AS ABOVE physical_plan after OptimizeAggregateOrder SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE -physical_plan after coalesce_batches SAME TEXT AS ABOVE physical_plan after OutputRequirements DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], file_type=csv, has_header=true physical_plan after LimitAggregation SAME TEXT AS ABOVE physical_plan after LimitPushPastWindows SAME TEXT AS ABOVE diff --git a/datafusion/sqllogictest/test_files/explain_tree.slt b/datafusion/sqllogictest/test_files/explain_tree.slt index 9215ce87e3be..3a183a735743 100644 --- a/datafusion/sqllogictest/test_files/explain_tree.slt +++ b/datafusion/sqllogictest/test_files/explain_tree.slt @@ -268,7 +268,7 @@ physical_plan 06)┌─────────────┴─────────────┐ 07)│ DataSourceExec │ 08)│ -------------------- │ -09)│ bytes: 1040 │ +09)│ bytes: 1024 │ 10)│ format: memory │ 11)│ rows: 2 │ 12)└───────────────────────────┘ @@ -345,7 +345,7 @@ physical_plan 15)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ 16)│ DataSourceExec ││ ProjectionExec │ 17)│ -------------------- ││ -------------------- │ -18)│ bytes: 520 ││ date_col: date_col │ +18)│ bytes: 512 ││ date_col: date_col │ 19)│ format: memory ││ int_col: int_col │ 20)│ rows: 1 ││ │ 21)│ ││ string_col: │ @@ -592,7 +592,7 @@ physical_plan 07)┌─────────────┴─────────────┐ 08)│ DataSourceExec │ 09)│ -------------------- │ -10)│ bytes: 520 │ +10)│ bytes: 512 │ 11)│ format: memory │ 12)│ rows: 1 │ 13)└───────────────────────────┘ @@ -954,7 +954,7 @@ physical_plan 13)┌─────────────┴─────────────┐ 14)│ DataSourceExec │ 15)│ -------------------- │ -16)│ bytes: 520 │ +16)│ bytes: 512 │ 17)│ format: memory │ 18)│ rows: 1 │ 19)└───────────────────────────┘ @@ -1305,7 +1305,7 @@ physical_plan 42)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ 43)│ DataSourceExec ││ DataSourceExec │ 44)│ -------------------- ││ -------------------- │ -45)│ bytes: 296 ││ bytes: 288 │ +45)│ bytes: 288 ││ bytes: 280 │ 46)│ format: memory ││ format: memory │ 47)│ rows: 1 ││ rows: 1 │ 48)└───────────────────────────┘└───────────────────────────┘ @@ -1324,14 +1324,14 @@ physical_plan 04)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ 05)│ DataSourceExec ││ ProjectionExec │ 06)│ -------------------- ││ -------------------- │ -07)│ bytes: 296 ││ id: CAST(id AS Int32) │ +07)│ bytes: 288 ││ id: CAST(id AS Int32) │ 08)│ format: memory ││ name: name │ 09)│ rows: 1 ││ │ 10)└───────────────────────────┘└─────────────┬─────────────┘ 11)-----------------------------┌─────────────┴─────────────┐ 12)-----------------------------│ DataSourceExec │ 13)-----------------------------│ -------------------- │ -14)-----------------------------│ bytes: 288 │ +14)-----------------------------│ bytes: 280 │ 15)-----------------------------│ format: memory │ 16)-----------------------------│ rows: 1 │ 17)-----------------------------└───────────────────────────┘ diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index cec9b63675a6..6d19d1436e1c 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -60,7 +60,7 @@ SELECT isnan(NULL), iszero(NULL) ---- -NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 NULL NULL NULL # test_array_cast_invalid_timezone_will_panic statement error Parser error: Invalid timezone "Foo": failed to parse timezone @@ -432,6 +432,16 @@ SELECT chr(CAST(0 AS int)) statement error DataFusion error: Execution error: invalid Unicode scalar value: 9223372036854775807 SELECT chr(CAST(9223372036854775807 AS bigint)) +statement error DataFusion error: Execution error: invalid Unicode scalar value: 1114112 +SELECT chr(CAST(1114112 AS bigint)) + +statement error DataFusion error: Execution error: invalid Unicode scalar value: -1 +SELECT chr(CAST(-1 AS bigint)) + +# surrogate code point (invalid scalar value) +statement error DataFusion error: Execution error: invalid Unicode scalar value: 55297 +SELECT chr(CAST(55297 AS bigint)) + query T SELECT concat('a','b','c') ---- @@ -494,6 +504,25 @@ abc statement ok drop table foo +# concat_ws with a Utf8View column as separator +statement ok +create table test_concat_ws_sep (sep varchar, val1 varchar, val2 varchar) as values (',', 'foo', 'bar'), ('|', 'a', 'b'); + +query T +SELECT concat_ws(arrow_cast(sep, 'Utf8View'), val1, val2) FROM test_concat_ws_sep ORDER BY val1 +---- +a|b +foo,bar + +query T +SELECT concat_ws(arrow_cast(sep, 'LargeUtf8'), val1, val2) FROM test_concat_ws_sep ORDER BY val1 +---- +a|b +foo,bar + +statement ok +drop table test_concat_ws_sep + query T SELECT initcap('') ---- @@ -589,7 +618,7 @@ select repeat('-1.2', arrow_cast(3, 'Int32')); ---- -1.2-1.2-1.2 -query error DataFusion error: Error during planning: Internal error: Expect TypeSignatureClass::Native\(LogicalType\(Native\(Int64\), Int64\)\) but received NativeType::Float64, DataType: Float64 +query error DataFusion error: Error during planning: Function 'repeat' requires TypeSignatureClass::Native\(LogicalType\(Native\(Int64\), Int64\)\), but received Float64 \(DataType: Float64\) select repeat('-1.2', 3.2); query T @@ -672,6 +701,26 @@ SELECT split_part('abc~@~def~@~ghi', '~@~', -100) ---- (empty) +query T +SELECT split_part('a,b', '', 1) +---- +a,b + +query T +SELECT split_part('a,b', '', -1) +---- +a,b + +query T +SELECT split_part('a,b', '', 2) +---- +(empty) + +query T +SELECT split_part('a,b', '', -2) +---- +(empty) + statement error DataFusion error: Execution error: field position must not be zero SELECT split_part('abc~@~def~@~ghi', '~@~', 0) @@ -715,6 +764,27 @@ SELECT to_hex(CAST(NULL AS int)) ---- NULL +query T +SELECT to_hex(0) +---- +0 + +# negative values (two's complement encoding) +query T +SELECT to_hex(-1) +---- +ffffffffffffffff + +query T +SELECT to_hex(CAST(-1 AS INT)) +---- +ffffffffffffffff + +query T +SELECT to_hex(CAST(255 AS TINYINT UNSIGNED)) +---- +ff + query T SELECT trim(' tom ') ---- diff --git a/datafusion/sqllogictest/test_files/floor_preimage.slt b/datafusion/sqllogictest/test_files/floor_preimage.slt new file mode 100644 index 000000000000..93302b3d7a2f --- /dev/null +++ b/datafusion/sqllogictest/test_files/floor_preimage.slt @@ -0,0 +1,308 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Floor Preimage Tests +## +## Tests for floor function preimage optimization: +## floor(col) = N transforms to col >= N AND col < N + 1 +## +## Uses representative types only (Float64, Int32, Decimal128). +## Unit tests cover all type variants. +########## + +# Setup: Single table with representative types +statement ok +CREATE TABLE test_data ( + id INT, + float_val DOUBLE, + int_val INT, + decimal_val DECIMAL(10,2) +) AS VALUES + (1, 5.3, 100, 100.00), + (2, 5.7, 101, 100.50), + (3, 6.0, 102, 101.00), + (4, 6.5, -5, 101.99), + (5, 7.0, 0, 102.00), + (6, NULL, NULL, NULL); + +########## +## Data Correctness Tests +########## + +# Float64: floor(x) = 5 matches values in [5.0, 6.0) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) = arrow_cast(5, 'Float64'); +---- +1 +2 + +# Int32: floor(x) = 100 matches values in [100, 101) +query I rowsort +SELECT id FROM test_data WHERE floor(int_val) = 100; +---- +1 + +# Decimal128: floor(x) = 100 matches values in [100.00, 101.00) +query I rowsort +SELECT id FROM test_data WHERE floor(decimal_val) = arrow_cast(100, 'Decimal128(10,2)'); +---- +1 +2 + +# Negative value: floor(x) = -5 matches values in [-5, -4) +query I rowsort +SELECT id FROM test_data WHERE floor(int_val) = -5; +---- +4 + +# Zero value: floor(x) = 0 matches values in [0, 1) +query I rowsort +SELECT id FROM test_data WHERE floor(int_val) = 0; +---- +5 + +# Column on RHS (same result as LHS) +query I rowsort +SELECT id FROM test_data WHERE arrow_cast(5, 'Float64') = floor(float_val); +---- +1 +2 + +# IS NOT DISTINCT FROM (excludes NULLs) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) IS NOT DISTINCT FROM arrow_cast(5, 'Float64'); +---- +1 +2 + +# IS DISTINCT FROM (includes NULLs) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) IS DISTINCT FROM arrow_cast(5, 'Float64'); +---- +3 +4 +5 +6 + +# Non-integer literal (empty result - floor returns integers) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) = arrow_cast(5.5, 'Float64'); +---- + +# IN list: floor(x) IN (5, 7) matches [5.0, 6.0) and [7.0, 8.0) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) IN (arrow_cast(5, 'Float64'), arrow_cast(7, 'Float64')); +---- +1 +2 +5 + +# NOT IN list: floor(x) NOT IN (5, 7) excludes matching ranges and NULLs +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) NOT IN (arrow_cast(5, 'Float64'), arrow_cast(7, 'Float64')); +---- +3 +4 + +########## +## EXPLAIN Tests - Plan Optimization +########## + +statement ok +set datafusion.explain.logical_plan_only = true; + +# 1. Basic: Float64 - floor(col) = N transforms to col >= N AND col < N+1 +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) = arrow_cast(5, 'Float64'); +---- +logical_plan +01)Filter: test_data.float_val >= Float64(5) AND test_data.float_val < Float64(6) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# 2. Basic: Int32 - transformed (coerced to Float64) +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(int_val) = 100; +---- +logical_plan +01)Projection: test_data.id, test_data.float_val, test_data.int_val, test_data.decimal_val +02)--Filter: __common_expr_3 >= Float64(100) AND __common_expr_3 < Float64(101) +03)----Projection: CAST(test_data.int_val AS Float64) AS __common_expr_3, test_data.id, test_data.float_val, test_data.int_val, test_data.decimal_val +04)------TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# 3. Basic: Decimal128 - same transformation +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(decimal_val) = arrow_cast(100, 'Decimal128(10,2)'); +---- +logical_plan +01)Filter: test_data.decimal_val >= Decimal128(Some(10000),10,2) AND test_data.decimal_val < Decimal128(Some(10100),10,2) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# 4. Column on RHS - same transformation +query TT +EXPLAIN SELECT * FROM test_data WHERE arrow_cast(5, 'Float64') = floor(float_val); +---- +logical_plan +01)Filter: test_data.float_val >= Float64(5) AND test_data.float_val < Float64(6) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# 5. IS NOT DISTINCT FROM - adds IS NOT NULL +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) IS NOT DISTINCT FROM arrow_cast(5, 'Float64'); +---- +logical_plan +01)Filter: test_data.float_val IS NOT NULL AND test_data.float_val >= Float64(5) AND test_data.float_val < Float64(6) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# 6. IS DISTINCT FROM - includes NULL check +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) IS DISTINCT FROM arrow_cast(5, 'Float64'); +---- +logical_plan +01)Filter: test_data.float_val < Float64(5) OR test_data.float_val >= Float64(6) OR test_data.float_val IS NULL +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# 7. Non-optimizable: non-integer literal (original predicate preserved) +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) = arrow_cast(5.5, 'Float64'); +---- +logical_plan +01)Filter: floor(test_data.float_val) = Float64(5.5) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# 8. Non-optimizable: extreme float literal (2^53) where n+1 loses precision, so preimage returns None +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) = 9007199254740992; +---- +logical_plan +01)Filter: floor(test_data.float_val) = Float64(9007199254740992) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# 9. IN list: each list item is rewritten with preimage and OR-ed together +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) IN (arrow_cast(5, 'Float64'), arrow_cast(7, 'Float64')); +---- +logical_plan +01)Filter: test_data.float_val >= Float64(5) AND test_data.float_val < Float64(6) OR test_data.float_val >= Float64(7) AND test_data.float_val < Float64(8) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# Data correctness: floor(col) = 2^53 returns no rows (no value in test_data has floor exactly 2^53) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) = 9007199254740992; +---- + +########## +## Other Comparison Operators +## +## The preimage framework automatically handles all comparison operators: +## floor(x) <> N -> x < N OR x >= N+1 +## floor(x) > N -> x >= N+1 +## floor(x) < N -> x < N +## floor(x) >= N -> x >= N +## floor(x) <= N -> x < N+1 +########## + +# Data correctness tests for other operators + +# Not equals: floor(x) <> 5 matches values outside [5.0, 6.0) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) <> arrow_cast(5, 'Float64'); +---- +3 +4 +5 + +# Greater than: floor(x) > 5 matches values in [6.0, inf) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) > arrow_cast(5, 'Float64'); +---- +3 +4 +5 + +# Less than: floor(x) < 6 matches values in (-inf, 6.0) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) < arrow_cast(6, 'Float64'); +---- +1 +2 + +# Greater than or equal: floor(x) >= 5 matches values in [5.0, inf) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) >= arrow_cast(5, 'Float64'); +---- +1 +2 +3 +4 +5 + +# Less than or equal: floor(x) <= 5 matches values in (-inf, 6.0) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) <= arrow_cast(5, 'Float64'); +---- +1 +2 + +# EXPLAIN tests showing optimized transformations + +# Not equals: floor(x) <> 5 -> x < 5 OR x >= 6 +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) <> arrow_cast(5, 'Float64'); +---- +logical_plan +01)Filter: test_data.float_val < Float64(5) OR test_data.float_val >= Float64(6) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# Greater than: floor(x) > 5 -> x >= 6 +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) > arrow_cast(5, 'Float64'); +---- +logical_plan +01)Filter: test_data.float_val >= Float64(6) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# Less than: floor(x) < 6 -> x < 6 +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) < arrow_cast(6, 'Float64'); +---- +logical_plan +01)Filter: test_data.float_val < Float64(6) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# Greater than or equal: floor(x) >= 5 -> x >= 5 +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) >= arrow_cast(5, 'Float64'); +---- +logical_plan +01)Filter: test_data.float_val >= Float64(5) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# Less than or equal: floor(x) <= 5 -> x < 6 +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) <= arrow_cast(5, 'Float64'); +---- +logical_plan +01)Filter: test_data.float_val < Float64(6) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +########## +## Cleanup +########## + +statement ok +DROP TABLE test_data; diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 6c87d618c727..5a43d18e2387 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -239,6 +239,11 @@ SELECT translate('12345', '143', NULL) ---- NULL +query T +SELECT translate(arrow_cast('12345', 'LargeUtf8'), '143', 'ax') +---- +a2x5 + statement ok CREATE TABLE test( c1 VARCHAR @@ -536,6 +541,15 @@ SELECT trim(arrow_cast(' foo ', 'Dictionary(Int32, Utf8)')) ---- foo +# Verify that trim, ltrim, and rtrim only strip spaces by default, +# not other whitespace characters (tabs, newlines, etc.) +query III +SELECT length(trim(chr(9) || 'foo' || chr(10))), + length(ltrim(chr(9) || 'foo')), + length(rtrim('foo' || chr(10))) +---- +5 4 4 + query I SELECT bit_length('foo') ---- diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index cd1ed2bc0cac..294841552a66 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -4329,9 +4329,9 @@ physical_plan 01)SortPreservingMergeExec: [months@0 DESC], fetch=5 02)--SortExec: TopK(fetch=5), expr=[months@0 DESC], preserve_partitioning=[true] 03)----ProjectionExec: expr=[date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0 as months] -04)------AggregateExec: mode=FinalPartitioned, gby=[date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0 as date_part(Utf8("MONTH"),csv_with_timestamps.ts)], aggr=[] +04)------AggregateExec: mode=FinalPartitioned, gby=[date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0 as date_part(Utf8("MONTH"),csv_with_timestamps.ts)], aggr=[], lim=[5] 05)--------RepartitionExec: partitioning=Hash([date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0], 8), input_partitions=8 -06)----------AggregateExec: mode=Partial, gby=[date_part(MONTH, ts@0) as date_part(Utf8("MONTH"),csv_with_timestamps.ts)], aggr=[] +06)----------AggregateExec: mode=Partial, gby=[date_part(MONTH, ts@0) as date_part(Utf8("MONTH"),csv_with_timestamps.ts)], aggr=[], lim=[5] 07)------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true 08)--------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/timestamps.csv]]}, projection=[ts], output_ordering=[ts@0 DESC], file_type=csv, has_header=false @@ -5478,7 +5478,7 @@ create table source as values ; statement ok -create view t as select column1 as a, arrow_cast(column2, 'Timestamp(Nanosecond, None)') as b from source; +create view t as select column1 as a, arrow_cast(column2, 'Timestamp(ns)') as b from source; query IPI select a, b, count(*) from t group by a, b order by a, b; diff --git a/datafusion/sqllogictest/test_files/grouping_set_repartition.slt b/datafusion/sqllogictest/test_files/grouping_set_repartition.slt new file mode 100644 index 000000000000..16ab90651c8b --- /dev/null +++ b/datafusion/sqllogictest/test_files/grouping_set_repartition.slt @@ -0,0 +1,246 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +# Tests for ROLLUP/CUBE/GROUPING SETS with multiple partitions +# +# This tests the fix for https://github.com/apache/datafusion/issues/19849 +# where ROLLUP queries produced incorrect results with multiple partitions +# because subset partitioning satisfaction was incorrectly applied. +# +# The bug manifests when: +# 1. UNION ALL of subqueries each with hash-partitioned aggregates +# 2. Outer ROLLUP groups by more columns than inner hash partitioning +# 3. InterleaveExec preserves the inner hash partitioning +# 4. Optimizer incorrectly uses subset satisfaction, skipping necessary repartition +# +# The fix ensures that when hash partitioning includes __grouping_id, +# subset satisfaction is disabled and proper RepartitionExec is inserted. +########## + +########## +# SETUP: Create partitioned parquet files to simulate distributed data +########## + +statement ok +set datafusion.execution.target_partitions = 4; + +statement ok +set datafusion.optimizer.repartition_aggregations = true; + +# Create partition 1 +statement ok +COPY (SELECT column1 as channel, column2 as brand, column3 as amount FROM (VALUES + ('store', 'nike', 100), + ('store', 'nike', 200), + ('store', 'adidas', 150) +)) +TO 'test_files/scratch/grouping_set_repartition/part=1/data.parquet' +STORED AS PARQUET; + +# Create partition 2 +statement ok +COPY (SELECT column1 as channel, column2 as brand, column3 as amount FROM (VALUES + ('store', 'adidas', 250), + ('web', 'nike', 300), + ('web', 'nike', 400) +)) +TO 'test_files/scratch/grouping_set_repartition/part=2/data.parquet' +STORED AS PARQUET; + +# Create partition 3 +statement ok +COPY (SELECT column1 as channel, column2 as brand, column3 as amount FROM (VALUES + ('web', 'adidas', 350), + ('web', 'adidas', 450), + ('catalog', 'nike', 500) +)) +TO 'test_files/scratch/grouping_set_repartition/part=3/data.parquet' +STORED AS PARQUET; + +# Create partition 4 +statement ok +COPY (SELECT column1 as channel, column2 as brand, column3 as amount FROM (VALUES + ('catalog', 'nike', 600), + ('catalog', 'adidas', 550), + ('catalog', 'adidas', 650) +)) +TO 'test_files/scratch/grouping_set_repartition/part=4/data.parquet' +STORED AS PARQUET; + +# Create external table pointing to the partitioned data +statement ok +CREATE EXTERNAL TABLE sales (channel VARCHAR, brand VARCHAR, amount INT) +STORED AS PARQUET +PARTITIONED BY (part INT) +LOCATION 'test_files/scratch/grouping_set_repartition/'; + +########## +# TEST 1: UNION ALL + ROLLUP pattern (similar to TPC-DS q14) +# This query pattern triggers the subset satisfaction bug because: +# - Each UNION ALL branch has hash partitioning on (brand) +# - The outer ROLLUP requires hash partitioning on (channel, brand, __grouping_id) +# - Without the fix, subset satisfaction incorrectly skips repartition +# +# Verify the physical plan includes RepartitionExec with __grouping_id +########## + +query TT +EXPLAIN SELECT channel, brand, SUM(total) as grand_total +FROM ( + SELECT 'store' as channel, brand, SUM(amount) as total + FROM sales WHERE channel = 'store' + GROUP BY brand + UNION ALL + SELECT 'web' as channel, brand, SUM(amount) as total + FROM sales WHERE channel = 'web' + GROUP BY brand + UNION ALL + SELECT 'catalog' as channel, brand, SUM(amount) as total + FROM sales WHERE channel = 'catalog' + GROUP BY brand +) sub +GROUP BY ROLLUP(channel, brand) +ORDER BY channel NULLS FIRST, brand NULLS FIRST; +---- +logical_plan +01)Sort: sub.channel ASC NULLS FIRST, sub.brand ASC NULLS FIRST +02)--Projection: sub.channel, sub.brand, sum(sub.total) AS grand_total +03)----Aggregate: groupBy=[[ROLLUP (sub.channel, sub.brand)]], aggr=[[sum(sub.total)]] +04)------SubqueryAlias: sub +05)--------Union +06)----------Projection: Utf8("store") AS channel, sales.brand, sum(sales.amount) AS total +07)------------Aggregate: groupBy=[[sales.brand]], aggr=[[sum(CAST(sales.amount AS Int64))]] +08)--------------Projection: sales.brand, sales.amount +09)----------------Filter: sales.channel = Utf8View("store") +10)------------------TableScan: sales projection=[channel, brand, amount], partial_filters=[sales.channel = Utf8View("store")] +11)----------Projection: Utf8("web") AS channel, sales.brand, sum(sales.amount) AS total +12)------------Aggregate: groupBy=[[sales.brand]], aggr=[[sum(CAST(sales.amount AS Int64))]] +13)--------------Projection: sales.brand, sales.amount +14)----------------Filter: sales.channel = Utf8View("web") +15)------------------TableScan: sales projection=[channel, brand, amount], partial_filters=[sales.channel = Utf8View("web")] +16)----------Projection: Utf8("catalog") AS channel, sales.brand, sum(sales.amount) AS total +17)------------Aggregate: groupBy=[[sales.brand]], aggr=[[sum(CAST(sales.amount AS Int64))]] +18)--------------Projection: sales.brand, sales.amount +19)----------------Filter: sales.channel = Utf8View("catalog") +20)------------------TableScan: sales projection=[channel, brand, amount], partial_filters=[sales.channel = Utf8View("catalog")] +physical_plan +01)SortPreservingMergeExec: [channel@0 ASC, brand@1 ASC] +02)--SortExec: expr=[channel@0 ASC, brand@1 ASC], preserve_partitioning=[true] +03)----ProjectionExec: expr=[channel@0 as channel, brand@1 as brand, sum(sub.total)@3 as grand_total] +04)------AggregateExec: mode=FinalPartitioned, gby=[channel@0 as channel, brand@1 as brand, __grouping_id@2 as __grouping_id], aggr=[sum(sub.total)] +05)--------RepartitionExec: partitioning=Hash([channel@0, brand@1, __grouping_id@2], 4), input_partitions=4 +06)----------AggregateExec: mode=Partial, gby=[(NULL as channel, NULL as brand), (channel@0 as channel, NULL as brand), (channel@0 as channel, brand@1 as brand)], aggr=[sum(sub.total)] +07)------------InterleaveExec +08)--------------ProjectionExec: expr=[store as channel, brand@0 as brand, sum(sales.amount)@1 as total] +09)----------------AggregateExec: mode=FinalPartitioned, gby=[brand@0 as brand], aggr=[sum(sales.amount)] +10)------------------RepartitionExec: partitioning=Hash([brand@0], 4), input_partitions=4 +11)--------------------AggregateExec: mode=Partial, gby=[brand@0 as brand], aggr=[sum(sales.amount)] +12)----------------------FilterExec: channel@0 = store, projection=[brand@1, amount@2] +13)------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=1/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=2/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=3/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=4/data.parquet]]}, projection=[channel, brand, amount], file_type=parquet, predicate=channel@0 = store, pruning_predicate=channel_null_count@2 != row_count@3 AND channel_min@0 <= store AND store <= channel_max@1, required_guarantees=[channel in (store)] +14)--------------ProjectionExec: expr=[web as channel, brand@0 as brand, sum(sales.amount)@1 as total] +15)----------------AggregateExec: mode=FinalPartitioned, gby=[brand@0 as brand], aggr=[sum(sales.amount)] +16)------------------RepartitionExec: partitioning=Hash([brand@0], 4), input_partitions=4 +17)--------------------AggregateExec: mode=Partial, gby=[brand@0 as brand], aggr=[sum(sales.amount)] +18)----------------------FilterExec: channel@0 = web, projection=[brand@1, amount@2] +19)------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=1/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=2/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=3/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=4/data.parquet]]}, projection=[channel, brand, amount], file_type=parquet, predicate=channel@0 = web, pruning_predicate=channel_null_count@2 != row_count@3 AND channel_min@0 <= web AND web <= channel_max@1, required_guarantees=[channel in (web)] +20)--------------ProjectionExec: expr=[catalog as channel, brand@0 as brand, sum(sales.amount)@1 as total] +21)----------------AggregateExec: mode=FinalPartitioned, gby=[brand@0 as brand], aggr=[sum(sales.amount)] +22)------------------RepartitionExec: partitioning=Hash([brand@0], 4), input_partitions=4 +23)--------------------AggregateExec: mode=Partial, gby=[brand@0 as brand], aggr=[sum(sales.amount)] +24)----------------------FilterExec: channel@0 = catalog, projection=[brand@1, amount@2] +25)------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=1/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=2/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=3/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=4/data.parquet]]}, projection=[channel, brand, amount], file_type=parquet, predicate=channel@0 = catalog, pruning_predicate=channel_null_count@2 != row_count@3 AND channel_min@0 <= catalog AND catalog <= channel_max@1, required_guarantees=[channel in (catalog)] + +query TTI rowsort +SELECT channel, brand, SUM(total) as grand_total +FROM ( + SELECT 'store' as channel, brand, SUM(amount) as total + FROM sales WHERE channel = 'store' + GROUP BY brand + UNION ALL + SELECT 'web' as channel, brand, SUM(amount) as total + FROM sales WHERE channel = 'web' + GROUP BY brand + UNION ALL + SELECT 'catalog' as channel, brand, SUM(amount) as total + FROM sales WHERE channel = 'catalog' + GROUP BY brand +) sub +GROUP BY ROLLUP(channel, brand) +ORDER BY channel NULLS FIRST, brand NULLS FIRST; +---- +NULL NULL 4500 +catalog NULL 2300 +catalog adidas 1200 +catalog nike 1100 +store NULL 700 +store adidas 400 +store nike 300 +web NULL 1500 +web adidas 800 +web nike 700 + +########## +# TEST 2: Simple ROLLUP (baseline test) +########## + +query TTI rowsort +SELECT channel, brand, SUM(amount) as total +FROM sales +GROUP BY ROLLUP(channel, brand) +ORDER BY channel NULLS FIRST, brand NULLS FIRST; +---- +NULL NULL 4500 +catalog NULL 2300 +catalog adidas 1200 +catalog nike 1100 +store NULL 700 +store adidas 400 +store nike 300 +web NULL 1500 +web adidas 800 +web nike 700 + +########## +# TEST 3: Verify CUBE also works correctly +########## + +query TTI rowsort +SELECT channel, brand, SUM(amount) as total +FROM sales +GROUP BY CUBE(channel, brand) +ORDER BY channel NULLS FIRST, brand NULLS FIRST; +---- +NULL NULL 4500 +NULL adidas 2400 +NULL nike 2100 +catalog NULL 2300 +catalog adidas 1200 +catalog nike 1100 +store NULL 700 +store adidas 400 +store nike 300 +web NULL 1500 +web adidas 800 +web nike 700 + +########## +# CLEANUP +########## + +statement ok +DROP TABLE sales; diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 18f72cb9f779..b61ceecb24fc 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -260,6 +260,8 @@ datafusion.execution.parquet.statistics_enabled page datafusion.execution.parquet.statistics_truncate_length 64 datafusion.execution.parquet.write_batch_size 1024 datafusion.execution.parquet.writer_version 1.0 +datafusion.execution.perfect_hash_join_min_key_density 0.15 +datafusion.execution.perfect_hash_join_small_build_threshold 1024 datafusion.execution.planning_concurrency 13 datafusion.execution.skip_partial_aggregation_probe_ratio_threshold 0.8 datafusion.execution.skip_partial_aggregation_probe_rows_threshold 100000 @@ -295,6 +297,7 @@ datafusion.optimizer.enable_aggregate_dynamic_filter_pushdown true datafusion.optimizer.enable_distinct_aggregation_soft_limit true datafusion.optimizer.enable_dynamic_filter_pushdown true datafusion.optimizer.enable_join_dynamic_filter_pushdown true +datafusion.optimizer.enable_leaf_expression_pushdown true datafusion.optimizer.enable_piecewise_merge_join false datafusion.optimizer.enable_round_robin_repartition true datafusion.optimizer.enable_sort_pushdown true @@ -371,7 +374,7 @@ datafusion.execution.parquet.bloom_filter_on_read true (reading) Use any availab datafusion.execution.parquet.bloom_filter_on_write false (writing) Write bloom filters for all columns when creating parquet files datafusion.execution.parquet.coerce_int96 NULL (reading) If true, parquet reader will read columns of physical type int96 as originating from a different resolution than nanosecond. This is useful for reading data from systems like Spark which stores microsecond resolution timestamps in an int96 allowing it to write values with a larger date range than 64-bit timestamps with nanosecond resolution. datafusion.execution.parquet.column_index_truncate_length 64 (writing) Sets column index truncate length -datafusion.execution.parquet.compression zstd(3) (writing) Sets default parquet compression codec. Valid values are: uncompressed, snappy, gzip(level), lzo, brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting Note that this default setting is not the same as the default parquet writer setting. +datafusion.execution.parquet.compression zstd(3) (writing) Sets default parquet compression codec. Valid values are: uncompressed, snappy, gzip(level), brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting Note that this default setting is not the same as the default parquet writer setting. datafusion.execution.parquet.created_by datafusion (writing) Sets "created by" property datafusion.execution.parquet.data_page_row_count_limit 20000 (writing) Sets best effort maximum number of rows in data page datafusion.execution.parquet.data_pagesize_limit 1048576 (writing) Sets best effort maximum size of data page in bytes @@ -393,8 +396,10 @@ datafusion.execution.parquet.skip_arrow_metadata false (writing) Skip encoding t datafusion.execution.parquet.skip_metadata true (reading) If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata datafusion.execution.parquet.statistics_enabled page (writing) Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting datafusion.execution.parquet.statistics_truncate_length 64 (writing) Sets statistics truncate length. If NULL, uses default parquet writer setting -datafusion.execution.parquet.write_batch_size 1024 (writing) Sets write_batch_size in bytes +datafusion.execution.parquet.write_batch_size 1024 (writing) Sets write_batch_size in rows datafusion.execution.parquet.writer_version 1.0 (writing) Sets parquet writer version valid values are "1.0" and "2.0" +datafusion.execution.perfect_hash_join_min_key_density 0.15 The minimum required density of join keys on the build side to consider a perfect hash join (see `HashJoinExec` for more details). Density is calculated as: `(number of rows) / (max_key - min_key + 1)`. A perfect hash join may be used if the actual key density > this value. Currently only supports cases where build_side.num_rows() < u32::MAX. Support for build_side.num_rows() >= u32::MAX will be added in the future. +datafusion.execution.perfect_hash_join_small_build_threshold 1024 A perfect hash join (see `HashJoinExec` for more details) will be considered if the range of keys (max - min) on the build side is < this threshold. This provides a fast path for joins with very small key ranges, bypassing the density check. Currently only supports cases where build_side.num_rows() < u32::MAX. Support for build_side.num_rows() >= u32::MAX will be added in the future. datafusion.execution.planning_concurrency 13 Fan-out during initial physical planning. This is mostly use to plan `UNION` children in parallel. Defaults to the number of CPU cores on the system datafusion.execution.skip_partial_aggregation_probe_ratio_threshold 0.8 Aggregation ratio (number of distinct groups / number of input rows) threshold for skipping partial aggregation. If the value is greater then partial aggregation will skip aggregation for further input datafusion.execution.skip_partial_aggregation_probe_rows_threshold 100000 Number of input rows partial aggregation partition should process, before aggregation ratio check and trying to switch to skipping aggregation mode @@ -430,6 +435,7 @@ datafusion.optimizer.enable_aggregate_dynamic_filter_pushdown true When set to t datafusion.optimizer.enable_distinct_aggregation_soft_limit true When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. datafusion.optimizer.enable_dynamic_filter_pushdown true When set to true attempts to push down dynamic filters generated by operators (TopK, Join & Aggregate) into the file scan phase. For example, for a query such as `SELECT * FROM t ORDER BY timestamp DESC LIMIT 10`, the optimizer will attempt to push down the current top 10 timestamps that the TopK operator references into the file scans. This means that if we already have 10 timestamps in the year 2025 any files that only have timestamps in the year 2024 can be skipped / pruned at various stages in the scan. The config will suppress `enable_join_dynamic_filter_pushdown`, `enable_topk_dynamic_filter_pushdown` & `enable_aggregate_dynamic_filter_pushdown` So if you disable `enable_topk_dynamic_filter_pushdown`, then enable `enable_dynamic_filter_pushdown`, the `enable_topk_dynamic_filter_pushdown` will be overridden. datafusion.optimizer.enable_join_dynamic_filter_pushdown true When set to true, the optimizer will attempt to push down Join dynamic filters into the file scan phase. +datafusion.optimizer.enable_leaf_expression_pushdown true When set to true, the optimizer will extract leaf expressions (such as `get_field`) from filter/sort/join nodes into projections closer to the leaf table scans, and push those projections down towards the leaf nodes. datafusion.optimizer.enable_piecewise_merge_join false When set to true, piecewise merge join is enabled. PiecewiseMergeJoin is currently experimental. Physical planner will opt for PiecewiseMergeJoin when there is only one range filter. datafusion.optimizer.enable_round_robin_repartition true When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores datafusion.optimizer.enable_sort_pushdown true Enable sort pushdown optimization. When enabled, attempts to push sort requirements down to data sources that can natively handle them (e.g., by reversing file/row group read order). Returns **inexact ordering**: Sort operator is kept for correctness, but optimized input enables early termination for TopK queries (ORDER BY ... LIMIT N), providing significant speedup. Memory: No additional overhead (only changes read order). Future: Will add option to detect perfectly sorted data and eliminate Sort completely. Default: true @@ -793,14 +799,11 @@ string_agg String AGGREGATE query TTTTTTTBTTTT rowsort select * from information_schema.routines where routine_name = 'date_trunc' OR routine_name = 'string_agg' OR routine_name = 'rank' ORDER BY routine_name ---- -datafusion public date_trunc datafusion public date_trunc FUNCTION true Timestamp(Microsecond, None) SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -datafusion public date_trunc datafusion public date_trunc FUNCTION true Timestamp(Microsecond, Some("+TZ")) SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -datafusion public date_trunc datafusion public date_trunc FUNCTION true Timestamp(Millisecond, None) SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -datafusion public date_trunc datafusion public date_trunc FUNCTION true Timestamp(Millisecond, Some("+TZ")) SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -datafusion public date_trunc datafusion public date_trunc FUNCTION true Timestamp(Nanosecond, None) SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -datafusion public date_trunc datafusion public date_trunc FUNCTION true Timestamp(Nanosecond, Some("+TZ")) SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -datafusion public date_trunc datafusion public date_trunc FUNCTION true Timestamp(Second, None) SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -datafusion public date_trunc datafusion public date_trunc FUNCTION true Timestamp(Second, Some("+TZ")) SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) +datafusion public date_trunc datafusion public date_trunc FUNCTION true Date SCALAR Truncates a timestamp or time value to a specified precision. date_trunc(precision, expression) +datafusion public date_trunc datafusion public date_trunc FUNCTION true String SCALAR Truncates a timestamp or time value to a specified precision. date_trunc(precision, expression) +datafusion public date_trunc datafusion public date_trunc FUNCTION true Time(ns) SCALAR Truncates a timestamp or time value to a specified precision. date_trunc(precision, expression) +datafusion public date_trunc datafusion public date_trunc FUNCTION true Timestamp(ns) SCALAR Truncates a timestamp or time value to a specified precision. date_trunc(precision, expression) +datafusion public date_trunc datafusion public date_trunc FUNCTION true Timestamp(ns, "+TZ") SCALAR Truncates a timestamp or time value to a specified precision. date_trunc(precision, expression) datafusion public rank datafusion public rank FUNCTION true NULL WINDOW Returns the rank of the current row within its partition, allowing gaps between ranks. This function provides a ranking similar to `row_number`, but skips ranks for identical values. rank() datafusion public string_agg datafusion public string_agg FUNCTION true String AGGREGATE Concatenates the values of string expressions and places separator values between them. If ordering is required, strings are concatenated in the specified order. This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the first argument expression. string_agg([DISTINCT] expression, delimiter [ORDER BY expression]) @@ -813,30 +816,21 @@ false query TTTITTTTBI select * from information_schema.parameters where specific_name = 'date_trunc' OR specific_name = 'string_agg' OR specific_name = 'rank' ORDER BY specific_name, rid, data_type; ---- +datafusion public date_trunc 1 OUT NULL Date NULL false 0 +datafusion public date_trunc 2 IN expression Date NULL false 0 datafusion public date_trunc 1 IN precision String NULL false 0 -datafusion public date_trunc 2 IN expression Timestamp(Microsecond, None) NULL false 0 -datafusion public date_trunc 1 OUT NULL Timestamp(Microsecond, None) NULL false 0 datafusion public date_trunc 1 IN precision String NULL false 1 -datafusion public date_trunc 2 IN expression Timestamp(Microsecond, Some("+TZ")) NULL false 1 -datafusion public date_trunc 1 OUT NULL Timestamp(Microsecond, Some("+TZ")) NULL false 1 +datafusion public date_trunc 2 IN expression String NULL false 1 +datafusion public date_trunc 1 OUT NULL String NULL false 1 datafusion public date_trunc 1 IN precision String NULL false 2 -datafusion public date_trunc 2 IN expression Timestamp(Millisecond, None) NULL false 2 -datafusion public date_trunc 1 OUT NULL Timestamp(Millisecond, None) NULL false 2 +datafusion public date_trunc 2 IN expression Time(ns) NULL false 2 +datafusion public date_trunc 1 OUT NULL Time(ns) NULL false 2 datafusion public date_trunc 1 IN precision String NULL false 3 -datafusion public date_trunc 2 IN expression Timestamp(Millisecond, Some("+TZ")) NULL false 3 -datafusion public date_trunc 1 OUT NULL Timestamp(Millisecond, Some("+TZ")) NULL false 3 +datafusion public date_trunc 2 IN expression Timestamp(ns) NULL false 3 +datafusion public date_trunc 1 OUT NULL Timestamp(ns) NULL false 3 datafusion public date_trunc 1 IN precision String NULL false 4 -datafusion public date_trunc 2 IN expression Timestamp(Nanosecond, None) NULL false 4 -datafusion public date_trunc 1 OUT NULL Timestamp(Nanosecond, None) NULL false 4 -datafusion public date_trunc 1 IN precision String NULL false 5 -datafusion public date_trunc 2 IN expression Timestamp(Nanosecond, Some("+TZ")) NULL false 5 -datafusion public date_trunc 1 OUT NULL Timestamp(Nanosecond, Some("+TZ")) NULL false 5 -datafusion public date_trunc 1 IN precision String NULL false 6 -datafusion public date_trunc 2 IN expression Timestamp(Second, None) NULL false 6 -datafusion public date_trunc 1 OUT NULL Timestamp(Second, None) NULL false 6 -datafusion public date_trunc 1 IN precision String NULL false 7 -datafusion public date_trunc 2 IN expression Timestamp(Second, Some("+TZ")) NULL false 7 -datafusion public date_trunc 1 OUT NULL Timestamp(Second, Some("+TZ")) NULL false 7 +datafusion public date_trunc 2 IN expression Timestamp(ns, "+TZ") NULL false 4 +datafusion public date_trunc 1 OUT NULL Timestamp(ns, "+TZ") NULL false 4 datafusion public string_agg 2 IN delimiter Null NULL false 0 datafusion public string_agg 1 IN expression String NULL false 0 datafusion public string_agg 1 OUT NULL String NULL false 0 @@ -862,14 +856,11 @@ repeat String 1 OUT 0 query TT??TTT rowsort show functions like 'date_trunc'; ---- -date_trunc Timestamp(Microsecond, None) [precision, expression] [String, Timestamp(Microsecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Microsecond, Some("+TZ")) [precision, expression] [String, Timestamp(Microsecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Millisecond, None) [precision, expression] [String, Timestamp(Millisecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Millisecond, Some("+TZ")) [precision, expression] [String, Timestamp(Millisecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Nanosecond, None) [precision, expression] [String, Timestamp(Nanosecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Nanosecond, Some("+TZ")) [precision, expression] [String, Timestamp(Nanosecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Second, None) [precision, expression] [String, Timestamp(Second, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Second, Some("+TZ")) [precision, expression] [String, Timestamp(Second, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) +date_trunc Date [precision, expression] [String, Date] SCALAR Truncates a timestamp or time value to a specified precision. date_trunc(precision, expression) +date_trunc String [precision, expression] [String, String] SCALAR Truncates a timestamp or time value to a specified precision. date_trunc(precision, expression) +date_trunc Time(ns) [precision, expression] [String, Time(ns)] SCALAR Truncates a timestamp or time value to a specified precision. date_trunc(precision, expression) +date_trunc Timestamp(ns) [precision, expression] [String, Timestamp(ns)] SCALAR Truncates a timestamp or time value to a specified precision. date_trunc(precision, expression) +date_trunc Timestamp(ns, "+TZ") [precision, expression] [String, Timestamp(ns, "+TZ")] SCALAR Truncates a timestamp or time value to a specified precision. date_trunc(precision, expression) statement ok show functions diff --git a/datafusion/sqllogictest/test_files/insert.slt b/datafusion/sqllogictest/test_files/insert.slt index 8ef2596f18e3..e7b9e77dfef5 100644 --- a/datafusion/sqllogictest/test_files/insert.slt +++ b/datafusion/sqllogictest/test_files/insert.slt @@ -165,7 +165,7 @@ ORDER BY c1 ---- logical_plan 01)Dml: op=[Insert Into] table=[table_without_values] -02)--Projection: a1 AS a1, a2 AS a2 +02)--Projection: a1, a2 03)----Sort: aggregate_test_100.c1 ASC NULLS LAST 04)------Projection: sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS a1, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS a2, aggregate_test_100.c1 05)--------WindowAggr: windowExpr=[[sum(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] diff --git a/datafusion/sqllogictest/test_files/join.slt.part b/datafusion/sqllogictest/test_files/join.slt.part index 5d111374ac8c..c0a838c97d55 100644 --- a/datafusion/sqllogictest/test_files/join.slt.part +++ b/datafusion/sqllogictest/test_files/join.slt.part @@ -973,19 +973,19 @@ ON e.emp_id = d.emp_id WHERE ((dept_name != 'Engineering' AND e.name = 'Alice') OR (name != 'Alice' AND e.name = 'Carol')); ---- logical_plan -01)Filter: d.dept_name != Utf8View("Engineering") AND e.name = Utf8View("Alice") OR e.name != Utf8View("Alice") AND e.name = Utf8View("Carol") +01)Filter: d.dept_name != Utf8View("Engineering") AND e.name = Utf8View("Alice") OR e.name = Utf8View("Carol") 02)--Projection: e.emp_id, e.name, d.dept_name 03)----Left Join: e.emp_id = d.emp_id 04)------SubqueryAlias: e -05)--------Filter: employees.name = Utf8View("Alice") OR employees.name != Utf8View("Alice") AND employees.name = Utf8View("Carol") +05)--------Filter: employees.name = Utf8View("Alice") OR employees.name = Utf8View("Carol") 06)----------TableScan: employees projection=[emp_id, name] 07)------SubqueryAlias: d 08)--------TableScan: department projection=[emp_id, dept_name] physical_plan -01)FilterExec: dept_name@2 != Engineering AND name@1 = Alice OR name@1 != Alice AND name@1 = Carol +01)FilterExec: dept_name@2 != Engineering AND name@1 = Alice OR name@1 = Carol 02)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 03)----HashJoinExec: mode=CollectLeft, join_type=Left, on=[(emp_id@0, emp_id@0)], projection=[emp_id@0, name@1, dept_name@3] -04)------FilterExec: name@1 = Alice OR name@1 != Alice AND name@1 = Carol +04)------FilterExec: name@1 = Alice OR name@1 = Carol 05)--------DataSourceExec: partitions=1, partition_sizes=[1] 06)------DataSourceExec: partitions=1, partition_sizes=[1] diff --git a/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt b/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt index c16b3528aa7a..59f3d8285af4 100644 --- a/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt +++ b/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt @@ -55,7 +55,7 @@ logical_plan 07)--------TableScan: annotated_data projection=[a, c] physical_plan 01)SortPreservingMergeExec: [a@0 ASC NULLS LAST], fetch=5 -02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(c@0, c@1)], projection=[a@1] +02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(c@0, c@1)], projection=[a@1], fetch=5 03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], file_type=csv, has_header=true 04)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true 05)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_ordering=[a@0 ASC NULLS LAST], file_type=csv, has_header=true @@ -88,18 +88,22 @@ logical_plan 02)--Projection: t2.a AS a2, t2.b 03)----RightSemi Join: t1.d = t2.d, t1.c = t2.c 04)------SubqueryAlias: t1 -05)--------TableScan: annotated_data projection=[c, d] -06)------SubqueryAlias: t2 -07)--------Filter: annotated_data.d = Int32(3) -08)----------TableScan: annotated_data projection=[a, b, c, d], partial_filters=[annotated_data.d = Int32(3)] +05)--------Filter: annotated_data.d = Int32(3) +06)----------TableScan: annotated_data projection=[c, d], partial_filters=[annotated_data.d = Int32(3)] +07)------SubqueryAlias: t2 +08)--------Filter: annotated_data.d = Int32(3) +09)----------TableScan: annotated_data projection=[a, b, c, d], partial_filters=[annotated_data.d = Int32(3)] physical_plan 01)SortPreservingMergeExec: [a2@0 ASC NULLS LAST, b@1 ASC NULLS LAST], fetch=10 02)--ProjectionExec: expr=[a@0 as a2, b@1 as b] -03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(d@1, d@3), (c@0, c@2)], projection=[a@0, b@1] -04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], file_type=csv, has_header=true -05)------FilterExec: d@3 = 3 -06)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true -07)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], file_type=csv, has_header=true +03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(d@1, d@3), (c@0, c@2)], projection=[a@0, b@1], fetch=10 +04)------CoalescePartitionsExec +05)--------FilterExec: d@1 = 3 +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], file_type=csv, has_header=true +08)------FilterExec: d@3 = 3 +09)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true +10)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], file_type=csv, has_header=true # preserve_right_semi_join query II nosort diff --git a/datafusion/sqllogictest/test_files/join_is_not_distinct_from.slt b/datafusion/sqllogictest/test_files/join_is_not_distinct_from.slt index 8246f489c446..2bab89c99eae 100644 --- a/datafusion/sqllogictest/test_files/join_is_not_distinct_from.slt +++ b/datafusion/sqllogictest/test_files/join_is_not_distinct_from.slt @@ -291,6 +291,36 @@ JOIN t4 ON (t3.val1 IS NOT DISTINCT FROM t4.val1) AND (t3.val2 IS NOT DISTINCT F 2 2 NULL NULL 200 200 3 3 30 30 NULL NULL +# Test mixed: 1 Eq key + multiple IS NOT DISTINCT FROM keys. +# The optimizer unconditionally favours Eq keys (see extract_equijoin_predicate.rs, +# "Only convert when there are NO equijoin predicates, to be conservative"). +# All IS NOT DISTINCT FROM predicates should be demoted to filter, even when they outnumber the Eq key. +query TT +EXPLAIN SELECT t3.id AS t3_id, t4.id AS t4_id, t3.val1, t4.val1, t3.val2, t4.val2 +FROM t3 +JOIN t4 ON (t3.id = t4.id) AND (t3.val1 IS NOT DISTINCT FROM t4.val1) AND (t3.val2 IS NOT DISTINCT FROM t4.val2) +---- +logical_plan +01)Projection: t3.id AS t3_id, t4.id AS t4_id, t3.val1, t4.val1, t3.val2, t4.val2 +02)--Inner Join: t3.id = t4.id Filter: t3.val1 IS NOT DISTINCT FROM t4.val1 AND t3.val2 IS NOT DISTINCT FROM t4.val2 +03)----TableScan: t3 projection=[id, val1, val2] +04)----TableScan: t4 projection=[id, val1, val2] +physical_plan +01)ProjectionExec: expr=[id@0 as t3_id, id@3 as t4_id, val1@1 as val1, val1@4 as val1, val2@2 as val2, val2@5 as val2] +02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(id@0, id@0)], filter=val1@0 IS NOT DISTINCT FROM val1@2 AND val2@1 IS NOT DISTINCT FROM val2@3 +03)----DataSourceExec: partitions=1, partition_sizes=[1] +04)----DataSourceExec: partitions=1, partition_sizes=[1] + +# Verify correct results: all 3 rows should match (including NULL=NULL via IS NOT DISTINCT FROM in filter) +query IIIIII rowsort +SELECT t3.id AS t3_id, t4.id AS t4_id, t3.val1, t4.val1, t3.val2, t4.val2 +FROM t3 +JOIN t4 ON (t3.id = t4.id) AND (t3.val1 IS NOT DISTINCT FROM t4.val1) AND (t3.val2 IS NOT DISTINCT FROM t4.val2) +---- +1 1 10 10 100 100 +2 2 NULL NULL 200 200 +3 3 30 30 NULL NULL + statement ok drop table t0; diff --git a/datafusion/sqllogictest/test_files/join_limit_pushdown.slt b/datafusion/sqllogictest/test_files/join_limit_pushdown.slt new file mode 100644 index 000000000000..6bb23c1b4c24 --- /dev/null +++ b/datafusion/sqllogictest/test_files/join_limit_pushdown.slt @@ -0,0 +1,269 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tests for limit pushdown into joins + +# need to use a single partition for deterministic results +statement ok +set datafusion.execution.target_partitions = 1; + +statement ok +set datafusion.explain.logical_plan_only = false; + +statement ok +set datafusion.optimizer.prefer_hash_join = true; + +# Create test tables +statement ok +CREATE TABLE t1 (a INT, b VARCHAR) AS VALUES + (1, 'one'), + (2, 'two'), + (3, 'three'), + (4, 'four'), + (5, 'five'); + +statement ok +CREATE TABLE t2 (x INT, y VARCHAR) AS VALUES + (1, 'alpha'), + (2, 'beta'), + (3, 'gamma'), + (6, 'delta'), + (7, 'epsilon'); + +query TT +EXPLAIN SELECT t1.a, t2.x FROM t1 INNER JOIN t2 ON t1.a = t2.x LIMIT 2; +---- +logical_plan +01)Limit: skip=0, fetch=2 +02)--Inner Join: t1.a = t2.x +03)----TableScan: t1 projection=[a] +04)----TableScan: t2 projection=[x] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, x@0)], fetch=2 +02)--DataSourceExec: partitions=1, partition_sizes=[1] +03)--DataSourceExec: partitions=1, partition_sizes=[1] + +query II +SELECT t1.a, t2.x FROM t1 INNER JOIN t2 ON t1.a = t2.x LIMIT 2; +---- +1 1 +2 2 + +# Right join is converted to Left join with projection - fetch pushdown is supported +query TT +EXPLAIN SELECT t1.a, t2.x FROM t1 RIGHT JOIN t2 ON t1.a = t2.x LIMIT 3; +---- +logical_plan +01)Limit: skip=0, fetch=3 +02)--Right Join: t1.a = t2.x +03)----TableScan: t1 projection=[a] +04)----Limit: skip=0, fetch=3 +05)------TableScan: t2 projection=[x], fetch=3 +physical_plan +01)ProjectionExec: expr=[a@1 as a, x@0 as x] +02)--HashJoinExec: mode=CollectLeft, join_type=Left, on=[(x@0, a@0)], fetch=3 +03)----DataSourceExec: partitions=1, partition_sizes=[1], fetch=3 +04)----DataSourceExec: partitions=1, partition_sizes=[1] + +query II +SELECT t1.a, t2.x FROM t1 RIGHT JOIN t2 ON t1.a = t2.x LIMIT 3; +---- +1 1 +2 2 +3 3 + +# Left join supports fetch pushdown +query TT +EXPLAIN SELECT t1.a, t2.x FROM t1 LEFT JOIN t2 ON t1.a = t2.x LIMIT 3; +---- +logical_plan +01)Limit: skip=0, fetch=3 +02)--Left Join: t1.a = t2.x +03)----Limit: skip=0, fetch=3 +04)------TableScan: t1 projection=[a], fetch=3 +05)----TableScan: t2 projection=[x] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=Left, on=[(a@0, x@0)], fetch=3 +02)--DataSourceExec: partitions=1, partition_sizes=[1], fetch=3 +03)--DataSourceExec: partitions=1, partition_sizes=[1] + +query II +SELECT t1.a, t2.x FROM t1 LEFT JOIN t2 ON t1.a = t2.x LIMIT 3; +---- +1 1 +2 2 +3 3 + + +# Full join supports fetch pushdown +query TT +EXPLAIN SELECT t1.a, t2.x FROM t1 FULL OUTER JOIN t2 ON t1.a = t2.x LIMIT 4; +---- +logical_plan +01)Limit: skip=0, fetch=4 +02)--Full Join: t1.a = t2.x +03)----TableScan: t1 projection=[a] +04)----TableScan: t2 projection=[x] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=Full, on=[(a@0, x@0)], fetch=4 +02)--DataSourceExec: partitions=1, partition_sizes=[1] +03)--DataSourceExec: partitions=1, partition_sizes=[1] + +# Note: FULL OUTER JOIN order is not deterministic, so we just check count +query I +SELECT COUNT(*) FROM (SELECT t1.a, t2.x FROM t1 FULL OUTER JOIN t2 ON t1.a = t2.x LIMIT 4); +---- +4 + +# EXISTS becomes left semi join - fetch pushdown is supported +query TT +EXPLAIN SELECT t2.x FROM t2 WHERE EXISTS (SELECT 1 FROM t1 WHERE t1.a = t2.x) LIMIT 2; +---- +logical_plan +01)Limit: skip=0, fetch=2 +02)--LeftSemi Join: t2.x = __correlated_sq_1.a +03)----TableScan: t2 projection=[x] +04)----SubqueryAlias: __correlated_sq_1 +05)------TableScan: t1 projection=[a] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=LeftSemi, on=[(x@0, a@0)], fetch=2 +02)--DataSourceExec: partitions=1, partition_sizes=[1] +03)--DataSourceExec: partitions=1, partition_sizes=[1] + +query I +SELECT t2.x FROM t2 WHERE EXISTS (SELECT 1 FROM t1 WHERE t1.a = t2.x) LIMIT 2; +---- +1 +2 + +# NOT EXISTS becomes LeftAnti - fetch pushdown is supported +query TT +EXPLAIN SELECT t2.x FROM t2 WHERE NOT EXISTS (SELECT 1 FROM t1 WHERE t1.a = t2.x) LIMIT 1; +---- +logical_plan +01)Limit: skip=0, fetch=1 +02)--LeftAnti Join: t2.x = __correlated_sq_1.a +03)----TableScan: t2 projection=[x] +04)----SubqueryAlias: __correlated_sq_1 +05)------TableScan: t1 projection=[a] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=LeftAnti, on=[(x@0, a@0)], fetch=1 +02)--DataSourceExec: partitions=1, partition_sizes=[1] +03)--DataSourceExec: partitions=1, partition_sizes=[1] + +query I +SELECT t2.x FROM t2 WHERE NOT EXISTS (SELECT 1 FROM t1 WHERE t1.a = t2.x) LIMIT 1; +---- +6 + +# Inner join should push +query TT +EXPLAIN SELECT t1.a, t2.x FROM t1 INNER JOIN t2 ON t1.a = t2.x LIMIT 1 OFFSET 1; +---- +logical_plan +01)Limit: skip=1, fetch=1 +02)--Inner Join: t1.a = t2.x +03)----TableScan: t1 projection=[a] +04)----TableScan: t2 projection=[x] +physical_plan +01)GlobalLimitExec: skip=1, fetch=1 +02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, x@0)], fetch=2 +03)----DataSourceExec: partitions=1, partition_sizes=[1] +04)----DataSourceExec: partitions=1, partition_sizes=[1] + +query II +SELECT t1.a, t2.x FROM t1 INNER JOIN t2 ON t1.a = t2.x LIMIT 1 OFFSET 1; +---- +2 2 + +query TT +EXPLAIN SELECT t1.a, t2.x FROM t1 INNER JOIN t2 ON t1.a = t2.x LIMIT 0; +---- +logical_plan EmptyRelation: rows=0 +physical_plan EmptyExec + +query II +SELECT t1.a, t2.x FROM t1 INNER JOIN t2 ON t1.a = t2.x LIMIT 0; +---- + +statement ok +CREATE TABLE t3 (p INT, q VARCHAR) AS VALUES + (1, 'foo'), + (2, 'bar'), + (3, 'baz'); + +query TT +EXPLAIN SELECT t1.a, t2.x, t3.p +FROM t1 +INNER JOIN t2 ON t1.a = t2.x +INNER JOIN t3 ON t2.x = t3.p +LIMIT 2; +---- +logical_plan +01)Limit: skip=0, fetch=2 +02)--Inner Join: t2.x = t3.p +03)----Inner Join: t1.a = t2.x +04)------TableScan: t1 projection=[a] +05)------TableScan: t2 projection=[x] +06)----TableScan: t3 projection=[p] +physical_plan +01)ProjectionExec: expr=[a@1 as a, x@2 as x, p@0 as p] +02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(p@0, x@1)], fetch=2 +03)----DataSourceExec: partitions=1, partition_sizes=[1] +04)----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, x@0)] +05)------DataSourceExec: partitions=1, partition_sizes=[1] +06)------DataSourceExec: partitions=1, partition_sizes=[1] + +query III +SELECT t1.a, t2.x, t3.p +FROM t1 +INNER JOIN t2 ON t1.a = t2.x +INNER JOIN t3 ON t2.x = t3.p +LIMIT 2; +---- +1 1 1 +2 2 2 + +# Try larger limit +query TT +EXPLAIN SELECT t1.a, t2.x FROM t1 INNER JOIN t2 ON t1.a = t2.x LIMIT 100; +---- +logical_plan +01)Limit: skip=0, fetch=100 +02)--Inner Join: t1.a = t2.x +03)----TableScan: t1 projection=[a] +04)----TableScan: t2 projection=[x] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, x@0)], fetch=100 +02)--DataSourceExec: partitions=1, partition_sizes=[1] +03)--DataSourceExec: partitions=1, partition_sizes=[1] + +query II +SELECT t1.a, t2.x FROM t1 INNER JOIN t2 ON t1.a = t2.x LIMIT 100; +---- +1 1 +2 2 +3 3 + +statement ok +DROP TABLE t1; + +statement ok +DROP TABLE t2; + +statement ok +DROP TABLE t3; diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 38037ede21db..228918c3855f 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -57,15 +57,15 @@ statement ok CREATE TABLE join_t3(s3 struct) AS VALUES (NULL), - (struct(1)), - (struct(2)); + ({id: 1}), + ({id: 2}); statement ok CREATE TABLE join_t4(s4 struct) AS VALUES (NULL), - (struct(2)), - (struct(3)); + ({id: 2}), + ({id: 3}); # Left semi anti join @@ -146,10 +146,10 @@ AS VALUES statement ok CREATE TABLE test_timestamps_table as SELECT - arrow_cast(ts::timestamp::bigint, 'Timestamp(Nanosecond, None)') as nanos, - arrow_cast(ts::timestamp::bigint / 1000, 'Timestamp(Microsecond, None)') as micros, - arrow_cast(ts::timestamp::bigint / 1000000, 'Timestamp(Millisecond, None)') as millis, - arrow_cast(ts::timestamp::bigint / 1000000000, 'Timestamp(Second, None)') as secs, + arrow_cast(ts::timestamp::bigint, 'Timestamp(ns)') as nanos, + arrow_cast(ts::timestamp::bigint / 1000, 'Timestamp(µs)') as micros, + arrow_cast(ts::timestamp::bigint / 1000000, 'Timestamp(ms)') as millis, + arrow_cast(ts::timestamp::bigint / 1000000000, 'Timestamp(s)') as secs, names FROM test_timestamps_table_source; @@ -2085,7 +2085,7 @@ SELECT join_t1.t1_id, join_t2.t2_id FROM (select t1_id from join_t1 where join_t1.t1_id > 22) as join_t1 RIGHT JOIN (select t2_id from join_t2 where join_t2.t2_id > 11) as join_t2 ON join_t1.t1_id < join_t2.t2_id -ORDER BY 1, 2 +ORDER BY 1, 2 ---- 33 44 33 55 @@ -3516,7 +3516,6 @@ AS VALUES query IT SELECT t1_id, t1_name FROM join_test_left WHERE t1_id NOT IN (SELECT t2_id FROM join_test_right) ORDER BY t1_id; ---- -NULL e #### # join_partitioned_test @@ -3955,7 +3954,7 @@ query TT explain select t1_id, t1_name, i from join_t1 t1 cross join lateral (select * from unnest(generate_series(1, t1_int))) as series(i); ---- logical_plan -01)Cross Join: +01)Cross Join: 02)--SubqueryAlias: t1 03)----TableScan: join_t1 projection=[t1_id, t1_name] 04)--SubqueryAlias: series @@ -4162,10 +4161,9 @@ logical_plan 03)----TableScan: t0 projection=[c1, c2] 04)----TableScan: t1 projection=[c1, c2, c3] physical_plan -01)GlobalLimitExec: skip=0, fetch=2 -02)--HashJoinExec: mode=CollectLeft, join_type=Full, on=[(c1@0, c1@0)] -03)----DataSourceExec: partitions=1, partition_sizes=[2] -04)----DataSourceExec: partitions=1, partition_sizes=[2] +01)HashJoinExec: mode=CollectLeft, join_type=Full, on=[(c1@0, c1@0)], fetch=2 +02)--DataSourceExec: partitions=1, partition_sizes=[2] +03)--DataSourceExec: partitions=1, partition_sizes=[2] ## Test join.on.is_empty() && join.filter.is_some() -> single filter now a PWMJ query TT @@ -4192,10 +4190,9 @@ logical_plan 03)----TableScan: t0 projection=[c1, c2] 04)----TableScan: t1 projection=[c1, c2, c3] physical_plan -01)GlobalLimitExec: skip=0, fetch=2 -02)--HashJoinExec: mode=CollectLeft, join_type=Full, on=[(c1@0, c1@0)], filter=c2@0 >= c2@1 -03)----DataSourceExec: partitions=1, partition_sizes=[2] -04)----DataSourceExec: partitions=1, partition_sizes=[2] +01)HashJoinExec: mode=CollectLeft, join_type=Full, on=[(c1@0, c1@0)], filter=c2@0 >= c2@1, fetch=2 +02)--DataSourceExec: partitions=1, partition_sizes=[2] +03)--DataSourceExec: partitions=1, partition_sizes=[2] ## Add more test cases for join limit pushdown statement ok @@ -4246,6 +4243,7 @@ select * from t1 LEFT JOIN t2 ON t1.a = t2.b LIMIT 2; 1 1 # can only push down to t1 (preserved side) +# limit pushdown supported for left join - both to join and probe side query TT explain select * from t1 LEFT JOIN t2 ON t1.a = t2.b LIMIT 2; ---- @@ -4256,10 +4254,9 @@ logical_plan 04)------TableScan: t1 projection=[a], fetch=2 05)----TableScan: t2 projection=[b] physical_plan -01)GlobalLimitExec: skip=0, fetch=2 -02)--HashJoinExec: mode=CollectLeft, join_type=Left, on=[(a@0, b@0)] -03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t1.csv]]}, projection=[a], limit=2, file_type=csv, has_header=true -04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t2.csv]]}, projection=[b], file_type=csv, has_header=true +01)HashJoinExec: mode=CollectLeft, join_type=Left, on=[(a@0, b@0)], fetch=2 +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t1.csv]]}, projection=[a], limit=2, file_type=csv, has_header=true +03)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t2.csv]]}, projection=[b], file_type=csv, has_header=true ###### ## RIGHT JOIN w/ LIMIT @@ -4290,10 +4287,9 @@ logical_plan 04)----Limit: skip=0, fetch=2 05)------TableScan: t2 projection=[b], fetch=2 physical_plan -01)GlobalLimitExec: skip=0, fetch=2 -02)--HashJoinExec: mode=CollectLeft, join_type=Right, on=[(a@0, b@0)] -03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t1.csv]]}, projection=[a], file_type=csv, has_header=true -04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t2.csv]]}, projection=[b], limit=2, file_type=csv, has_header=true +01)HashJoinExec: mode=CollectLeft, join_type=Right, on=[(a@0, b@0)], fetch=2 +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t1.csv]]}, projection=[a], file_type=csv, has_header=true +03)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t2.csv]]}, projection=[b], limit=2, file_type=csv, has_header=true ###### ## FULL JOIN w/ LIMIT @@ -4317,7 +4313,7 @@ select * from t1 FULL JOIN t2 ON t1.a = t2.b LIMIT 2; 4 4 -# can't push limit for full outer join +# full outer join supports fetch pushdown query TT explain select * from t1 FULL JOIN t2 ON t1.a = t2.b LIMIT 2; ---- @@ -4327,10 +4323,9 @@ logical_plan 03)----TableScan: t1 projection=[a] 04)----TableScan: t2 projection=[b] physical_plan -01)GlobalLimitExec: skip=0, fetch=2 -02)--HashJoinExec: mode=CollectLeft, join_type=Full, on=[(a@0, b@0)] -03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t1.csv]]}, projection=[a], file_type=csv, has_header=true -04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t2.csv]]}, projection=[b], file_type=csv, has_header=true +01)HashJoinExec: mode=CollectLeft, join_type=Full, on=[(a@0, b@0)], fetch=2 +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t1.csv]]}, projection=[a], file_type=csv, has_header=true +03)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t2.csv]]}, projection=[b], file_type=csv, has_header=true statement ok drop table t1; @@ -4368,10 +4363,9 @@ logical_plan physical_plan 01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)] 02)--AggregateExec: mode=Single, gby=[], aggr=[count(Int64(1))] -03)----ProjectionExec: expr=[] -04)------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(binary_col@0, binary_col@0)] -05)--------DataSourceExec: partitions=1, partition_sizes=[1] -06)--------DataSourceExec: partitions=1, partition_sizes=[1] +03)----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(binary_col@0, binary_col@0)], projection=[] +04)------DataSourceExec: partitions=1, partition_sizes=[1] +05)------DataSourceExec: partitions=1, partition_sizes=[1] # Test hash join sort push down # Issue: https://github.com/apache/datafusion/issues/13559 @@ -4533,7 +4527,7 @@ query TT explain SELECT * FROM person a NATURAL JOIN lineitem b; ---- logical_plan -01)Cross Join: +01)Cross Join: 02)--SubqueryAlias: a 03)----TableScan: person projection=[id, age, state] 04)--SubqueryAlias: b @@ -4579,7 +4573,7 @@ query TT explain SELECT j1_string, j2_string FROM j1, LATERAL (SELECT * FROM j2 WHERE j1_id < j2_id) AS j2; ---- logical_plan -01)Cross Join: +01)Cross Join: 02)--TableScan: j1 projection=[j1_string] 03)--SubqueryAlias: j2 04)----Projection: j2.j2_string @@ -4592,7 +4586,7 @@ query TT explain SELECT * FROM j1 JOIN (j2 JOIN j3 ON(j2_id = j3_id - 2)) ON(j1_id = j2_id), LATERAL (SELECT * FROM j3 WHERE j3_string = j2_string) as j4 ---- logical_plan -01)Cross Join: +01)Cross Join: 02)--Inner Join: CAST(j2.j2_id AS Int64) = CAST(j3.j3_id AS Int64) - Int64(2) 03)----Inner Join: j1.j1_id = j2.j2_id 04)------TableScan: j1 projection=[j1_string, j1_id] @@ -4608,11 +4602,11 @@ query TT explain SELECT * FROM j1, LATERAL (SELECT * FROM j1, LATERAL (SELECT * FROM j2 WHERE j1_id = j2_id) as j2) as j2; ---- logical_plan -01)Cross Join: +01)Cross Join: 02)--TableScan: j1 projection=[j1_string, j1_id] 03)--SubqueryAlias: j2 04)----Subquery: -05)------Cross Join: +05)------Cross Join: 06)--------TableScan: j1 projection=[j1_string, j1_id] 07)--------SubqueryAlias: j2 08)----------Subquery: @@ -4624,7 +4618,7 @@ query TT explain SELECT j1_string, j2_string FROM j1 LEFT JOIN LATERAL (SELECT * FROM j2 WHERE j1_id < j2_id) AS j2 ON(true); ---- logical_plan -01)Left Join: +01)Left Join: 02)--TableScan: j1 projection=[j1_string] 03)--SubqueryAlias: j2 04)----Projection: j2.j2_string @@ -4637,9 +4631,9 @@ query TT explain SELECT * FROM j1, (j2 LEFT JOIN LATERAL (SELECT * FROM j3 WHERE j1_id + j2_id = j3_id) AS j3 ON(true)); ---- logical_plan -01)Cross Join: +01)Cross Join: 02)--TableScan: j1 projection=[j1_string, j1_id] -03)--Left Join: +03)--Left Join: 04)----TableScan: j2 projection=[j2_string, j2_id] 05)----SubqueryAlias: j3 06)------Subquery: @@ -4651,7 +4645,7 @@ query TT explain SELECT * FROM j1, LATERAL (SELECT 1) AS j2; ---- logical_plan -01)Cross Join: +01)Cross Join: 02)--TableScan: j1 projection=[j1_string, j1_id] 03)--SubqueryAlias: j2 04)----Projection: Int64(1) @@ -4993,7 +4987,7 @@ FULL JOIN t2 ON k1 = k2 # LEFT MARK JOIN query TT -EXPLAIN +EXPLAIN SELECT * FROM t2 WHERE k2 > 0 @@ -5050,9 +5044,10 @@ WHERE k1 < 0 ---- physical_plan 01)HashJoinExec: mode=CollectLeft, join_type=RightAnti, on=[(k2@0, k1@0)] -02)--DataSourceExec: partitions=1, partition_sizes=[0] -03)--FilterExec: k1@0 < 0 -04)----DataSourceExec: partitions=1, partition_sizes=[10000] +02)--FilterExec: k2@0 < 0 +03)----DataSourceExec: partitions=1, partition_sizes=[0] +04)--FilterExec: k1@0 < 0 +05)----DataSourceExec: partitions=1, partition_sizes=[10000] query II SELECT * @@ -5067,14 +5062,14 @@ CREATE OR REPLACE TABLE t1(b INT, c INT, d INT); statement ok INSERT INTO t1 VALUES - (10, 5, 3), - ( 1, 7, 8), - ( 2, 9, 7), - ( 3, 8,10), - ( 5, 6, 6), - ( 0, 4, 9), - ( 4, 8, 7), - (100,6, 5); + (10, 5, 3), + ( 1, 7, 8), + ( 2, 9, 7), + ( 3, 8,10), + ( 5, 6, 6), + ( 0, 4, 9), + ( 4, 8, 7), + (100,6, 5); query I rowsort SELECT c @@ -5198,3 +5193,171 @@ DROP TABLE t1_c; statement ok DROP TABLE t2_c; + +# Reproducer of https://github.com/apache/datafusion/issues/19067 +statement count 0 +set datafusion.explain.physical_plan_only = true; + +# Setup Left Table with FixedSizeBinary(4) +statement count 0 +CREATE TABLE issue_19067_left AS +SELECT + column1 as id, + arrow_cast(decode(column2, 'hex'), 'FixedSizeBinary(4)') as join_key +FROM (VALUES + (1, 'AAAAAAAA'), + (2, 'BBBBBBBB'), + (3, 'CCCCCCCC') +); + +# Setup Right Table with FixedSizeBinary(4) +statement count 0 +CREATE TABLE issue_19067_right AS +SELECT + arrow_cast(decode(column1, 'hex'), 'FixedSizeBinary(4)') as join_key, + column2 as value +FROM (VALUES + ('AAAAAAAA', 1000), + ('BBBBBBBB', 2000) +); + +# Perform Left Join. Third row should contain NULL in `right_key`. +query I??I +SELECT + l.id, + l.join_key as left_key, + r.join_key as right_key, + r.value +FROM issue_19067_left l +LEFT JOIN issue_19067_right r ON l.join_key = r.join_key +ORDER BY l.id; +---- +1 aaaaaaaa aaaaaaaa 1000 +2 bbbbbbbb bbbbbbbb 2000 +3 cccccccc NULL NULL + +# Ensure usage of HashJoinExec +query TT +EXPLAIN +SELECT + l.id, + l.join_key as left_key, + r.join_key as right_key, + r.value +FROM issue_19067_left l +LEFT JOIN issue_19067_right r ON l.join_key = r.join_key +ORDER BY l.id; +---- +physical_plan +01)SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--ProjectionExec: expr=[id@2 as id, join_key@3 as left_key, join_key@0 as right_key, value@1 as value] +03)----HashJoinExec: mode=CollectLeft, join_type=Right, on=[(join_key@0, join_key@1)] +04)------DataSourceExec: partitions=1, partition_sizes=[1] +05)------DataSourceExec: partitions=1, partition_sizes=[1] + +statement count 0 +set datafusion.explain.physical_plan_only = false; + +statement count 0 +DROP TABLE issue_19067_left; + +statement count 0 +DROP TABLE issue_19067_right; + +# Test that empty projections pushed into joins produce correct row counts at runtime. +# When count(1) is used over a RIGHT/FULL JOIN, the optimizer embeds an empty projection +# (projection=[]) into the HashJoinExec. This validates that the runtime batch construction +# handles zero-column output correctly, preserving the correct number of rows. + +statement ok +CREATE TABLE empty_proj_left AS VALUES (1, 'a'), (2, 'b'), (3, 'c'); + +statement ok +CREATE TABLE empty_proj_right AS VALUES (1, 'x'), (2, 'y'), (4, 'z'); + +query I +SELECT count(1) FROM empty_proj_left RIGHT JOIN empty_proj_right ON empty_proj_left.column1 = empty_proj_right.column1; +---- +3 + +query I +SELECT count(1) FROM empty_proj_left FULL JOIN empty_proj_right ON empty_proj_left.column1 = empty_proj_right.column1; +---- +4 + +statement count 0 +DROP TABLE empty_proj_left; + +statement count 0 +DROP TABLE empty_proj_right; + +# Issue #20437: HashJoin panic with dictionary-encoded columns in multi-key joins +# https://github.com/apache/datafusion/issues/20437 + +statement ok +CREATE TABLE issue_20437_small AS +SELECT id, arrow_cast(region, 'Dictionary(Int32, Utf8)') AS region +FROM (VALUES (1, 'west'), (2, 'west')) AS t(id, region); + +statement ok +CREATE TABLE issue_20437_large AS +SELECT id, region, value +FROM (VALUES (1, 'west', 100), (2, 'west', 200), (3, 'east', 300)) AS t(id, region, value); + +query ITI +SELECT s.id, s.region, l.value +FROM issue_20437_small s +JOIN issue_20437_large l ON s.id = l.id AND s.region = l.region +ORDER BY s.id; +---- +1 west 100 +2 west 200 + +statement count 0 +DROP TABLE issue_20437_small; + +statement count 0 +DROP TABLE issue_20437_large; + +# Test count(*) with right semi/anti joins returns correct row counts +# issue: https://github.com/apache/datafusion/issues/20669 + +statement ok +CREATE TABLE t1 (k INT, v INT); + +statement ok +CREATE TABLE t2 (k INT, v INT); + +statement ok +INSERT INTO t1 SELECT i AS k, i AS v FROM generate_series(1, 100) t(i); + +statement ok +INSERT INTO t2 VALUES (1, 1); + +query I +WITH t AS ( + SELECT * + FROM t1 + LEFT ANTI JOIN t2 ON t1.k = t2.k +) +SELECT count(*) +FROM t; +---- +99 + +query I +WITH t AS ( + SELECT * + FROM t1 + LEFT SEMI JOIN t2 ON t1.k = t2.k +) +SELECT count(*) +FROM t; +---- +1 + +statement count 0 +DROP TABLE t1; + +statement count 0 +DROP TABLE t2; diff --git a/datafusion/sqllogictest/test_files/json.slt b/datafusion/sqllogictest/test_files/json.slt index b46b8c49d662..60bec4213db0 100644 --- a/datafusion/sqllogictest/test_files/json.slt +++ b/datafusion/sqllogictest/test_files/json.slt @@ -146,3 +146,31 @@ EXPLAIN SELECT id FROM json_partitioned_test WHERE part = 2 ---- logical_plan TableScan: json_partitioned_test projection=[id], full_filters=[json_partitioned_test.part = Int32(2)] physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/partitioned_table_json/part=2/data.json]]}, projection=[id], file_type=json + +########## +## JSON Array Format Tests +########## + +# Test reading JSON array format file with newline_delimited=false +statement ok +CREATE EXTERNAL TABLE json_array_test +STORED AS JSON +OPTIONS ('format.newline_delimited' 'false') +LOCATION '../core/tests/data/json_array.json'; + +query IT rowsort +SELECT a, b FROM json_array_test +---- +1 hello +2 world +3 test + +statement ok +DROP TABLE json_array_test; + +# Test that reading JSON array format WITHOUT newline_delimited option fails +# (default is newline_delimited=true which can't parse array format correctly) +statement error Not valid JSON +CREATE EXTERNAL TABLE json_array_as_ndjson +STORED AS JSON +LOCATION '../core/tests/data/json_array.json'; diff --git a/datafusion/sqllogictest/test_files/limit.slt b/datafusion/sqllogictest/test_files/limit.slt index 524304546d56..f5ec26d304d4 100644 --- a/datafusion/sqllogictest/test_files/limit.slt +++ b/datafusion/sqllogictest/test_files/limit.slt @@ -679,19 +679,19 @@ ON t1.b = t2.b ORDER BY t1.b desc, c desc, c2 desc; ---- 3 98 96 -3 98 89 +3 98 87 3 98 82 3 98 79 3 97 96 -3 97 89 +3 97 87 3 97 82 3 97 79 3 96 96 -3 96 89 +3 96 87 3 96 82 3 96 79 3 95 96 -3 95 89 +3 95 87 3 95 82 3 95 79 @@ -706,8 +706,8 @@ ON t1.b = t2.b ORDER BY t1.b desc, c desc, c2 desc OFFSET 3 LIMIT 2; ---- -3 99 82 -3 99 79 +3 98 79 +3 97 96 statement ok drop table ordered_table; @@ -869,6 +869,45 @@ limit 1000; statement ok DROP TABLE test_limit_with_partitions; +# Tests for filter pushdown behavior with Sort + LIMIT (fetch). + +statement ok +CREATE TABLE t(id INT, value INT) AS VALUES +(1, 100), +(2, 200), +(3, 300), +(4, 400), +(5, 500); + +# Take the 3 smallest values (100, 200, 300), then filter value > 200. +query II +SELECT * FROM (SELECT * FROM t ORDER BY value LIMIT 3) sub WHERE sub.value > 200; +---- +3 300 + +# Take the 3 largest values (500, 400, 300), then filter value < 400. +query II +SELECT * FROM (SELECT * FROM t ORDER BY value DESC LIMIT 3) sub WHERE sub.value < 400; +---- +3 300 + +# The filter stays above the sort+fetch in the plan. +query TT +EXPLAIN SELECT * FROM (SELECT * FROM t ORDER BY value LIMIT 3) sub WHERE sub.value > 200; +---- +logical_plan +01)SubqueryAlias: sub +02)--Filter: t.value > Int32(200) +03)----Sort: t.value ASC NULLS LAST, fetch=3 +04)------TableScan: t projection=[id, value] +physical_plan +01)FilterExec: value@1 > 200 +02)--SortExec: TopK(fetch=3), expr=[value@1 ASC NULLS LAST], preserve_partitioning=[false] +03)----DataSourceExec: partitions=1, partition_sizes=[1] + +statement ok +DROP TABLE t; + # Tear down src_table table: statement ok DROP TABLE src_table; diff --git a/datafusion/sqllogictest/test_files/limit_pruning.slt b/datafusion/sqllogictest/test_files/limit_pruning.slt new file mode 100644 index 000000000000..72672b707d4f --- /dev/null +++ b/datafusion/sqllogictest/test_files/limit_pruning.slt @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +set datafusion.execution.parquet.pushdown_filters = true; + + +statement ok +CREATE TABLE tracking_data AS VALUES +-- ***** Row Group 0 ***** + ('Anow Vole', 7), + ('Brown Bear', 133), + ('Gray Wolf', 82), +-- ***** Row Group 1 ***** + ('Lynx', 71), + ('Red Fox', 40), + ('Alpine Bat', 6), +-- ***** Row Group 2 ***** + ('Nlpine Ibex', 101), + ('Nlpine Goat', 76), + ('Nlpine Sheep', 83), +-- ***** Row Group 3 ***** + ('Europ. Mole', 4), + ('Polecat', 16), + ('Alpine Ibex', 97); + +statement ok +COPY (SELECT column1 as species, column2 as s FROM tracking_data) +TO 'test_files/scratch/limit_pruning/data.parquet' +STORED AS PARQUET +OPTIONS ( + 'format.max_row_group_size' '3' +); + +statement ok +drop table tracking_data; + +statement ok +CREATE EXTERNAL TABLE tracking_data +STORED AS PARQUET +LOCATION 'test_files/scratch/limit_pruning/data.parquet'; + + +statement ok +set datafusion.explain.analyze_level = summary; + +# row_groups_pruned_statistics=4 total → 3 matched -> 1 fully matched +# limit_pruned_row_groups=2 total → 0 matched +query TT +explain analyze select * from tracking_data where species > 'M' AND s >= 50 limit 3; +---- +Plan with Metrics DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/limit_pruning/data.parquet]]}, projection=[species, s], limit=3, file_type=parquet, predicate=species@0 > M AND s@1 >= 50, pruning_predicate=species_null_count@1 != row_count@2 AND species_max@0 > M AND s_null_count@4 != row_count@2 AND s_max@3 >= 50, required_guarantees=[], metrics=[output_rows=3, elapsed_compute=, output_bytes=, files_ranges_pruned_statistics=1 total → 1 matched, row_groups_pruned_statistics=4 total → 3 matched -> 1 fully matched, row_groups_pruned_bloom_filter=3 total → 3 matched, page_index_pages_pruned=2 total → 2 matched, limit_pruned_row_groups=2 total → 0 matched, bytes_scanned=, metadata_load_time=, scan_efficiency_ratio= (171/2.35 K)] + +# limit_pruned_row_groups=0 total → 0 matched +# because of order by, scan needs to preserve sort, so limit pruning is disabled +query TT +explain analyze select * from tracking_data where species > 'M' AND s >= 50 order by species limit 3; +---- +Plan with Metrics +01)SortExec: TopK(fetch=3), expr=[species@0 ASC NULLS LAST], preserve_partitioning=[false], filter=[species@0 < Nlpine Sheep], metrics=[output_rows=3, elapsed_compute=, output_bytes=] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/limit_pruning/data.parquet]]}, projection=[species, s], file_type=parquet, predicate=species@0 > M AND s@1 >= 50 AND DynamicFilter [ species@0 < Nlpine Sheep ], pruning_predicate=species_null_count@1 != row_count@2 AND species_max@0 > M AND s_null_count@4 != row_count@2 AND s_max@3 >= 50 AND species_null_count@1 != row_count@2 AND species_min@5 < Nlpine Sheep, required_guarantees=[], metrics=[output_rows=3, elapsed_compute=, output_bytes=, files_ranges_pruned_statistics=1 total → 1 matched, row_groups_pruned_statistics=4 total → 3 matched -> 1 fully matched, row_groups_pruned_bloom_filter=3 total → 3 matched, page_index_pages_pruned=6 total → 6 matched, limit_pruned_row_groups=0 total → 0 matched, bytes_scanned=, metadata_load_time=, scan_efficiency_ratio= (521/2.35 K)] + +statement ok +drop table tracking_data; + +statement ok +reset datafusion.explain.analyze_level; diff --git a/datafusion/sqllogictest/test_files/limit_single_row_batches.slt b/datafusion/sqllogictest/test_files/limit_single_row_batches.slt new file mode 100644 index 000000000000..9f626816e214 --- /dev/null +++ b/datafusion/sqllogictest/test_files/limit_single_row_batches.slt @@ -0,0 +1,22 @@ + +# minimize batch size to 1 in order to trigger different code paths +statement ok +set datafusion.execution.batch_size = '1'; + +# ---- +# tests with target partition set to 1 +# ---- +statement ok +set datafusion.execution.target_partitions = '1'; + + +statement ok +CREATE TABLE filter_limit (i INT) as values (1), (2); + +query I +SELECT COUNT(*) FROM (SELECT i FROM filter_limit WHERE i <> 0 LIMIT 1); +---- +1 + +statement ok +DROP TABLE filter_limit; diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index 71a969c75159..2227466fdf25 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -111,12 +111,44 @@ SELECT isnan(1.0::FLOAT), isnan('NaN'::FLOAT), isnan(-'NaN'::FLOAT), isnan(NULL: ---- false true true NULL +# isnan: non-float numeric inputs are never NaN +query BBBB +SELECT isnan(1::INT), isnan(0::INT), isnan(NULL::INT), isnan(123::BIGINT) +---- +false false NULL false + +query BBBB +SELECT isnan(1::INT UNSIGNED), isnan(0::INT UNSIGNED), isnan(NULL::INT UNSIGNED), isnan(255::TINYINT UNSIGNED) +---- +false false NULL false + +query BBBB +SELECT isnan(1::DECIMAL(10,2)), isnan(0::DECIMAL(10,2)), isnan(NULL::DECIMAL(10,2)), isnan(-1::DECIMAL(10,2)) +---- +false false NULL false + # iszero query BBBB SELECT iszero(1.0), iszero(0.0), iszero(-0.0), iszero(NULL) ---- false true true NULL +# iszero: integers / unsigned / decimals +query BBBB +SELECT iszero(1::INT), iszero(0::INT), iszero(NULL::INT), iszero(-1::INT) +---- +false true NULL false + +query BBBB +SELECT iszero(1::INT UNSIGNED), iszero(0::INT UNSIGNED), iszero(NULL::INT UNSIGNED), iszero(255::TINYINT UNSIGNED) +---- +false true NULL false + +query BBBB +SELECT iszero(1::DECIMAL(10,2)), iszero(0::DECIMAL(10,2)), iszero(NULL::DECIMAL(10,2)), iszero(-1::DECIMAL(10,2)) +---- +false true NULL false + # abs: empty argument statement error SELECT abs(); diff --git a/datafusion/sqllogictest/test_files/metadata.slt b/datafusion/sqllogictest/test_files/metadata.slt index 41a511b5fa09..6ed461debb3b 100644 --- a/datafusion/sqllogictest/test_files/metadata.slt +++ b/datafusion/sqllogictest/test_files/metadata.slt @@ -24,6 +24,22 @@ ## in the test harness as there is no way to define schema ## with metadata in SQL. +query ITTPT +select * from table_with_metadata; +---- +1 NULL NULL 2020-09-08T13:42:29.190855123 no_foo +NULL bar l_bar 2020-09-08T13:42:29.190855123 no_bar +3 baz l_baz 2020-09-08T13:42:29.190855123 no_baz + +query TTT +describe table_with_metadata; +---- +id Int32 YES +name Utf8 YES +l_name Utf8 YES +ts Timestamp(ns) NO +nonnull_name Utf8 NO + query IT select id, name from table_with_metadata; ---- @@ -235,6 +251,28 @@ order by 1 asc nulls last; 3 1 NULL 1 +# Reproducer for https://github.com/apache/datafusion/issues/18337 +# this query should not get an internal error +query TI +SELECT + 'foo' AS name, + COUNT( + CASE + WHEN prev_value = 'no_bar' AND value = 'no_baz' THEN 1 + ELSE NULL + END + ) AS count_rises +FROM + ( + SELECT + nonnull_name as value, + LAG(nonnull_name) OVER (ORDER BY ts) AS prev_value + FROM + table_with_metadata +); +---- +foo 1 + # Regression test: first_value should preserve metadata query IT select first_value(id order by id asc nulls last), arrow_metadata(first_value(id order by id asc nulls last), 'metadata_key') diff --git a/datafusion/sqllogictest/test_files/null_aware_anti_join.slt b/datafusion/sqllogictest/test_files/null_aware_anti_join.slt new file mode 100644 index 000000000000..5907a85a9b92 --- /dev/null +++ b/datafusion/sqllogictest/test_files/null_aware_anti_join.slt @@ -0,0 +1,453 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +############# +## Null-Aware Anti Join Tests +## Tests for automatic null-aware semantics in NOT IN subqueries +############# + +statement ok +CREATE TABLE outer_table(id INT, value TEXT) AS VALUES +(1, 'a'), +(2, 'b'), +(3, 'c'), +(4, 'd'), +(NULL, 'e'); + +statement ok +CREATE TABLE inner_table_no_null(id INT, value TEXT) AS VALUES +(2, 'x'), +(4, 'y'); + +statement ok +CREATE TABLE inner_table_with_null(id INT, value TEXT) AS VALUES +(2, 'x'), +(NULL, 'y'); + +############# +## Test 1: NOT IN with no NULLs - should behave like regular anti join +############# + +query IT rowsort +SELECT * FROM outer_table WHERE id NOT IN (SELECT id FROM inner_table_no_null); +---- +1 a +3 c + +# Verify the plan uses LeftAnti join +query TT +EXPLAIN SELECT * FROM outer_table WHERE id NOT IN (SELECT id FROM inner_table_no_null); +---- +logical_plan +01)LeftAnti Join: outer_table.id = __correlated_sq_1.id +02)--TableScan: outer_table projection=[id, value] +03)--SubqueryAlias: __correlated_sq_1 +04)----TableScan: inner_table_no_null projection=[id] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=LeftAnti, on=[(id@0, id@0)] +02)--DataSourceExec: partitions=1, partition_sizes=[1] +03)--DataSourceExec: partitions=1, partition_sizes=[1] + +############# +## Test 2: NOT IN with NULL in subquery - should return 0 rows (null-aware semantics) +############# + +query IT rowsort +SELECT * FROM outer_table WHERE id NOT IN (SELECT id FROM inner_table_with_null); +---- + +# Verify the result is empty even though there are rows in outer_table +# that don't match the non-NULL value (2) in the subquery. +# This is correct null-aware behavior: if subquery contains NULL, result is unknown. + +############# +## Test 3: NOT IN with NULL in outer table but not in subquery +## NULL rows from outer should not appear in output +############# + +query IT rowsort +SELECT * FROM outer_table WHERE id NOT IN (SELECT id FROM inner_table_no_null) AND id IS NOT NULL; +---- +1 a +3 c + +############# +## Test 4: Test with all NULL subquery +############# + +statement ok +CREATE TABLE all_null_table(id INT) AS VALUES (NULL), (NULL); + +query IT rowsort +SELECT * FROM outer_table WHERE id NOT IN (SELECT id FROM all_null_table); +---- + +############# +## Test 5: Test with empty subquery - should return all rows +############# + +statement ok +CREATE TABLE empty_table(id INT, value TEXT); + +query IT rowsort +SELECT * FROM outer_table WHERE id NOT IN (SELECT id FROM empty_table); +---- +1 a +2 b +3 c +4 d +NULL e + +############# +## Test 6: NOT IN with complex expression +############# + +query IT rowsort +SELECT * FROM outer_table WHERE id + 1 NOT IN (SELECT id FROM inner_table_no_null); +---- +2 b +4 d + +############# +## Test 7: NOT IN with complex expression and NULL in subquery +############# + +query IT rowsort +SELECT * FROM outer_table WHERE id + 1 NOT IN (SELECT id FROM inner_table_with_null); +---- + +############# +## Test 8: Multiple NOT IN conditions (AND) +############# + +statement ok +CREATE TABLE inner_table2(id INT) AS VALUES (1), (3); + +query IT rowsort +SELECT * FROM outer_table +WHERE id NOT IN (SELECT id FROM inner_table_no_null) + AND id NOT IN (SELECT id FROM inner_table2); +---- + +############# +## Test 9: Multiple NOT IN conditions (OR) +############# + +# KNOWN LIMITATION: Mark joins used for OR conditions don't support null-aware semantics. +# The NULL row is incorrectly returned here. According to SQL semantics: +# - NULL NOT IN (2, 4) = UNKNOWN +# - NULL NOT IN (1, 3) = UNKNOWN +# - UNKNOWN OR UNKNOWN = UNKNOWN (should be filtered out) +# But mark joins treat NULL keys as non-matching (FALSE), so: +# - NULL mark column = FALSE +# - NOT FALSE OR NOT FALSE = TRUE OR TRUE = TRUE (incorrectly included) +# TODO: Implement null-aware support for mark joins to fix this + +query IT rowsort +SELECT * FROM outer_table +WHERE id NOT IN (SELECT id FROM inner_table_no_null) + OR id NOT IN (SELECT id FROM inner_table2); +---- +1 a +2 b +3 c +4 d +NULL e + +############# +## Test 10: NOT IN with WHERE clause in subquery +############# + +query IT rowsort +SELECT * FROM outer_table +WHERE id NOT IN (SELECT id FROM inner_table_with_null WHERE value = 'x'); +---- +1 a +3 c +4 d + +# Note: The NULL row from inner_table_with_null is filtered out by WHERE clause, +# so this behaves like regular anti join (not null-aware) + +############# +## Test 11: Verify NULL-aware flag is set for LeftAnti joins +############# + +# Check that the physical plan shows null-aware anti join +# Note: The exact format may vary, but we should see LeftAnti join type +query TT +EXPLAIN SELECT * FROM outer_table WHERE id NOT IN (SELECT id FROM inner_table_with_null); +---- +logical_plan +01)LeftAnti Join: outer_table.id = __correlated_sq_1.id +02)--TableScan: outer_table projection=[id, value] +03)--SubqueryAlias: __correlated_sq_1 +04)----TableScan: inner_table_with_null projection=[id] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=LeftAnti, on=[(id@0, id@0)] +02)--DataSourceExec: partitions=1, partition_sizes=[1] +03)--DataSourceExec: partitions=1, partition_sizes=[1] + +############# +## Test 12: Correlated NOT IN subquery with NULL +############# + +statement ok +CREATE TABLE orders(order_id INT, customer_id INT) AS VALUES +(1, 100), +(2, 200), +(3, 300); + +statement ok +CREATE TABLE payments(payment_id INT, order_id INT) AS VALUES +(1, 1), +(2, NULL); + +# Find orders that don't have payments +# Should return empty because there's a NULL in payments.order_id +query I rowsort +SELECT order_id FROM orders +WHERE order_id NOT IN (SELECT order_id FROM payments); +---- + +############# +## Test 13: NOT IN with DISTINCT in subquery +############# + +statement ok +CREATE TABLE duplicates_with_null(id INT) AS VALUES +(2), +(2), +(NULL), +(NULL); + +query IT rowsort +SELECT * FROM outer_table +WHERE id NOT IN (SELECT DISTINCT id FROM duplicates_with_null); +---- + +############# +## Test 14: NOT EXISTS vs NOT IN - Demonstrating the difference +############# + +# NOT EXISTS should NOT use null-aware semantics +# It uses two-valued logic (TRUE/FALSE), not three-valued logic (TRUE/FALSE/UNKNOWN) + +# Setup tables for comparison +statement ok +CREATE TABLE customers(id INT, name TEXT) AS VALUES +(1, 'Alice'), +(2, 'Bob'), +(3, 'Charlie'), +(NULL, 'Dave'); + +statement ok +CREATE TABLE banned(id INT) AS VALUES +(2), +(NULL); + +# Test 14a: NOT IN with NULL in subquery - Returns EMPTY (null-aware) +query IT rowsort +SELECT * FROM customers WHERE id NOT IN (SELECT id FROM banned); +---- + +# Test 14b: NOT EXISTS with NULL in subquery - Returns rows (NOT null-aware) +# This should return (1, 'Alice'), (3, 'Charlie'), (NULL, 'Dave') +# Because NOT EXISTS uses two-valued logic: NULL = NULL is FALSE, so no match found +query IT rowsort +SELECT * FROM customers c +WHERE NOT EXISTS (SELECT 1 FROM banned b WHERE c.id = b.id); +---- +1 Alice +3 Charlie +NULL Dave + +# Test 14c: Verify with EXPLAIN that NOT EXISTS doesn't use null-aware +query TT +EXPLAIN SELECT * FROM customers c +WHERE NOT EXISTS (SELECT 1 FROM banned b WHERE c.id = b.id); +---- +logical_plan +01)LeftAnti Join: c.id = __correlated_sq_1.id +02)--SubqueryAlias: c +03)----TableScan: customers projection=[id, name] +04)--SubqueryAlias: __correlated_sq_1 +05)----SubqueryAlias: b +06)------TableScan: banned projection=[id] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=RightAnti, on=[(id@0, id@0)] +02)--DataSourceExec: partitions=1, partition_sizes=[1] +03)--DataSourceExec: partitions=1, partition_sizes=[1] + +############# +## Test 15: NOT EXISTS - No NULLs +############# + +statement ok +CREATE TABLE active_customers(id INT) AS VALUES (1), (3); + +# Should return only Bob (id=2) and Dave (id=NULL) +query IT rowsort +SELECT * FROM customers c +WHERE NOT EXISTS (SELECT 1 FROM active_customers a WHERE c.id = a.id); +---- +2 Bob +NULL Dave + +############# +## Test 16: NOT EXISTS - Correlated subquery +############# + +statement ok +CREATE TABLE orders_test(order_id INT, customer_id INT) AS VALUES +(1, 100), +(2, 200), +(3, NULL); + +statement ok +CREATE TABLE customers_test(customer_id INT, name TEXT) AS VALUES +(100, 'Alice'), +(200, 'Bob'), +(300, 'Charlie'), +(NULL, 'Unknown'); + +# Find customers with no orders +# Should return Charlie (300) and Unknown (NULL) +query IT rowsort +SELECT * FROM customers_test c +WHERE NOT EXISTS ( + SELECT 1 FROM orders_test o WHERE o.customer_id = c.customer_id +); +---- +300 Charlie +NULL Unknown + +############# +## Test 17: NOT EXISTS with all NULL subquery +############# + +statement ok +CREATE TABLE all_null_banned(id INT) AS VALUES (NULL), (NULL); + +# NOT EXISTS should return all rows because NULL = NULL is FALSE (no matches) +query IT rowsort +SELECT * FROM customers c +WHERE NOT EXISTS (SELECT 1 FROM all_null_banned b WHERE c.id = b.id); +---- +1 Alice +2 Bob +3 Charlie +NULL Dave + +# Compare with NOT IN which returns empty +query IT rowsort +SELECT * FROM customers WHERE id NOT IN (SELECT id FROM all_null_banned); +---- + +############# +## Test 18: Nested NOT EXISTS and NOT IN +############# + +# NOT EXISTS outside, NOT IN inside - should work correctly +query IT rowsort +SELECT * FROM customers c +WHERE NOT EXISTS ( + SELECT 1 FROM banned b + WHERE c.id = b.id + AND b.id NOT IN (SELECT id FROM active_customers) +); +---- +1 Alice +3 Charlie +NULL Dave + +############# +## Test from GitHub issue #10583 +## Tests NOT IN with NULL in subquery result - should return empty result +############# + +statement ok +CREATE TABLE test_table(c1 INT, c2 INT) AS VALUES +(1, 1), +(2, 2), +(3, 3), +(4, NULL), +(NULL, 0); + +# When subquery contains NULL, NOT IN should return empty result +# because NULL NOT IN (values including NULL) is UNKNOWN for all rows +query II rowsort +SELECT * FROM test_table WHERE (c1 NOT IN (SELECT c2 FROM test_table)) = true; +---- + +# NOTE: The correlated subquery version from issue #10583: +# SELECT * FROM test_table t1 WHERE c1 NOT IN (SELECT c2 FROM test_table t2 WHERE t1.c1 = t2.c1) +# is not yet supported because it creates a multi-column join (correlation + NOT IN condition). +# This is a known limitation - currently only supports single column null-aware anti joins. +# This will be addressed in next Phase (multi-column support). + +############# +## Cleanup +############# + +statement ok +DROP TABLE test_table; + +statement ok +DROP TABLE outer_table; + +statement ok +DROP TABLE inner_table_no_null; + +statement ok +DROP TABLE inner_table_with_null; + +statement ok +DROP TABLE all_null_table; + +statement ok +DROP TABLE empty_table; + +statement ok +DROP TABLE inner_table2; + +statement ok +DROP TABLE orders; + +statement ok +DROP TABLE payments; + +statement ok +DROP TABLE duplicates_with_null; + +statement ok +DROP TABLE customers; + +statement ok +DROP TABLE banned; + +statement ok +DROP TABLE active_customers; + +statement ok +DROP TABLE orders_test; + +statement ok +DROP TABLE customers_test; + +statement ok +DROP TABLE all_null_banned; diff --git a/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt b/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt index 8bb79d576990..85f954935713 100644 --- a/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt +++ b/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt @@ -563,3 +563,380 @@ ORDER BY start_timestamp, trace_id LIMIT 1; ---- 2024-10-01T00:00:00 + +### +# Array function predicate pushdown tests +# These tests verify that array_has, array_has_all, and array_has_any predicates +# are correctly pushed down to the DataSourceExec node +### + +# Create test data with array columns +statement ok +COPY ( + SELECT 1 as id, ['rust', 'performance'] as tags + UNION ALL + SELECT 2 as id, ['python', 'javascript'] as tags + UNION ALL + SELECT 3 as id, ['rust', 'webassembly'] as tags +) +TO 'test_files/scratch/parquet_filter_pushdown/array_data/data.parquet'; + +statement ok +CREATE EXTERNAL TABLE array_test STORED AS PARQUET LOCATION 'test_files/scratch/parquet_filter_pushdown/array_data/'; + +statement ok +SET datafusion.execution.parquet.pushdown_filters = true; + +# Test array_has predicate pushdown +query I? +SELECT id, tags FROM array_test WHERE array_has(tags, 'rust') ORDER BY id; +---- +1 [rust, performance] +3 [rust, webassembly] + +query TT +EXPLAIN SELECT id, tags FROM array_test WHERE array_has(tags, 'rust') ORDER BY id; +---- +logical_plan +01)Sort: array_test.id ASC NULLS LAST +02)--Filter: array_has(array_test.tags, Utf8("rust")) +03)----TableScan: array_test projection=[id, tags], partial_filters=[array_has(array_test.tags, Utf8("rust"))] +physical_plan +01)SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/array_data/data.parquet]]}, projection=[id, tags], file_type=parquet, predicate=array_has(tags@1, rust) + +# Test array_has_all predicate pushdown +query I? +SELECT id, tags FROM array_test WHERE array_has_all(tags, ['rust', 'performance']) ORDER BY id; +---- +1 [rust, performance] + +query TT +EXPLAIN SELECT id, tags FROM array_test WHERE array_has_all(tags, ['rust', 'performance']) ORDER BY id; +---- +logical_plan +01)Sort: array_test.id ASC NULLS LAST +02)--Filter: array_has_all(array_test.tags, List([rust, performance])) +03)----TableScan: array_test projection=[id, tags], partial_filters=[array_has_all(array_test.tags, List([rust, performance]))] +physical_plan +01)SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/array_data/data.parquet]]}, projection=[id, tags], file_type=parquet, predicate=array_has_all(tags@1, [rust, performance]) + +# Test array_has_any predicate pushdown +query I? +SELECT id, tags FROM array_test WHERE array_has_any(tags, ['python', 'go']) ORDER BY id; +---- +2 [python, javascript] + +query TT +EXPLAIN SELECT id, tags FROM array_test WHERE array_has_any(tags, ['python', 'go']) ORDER BY id; +---- +logical_plan +01)Sort: array_test.id ASC NULLS LAST +02)--Filter: array_has_any(array_test.tags, List([python, go])) +03)----TableScan: array_test projection=[id, tags], partial_filters=[array_has_any(array_test.tags, List([python, go]))] +physical_plan +01)SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/array_data/data.parquet]]}, projection=[id, tags], file_type=parquet, predicate=array_has_any(tags@1, [python, go]) + +# Test complex predicate with OR +query I? +SELECT id, tags FROM array_test WHERE array_has_all(tags, ['rust']) OR array_has_any(tags, ['python', 'go']) ORDER BY id; +---- +1 [rust, performance] +2 [python, javascript] +3 [rust, webassembly] + +query TT +EXPLAIN SELECT id, tags FROM array_test WHERE array_has_all(tags, ['rust']) OR array_has_any(tags, ['python', 'go']) ORDER BY id; +---- +logical_plan +01)Sort: array_test.id ASC NULLS LAST +02)--Filter: array_has_all(array_test.tags, List([rust])) OR array_has_any(array_test.tags, List([python, go])) +03)----TableScan: array_test projection=[id, tags], partial_filters=[array_has_all(array_test.tags, List([rust])) OR array_has_any(array_test.tags, List([python, go]))] +physical_plan +01)SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/array_data/data.parquet]]}, projection=[id, tags], file_type=parquet, predicate=array_has_all(tags@1, [rust]) OR array_has_any(tags@1, [python, go]) + +# Test array function with other predicates +query I? +SELECT id, tags FROM array_test WHERE id > 1 AND array_has(tags, 'rust') ORDER BY id; +---- +3 [rust, webassembly] + +query TT +EXPLAIN SELECT id, tags FROM array_test WHERE id > 1 AND array_has(tags, 'rust') ORDER BY id; +---- +logical_plan +01)Sort: array_test.id ASC NULLS LAST +02)--Filter: array_test.id > Int64(1) AND array_has(array_test.tags, Utf8("rust")) +03)----TableScan: array_test projection=[id, tags], partial_filters=[array_test.id > Int64(1), array_has(array_test.tags, Utf8("rust"))] +physical_plan +01)SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/array_data/data.parquet]]}, projection=[id, tags], file_type=parquet, predicate=id@0 > 1 AND array_has(tags@1, rust), pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1, required_guarantees=[] + +### +# Test filter pushdown through UNION with mixed support +# This tests the case where one child supports filter pushdown (parquet) and one doesn't (memory table) +### + +# enable filter pushdown +statement ok +set datafusion.execution.parquet.pushdown_filters = true; + +statement ok +set datafusion.optimizer.max_passes = 0; + +# Create memory table with matching schema (a: VARCHAR, b: BIGINT) +statement ok +CREATE TABLE t_union_mem(a VARCHAR, b BIGINT) AS VALUES ('qux', 4), ('quux', 5); + +# Create parquet table with matching schema +statement ok +CREATE EXTERNAL TABLE t_union_parquet(a VARCHAR, b BIGINT) STORED AS PARQUET +LOCATION 'test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet'; + +# Query results combining memory table and Parquet with filter +query I rowsort +SELECT b FROM ( + SELECT a, b FROM t_union_mem + UNION ALL + SELECT a, b FROM t_union_parquet +) WHERE b > 2; +---- +3 +4 +5 +50 + +# Explain the union query - filter should be pushed to parquet but not memory table +query TT +EXPLAIN SELECT b FROM ( + SELECT a, b FROM t_union_mem + UNION ALL + SELECT a, b FROM t_union_parquet +) WHERE b > 2; +---- +logical_plan +01)Projection: b +02)--Filter: b > Int64(2) +03)----Union +04)------Projection: t_union_mem.a, t_union_mem.b +05)--------TableScan: t_union_mem +06)------Projection: t_union_parquet.a, t_union_parquet.b +07)--------TableScan: t_union_parquet +physical_plan +01)UnionExec +02)--FilterExec: b@0 > 2 +03)----DataSourceExec: partitions=1, partition_sizes=[1] +04)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet]]}, projection=[b], file_type=parquet, predicate=b@1 > 2, pruning_predicate=b_null_count@1 != row_count@2 AND b_max@0 > 2, required_guarantees=[] + +# Clean up union test tables +statement ok +DROP TABLE t_union_mem; + +statement ok +DROP TABLE t_union_parquet; + +# Cleanup settings +statement ok +set datafusion.optimizer.max_passes = 3; + +statement ok +set datafusion.execution.parquet.pushdown_filters = false; + + +# Regression test for https://github.com/apache/datafusion/issues/20696 +# Multi-column INNER JOIN with dictionary fails +# when parquet pushdown filters are enabled. + + +statement ok +COPY ( + SELECT + to_timestamp_nanos(time_ns) AS time, + arrow_cast(state, 'Dictionary(Int32, Utf8)') AS state, + arrow_cast(city, 'Dictionary(Int32, Utf8)') AS city, + temp + FROM ( + VALUES + (200, 'CA', 'LA', 90.0), + (250, 'MA', 'Boston', 72.4), + (100, 'MA', 'Boston', 70.4), + (350, 'CA', 'LA', 90.0) + ) AS t(time_ns, state, city, temp) +) +TO 'test_files/scratch/parquet_filter_pushdown/issue_20696/h2o/data.parquet'; + +statement ok +COPY ( + SELECT + to_timestamp_nanos(time_ns) AS time, + arrow_cast(state, 'Dictionary(Int32, Utf8)') AS state, + arrow_cast(city, 'Dictionary(Int32, Utf8)') AS city, + temp, + reading + FROM ( + VALUES + (250, 'MA', 'Boston', 53.4, 51.0), + (100, 'MA', 'Boston', 50.4, 50.0) + ) AS t(time_ns, state, city, temp, reading) +) +TO 'test_files/scratch/parquet_filter_pushdown/issue_20696/o2/data.parquet'; + +statement ok +CREATE EXTERNAL TABLE h2o_parquet_20696 STORED AS PARQUET +LOCATION 'test_files/scratch/parquet_filter_pushdown/issue_20696/h2o/'; + +statement ok +CREATE EXTERNAL TABLE o2_parquet_20696 STORED AS PARQUET +LOCATION 'test_files/scratch/parquet_filter_pushdown/issue_20696/o2/'; + +# Query should work both with and without filters +statement ok +set datafusion.execution.parquet.pushdown_filters = false; + +query RRR +SELECT + h2o_parquet_20696.temp AS h2o_temp, + o2_parquet_20696.temp AS o2_temp, + o2_parquet_20696.reading +FROM h2o_parquet_20696 +INNER JOIN o2_parquet_20696 + ON h2o_parquet_20696.time = o2_parquet_20696.time + AND h2o_parquet_20696.state = o2_parquet_20696.state + AND h2o_parquet_20696.city = o2_parquet_20696.city +WHERE h2o_parquet_20696.time >= '1970-01-01T00:00:00.000000050Z' + AND h2o_parquet_20696.time <= '1970-01-01T00:00:00.000000300Z'; +---- +72.4 53.4 51 +70.4 50.4 50 + +statement ok +set datafusion.execution.parquet.pushdown_filters = true; + +query RRR +SELECT + h2o_parquet_20696.temp AS h2o_temp, + o2_parquet_20696.temp AS o2_temp, + o2_parquet_20696.reading +FROM h2o_parquet_20696 +INNER JOIN o2_parquet_20696 + ON h2o_parquet_20696.time = o2_parquet_20696.time + AND h2o_parquet_20696.state = o2_parquet_20696.state + AND h2o_parquet_20696.city = o2_parquet_20696.city +WHERE h2o_parquet_20696.time >= '1970-01-01T00:00:00.000000050Z' + AND h2o_parquet_20696.time <= '1970-01-01T00:00:00.000000300Z'; +---- +72.4 53.4 51 +70.4 50.4 50 + +# Cleanup +statement ok +DROP TABLE h2o_parquet_20696; + +statement ok +DROP TABLE o2_parquet_20696; + +# Cleanup settings +statement ok +set datafusion.execution.parquet.pushdown_filters = false; + +########## +# Regression test: filter pushdown with Struct columns in schema +# +# When a schema has Struct columns, Arrow field indices diverge from Parquet +# leaf indices (Struct children become separate leaves). A filter on a +# primitive column *after* a Struct must use the correct Parquet leaf index. +# +# Schema: +# Arrow: col_a=0 struct_col=1 col_b=2 +# Parquet: col_a=0 struct_col.x=1 struct_col.y=2 col_b=3 +########## + +statement ok +set datafusion.execution.parquet.pushdown_filters = true; + +statement ok +COPY ( + SELECT + column1 as col_a, + column2 as struct_col, + column3 as col_b + FROM VALUES + (1, {x: 10, y: 100}, 'aaa'), + (2, {x: 20, y: 200}, 'target'), + (3, {x: 30, y: 300}, 'zzz') +) TO 'test_files/scratch/parquet_filter_pushdown/struct_filter.parquet' +STORED AS PARQUET; + +statement ok +CREATE EXTERNAL TABLE t_struct_filter +STORED AS PARQUET +LOCATION 'test_files/scratch/parquet_filter_pushdown/struct_filter.parquet'; + +# Filter on col_b (the primitive column after the struct). +# Before the fix, this returned 0 rows because the filter read struct_col.y +# (Parquet leaf 2) instead of col_b (Parquet leaf 3). +query IT +SELECT col_a, col_b FROM t_struct_filter WHERE col_b = 'target'; +---- +2 target + +# Clean up +statement ok +set datafusion.execution.parquet.pushdown_filters = false; + +statement ok +DROP TABLE t_struct_filter; + +########## +# Regression test for https://github.com/apache/datafusion/issues/20937 +# +# Dynamic filter pushdown fails when joining VALUES against +# Dictionary-encoded Parquet columns. The InListExpr's ArrayStaticFilter +# unwraps the needle Dictionary but not the stored in_array, causing a +# make_comparator(Utf8, Dictionary) type mismatch. +########## + +statement ok +set datafusion.execution.parquet.pushdown_filters = true; + +statement ok +set datafusion.execution.parquet.reorder_filters = true; + +statement ok +COPY ( + SELECT + arrow_cast(chr(65 + (row_num % 26)), 'Dictionary(Int32, Utf8)') as tag1, + row_num * 1.0 as value + FROM (SELECT unnest(range(0, 10000)) as row_num) +) TO 'test_files/scratch/parquet_filter_pushdown/dict_filter_bug.parquet'; + +statement ok +CREATE EXTERNAL TABLE dict_filter_bug +STORED AS PARQUET +LOCATION 'test_files/scratch/parquet_filter_pushdown/dict_filter_bug.parquet'; + +query TR +SELECT t.tag1, t.value +FROM dict_filter_bug t +JOIN (VALUES ('A'), ('B')) AS v(c1) +ON t.tag1 = v.c1 +ORDER BY t.tag1, t.value +LIMIT 4; +---- +A 0 +A 26 +A 52 +A 78 + +# Cleanup +statement ok +set datafusion.execution.parquet.pushdown_filters = false; + +statement ok +set datafusion.execution.parquet.reorder_filters = false; + +statement ok +DROP TABLE dict_filter_bug; diff --git a/datafusion/sqllogictest/test_files/parquet_sorted_statistics.slt b/datafusion/sqllogictest/test_files/parquet_sorted_statistics.slt index 5a559bdb9483..fd3a40ca1707 100644 --- a/datafusion/sqllogictest/test_files/parquet_sorted_statistics.slt +++ b/datafusion/sqllogictest/test_files/parquet_sorted_statistics.slt @@ -274,4 +274,4 @@ logical_plan 02)--TableScan: test_table projection=[constant_col] physical_plan 01)SortPreservingMergeExec: [constant_col@0 ASC NULLS LAST] -02)--DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=A/0.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=B/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=C/2.parquet]]}, projection=[constant_col], file_type=parquet +02)--DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=A/0.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=B/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=C/2.parquet]]}, projection=[constant_col], output_ordering=[constant_col@0 ASC NULLS LAST], file_type=parquet diff --git a/datafusion/sqllogictest/test_files/preserve_file_partitioning.slt b/datafusion/sqllogictest/test_files/preserve_file_partitioning.slt index 34c5fd97b51f..297094fab16e 100644 --- a/datafusion/sqllogictest/test_files/preserve_file_partitioning.slt +++ b/datafusion/sqllogictest/test_files/preserve_file_partitioning.slt @@ -101,6 +101,29 @@ STORED AS PARQUET; ---- 4 +# Create hive-partitioned dimension table (3 partitions matching fact_table) +# For testing Partitioned joins with matching partition counts +query I +COPY (SELECT 'dev' as env, 'log' as service) +TO 'test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=A/data.parquet' +STORED AS PARQUET; +---- +1 + +query I +COPY (SELECT 'prod' as env, 'log' as service) +TO 'test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=B/data.parquet' +STORED AS PARQUET; +---- +1 + +query I +COPY (SELECT 'prod' as env, 'log' as service) +TO 'test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=C/data.parquet' +STORED AS PARQUET; +---- +1 + # Create high-cardinality fact table (5 partitions > 3 target_partitions) # For testing partition merging with consistent hashing query I @@ -173,6 +196,13 @@ CREATE EXTERNAL TABLE dimension_table (d_dkey STRING, env STRING, service STRING STORED AS PARQUET LOCATION 'test_files/scratch/preserve_file_partitioning/dimension/'; +# Hive-partitioned dimension table (3 partitions matching fact_table for Partitioned join tests) +statement ok +CREATE EXTERNAL TABLE dimension_table_partitioned (env STRING, service STRING) +STORED AS PARQUET +PARTITIONED BY (d_dkey STRING) +LOCATION 'test_files/scratch/preserve_file_partitioning/dimension_partitioned/'; + # 'High'-cardinality fact table (5 partitions > 3 target_partitions) statement ok CREATE EXTERNAL TABLE high_cardinality_table (timestamp TIMESTAMP, value DOUBLE) @@ -579,6 +609,101 @@ C 1 300 D 1 400 E 1 500 +########## +# TEST 11: Partitioned Join with Matching Partition Counts - Without Optimization +# fact_table (3 partitions) joins dimension_table_partitioned (3 partitions) +# Shows RepartitionExec added when preserve_file_partitions is disabled +########## + +statement ok +set datafusion.optimizer.preserve_file_partitions = 0; + +# Force Partitioned join mode (not CollectLeft) +statement ok +set datafusion.optimizer.hash_join_single_partition_threshold = 0; + +statement ok +set datafusion.optimizer.hash_join_single_partition_threshold_rows = 0; + +query TT +EXPLAIN SELECT f.f_dkey, d.env, sum(f.value) +FROM fact_table f +INNER JOIN dimension_table_partitioned d ON f.f_dkey = d.d_dkey +GROUP BY f.f_dkey, d.env; +---- +logical_plan +01)Aggregate: groupBy=[[f.f_dkey, d.env]], aggr=[[sum(f.value)]] +02)--Projection: f.value, f.f_dkey, d.env +03)----Inner Join: f.f_dkey = d.d_dkey +04)------SubqueryAlias: f +05)--------TableScan: fact_table projection=[value, f_dkey] +06)------SubqueryAlias: d +07)--------TableScan: dimension_table_partitioned projection=[env, d_dkey] +physical_plan +01)AggregateExec: mode=FinalPartitioned, gby=[f_dkey@0 as f_dkey, env@1 as env], aggr=[sum(f.value)] +02)--RepartitionExec: partitioning=Hash([f_dkey@0, env@1], 3), input_partitions=3 +03)----AggregateExec: mode=Partial, gby=[f_dkey@1 as f_dkey, env@2 as env], aggr=[sum(f.value)] +04)------ProjectionExec: expr=[value@1 as value, f_dkey@2 as f_dkey, env@0 as env] +05)--------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(d_dkey@1, f_dkey@1)], projection=[env@0, value@2, f_dkey@3] +06)----------RepartitionExec: partitioning=Hash([d_dkey@1], 3), input_partitions=3 +07)------------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=A/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=B/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=C/data.parquet]]}, projection=[env, d_dkey], file_type=parquet +08)----------RepartitionExec: partitioning=Hash([f_dkey@1], 3), input_partitions=3 +09)------------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/fact/f_dkey=A/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/fact/f_dkey=B/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/fact/f_dkey=C/data.parquet]]}, projection=[value, f_dkey], file_type=parquet, predicate=DynamicFilter [ empty ] + +query TTR rowsort +SELECT f.f_dkey, d.env, sum(f.value) +FROM fact_table f +INNER JOIN dimension_table_partitioned d ON f.f_dkey = d.d_dkey +GROUP BY f.f_dkey, d.env; +---- +A dev 772.4 +B prod 614.4 +C prod 2017.6 + +########## +# TEST 12: Partitioned Join with Matching Partition Counts - With Optimization +# Both tables have 3 partitions matching target_partitions=3 +# No RepartitionExec needed for join - partitions already satisfy the requirement +# Dynamic filter pushdown is disabled in this mode because preserve_file_partitions +# reports Hash partitioning for Hive-style file groups, which are not hash-routed. +########## + +statement ok +set datafusion.optimizer.preserve_file_partitions = 1; + +query TT +EXPLAIN SELECT f.f_dkey, d.env, sum(f.value) +FROM fact_table f +INNER JOIN dimension_table_partitioned d ON f.f_dkey = d.d_dkey +GROUP BY f.f_dkey, d.env; +---- +logical_plan +01)Aggregate: groupBy=[[f.f_dkey, d.env]], aggr=[[sum(f.value)]] +02)--Projection: f.value, f.f_dkey, d.env +03)----Inner Join: f.f_dkey = d.d_dkey +04)------SubqueryAlias: f +05)--------TableScan: fact_table projection=[value, f_dkey] +06)------SubqueryAlias: d +07)--------TableScan: dimension_table_partitioned projection=[env, d_dkey] +physical_plan +01)AggregateExec: mode=FinalPartitioned, gby=[f_dkey@0 as f_dkey, env@1 as env], aggr=[sum(f.value)] +02)--RepartitionExec: partitioning=Hash([f_dkey@0, env@1], 3), input_partitions=3 +03)----AggregateExec: mode=Partial, gby=[f_dkey@1 as f_dkey, env@2 as env], aggr=[sum(f.value)] +04)------ProjectionExec: expr=[value@1 as value, f_dkey@2 as f_dkey, env@0 as env] +05)--------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(d_dkey@1, f_dkey@1)], projection=[env@0, value@2, f_dkey@3] +06)----------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=A/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=B/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=C/data.parquet]]}, projection=[env, d_dkey], file_type=parquet +07)----------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/fact/f_dkey=A/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/fact/f_dkey=B/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/fact/f_dkey=C/data.parquet]]}, projection=[value, f_dkey], file_type=parquet + +query TTR rowsort +SELECT f.f_dkey, d.env, sum(f.value) +FROM fact_table f +INNER JOIN dimension_table_partitioned d ON f.f_dkey = d.d_dkey +GROUP BY f.f_dkey, d.env; +---- +A dev 772.4 +B prod 614.4 +C prod 2017.6 + ########## # CLEANUP ########## @@ -592,5 +717,8 @@ DROP TABLE fact_table_ordered; statement ok DROP TABLE dimension_table; +statement ok +DROP TABLE dimension_table_partitioned; + statement ok DROP TABLE high_cardinality_table; diff --git a/datafusion/sqllogictest/test_files/projection.slt b/datafusion/sqllogictest/test_files/projection.slt index 5a4411233424..e18114bc51ca 100644 --- a/datafusion/sqllogictest/test_files/projection.slt +++ b/datafusion/sqllogictest/test_files/projection.slt @@ -167,12 +167,12 @@ set datafusion.explain.logical_plan_only = false # project cast dictionary query T -SELECT - CASE +SELECT + CASE WHEN cpu_load_short.host IS NULL THEN '' ELSE cpu_load_short.host END AS host -FROM +FROM cpu_load_short; ---- host1 @@ -275,7 +275,6 @@ logical_plan 02)--Filter: t1.a > Int64(1) 03)----TableScan: t1 projection=[a], partial_filters=[t1.a > Int64(1)] physical_plan -01)ProjectionExec: expr=[] -02)--FilterExec: a@0 > 1 -03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection/17513.parquet]]}, projection=[a], file_type=parquet, predicate=a@0 > 1, pruning_predicate=a_null_count@1 != row_count@2 AND a_max@0 > 1, required_guarantees=[] +01)FilterExec: a@0 > 1, projection=[] +02)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection/17513.parquet]]}, projection=[a], file_type=parquet, predicate=a@0 > 1, pruning_predicate=a_null_count@1 != row_count@2 AND a_max@0 > 1, required_guarantees=[] diff --git a/datafusion/sqllogictest/test_files/projection_pushdown.slt b/datafusion/sqllogictest/test_files/projection_pushdown.slt new file mode 100644 index 000000000000..dbb77b33c21b --- /dev/null +++ b/datafusion/sqllogictest/test_files/projection_pushdown.slt @@ -0,0 +1,1992 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +# Tests for projection pushdown behavior with get_field expressions +# +# This file tests the ExtractTrivialProjections optimizer rule and +# physical projection pushdown for: +# - get_field expressions (struct field access like s['foo']) +# - Pushdown through Filter, Sort, and TopK operators +# - Multi-partition scenarios with SortPreservingMergeExec +########## + +##################### +# Section 1: Setup - Single Partition Tests +##################### + +# Set target_partitions = 1 for deterministic plan output +statement ok +SET datafusion.execution.target_partitions = 1; + +# Create parquet file with struct column containing value and label fields +statement ok +COPY ( + SELECT + column1 as id, + column2 as s + FROM VALUES + (1, {value: 100, label: 'alpha'}), + (2, {value: 200, label: 'beta'}), + (3, {value: 150, label: 'gamma'}), + (4, {value: 300, label: 'delta'}), + (5, {value: 250, label: 'epsilon'}) +) TO 'test_files/scratch/projection_pushdown/simple.parquet' +STORED AS PARQUET; + +# Create table for simple struct tests +statement ok +CREATE EXTERNAL TABLE simple_struct STORED AS PARQUET +LOCATION 'test_files/scratch/projection_pushdown/simple.parquet'; + +# Create parquet file with nested struct column +statement ok +COPY ( + SELECT + column1 as id, + column2 as nested + FROM VALUES + (1, {outer: {inner: 10, name: 'one'}, extra: 'x'}), + (2, {outer: {inner: 20, name: 'two'}, extra: 'y'}), + (3, {outer: {inner: 30, name: 'three'}, extra: 'z'}) +) TO 'test_files/scratch/projection_pushdown/nested.parquet' +STORED AS PARQUET; + +# Create table for nested struct tests +statement ok +CREATE EXTERNAL TABLE nested_struct STORED AS PARQUET +LOCATION 'test_files/scratch/projection_pushdown/nested.parquet'; + +# Create parquet file with nullable struct column +statement ok +COPY ( + SELECT + column1 as id, + column2 as s + FROM VALUES + (1, {value: 100, label: 'alpha'}), + (2, NULL), + (3, {value: 150, label: 'gamma'}), + (4, NULL), + (5, {value: 250, label: 'epsilon'}) +) TO 'test_files/scratch/projection_pushdown/nullable.parquet' +STORED AS PARQUET; + +# Create table for nullable struct tests +statement ok +CREATE EXTERNAL TABLE nullable_struct STORED AS PARQUET +LOCATION 'test_files/scratch/projection_pushdown/nullable.parquet'; + + +##################### +# Section 2: Basic get_field Pushdown (Projection above scan) +##################### + +### +# Test 2.1: Simple s['value'] +### + +query TT +EXPLAIN SELECT id, s['value'] FROM simple_struct; +---- +logical_plan +01)Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")) +02)--TableScan: simple_struct projection=[id, s] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as simple_struct.s[value]], file_type=parquet + +# Verify correctness +query II +SELECT id, s['value'] FROM simple_struct ORDER BY id; +---- +1 100 +2 200 +3 150 +4 300 +5 250 + +query TT +EXPLAIN SELECT s['label'] FROM simple_struct; +---- +logical_plan +01)Projection: get_field(simple_struct.s, Utf8("label")) +02)--TableScan: simple_struct projection=[s] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, label) as simple_struct.s[label]], file_type=parquet + +# Verify correctness +query T +SELECT s['label'] FROM simple_struct ORDER BY s['label']; +---- +alpha +beta +delta +epsilon +gamma + +### +# Test 2.2: Multiple get_field expressions +### + +query TT +EXPLAIN SELECT id, s['value'], s['label'] FROM simple_struct; +---- +logical_plan +01)Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")), get_field(simple_struct.s, Utf8("label")) +02)--TableScan: simple_struct projection=[id, s] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as simple_struct.s[value], get_field(s@1, label) as simple_struct.s[label]], file_type=parquet + +# Verify correctness +query IIT +SELECT id, s['value'], s['label'] FROM simple_struct ORDER BY id; +---- +1 100 alpha +2 200 beta +3 150 gamma +4 300 delta +5 250 epsilon + +### +# Test 2.3: Nested s['outer']['inner'] +### + +query TT +EXPLAIN SELECT id, nested['outer']['inner'] FROM nested_struct; +---- +logical_plan +01)Projection: nested_struct.id, get_field(nested_struct.nested, Utf8("outer"), Utf8("inner")) +02)--TableScan: nested_struct projection=[id, nested] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/nested.parquet]]}, projection=[id, get_field(nested@1, outer, inner) as nested_struct.nested[outer][inner]], file_type=parquet + +# Verify correctness +query II +SELECT id, nested['outer']['inner'] FROM nested_struct ORDER BY id; +---- +1 10 +2 20 +3 30 + +### +# Test 2.4: s['value'] + 1 +### + +query TT +EXPLAIN SELECT id, s['value'] + 1 FROM simple_struct; +---- +logical_plan +01)Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")) + Int64(1) +02)--TableScan: simple_struct projection=[id, s] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) + 1 as simple_struct.s[value] + Int64(1)], file_type=parquet + +# Verify correctness +query II +SELECT id, s['value'] + 1 FROM simple_struct ORDER BY id; +---- +1 101 +2 201 +3 151 +4 301 +5 251 + +### +# Test 2.5: s['label'] || '_suffix' +### + +query TT +EXPLAIN SELECT id, s['label'] || '_suffix' FROM simple_struct; +---- +logical_plan +01)Projection: simple_struct.id, get_field(simple_struct.s, Utf8("label")) || Utf8("_suffix") +02)--TableScan: simple_struct projection=[id, s] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, label) || _suffix as simple_struct.s[label] || Utf8("_suffix")], file_type=parquet + +# Verify correctness +query IT +SELECT id, s['label'] || '_suffix' FROM simple_struct ORDER BY id; +---- +1 alpha_suffix +2 beta_suffix +3 gamma_suffix +4 delta_suffix +5 epsilon_suffix + + +##################### +# Section 3: Projection Pushdown Through FilterExec +##################### + +### +# Test 3.1: Simple get_field through Filter +### + +query TT +EXPLAIN SELECT id, s['value'] FROM simple_struct WHERE id > 2; +---- +logical_plan +01)Projection: simple_struct.id, __datafusion_extracted_1 AS simple_struct.s[value] +02)--Filter: simple_struct.id > Int64(2) +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +04)------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(2)] +physical_plan +01)ProjectionExec: expr=[id@1 as id, __datafusion_extracted_1@0 as simple_struct.s[value]] +02)--FilterExec: id@1 > 2 +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 > 2, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 2, required_guarantees=[] + +# Verify correctness +query II +SELECT id, s['value'] FROM simple_struct WHERE id > 2 ORDER BY id; +---- +3 150 +4 300 +5 250 + +### +# Test 3.2: s['value'] + 1 through Filter +### + +query TT +EXPLAIN SELECT id, s['value'] + 1 FROM simple_struct WHERE id > 2; +---- +logical_plan +01)Projection: simple_struct.id, __datafusion_extracted_1 + Int64(1) AS simple_struct.s[value] + Int64(1) +02)--Filter: simple_struct.id > Int64(2) +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +04)------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(2)] +physical_plan +01)ProjectionExec: expr=[id@1 as id, __datafusion_extracted_1@0 + 1 as simple_struct.s[value] + Int64(1)] +02)--FilterExec: id@1 > 2 +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 > 2, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 2, required_guarantees=[] + +# Verify correctness +query II +SELECT id, s['value'] + 1 FROM simple_struct WHERE id > 2 ORDER BY id; +---- +3 151 +4 301 +5 251 + +### +# Test 3.3: Filter on get_field expression +### + +query TT +EXPLAIN SELECT id, s['label'] FROM simple_struct WHERE s['value'] > 150; +---- +logical_plan +01)Projection: simple_struct.id, __datafusion_extracted_2 AS simple_struct.s[label] +02)--Filter: __datafusion_extracted_1 > Int64(150) +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id, get_field(simple_struct.s, Utf8("label")) AS __datafusion_extracted_2 +04)------TableScan: simple_struct projection=[id, s], partial_filters=[get_field(simple_struct.s, Utf8("value")) > Int64(150)] +physical_plan +01)ProjectionExec: expr=[id@0 as id, __datafusion_extracted_2@1 as simple_struct.s[label]] +02)--FilterExec: __datafusion_extracted_1@0 > 150, projection=[id@1, __datafusion_extracted_2@2] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id, get_field(s@1, label) as __datafusion_extracted_2], file_type=parquet + +# Verify correctness +query IT +SELECT id, s['label'] FROM simple_struct WHERE s['value'] > 150 ORDER BY id; +---- +2 beta +4 delta +5 epsilon + + +##################### +# Section 4: Projection Pushdown Through SortExec (no LIMIT) +##################### + +### +# Test 4.1: Simple get_field through Sort +### + +query TT +EXPLAIN SELECT id, s['value'] FROM simple_struct ORDER BY id; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST +02)--Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")) +03)----TableScan: simple_struct projection=[id, s] +physical_plan +01)SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as simple_struct.s[value]], file_type=parquet + +# Verify correctness +query II +SELECT id, s['value'] FROM simple_struct ORDER BY id; +---- +1 100 +2 200 +3 150 +4 300 +5 250 + +### +# Test 4.2: s['value'] + 1 through Sort - split projection +### + +query TT +EXPLAIN SELECT id, s['value'] + 1 FROM simple_struct ORDER BY id; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST +02)--Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")) + Int64(1) +03)----TableScan: simple_struct projection=[id, s] +physical_plan +01)SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) + 1 as simple_struct.s[value] + Int64(1)], file_type=parquet + +# Verify correctness +query II +SELECT id, s['value'] + 1 FROM simple_struct ORDER BY id; +---- +1 101 +2 201 +3 151 +4 301 +5 251 + +### +# Test 4.3: Sort by get_field expression +### + +query TT +EXPLAIN SELECT id, s['value'] FROM simple_struct ORDER BY s['value']; +---- +logical_plan +01)Sort: simple_struct.s[value] ASC NULLS LAST +02)--Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")) +03)----TableScan: simple_struct projection=[id, s] +physical_plan +01)SortExec: expr=[simple_struct.s[value]@1 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as simple_struct.s[value]], file_type=parquet + +# Verify correctness +query II +SELECT id, s['value'] FROM simple_struct ORDER BY s['value']; +---- +1 100 +3 150 +2 200 +5 250 +4 300 + +### +# Test 4.4: Projection with duplicate column through Sort +# The projection expands the number of columns from 3 to 4 by introducing `col_b_dup` +### + +statement ok +COPY ( + SELECT + column1 as col_a, + column2 as col_b, + column3 as col_c + FROM VALUES + (1, 2, 3), + (4, 5, 6), + (7, 8, 9) +) TO 'test_files/scratch/projection_pushdown/three_cols.parquet' +STORED AS PARQUET; + +statement ok +CREATE EXTERNAL TABLE three_cols STORED AS PARQUET +LOCATION 'test_files/scratch/projection_pushdown/three_cols.parquet'; + +query TT +EXPLAIN SELECT col_a, col_b, col_c, col_b as col_b_dup FROM three_cols ORDER BY col_a; +---- +logical_plan +01)Sort: three_cols.col_a ASC NULLS LAST +02)--Projection: three_cols.col_a, three_cols.col_b, three_cols.col_c, three_cols.col_b AS col_b_dup +03)----TableScan: three_cols projection=[col_a, col_b, col_c] +physical_plan +01)SortExec: expr=[col_a@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/three_cols.parquet]]}, projection=[col_a, col_b, col_c, col_b@1 as col_b_dup], file_type=parquet + +# Verify correctness +query IIII +SELECT col_a, col_b, col_c, col_b as col_b_dup FROM three_cols ORDER BY col_a DESC; +---- +7 8 9 8 +4 5 6 5 +1 2 3 2 + +statement ok +DROP TABLE three_cols; + + +##################### +# Section 5: Projection Pushdown Through TopK (ORDER BY + LIMIT) +##################### + +### +# Test 5.1: Simple get_field through TopK +### + +query TT +EXPLAIN SELECT id, s['value'] FROM simple_struct ORDER BY id LIMIT 3; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST, fetch=3 +02)--Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")) +03)----TableScan: simple_struct projection=[id, s] +physical_plan +01)SortExec: TopK(fetch=3), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as simple_struct.s[value]], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query II +SELECT id, s['value'] FROM simple_struct ORDER BY id LIMIT 3; +---- +1 100 +2 200 +3 150 + +### +# Test 5.2: s['value'] + 1 through TopK +### + +query TT +EXPLAIN SELECT id, s['value'] + 1 FROM simple_struct ORDER BY id LIMIT 3; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST, fetch=3 +02)--Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")) + Int64(1) +03)----TableScan: simple_struct projection=[id, s] +physical_plan +01)SortExec: TopK(fetch=3), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) + 1 as simple_struct.s[value] + Int64(1)], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query II +SELECT id, s['value'] + 1 FROM simple_struct ORDER BY id LIMIT 3; +---- +1 101 +2 201 +3 151 + +### +# Test 5.3: Multiple get_field through TopK +### + +query TT +EXPLAIN SELECT id, s['value'], s['label'] FROM simple_struct ORDER BY id LIMIT 3; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST, fetch=3 +02)--Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")), get_field(simple_struct.s, Utf8("label")) +03)----TableScan: simple_struct projection=[id, s] +physical_plan +01)SortExec: TopK(fetch=3), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as simple_struct.s[value], get_field(s@1, label) as simple_struct.s[label]], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query IIT +SELECT id, s['value'], s['label'] FROM simple_struct ORDER BY id LIMIT 3; +---- +1 100 alpha +2 200 beta +3 150 gamma + +### +# Test 5.4: Nested get_field through TopK +### + +query TT +EXPLAIN SELECT id, nested['outer']['inner'] FROM nested_struct ORDER BY id LIMIT 2; +---- +logical_plan +01)Sort: nested_struct.id ASC NULLS LAST, fetch=2 +02)--Projection: nested_struct.id, get_field(nested_struct.nested, Utf8("outer"), Utf8("inner")) +03)----TableScan: nested_struct projection=[id, nested] +physical_plan +01)SortExec: TopK(fetch=2), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/nested.parquet]]}, projection=[id, get_field(nested@1, outer, inner) as nested_struct.nested[outer][inner]], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query II +SELECT id, nested['outer']['inner'] FROM nested_struct ORDER BY id LIMIT 2; +---- +1 10 +2 20 + +### +# Test 5.5: String concat through TopK +### + +query TT +EXPLAIN SELECT id, s['label'] || '_suffix' FROM simple_struct ORDER BY id LIMIT 3; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST, fetch=3 +02)--Projection: simple_struct.id, get_field(simple_struct.s, Utf8("label")) || Utf8("_suffix") +03)----TableScan: simple_struct projection=[id, s] +physical_plan +01)SortExec: TopK(fetch=3), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, label) || _suffix as simple_struct.s[label] || Utf8("_suffix")], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query IT +SELECT id, s['label'] || '_suffix' FROM simple_struct ORDER BY id LIMIT 3; +---- +1 alpha_suffix +2 beta_suffix +3 gamma_suffix + + +##################### +# Section 6: Combined Operators (Filter + Sort/TopK) +##################### + +### +# Test 6.1: Filter + Sort + get_field +### + +query TT +EXPLAIN SELECT id, s['value'] FROM simple_struct WHERE id > 1 ORDER BY s['value']; +---- +logical_plan +01)Sort: simple_struct.s[value] ASC NULLS LAST +02)--Projection: simple_struct.id, __datafusion_extracted_1 AS simple_struct.s[value] +03)----Filter: simple_struct.id > Int64(1) +04)------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +05)--------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(1)] +physical_plan +01)SortExec: expr=[simple_struct.s[value]@1 ASC NULLS LAST], preserve_partitioning=[false] +02)--ProjectionExec: expr=[id@1 as id, __datafusion_extracted_1@0 as simple_struct.s[value]] +03)----FilterExec: id@1 > 1 +04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 > 1, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1, required_guarantees=[] + +# Verify correctness +query II +SELECT id, s['value'] FROM simple_struct WHERE id > 1 ORDER BY s['value']; +---- +3 150 +2 200 +5 250 +4 300 + +### +# Test 6.2: Filter + TopK + get_field +### + +query TT +EXPLAIN SELECT id, s['value'] FROM simple_struct WHERE id > 1 ORDER BY s['value'] LIMIT 2; +---- +logical_plan +01)Sort: simple_struct.s[value] ASC NULLS LAST, fetch=2 +02)--Projection: simple_struct.id, __datafusion_extracted_1 AS simple_struct.s[value] +03)----Filter: simple_struct.id > Int64(1) +04)------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +05)--------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(1)] +physical_plan +01)SortExec: TopK(fetch=2), expr=[simple_struct.s[value]@1 ASC NULLS LAST], preserve_partitioning=[false] +02)--ProjectionExec: expr=[id@1 as id, __datafusion_extracted_1@0 as simple_struct.s[value]] +03)----FilterExec: id@1 > 1 +04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 > 1, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1, required_guarantees=[] + +# Verify correctness +query II +SELECT id, s['value'] FROM simple_struct WHERE id > 1 ORDER BY s['value'] LIMIT 2; +---- +3 150 +2 200 + +### +# Test 6.3: Filter + TopK + get_field with arithmetic +### + +query TT +EXPLAIN SELECT id, s['value'] + 1 FROM simple_struct WHERE id > 1 ORDER BY id LIMIT 2; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST, fetch=2 +02)--Projection: simple_struct.id, __datafusion_extracted_1 + Int64(1) AS simple_struct.s[value] + Int64(1) +03)----Filter: simple_struct.id > Int64(1) +04)------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +05)--------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(1)] +physical_plan +01)SortExec: TopK(fetch=2), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--ProjectionExec: expr=[id@1 as id, __datafusion_extracted_1@0 + 1 as simple_struct.s[value] + Int64(1)] +03)----FilterExec: id@1 > 1 +04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 > 1 AND DynamicFilter [ empty ], pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1, required_guarantees=[] + +# Verify correctness +query II +SELECT id, s['value'] + 1 FROM simple_struct WHERE id > 1 ORDER BY id LIMIT 2; +---- +2 201 +3 151 + + +##################### +# Section 7: Multi-Partition Tests +##################### + +# Set target_partitions = 4 for parallel execution +statement ok +SET datafusion.execution.target_partitions = 4; + +# Create 5 parquet files (more than partitions) for parallel tests +statement ok +COPY (SELECT 1 as id, {value: 100, label: 'alpha'} as s) +TO 'test_files/scratch/projection_pushdown/multi/part1.parquet' +STORED AS PARQUET; + +statement ok +COPY (SELECT 2 as id, {value: 200, label: 'beta'} as s) +TO 'test_files/scratch/projection_pushdown/multi/part2.parquet' +STORED AS PARQUET; + +statement ok +COPY (SELECT 3 as id, {value: 150, label: 'gamma'} as s) +TO 'test_files/scratch/projection_pushdown/multi/part3.parquet' +STORED AS PARQUET; + +statement ok +COPY (SELECT 4 as id, {value: 300, label: 'delta'} as s) +TO 'test_files/scratch/projection_pushdown/multi/part4.parquet' +STORED AS PARQUET; + +statement ok +COPY (SELECT 5 as id, {value: 250, label: 'epsilon'} as s) +TO 'test_files/scratch/projection_pushdown/multi/part5.parquet' +STORED AS PARQUET; + +# Create table from multiple parquet files +statement ok +CREATE EXTERNAL TABLE multi_struct STORED AS PARQUET +LOCATION 'test_files/scratch/projection_pushdown/multi/'; + +### +# Test 7.1: Multi-partition Sort with get_field +### + +query TT +EXPLAIN SELECT id, s['value'] FROM multi_struct ORDER BY id; +---- +logical_plan +01)Sort: multi_struct.id ASC NULLS LAST +02)--Projection: multi_struct.id, get_field(multi_struct.s, Utf8("value")) +03)----TableScan: multi_struct projection=[id, s] +physical_plan +01)SortPreservingMergeExec: [id@0 ASC NULLS LAST] +02)--SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part1.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part2.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part3.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part4.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part5.parquet]]}, projection=[id, get_field(s@1, value) as multi_struct.s[value]], file_type=parquet + +# Verify correctness +query II +SELECT id, s['value'] FROM multi_struct ORDER BY id; +---- +1 100 +2 200 +3 150 +4 300 +5 250 + +### +# Test 7.2: Multi-partition TopK with get_field +### + +query TT +EXPLAIN SELECT id, s['value'] FROM multi_struct ORDER BY id LIMIT 3; +---- +logical_plan +01)Sort: multi_struct.id ASC NULLS LAST, fetch=3 +02)--Projection: multi_struct.id, get_field(multi_struct.s, Utf8("value")) +03)----TableScan: multi_struct projection=[id, s] +physical_plan +01)SortPreservingMergeExec: [id@0 ASC NULLS LAST], fetch=3 +02)--SortExec: TopK(fetch=3), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part1.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part2.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part3.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part4.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part5.parquet]]}, projection=[id, get_field(s@1, value) as multi_struct.s[value]], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query II +SELECT id, s['value'] FROM multi_struct ORDER BY id LIMIT 3; +---- +1 100 +2 200 +3 150 + +### +# Test 7.3: Multi-partition TopK with arithmetic (non-trivial stays above merge) +### + +query TT +EXPLAIN SELECT id, s['value'] + 1 FROM multi_struct ORDER BY id LIMIT 3; +---- +logical_plan +01)Sort: multi_struct.id ASC NULLS LAST, fetch=3 +02)--Projection: multi_struct.id, get_field(multi_struct.s, Utf8("value")) + Int64(1) +03)----TableScan: multi_struct projection=[id, s] +physical_plan +01)SortPreservingMergeExec: [id@0 ASC NULLS LAST], fetch=3 +02)--SortExec: TopK(fetch=3), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part1.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part2.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part3.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part4.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part5.parquet]]}, projection=[id, get_field(s@1, value) + 1 as multi_struct.s[value] + Int64(1)], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query II +SELECT id, s['value'] + 1 FROM multi_struct ORDER BY id LIMIT 3; +---- +1 101 +2 201 +3 151 + +### +# Test 7.4: Multi-partition Filter with get_field +### + +query TT +EXPLAIN SELECT id, s['value'] FROM multi_struct WHERE id > 2 ORDER BY id; +---- +logical_plan +01)Sort: multi_struct.id ASC NULLS LAST +02)--Projection: multi_struct.id, __datafusion_extracted_1 AS multi_struct.s[value] +03)----Filter: multi_struct.id > Int64(2) +04)------Projection: get_field(multi_struct.s, Utf8("value")) AS __datafusion_extracted_1, multi_struct.id +05)--------TableScan: multi_struct projection=[id, s], partial_filters=[multi_struct.id > Int64(2)] +physical_plan +01)SortPreservingMergeExec: [id@0 ASC NULLS LAST] +02)--SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[id@1 as id, __datafusion_extracted_1@0 as multi_struct.s[value]] +04)------FilterExec: id@1 > 2 +05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=3 +06)----------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part1.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part2.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part3.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part4.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part5.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 > 2, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 2, required_guarantees=[] + +# Verify correctness +query II +SELECT id, s['value'] FROM multi_struct WHERE id > 2 ORDER BY id; +---- +3 150 +4 300 +5 250 + +### +# Test 7.5: Aggregation with get_field (CoalescePartitions) +### + +query TT +EXPLAIN SELECT s['label'], SUM(s['value']) FROM multi_struct GROUP BY s['label']; +---- +logical_plan +01)Projection: __datafusion_extracted_1 AS multi_struct.s[label], sum(__datafusion_extracted_2) AS sum(multi_struct.s[value]) +02)--Aggregate: groupBy=[[__datafusion_extracted_1]], aggr=[[sum(__datafusion_extracted_2)]] +03)----Projection: get_field(multi_struct.s, Utf8("label")) AS __datafusion_extracted_1, get_field(multi_struct.s, Utf8("value")) AS __datafusion_extracted_2 +04)------TableScan: multi_struct projection=[s] +physical_plan +01)ProjectionExec: expr=[__datafusion_extracted_1@0 as multi_struct.s[label], sum(__datafusion_extracted_2)@1 as sum(multi_struct.s[value])] +02)--AggregateExec: mode=FinalPartitioned, gby=[__datafusion_extracted_1@0 as __datafusion_extracted_1], aggr=[sum(__datafusion_extracted_2)] +03)----RepartitionExec: partitioning=Hash([__datafusion_extracted_1@0], 4), input_partitions=3 +04)------AggregateExec: mode=Partial, gby=[__datafusion_extracted_1@0 as __datafusion_extracted_1], aggr=[sum(__datafusion_extracted_2)] +05)--------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part1.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part2.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part3.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part4.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part5.parquet]]}, projection=[get_field(s@1, label) as __datafusion_extracted_1, get_field(s@1, value) as __datafusion_extracted_2], file_type=parquet + +# Verify correctness +query TI +SELECT s['label'], SUM(s['value']) FROM multi_struct GROUP BY s['label'] ORDER BY s['label']; +---- +alpha 100 +beta 200 +delta 300 +epsilon 250 +gamma 150 + + +##################### +# Section 8: Edge Cases +##################### + +# Reset to single partition for edge case tests +statement ok +SET datafusion.execution.target_partitions = 1; + +### +# Test 8.1: get_field on nullable struct column +### + +query TT +EXPLAIN SELECT id, s['value'] FROM nullable_struct; +---- +logical_plan +01)Projection: nullable_struct.id, get_field(nullable_struct.s, Utf8("value")) +02)--TableScan: nullable_struct projection=[id, s] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/nullable.parquet]]}, projection=[id, get_field(s@1, value) as nullable_struct.s[value]], file_type=parquet + +# Verify correctness (NULL struct returns NULL field) +query II +SELECT id, s['value'] FROM nullable_struct ORDER BY id; +---- +1 100 +2 NULL +3 150 +4 NULL +5 250 + +### +# Test 8.2: get_field returning NULL values +### + +query TT +EXPLAIN SELECT id, s['label'] FROM nullable_struct WHERE s['value'] IS NOT NULL; +---- +logical_plan +01)Projection: nullable_struct.id, __datafusion_extracted_2 AS nullable_struct.s[label] +02)--Filter: __datafusion_extracted_1 IS NOT NULL +03)----Projection: get_field(nullable_struct.s, Utf8("value")) AS __datafusion_extracted_1, nullable_struct.id, get_field(nullable_struct.s, Utf8("label")) AS __datafusion_extracted_2 +04)------TableScan: nullable_struct projection=[id, s], partial_filters=[get_field(nullable_struct.s, Utf8("value")) IS NOT NULL] +physical_plan +01)ProjectionExec: expr=[id@0 as id, __datafusion_extracted_2@1 as nullable_struct.s[label]] +02)--FilterExec: __datafusion_extracted_1@0 IS NOT NULL, projection=[id@1, __datafusion_extracted_2@2] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/nullable.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id, get_field(s@1, label) as __datafusion_extracted_2], file_type=parquet + +# Verify correctness +query IT +SELECT id, s['label'] FROM nullable_struct WHERE s['value'] IS NOT NULL ORDER BY id; +---- +1 alpha +3 gamma +5 epsilon + +### +# Test 8.3: Mixed trivial and non-trivial in same projection +### + +query TT +EXPLAIN SELECT id, s['value'], s['value'] + 10, s['label'] FROM simple_struct ORDER BY id LIMIT 3; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST, fetch=3 +02)--Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")), get_field(simple_struct.s, Utf8("value")) + Int64(10), get_field(simple_struct.s, Utf8("label")) +03)----TableScan: simple_struct projection=[id, s] +physical_plan +01)SortExec: TopK(fetch=3), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as simple_struct.s[value], get_field(s@1, value) + 10 as simple_struct.s[value] + Int64(10), get_field(s@1, label) as simple_struct.s[label]], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query IIIT +SELECT id, s['value'], s['value'] + 10, s['label'] FROM simple_struct ORDER BY id LIMIT 3; +---- +1 100 110 alpha +2 200 210 beta +3 150 160 gamma + +### +# Test 8.4: Literal projection through TopK +### + +query TT +EXPLAIN SELECT id, 42 as constant FROM simple_struct ORDER BY id LIMIT 3; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST, fetch=3 +02)--Projection: simple_struct.id, Int64(42) AS constant +03)----TableScan: simple_struct projection=[id] +physical_plan +01)SortExec: TopK(fetch=3), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, 42 as constant], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query II +SELECT id, 42 as constant FROM simple_struct ORDER BY id LIMIT 3; +---- +1 42 +2 42 +3 42 + +### +# Test 8.5: Simple column through TopK (baseline comparison) +### + +query TT +EXPLAIN SELECT id FROM simple_struct ORDER BY id LIMIT 3; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST, fetch=3 +02)--TableScan: simple_struct projection=[id] +physical_plan +01)SortExec: TopK(fetch=3), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query I +SELECT id FROM simple_struct ORDER BY id LIMIT 3; +---- +1 +2 +3 + + +##################### +# Section 9: Coverage Tests - Edge Cases for Uncovered Code Paths +##################### + +### +# Test 9.1: TopK with computed projection +### + +query TT +EXPLAIN SELECT id, id + 100 as computed FROM simple_struct ORDER BY id LIMIT 3; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST, fetch=3 +02)--Projection: simple_struct.id, simple_struct.id + Int64(100) AS computed +03)----TableScan: simple_struct projection=[id] +physical_plan +01)SortExec: TopK(fetch=3), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, id@0 + 100 as computed], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query II +SELECT id, id + 100 as computed FROM simple_struct ORDER BY id LIMIT 3; +---- +1 101 +2 102 +3 103 + +### +# Test 9.2: Duplicate get_field expressions (same expression referenced twice) +# Common subexpression elimination happens in the logical plan, and the physical +# plan extracts the shared get_field for efficient computation +### + +query TT +EXPLAIN SELECT (id + s['value']) * (id + s['value']) as id_and_value FROM simple_struct WHERE id > 2; +---- +logical_plan +01)Projection: __common_expr_1 * __common_expr_1 AS id_and_value +02)--Projection: simple_struct.id + __datafusion_extracted_2 AS __common_expr_1 +03)----Filter: simple_struct.id > Int64(2) +04)------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_2, simple_struct.id +05)--------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(2)] +physical_plan +01)ProjectionExec: expr=[__common_expr_1@0 * __common_expr_1@0 as id_and_value] +02)--ProjectionExec: expr=[id@1 + __datafusion_extracted_2@0 as __common_expr_1] +03)----FilterExec: id@1 > 2 +04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_2, id], file_type=parquet, predicate=id@0 > 2, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 2, required_guarantees=[] + + +query TT +EXPLAIN SELECT s['value'] + s['value'] as doubled FROM simple_struct WHERE id > 2; +---- +logical_plan +01)Projection: __datafusion_extracted_1 + __datafusion_extracted_1 AS doubled +02)--Filter: simple_struct.id > Int64(2) +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +04)------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(2)] +physical_plan +01)ProjectionExec: expr=[__datafusion_extracted_1@0 + __datafusion_extracted_1@0 as doubled] +02)--FilterExec: id@1 > 2, projection=[__datafusion_extracted_1@0] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 > 2, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 2, required_guarantees=[] + +# Verify correctness +query I +SELECT s['value'] + s['value'] as doubled FROM simple_struct WHERE id > 2 ORDER BY doubled; +---- +300 +500 +600 + +### +# Test 9.3: Projection with only get_field expressions through Filter +### + +query TT +EXPLAIN SELECT s['value'], s['label'] FROM simple_struct WHERE id > 2; +---- +logical_plan +01)Projection: __datafusion_extracted_1 AS simple_struct.s[value], __datafusion_extracted_2 AS simple_struct.s[label] +02)--Filter: simple_struct.id > Int64(2) +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, get_field(simple_struct.s, Utf8("label")) AS __datafusion_extracted_2, simple_struct.id +04)------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(2)] +physical_plan +01)ProjectionExec: expr=[__datafusion_extracted_1@0 as simple_struct.s[value], __datafusion_extracted_2@1 as simple_struct.s[label]] +02)--FilterExec: id@2 > 2, projection=[__datafusion_extracted_1@0, __datafusion_extracted_2@1] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, get_field(s@1, label) as __datafusion_extracted_2, id], file_type=parquet, predicate=id@0 > 2, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 2, required_guarantees=[] + +# Verify correctness +query IT +SELECT s['value'], s['label'] FROM simple_struct WHERE id > 2 ORDER BY s['value']; +---- +150 gamma +250 epsilon +300 delta + +### +# Test 9.4: Mixed column reference with get_field in expression through TopK +# Tests column remapping in finalize_outer_exprs when outer expr references both extracted and original columns +### + +query TT +EXPLAIN SELECT id, s['value'] + id as combined FROM simple_struct ORDER BY id LIMIT 3; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST, fetch=3 +02)--Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")) + simple_struct.id AS combined +03)----TableScan: simple_struct projection=[id, s] +physical_plan +01)SortExec: TopK(fetch=3), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) + id@0 as combined], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query II +SELECT id, s['value'] + id as combined FROM simple_struct ORDER BY id LIMIT 3; +---- +1 101 +2 202 +3 153 + +### +# Test 9.5: Multiple get_field from same struct in expression through Filter +# Tests extraction when base struct is shared across multiple get_field calls +### + +query TT +EXPLAIN SELECT s['value'] * 2 + length(s['label']) as score FROM simple_struct WHERE id > 1; +---- +logical_plan +01)Projection: __datafusion_extracted_1 * Int64(2) + CAST(character_length(__datafusion_extracted_2) AS Int64) AS score +02)--Filter: simple_struct.id > Int64(1) +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, get_field(simple_struct.s, Utf8("label")) AS __datafusion_extracted_2, simple_struct.id +04)------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(1)] +physical_plan +01)ProjectionExec: expr=[__datafusion_extracted_1@0 * 2 + CAST(character_length(__datafusion_extracted_2@1) AS Int64) as score] +02)--FilterExec: id@2 > 1, projection=[__datafusion_extracted_1@0, __datafusion_extracted_2@1] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, get_field(s@1, label) as __datafusion_extracted_2, id], file_type=parquet, predicate=id@0 > 1, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1, required_guarantees=[] + +# Verify correctness +query I +SELECT s['value'] * 2 + length(s['label']) as score FROM simple_struct WHERE id > 1 ORDER BY score; +---- +305 +404 +507 +605 + + +##################### +# Section 10: Literal with get_field Expressions +##################### + +### +# Test 10.1: Literal constant + get_field in same projection +# Tests projection with both trivial (literal) and get_field expressions +### + +query TT +EXPLAIN SELECT id, 42 as answer, s['label'] FROM simple_struct ORDER BY id LIMIT 2; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST, fetch=2 +02)--Projection: simple_struct.id, Int64(42) AS answer, get_field(simple_struct.s, Utf8("label")) +03)----TableScan: simple_struct projection=[id, s] +physical_plan +01)SortExec: TopK(fetch=2), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, 42 as answer, get_field(s@1, label) as simple_struct.s[label]], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query IIT +SELECT id, 42 as answer, s['label'] FROM simple_struct ORDER BY id LIMIT 2; +---- +1 42 alpha +2 42 beta + +### +# Test 10.2: Multiple non-trivial get_field expressions together +# Tests arithmetic on one field and string concat on another in same projection +### + +query TT +EXPLAIN SELECT id, s['value'] + 100, s['label'] || '_test' FROM simple_struct ORDER BY id LIMIT 2; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST, fetch=2 +02)--Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")) + Int64(100), get_field(simple_struct.s, Utf8("label")) || Utf8("_test") +03)----TableScan: simple_struct projection=[id, s] +physical_plan +01)SortExec: TopK(fetch=2), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) + 100 as simple_struct.s[value] + Int64(100), get_field(s@1, label) || _test as simple_struct.s[label] || Utf8("_test")], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query IIT +SELECT id, s['value'] + 100, s['label'] || '_test' FROM simple_struct ORDER BY id LIMIT 2; +---- +1 200 alpha_test +2 300 beta_test + +##################### +# Section 11: FilterExec Projection Pushdown - Handling Predicate Column Requirements +##################### + +query TT +EXPLAIN SELECT id, s['value'] FROM simple_struct WHERE id > 1; +---- +logical_plan +01)Projection: simple_struct.id, __datafusion_extracted_1 AS simple_struct.s[value] +02)--Filter: simple_struct.id > Int64(1) +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +04)------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(1)] +physical_plan +01)ProjectionExec: expr=[id@1 as id, __datafusion_extracted_1@0 as simple_struct.s[value]] +02)--FilterExec: id@1 > 1 +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 > 1, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1, required_guarantees=[] + +# Verify correctness +query II +SELECT id, s['value'] FROM simple_struct WHERE id > 1 ORDER BY id LIMIT 2; +---- +2 200 +3 150 + +query TT +EXPLAIN SELECT s['value'] FROM simple_struct WHERE id > 1 AND (id < 4 OR id = 5); +---- +logical_plan +01)Projection: __datafusion_extracted_1 AS simple_struct.s[value] +02)--Filter: simple_struct.id > Int64(1) AND (simple_struct.id < Int64(4) OR simple_struct.id = Int64(5)) +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +04)------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(1), simple_struct.id < Int64(4) OR simple_struct.id = Int64(5)] +physical_plan +01)ProjectionExec: expr=[__datafusion_extracted_1@0 as simple_struct.s[value]] +02)--FilterExec: id@1 > 1 AND (id@1 < 4 OR id@1 = 5), projection=[__datafusion_extracted_1@0] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 > 1 AND (id@0 < 4 OR id@0 = 5), pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1 AND (id_null_count@1 != row_count@2 AND id_min@3 < 4 OR id_null_count@1 != row_count@2 AND id_min@3 <= 5 AND 5 <= id_max@0), required_guarantees=[] + +# Verify correctness - should return rows where (id > 1) AND ((id < 4) OR (id = 5)) +# That's: id=2,3 (1 1 AND (id < 4 OR id = 5) ORDER BY s['value']; +---- +150 +200 +250 + +query TT +EXPLAIN SELECT s['value'] FROM simple_struct WHERE id > 1 AND id < 5; +---- +logical_plan +01)Projection: __datafusion_extracted_1 AS simple_struct.s[value] +02)--Filter: simple_struct.id > Int64(1) AND simple_struct.id < Int64(5) +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +04)------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(1), simple_struct.id < Int64(5)] +physical_plan +01)ProjectionExec: expr=[__datafusion_extracted_1@0 as simple_struct.s[value]] +02)--FilterExec: id@1 > 1 AND id@1 < 5, projection=[__datafusion_extracted_1@0] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 > 1 AND id@0 < 5, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1 AND id_null_count@1 != row_count@2 AND id_min@3 < 5, required_guarantees=[] + +# Verify correctness - should return rows where 1 < id < 5 (id=2,3,4) +query I +SELECT s['value'] FROM simple_struct WHERE id > 1 AND id < 5 ORDER BY s['value']; +---- +150 +200 +300 + +query TT +EXPLAIN SELECT s['value'], s['label'], id FROM simple_struct WHERE id > 1; +---- +logical_plan +01)Projection: __datafusion_extracted_1 AS simple_struct.s[value], __datafusion_extracted_2 AS simple_struct.s[label], simple_struct.id +02)--Filter: simple_struct.id > Int64(1) +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, get_field(simple_struct.s, Utf8("label")) AS __datafusion_extracted_2, simple_struct.id +04)------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(1)] +physical_plan +01)ProjectionExec: expr=[__datafusion_extracted_1@0 as simple_struct.s[value], __datafusion_extracted_2@1 as simple_struct.s[label], id@2 as id] +02)--FilterExec: id@2 > 1 +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, get_field(s@1, label) as __datafusion_extracted_2, id], file_type=parquet, predicate=id@0 > 1, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1, required_guarantees=[] + +# Verify correctness - note that id is now at index 2 in the augmented projection +query ITI +SELECT s['value'], s['label'], id FROM simple_struct WHERE id > 1 ORDER BY id LIMIT 3; +---- +200 beta 2 +150 gamma 3 +300 delta 4 + +query TT +EXPLAIN SELECT s['value'] FROM simple_struct WHERE length(s['label']) > 4; +---- +logical_plan +01)Projection: __datafusion_extracted_2 AS simple_struct.s[value] +02)--Filter: character_length(__datafusion_extracted_1) > Int32(4) +03)----Projection: get_field(simple_struct.s, Utf8("label")) AS __datafusion_extracted_1, get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_2 +04)------TableScan: simple_struct projection=[s], partial_filters=[character_length(get_field(simple_struct.s, Utf8("label"))) > Int32(4)] +physical_plan +01)ProjectionExec: expr=[__datafusion_extracted_2@0 as simple_struct.s[value]] +02)--FilterExec: character_length(__datafusion_extracted_1@0) > 4, projection=[__datafusion_extracted_2@1] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, label) as __datafusion_extracted_1, get_field(s@1, value) as __datafusion_extracted_2], file_type=parquet + +# Verify correctness - filter on rows where label length > 4 (all have length 5, except 'one' has 3) +# Wait, from the data: alpha(5), beta(4), gamma(5), delta(5), epsilon(7) +# So: alpha, gamma, delta, epsilon (not beta which has 4 characters) +query I +SELECT s['value'] FROM simple_struct WHERE length(s['label']) > 4 ORDER BY s['value']; +---- +100 +150 +250 +300 + +##################### +# Section 11a: ProjectionExec on top of a SortExec with missing Sort Columns +##################### + +### +# Test 11a.1: Sort by dropped column +# Selects only id, drops s entirely, but sorts by s['value'] +### + +query TT +EXPLAIN SELECT id FROM simple_struct ORDER BY s['value']; +---- +logical_plan +01)Projection: simple_struct.id +02)--Sort: __datafusion_extracted_1 ASC NULLS LAST +03)----Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1 +04)------TableScan: simple_struct projection=[id, s] +physical_plan +01)ProjectionExec: expr=[id@0 as id] +02)--SortExec: expr=[__datafusion_extracted_1@1 ASC NULLS LAST], preserve_partitioning=[false] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as __datafusion_extracted_1], file_type=parquet + +# Verify correctness +query I +SELECT id FROM simple_struct ORDER BY s['value']; +---- +1 +3 +2 +5 +4 + +### +# Test 11a.2: Multiple sort columns with partial selection +# Selects only id and s['value'], but sorts by id and s['label'] +# One sort column (s['label']) is not selected but needed for ordering +### + +query TT +EXPLAIN SELECT id, s['value'] FROM simple_struct ORDER BY id, s['label']; +---- +logical_plan +01)Projection: simple_struct.id, simple_struct.s[value] +02)--Sort: simple_struct.id ASC NULLS LAST, __datafusion_extracted_1 ASC NULLS LAST +03)----Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")), get_field(simple_struct.s, Utf8("label")) AS __datafusion_extracted_1 +04)------TableScan: simple_struct projection=[id, s] +physical_plan +01)ProjectionExec: expr=[id@0 as id, simple_struct.s[value]@1 as simple_struct.s[value]] +02)--SortExec: expr=[id@0 ASC NULLS LAST, __datafusion_extracted_1@2 ASC NULLS LAST], preserve_partitioning=[false] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as simple_struct.s[value], get_field(s@1, label) as __datafusion_extracted_1], file_type=parquet + +# Verify correctness +query II +SELECT id, s['value'] FROM simple_struct ORDER BY id, s['label']; +---- +1 100 +2 200 +3 150 +4 300 +5 250 + + +### +# Test 11a.3: TopK with dropped sort column +# Same as test 11a.1 but with LIMIT +### + +query TT +EXPLAIN SELECT id FROM simple_struct ORDER BY s['value'] LIMIT 2; +---- +logical_plan +01)Projection: simple_struct.id +02)--Sort: __datafusion_extracted_1 ASC NULLS LAST, fetch=2 +03)----Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1 +04)------TableScan: simple_struct projection=[id, s] +physical_plan +01)ProjectionExec: expr=[id@0 as id] +02)--SortExec: TopK(fetch=2), expr=[__datafusion_extracted_1@1 ASC NULLS LAST], preserve_partitioning=[false] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as __datafusion_extracted_1], file_type=parquet + +# Verify correctness +query I +SELECT id FROM simple_struct ORDER BY s['value'] LIMIT 2; +---- +1 +3 + +### +# Test 11a.4: Sort by derived expression with dropped column +# Projects only id, sorts by s['value'] * 2 (derived expression) +# Sort column is computed but requires base columns not in projection +### + +query TT +EXPLAIN SELECT id FROM simple_struct ORDER BY s['value'] * 2; +---- +logical_plan +01)Projection: simple_struct.id +02)--Sort: __datafusion_extracted_1 * Int64(2) ASC NULLS LAST +03)----Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1 +04)------TableScan: simple_struct projection=[id, s] +physical_plan +01)ProjectionExec: expr=[id@0 as id] +02)--SortExec: expr=[__datafusion_extracted_1@1 * 2 ASC NULLS LAST], preserve_partitioning=[false] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as __datafusion_extracted_1], file_type=parquet + +# Verify correctness +query I +SELECT id FROM simple_struct ORDER BY s['value'] * 2; +---- +1 +3 +2 +5 +4 + +### +# Test 11a.5: All sort columns selected +# All columns needed for sorting are included in projection +### + +query TT +EXPLAIN SELECT id, s['value'] FROM simple_struct ORDER BY id, s['value']; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST, simple_struct.s[value] ASC NULLS LAST +02)--Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")) +03)----TableScan: simple_struct projection=[id, s] +physical_plan +01)SortExec: expr=[id@0 ASC NULLS LAST, simple_struct.s[value]@1 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as simple_struct.s[value]], file_type=parquet + +# Verify correctness +query II +SELECT id, s['value'] FROM simple_struct ORDER BY id, s['value']; +---- +1 100 +2 200 +3 150 +4 300 +5 250 + +##################### +# Section 12: Join Tests - get_field Extraction from Join Nodes +##################### + +# Create a second table for join tests +statement ok +COPY ( + SELECT + column1 as id, + column2 as s + FROM VALUES + (1, {role: 'admin', level: 10}), + (2, {role: 'user', level: 5}), + (3, {role: 'guest', level: 1}), + (4, {role: 'admin', level: 8}), + (5, {role: 'user', level: 3}) +) TO 'test_files/scratch/projection_pushdown/join_right.parquet' +STORED AS PARQUET; + +statement ok +CREATE EXTERNAL TABLE join_right STORED AS PARQUET +LOCATION 'test_files/scratch/projection_pushdown/join_right.parquet'; + +### +# Test 12.1: Join with get_field in equijoin condition +# Tests extraction from join ON clause - get_field on each side routed appropriately +### + +query TT +EXPLAIN SELECT simple_struct.id, join_right.id +FROM simple_struct +INNER JOIN join_right ON simple_struct.s['value'] = join_right.s['level'] * 10; +---- +logical_plan +01)Projection: simple_struct.id, join_right.id +02)--Inner Join: __datafusion_extracted_1 = __datafusion_extracted_2 * Int64(10) +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +04)------TableScan: simple_struct projection=[id, s] +05)----Projection: get_field(join_right.s, Utf8("level")) AS __datafusion_extracted_2, join_right.id +06)------TableScan: join_right projection=[id, s] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(__datafusion_extracted_1@0, __datafusion_extracted_2 * Int64(10)@2)], projection=[id@1, id@3] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet +03)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/join_right.parquet]]}, projection=[get_field(s@1, level) as __datafusion_extracted_2, id, get_field(s@1, level) * 10 as __datafusion_extracted_2 * Int64(10)], file_type=parquet + +# Verify correctness - value = level * 10 +# simple_struct: (1,100), (2,200), (3,150), (4,300), (5,250) +# join_right: (1,10), (2,5), (3,1), (4,8), (5,3) +# Matches: simple_struct.value=100 matches join_right.level*10=100 (level=10, id=1) +query II +SELECT simple_struct.id, join_right.id +FROM simple_struct +INNER JOIN join_right ON simple_struct.s['value'] = join_right.s['level'] * 10 +ORDER BY simple_struct.id; +---- +1 1 + +### +# Test 12.2: Join with get_field in non-equi filter +# Tests extraction from join filter expression - left side only +### + +query TT +EXPLAIN SELECT simple_struct.id, join_right.id +FROM simple_struct +INNER JOIN join_right ON simple_struct.id = join_right.id +WHERE simple_struct.s['value'] > 150; +---- +logical_plan +01)Inner Join: simple_struct.id = join_right.id +02)--Projection: simple_struct.id +03)----Filter: __datafusion_extracted_1 > Int64(150) +04)------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +05)--------TableScan: simple_struct projection=[id, s], partial_filters=[get_field(simple_struct.s, Utf8("value")) > Int64(150)] +06)--TableScan: join_right projection=[id] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(id@0, id@0)] +02)--FilterExec: __datafusion_extracted_1@0 > 150, projection=[id@1] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet +04)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/join_right.parquet]]}, projection=[id], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness - id matches and value > 150 +query II +SELECT simple_struct.id, join_right.id +FROM simple_struct +INNER JOIN join_right ON simple_struct.id = join_right.id +WHERE simple_struct.s['value'] > 150 +ORDER BY simple_struct.id; +---- +2 2 +4 4 +5 5 + +### +# Test 12.3: Join with get_field from both sides in filter +# Tests extraction routing to both left and right inputs +### + +query TT +EXPLAIN SELECT simple_struct.id, join_right.id +FROM simple_struct +INNER JOIN join_right ON simple_struct.id = join_right.id +WHERE simple_struct.s['value'] > 100 AND join_right.s['level'] > 3; +---- +logical_plan +01)Inner Join: simple_struct.id = join_right.id +02)--Projection: simple_struct.id +03)----Filter: __datafusion_extracted_1 > Int64(100) +04)------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +05)--------TableScan: simple_struct projection=[id, s], partial_filters=[get_field(simple_struct.s, Utf8("value")) > Int64(100)] +06)--Projection: join_right.id +07)----Filter: __datafusion_extracted_2 > Int64(3) +08)------Projection: get_field(join_right.s, Utf8("level")) AS __datafusion_extracted_2, join_right.id +09)--------TableScan: join_right projection=[id, s], partial_filters=[get_field(join_right.s, Utf8("level")) > Int64(3)] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(id@0, id@0)] +02)--FilterExec: __datafusion_extracted_1@0 > 100, projection=[id@1] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet +04)--FilterExec: __datafusion_extracted_2@0 > 3, projection=[id@1] +05)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/join_right.parquet]]}, projection=[get_field(s@1, level) as __datafusion_extracted_2, id], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness - id matches, value > 100, and level > 3 +# Matching ids where value > 100: 2(200), 3(150), 4(300), 5(250) +# Of those, level > 3: 2(5), 4(8), 5(3) -> only 2 and 4 +query II +SELECT simple_struct.id, join_right.id +FROM simple_struct +INNER JOIN join_right ON simple_struct.id = join_right.id +WHERE simple_struct.s['value'] > 100 AND join_right.s['level'] > 3 +ORDER BY simple_struct.id; +---- +2 2 +4 4 + +### +# Test 12.4: Join with get_field in SELECT projection +# Tests that get_field in output columns pushes down through the join +### + +query TT +EXPLAIN SELECT simple_struct.id, simple_struct.s['label'], join_right.s['role'] +FROM simple_struct +INNER JOIN join_right ON simple_struct.id = join_right.id; +---- +logical_plan +01)Projection: simple_struct.id, __datafusion_extracted_1 AS simple_struct.s[label], __datafusion_extracted_2 AS join_right.s[role] +02)--Inner Join: simple_struct.id = join_right.id +03)----Projection: get_field(simple_struct.s, Utf8("label")) AS __datafusion_extracted_1, simple_struct.id +04)------TableScan: simple_struct projection=[id, s] +05)----Projection: get_field(join_right.s, Utf8("role")) AS __datafusion_extracted_2, join_right.id +06)------TableScan: join_right projection=[id, s] +physical_plan +01)ProjectionExec: expr=[id@1 as id, __datafusion_extracted_1@0 as simple_struct.s[label], __datafusion_extracted_2@2 as join_right.s[role]] +02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(id@1, id@1)], projection=[__datafusion_extracted_1@0, id@1, __datafusion_extracted_2@2] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, label) as __datafusion_extracted_1, id], file_type=parquet +04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/join_right.parquet]]}, projection=[get_field(s@1, role) as __datafusion_extracted_2, id], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query ITT +SELECT simple_struct.id, simple_struct.s['label'], join_right.s['role'] +FROM simple_struct +INNER JOIN join_right ON simple_struct.id = join_right.id +ORDER BY simple_struct.id; +---- +1 alpha admin +2 beta user +3 gamma guest +4 delta admin +5 epsilon user + +### +# Test 12.5: Join without get_field (baseline - no extraction needed) +# Verifies no unnecessary projections are added when there's nothing to extract +### + +query TT +EXPLAIN SELECT simple_struct.id, join_right.id +FROM simple_struct +INNER JOIN join_right ON simple_struct.id = join_right.id; +---- +logical_plan +01)Inner Join: simple_struct.id = join_right.id +02)--TableScan: simple_struct projection=[id] +03)--TableScan: join_right projection=[id] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(id@0, id@0)] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id], file_type=parquet +03)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/join_right.parquet]]}, projection=[id], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query II +SELECT simple_struct.id, join_right.id +FROM simple_struct +INNER JOIN join_right ON simple_struct.id = join_right.id +ORDER BY simple_struct.id; +---- +1 1 +2 2 +3 3 +4 4 +5 5 + +### +# Test 12.6: Left Join with get_field extraction +# Tests extraction works correctly with outer joins +### + +query TT +EXPLAIN SELECT simple_struct.id, simple_struct.s['value'], join_right.s['level'] +FROM simple_struct +LEFT JOIN join_right ON simple_struct.id = join_right.id AND join_right.s['level'] > 5; +---- +logical_plan +01)Projection: simple_struct.id, __datafusion_extracted_2 AS simple_struct.s[value], __datafusion_extracted_3 AS join_right.s[level] +02)--Left Join: simple_struct.id = join_right.id +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_2, simple_struct.id +04)------TableScan: simple_struct projection=[id, s] +05)----Projection: join_right.id, __datafusion_extracted_3 +06)------Filter: __datafusion_extracted_1 > Int64(5) +07)--------Projection: get_field(join_right.s, Utf8("level")) AS __datafusion_extracted_1, join_right.id, get_field(join_right.s, Utf8("level")) AS __datafusion_extracted_3 +08)----------TableScan: join_right projection=[id, s], partial_filters=[get_field(join_right.s, Utf8("level")) > Int64(5)] +physical_plan +01)ProjectionExec: expr=[id@1 as id, __datafusion_extracted_2@0 as simple_struct.s[value], __datafusion_extracted_3@2 as join_right.s[level]] +02)--HashJoinExec: mode=CollectLeft, join_type=Left, on=[(id@1, id@0)], projection=[__datafusion_extracted_2@0, id@1, __datafusion_extracted_3@3] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_2, id], file_type=parquet +04)----FilterExec: __datafusion_extracted_1@0 > 5, projection=[id@1, __datafusion_extracted_3@2] +05)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/join_right.parquet]]}, projection=[get_field(s@1, level) as __datafusion_extracted_1, id, get_field(s@1, level) as __datafusion_extracted_3], file_type=parquet + +# Verify correctness - left join with level > 5 condition +# Only join_right rows with level > 5 are matched: id=1 (level=10), id=4 (level=8) +query III +SELECT simple_struct.id, simple_struct.s['value'], join_right.s['level'] +FROM simple_struct +LEFT JOIN join_right ON simple_struct.id = join_right.id AND join_right.s['level'] > 5 +ORDER BY simple_struct.id; +---- +1 100 10 +2 200 NULL +3 150 NULL +4 300 8 +5 250 NULL + +##################### +# Section 13: RepartitionExec tests +##################### + +# Set target partitions to 32 -> this forces a RepartitionExec +statement ok +SET datafusion.execution.target_partitions = 32; + +query TT +EXPLAIN SELECT s['value'] FROM simple_struct WHERE id > 2; +---- +logical_plan +01)Projection: __datafusion_extracted_1 AS simple_struct.s[value] +02)--Filter: simple_struct.id > Int64(2) +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +04)------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(2)] +physical_plan +01)ProjectionExec: expr=[__datafusion_extracted_1@0 as simple_struct.s[value]] +02)--FilterExec: id@1 > 2, projection=[__datafusion_extracted_1@0] +03)----RepartitionExec: partitioning=RoundRobinBatch(32), input_partitions=1 +04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 > 2, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 2, required_guarantees=[] + +##################### +# Section 14: SubqueryAlias tests +##################### + +# Reset target partitions +statement ok +SET datafusion.execution.target_partitions = 1; + +# get_field pushdown through subquery alias with filter +query TT +EXPLAIN SELECT t.s['value'] FROM (SELECT * FROM simple_struct) t WHERE t.id > 2; +---- +logical_plan +01)Projection: __datafusion_extracted_1 AS t.s[value] +02)--SubqueryAlias: t +03)----Projection: __datafusion_extracted_1 +04)------Filter: simple_struct.id > Int64(2) +05)--------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +06)----------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(2)] +physical_plan +01)ProjectionExec: expr=[__datafusion_extracted_1@0 as t.s[value]] +02)--FilterExec: id@1 > 2, projection=[__datafusion_extracted_1@0] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 > 2, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 2, required_guarantees=[] + +# Verify correctness +query I +SELECT t.s['value'] FROM (SELECT * FROM simple_struct) t WHERE t.id > 2 ORDER BY t.id; +---- +150 +300 +250 + +# Multiple get_field through subquery alias with sort +query TT +EXPLAIN SELECT t.s['value'], t.s['label'] FROM (SELECT * FROM simple_struct) t ORDER BY t.s['value']; +---- +logical_plan +01)Sort: t.s[value] ASC NULLS LAST +02)--Projection: __datafusion_extracted_1 AS t.s[value], __datafusion_extracted_2 AS t.s[label] +03)----SubqueryAlias: t +04)------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, get_field(simple_struct.s, Utf8("label")) AS __datafusion_extracted_2 +05)--------TableScan: simple_struct projection=[s] +physical_plan +01)SortExec: expr=[t.s[value]@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as t.s[value], get_field(s@1, label) as t.s[label]], file_type=parquet + +# Verify correctness +query IT +SELECT t.s['value'], t.s['label'] FROM (SELECT * FROM simple_struct) t ORDER BY t.s['value']; +---- +100 alpha +150 gamma +200 beta +250 epsilon +300 delta + +# Nested subquery aliases +query TT +EXPLAIN SELECT u.s['value'] FROM (SELECT * FROM (SELECT * FROM simple_struct) t) u WHERE u.id > 2; +---- +logical_plan +01)Projection: __datafusion_extracted_1 AS u.s[value] +02)--SubqueryAlias: u +03)----SubqueryAlias: t +04)------Projection: __datafusion_extracted_1 +05)--------Filter: simple_struct.id > Int64(2) +06)----------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +07)------------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(2)] +physical_plan +01)ProjectionExec: expr=[__datafusion_extracted_1@0 as u.s[value]] +02)--FilterExec: id@1 > 2, projection=[__datafusion_extracted_1@0] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 > 2, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 2, required_guarantees=[] + +# Verify correctness +query I +SELECT u.s['value'] FROM (SELECT * FROM (SELECT * FROM simple_struct) t) u WHERE u.id > 2 ORDER BY u.id; +---- +150 +300 +250 + +# get_field in filter through subquery alias +query TT +EXPLAIN SELECT t.id FROM (SELECT * FROM simple_struct) t WHERE t.s['value'] > 200; +---- +logical_plan +01)SubqueryAlias: t +02)--Projection: simple_struct.id +03)----Filter: __datafusion_extracted_1 > Int64(200) +04)------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +05)--------TableScan: simple_struct projection=[id, s], partial_filters=[get_field(simple_struct.s, Utf8("value")) > Int64(200)] +physical_plan +01)FilterExec: __datafusion_extracted_1@0 > 200, projection=[id@1] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet + +# Verify correctness +query I +SELECT t.id FROM (SELECT * FROM simple_struct) t WHERE t.s['value'] > 200 ORDER BY t.id; +---- +4 +5 + +##################### +# Section 15: UNION ALL tests +##################### + +# get_field on UNION ALL result +query TT +EXPLAIN SELECT s['value'] FROM ( + SELECT s FROM simple_struct WHERE id <= 3 + UNION ALL + SELECT s FROM simple_struct WHERE id > 3 +) t; +---- +logical_plan +01)Projection: __datafusion_extracted_1 AS t.s[value] +02)--SubqueryAlias: t +03)----Union +04)------Projection: __datafusion_extracted_1 +05)--------Filter: simple_struct.id <= Int64(3) +06)----------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +07)------------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id <= Int64(3)] +08)------Projection: __datafusion_extracted_1 +09)--------Filter: simple_struct.id > Int64(3) +10)----------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +11)------------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(3)] +physical_plan +01)ProjectionExec: expr=[__datafusion_extracted_1@0 as t.s[value]] +02)--UnionExec +03)----FilterExec: id@1 <= 3, projection=[__datafusion_extracted_1@0] +04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 <= 3, pruning_predicate=id_null_count@1 != row_count@2 AND id_min@0 <= 3, required_guarantees=[] +05)----FilterExec: id@1 > 3, projection=[__datafusion_extracted_1@0] +06)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 > 3, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 3, required_guarantees=[] + +# Verify correctness +query I +SELECT s['value'] FROM ( + SELECT s FROM simple_struct WHERE id <= 3 + UNION ALL + SELECT s FROM simple_struct WHERE id > 3 +) t ORDER BY s['value']; +---- +100 +150 +200 +250 +300 + +# Multiple get_field on UNION ALL with ORDER BY +query TT +EXPLAIN SELECT s['value'], s['label'] FROM ( + SELECT s FROM simple_struct WHERE id <= 3 + UNION ALL + SELECT s FROM simple_struct WHERE id > 3 +) t ORDER BY s['value']; +---- +logical_plan +01)Sort: t.s[value] ASC NULLS LAST +02)--Projection: __datafusion_extracted_1 AS t.s[value], __datafusion_extracted_2 AS t.s[label] +03)----SubqueryAlias: t +04)------Union +05)--------Projection: __datafusion_extracted_1, __datafusion_extracted_2 +06)----------Filter: simple_struct.id <= Int64(3) +07)------------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, get_field(simple_struct.s, Utf8("label")) AS __datafusion_extracted_2, simple_struct.id +08)--------------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id <= Int64(3)] +09)--------Projection: __datafusion_extracted_1, __datafusion_extracted_2 +10)----------Filter: simple_struct.id > Int64(3) +11)------------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, get_field(simple_struct.s, Utf8("label")) AS __datafusion_extracted_2, simple_struct.id +12)--------------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(3)] +physical_plan +01)SortPreservingMergeExec: [t.s[value]@0 ASC NULLS LAST] +02)--SortExec: expr=[t.s[value]@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[__datafusion_extracted_1@0 as t.s[value], __datafusion_extracted_2@1 as t.s[label]] +04)------UnionExec +05)--------FilterExec: id@2 <= 3, projection=[__datafusion_extracted_1@0, __datafusion_extracted_2@1] +06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, get_field(s@1, label) as __datafusion_extracted_2, id], file_type=parquet, predicate=id@0 <= 3, pruning_predicate=id_null_count@1 != row_count@2 AND id_min@0 <= 3, required_guarantees=[] +07)--------FilterExec: id@2 > 3, projection=[__datafusion_extracted_1@0, __datafusion_extracted_2@1] +08)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, get_field(s@1, label) as __datafusion_extracted_2, id], file_type=parquet, predicate=id@0 > 3, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 3, required_guarantees=[] + +# Verify correctness +query IT +SELECT s['value'], s['label'] FROM ( + SELECT s FROM simple_struct WHERE id <= 3 + UNION ALL + SELECT s FROM simple_struct WHERE id > 3 +) t ORDER BY s['value']; +---- +100 alpha +150 gamma +200 beta +250 epsilon +300 delta + +##################### +# Section 16: Aggregate / Join edge-case tests +# Translated from unit tests in extract_leaf_expressions.rs +##################### + +### +# Test 16.1: Projection with get_field above Aggregate +# Aggregate blocks pushdown, so the get_field stays in the top projection. +# (mirrors test_projection_with_leaf_expr_above_aggregate) +### + +query TT +EXPLAIN SELECT s['label'] IS NOT NULL AS has_label, COUNT(1) +FROM simple_struct GROUP BY s; +---- +logical_plan +01)Projection: get_field(simple_struct.s, Utf8("label")) IS NOT NULL AS has_label, count(Int64(1)) +02)--Aggregate: groupBy=[[simple_struct.s]], aggr=[[count(Int64(1))]] +03)----TableScan: simple_struct projection=[s] +physical_plan +01)ProjectionExec: expr=[get_field(s@0, label) IS NOT NULL as has_label, count(Int64(1))@1 as count(Int64(1))] +02)--AggregateExec: mode=Single, gby=[s@0 as s], aggr=[count(Int64(1))] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[s], file_type=parquet + +# Verify correctness - all labels are non-null +query BI +SELECT s['label'] IS NOT NULL AS has_label, COUNT(1) +FROM simple_struct GROUP BY s ORDER BY COUNT(1); +---- +true 1 +true 1 +true 1 +true 1 +true 1 + +### +# Test 16.2: Join with get_field filter on qualified right side +# The get_field on join_right.s['role'] must be routed to the right input only. +# (mirrors test_extract_from_join_qualified_right_side) +### + +query TT +EXPLAIN +SELECT s.s['value'], j.s['role'] +FROM join_right j +INNER JOIN simple_struct s ON s.id = j.id +WHERE s.s['value'] > j.s['level']; +---- +logical_plan +01)Projection: __datafusion_extracted_3 AS s.s[value], __datafusion_extracted_4 AS j.s[role] +02)--Inner Join: j.id = s.id Filter: __datafusion_extracted_1 > __datafusion_extracted_2 +03)----SubqueryAlias: j +04)------Projection: get_field(join_right.s, Utf8("level")) AS __datafusion_extracted_2, get_field(join_right.s, Utf8("role")) AS __datafusion_extracted_4, join_right.id +05)--------TableScan: join_right projection=[id, s] +06)----SubqueryAlias: s +07)------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_3, simple_struct.id +08)--------TableScan: simple_struct projection=[id, s] +physical_plan +01)ProjectionExec: expr=[__datafusion_extracted_3@1 as s.s[value], __datafusion_extracted_4@0 as j.s[role]] +02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(id@2, id@2)], filter=__datafusion_extracted_1@1 > __datafusion_extracted_2@0, projection=[__datafusion_extracted_4@1, __datafusion_extracted_3@4] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/join_right.parquet]]}, projection=[get_field(s@1, level) as __datafusion_extracted_2, get_field(s@1, role) as __datafusion_extracted_4, id], file_type=parquet +04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, get_field(s@1, value) as __datafusion_extracted_3, id], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness - only admin roles match (ids 1 and 4) +query II +SELECT simple_struct.id, join_right.id +FROM simple_struct +INNER JOIN join_right + ON simple_struct.id = join_right.id + AND join_right.s['role'] = 'admin' +ORDER BY simple_struct.id; +---- +1 1 +4 4 + +### +# Test 16.3: Join with cross-input get_field comparison in WHERE +# get_field from each side is extracted and routed to its respective input independently. +# (mirrors test_extract_from_join_cross_input_expression) +### + +query TT +EXPLAIN SELECT simple_struct.id, join_right.id +FROM simple_struct +INNER JOIN join_right ON simple_struct.id = join_right.id +WHERE simple_struct.s['value'] > join_right.s['level']; +---- +logical_plan +01)Projection: simple_struct.id, join_right.id +02)--Inner Join: simple_struct.id = join_right.id Filter: __datafusion_extracted_1 > __datafusion_extracted_2 +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +04)------TableScan: simple_struct projection=[id, s] +05)----Projection: get_field(join_right.s, Utf8("level")) AS __datafusion_extracted_2, join_right.id +06)------TableScan: join_right projection=[id, s] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(id@1, id@1)], filter=__datafusion_extracted_1@0 > __datafusion_extracted_2@1, projection=[id@1, id@3] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet +03)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/join_right.parquet]]}, projection=[get_field(s@1, level) as __datafusion_extracted_2, id], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness - all rows match since value >> level for all ids +# simple_struct: (1,100), (2,200), (3,150), (4,300), (5,250) +# join_right: (1,10), (2,5), (3,1), (4,8), (5,3) +query II +SELECT simple_struct.id, join_right.id +FROM simple_struct +INNER JOIN join_right ON simple_struct.id = join_right.id +WHERE simple_struct.s['value'] > join_right.s['level'] +ORDER BY simple_struct.id; +---- +1 1 +2 2 +3 3 +4 4 +5 5 + +# ========================================================================= +# Regression: user-provided __datafusion_extracted aliases must not +# collide with optimizer-generated ones +# (https://github.com/apache/datafusion/issues/20430) +# ========================================================================= + +statement ok +COPY ( select {f1: 1, f2: 2} as s +) TO 'test_files/scratch/projection_pushdown/test.parquet' +STORED AS PARQUET; + +statement ok +CREATE EXTERNAL TABLE t +STORED AS PARQUET +LOCATION 'test_files/scratch/projection_pushdown/test.parquet'; + +# Verify that the user-provided __datafusion_extracted_2 alias is preserved +# and the optimizer skips to _3 and _4 for its generated aliases. +query TT +EXPLAIN SELECT + get_field(s, 'f1') AS __datafusion_extracted_2 +FROM t +WHERE COALESCE(get_field(s, 'f1'), get_field(s, 'f2')) = 1; +---- +logical_plan +01)Projection: __datafusion_extracted_2 +02)--Filter: CASE WHEN __datafusion_extracted_3 IS NOT NULL THEN __datafusion_extracted_3 ELSE __datafusion_extracted_4 END = Int64(1) +03)----Projection: get_field(t.s, Utf8("f1")) AS __datafusion_extracted_3, get_field(t.s, Utf8("f2")) AS __datafusion_extracted_4, get_field(t.s, Utf8("f1")) AS __datafusion_extracted_2 +04)------TableScan: t projection=[s], partial_filters=[CASE WHEN get_field(t.s, Utf8("f1")) IS NOT NULL THEN get_field(t.s, Utf8("f1")) ELSE get_field(t.s, Utf8("f2")) END = Int64(1)] +physical_plan +01)FilterExec: CASE WHEN __datafusion_extracted_3@0 IS NOT NULL THEN __datafusion_extracted_3@0 ELSE __datafusion_extracted_4@1 END = 1, projection=[__datafusion_extracted_2@2] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/test.parquet]]}, projection=[get_field(s@0, f1) as __datafusion_extracted_3, get_field(s@0, f2) as __datafusion_extracted_4, get_field(s@0, f1) as __datafusion_extracted_2], file_type=parquet + +query I +SELECT + get_field(s, 'f1') AS __datafusion_extracted_2 +FROM t +WHERE COALESCE(get_field(s, 'f1'), get_field(s, 'f2')) = 1; +---- +1 diff --git a/datafusion/sqllogictest/test_files/push_down_filter.slt b/datafusion/sqllogictest/test_files/push_down_filter.slt deleted file mode 100644 index 4353f805c848..000000000000 --- a/datafusion/sqllogictest/test_files/push_down_filter.slt +++ /dev/null @@ -1,490 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# Test push down filter - -statement ok -set datafusion.explain.physical_plan_only = true; - -statement ok -CREATE TABLE IF NOT EXISTS v AS VALUES(1,[1,2,3]),(2,[3,4,5]); - -query I -select uc2 from (select unnest(column2) as uc2, column1 from v) where column1 = 2; ----- -3 -4 -5 - -# test push down filter for unnest with filter on non-unnest column -# filter plan is pushed down into projection plan -query TT -explain select uc2 from (select unnest(column2) as uc2, column1 from v) where column1 = 2; ----- -physical_plan -01)ProjectionExec: expr=[__unnest_placeholder(v.column2,depth=1)@0 as uc2] -02)--UnnestExec -03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -04)------ProjectionExec: expr=[column2@0 as __unnest_placeholder(v.column2)] -05)--------FilterExec: column1@0 = 2, projection=[column2@1] -06)----------DataSourceExec: partitions=1, partition_sizes=[1] - -query I -select uc2 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3; ----- -4 -5 - -# test push down filter for unnest with filter on unnest column -query TT -explain select uc2 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3; ----- -physical_plan -01)ProjectionExec: expr=[__unnest_placeholder(v.column2,depth=1)@0 as uc2] -02)--FilterExec: __unnest_placeholder(v.column2,depth=1)@0 > 3 -03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -04)------UnnestExec -05)--------ProjectionExec: expr=[column2@0 as __unnest_placeholder(v.column2)] -06)----------DataSourceExec: partitions=1, partition_sizes=[1] - -query II -select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 AND column1 = 2; ----- -4 2 -5 2 - -# Could push the filter (column1 = 2) down below unnest -query TT -explain select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 AND column1 = 2; ----- -physical_plan -01)ProjectionExec: expr=[__unnest_placeholder(v.column2,depth=1)@0 as uc2, column1@1 as column1] -02)--FilterExec: __unnest_placeholder(v.column2,depth=1)@0 > 3 -03)----UnnestExec -04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -05)--------ProjectionExec: expr=[column2@1 as __unnest_placeholder(v.column2), column1@0 as column1] -06)----------FilterExec: column1@0 = 2 -07)------------DataSourceExec: partitions=1, partition_sizes=[1] - -query II -select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 OR column1 = 2; ----- -3 2 -4 2 -5 2 - -# only non-unnest filter in AND clause could be pushed down -query TT -explain select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 OR column1 = 2; ----- -physical_plan -01)ProjectionExec: expr=[__unnest_placeholder(v.column2,depth=1)@0 as uc2, column1@1 as column1] -02)--FilterExec: __unnest_placeholder(v.column2,depth=1)@0 > 3 OR column1@1 = 2 -03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -04)------UnnestExec -05)--------ProjectionExec: expr=[column2@1 as __unnest_placeholder(v.column2), column1@0 as column1] -06)----------DataSourceExec: partitions=1, partition_sizes=[1] - -statement ok -drop table v; - -# test with unnest struct, should not push down filter -statement ok -CREATE TABLE d AS VALUES(1,[named_struct('a', 1, 'b', 2)]),(2,[named_struct('a', 3, 'b', 4), named_struct('a', 5, 'b', 6)]); - -query I? -select * from (select column1, unnest(column2) as o from d) where o['a'] = 1; ----- -1 {a: 1, b: 2} - -query TT -explain select * from (select column1, unnest(column2) as o from d) where o['a'] = 1; ----- -physical_plan -01)ProjectionExec: expr=[column1@0 as column1, __unnest_placeholder(d.column2,depth=1)@1 as o] -02)--FilterExec: get_field(__unnest_placeholder(d.column2,depth=1)@1, a) = 1 -03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -04)------UnnestExec -05)--------ProjectionExec: expr=[column1@0 as column1, column2@1 as __unnest_placeholder(d.column2)] -06)----------DataSourceExec: partitions=1, partition_sizes=[1] - -statement ok -drop table d; - -statement ok -CREATE TABLE d AS VALUES (named_struct('a', 1, 'b', 2)), (named_struct('a', 3, 'b', 4)), (named_struct('a', 5, 'b', 6)); - -query II -select * from (select unnest(column1) from d) where "__unnest_placeholder(d.column1).b" > 5; ----- -5 6 - -query TT -explain select * from (select unnest(column1) from d) where "__unnest_placeholder(d.column1).b" > 5; ----- -physical_plan -01)FilterExec: __unnest_placeholder(d.column1).b@1 > 5 -02)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -03)----UnnestExec -04)------ProjectionExec: expr=[column1@0 as __unnest_placeholder(d.column1)] -05)--------DataSourceExec: partitions=1, partition_sizes=[1] - -statement ok -drop table d; - -# Test push down filter with limit for parquet -statement ok -set datafusion.execution.parquet.pushdown_filters = true; - -# this one is also required to make DF skip second file due to "sufficient" amount of rows -statement ok -set datafusion.execution.collect_statistics = true; - -# Create a table as a data source -statement ok -CREATE TABLE src_table ( - part_key INT, - value INT -) AS VALUES(1, 0), (1, 1), (1, 100), (2, 0), (2, 2), (2, 2), (2, 100), (3, 4), (3, 5), (3, 6); - - -# There will be more than 2 records filtered from the table to check that `limit 1` actually applied. -# Setup 3 files, i.e., as many as there are partitions: - -# File 1: -query I -COPY (SELECT * FROM src_table where part_key = 1) -TO 'test_files/scratch/push_down_filter/test_filter_with_limit/part-0.parquet' -STORED AS PARQUET; ----- -3 - -# File 2: -query I -COPY (SELECT * FROM src_table where part_key = 2) -TO 'test_files/scratch/push_down_filter/test_filter_with_limit/part-1.parquet' -STORED AS PARQUET; ----- -4 - -# File 3: -query I -COPY (SELECT * FROM src_table where part_key = 3) -TO 'test_files/scratch/push_down_filter/test_filter_with_limit/part-2.parquet' -STORED AS PARQUET; ----- -3 - -statement ok -CREATE EXTERNAL TABLE test_filter_with_limit -( - part_key INT, - value INT -) -STORED AS PARQUET -LOCATION 'test_files/scratch/push_down_filter/test_filter_with_limit/'; - -query TT -explain select * from test_filter_with_limit where value = 2 limit 1; ----- -physical_plan -01)CoalescePartitionsExec: fetch=1 -02)--DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/test_filter_with_limit/part-0.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/test_filter_with_limit/part-1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/test_filter_with_limit/part-2.parquet]]}, projection=[part_key, value], limit=1, file_type=parquet, predicate=value@1 = 2, pruning_predicate=value_null_count@2 != row_count@3 AND value_min@0 <= 2 AND 2 <= value_max@1, required_guarantees=[value in (2)] - -query II -select * from test_filter_with_limit where value = 2 limit 1; ----- -2 2 - - -# Tear down test_filter_with_limit table: -statement ok -DROP TABLE test_filter_with_limit; - -# Tear down src_table table: -statement ok -DROP TABLE src_table; - - -query I -COPY (VALUES (1), (2), (3), (4), (5), (6), (7), (8), (9), (10)) -TO 'test_files/scratch/push_down_filter/t.parquet' -STORED AS PARQUET; ----- -10 - -statement ok -CREATE EXTERNAL TABLE t -( - a INT -) -STORED AS PARQUET -LOCATION 'test_files/scratch/push_down_filter/t.parquet'; - - -# The predicate should not have a column cast when the value is a valid i32 -query TT -explain select a from t where a = '100'; ----- -physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=a@0 = 100, pruning_predicate=a_null_count@2 != row_count@3 AND a_min@0 <= 100 AND 100 <= a_max@1, required_guarantees=[a in (100)] - -# The predicate should not have a column cast when the value is a valid i32 -query TT -explain select a from t where a != '100'; ----- -physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=a@0 != 100, pruning_predicate=a_null_count@2 != row_count@3 AND (a_min@0 != 100 OR 100 != a_max@1), required_guarantees=[a not in (100)] - -# The predicate should still have the column cast when the value is a NOT valid i32 -query TT -explain select a from t where a = '99999999999'; ----- -physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=CAST(a@0 AS Utf8) = 99999999999 - -# The predicate should still have the column cast when the value is a NOT valid i32 -query TT -explain select a from t where a = '99.99'; ----- -physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=CAST(a@0 AS Utf8) = 99.99 - -# The predicate should still have the column cast when the value is a NOT valid i32 -query TT -explain select a from t where a = ''; ----- -physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=CAST(a@0 AS Utf8) = - -# The predicate should not have a column cast when the operator is = or != and the literal can be round-trip casted without losing information. -query TT -explain select a from t where cast(a as string) = '100'; ----- -physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=a@0 = 100, pruning_predicate=a_null_count@2 != row_count@3 AND a_min@0 <= 100 AND 100 <= a_max@1, required_guarantees=[a in (100)] - -# The predicate should still have the column cast when the literal alters its string representation after round-trip casting (leading zero lost). -query TT -explain select a from t where CAST(a AS string) = '0123'; ----- -physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=CAST(a@0 AS Utf8View) = 0123 - - -# Test dynamic filter pushdown with swapped join inputs (issue #17196) -# Create tables with different sizes to force join input swapping -statement ok -copy (select i as k from generate_series(1, 100) t(i)) to 'test_files/scratch/push_down_filter/small_table.parquet'; - -statement ok -copy (select i as k, i as v from generate_series(1, 1000) t(i)) to 'test_files/scratch/push_down_filter/large_table.parquet'; - -statement ok -create external table small_table stored as parquet location 'test_files/scratch/push_down_filter/small_table.parquet'; - -statement ok -create external table large_table stored as parquet location 'test_files/scratch/push_down_filter/large_table.parquet'; - -# Test that dynamic filter is applied to the correct table after join input swapping -# The small_table should be the build side, large_table should be the probe side with dynamic filter -query TT -explain select * from small_table join large_table on small_table.k = large_table.k where large_table.v >= 50; ----- -physical_plan -01)HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(k@0, k@0)] -02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/small_table.parquet]]}, projection=[k], file_type=parquet -03)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/large_table.parquet]]}, projection=[k, v], file_type=parquet, predicate=v@1 >= 50 AND DynamicFilter [ empty ], pruning_predicate=v_null_count@1 != row_count@2 AND v_max@0 >= 50, required_guarantees=[] - -statement ok -drop table small_table; - -statement ok -drop table large_table; - -statement ok -drop table t; - -# Regression test for https://github.com/apache/datafusion/issues/17188 -query I -COPY (select i as k from generate_series(1, 10000000) as t(i)) -TO 'test_files/scratch/push_down_filter/t1.parquet' -STORED AS PARQUET; ----- -10000000 - -query I -COPY (select i as k, i as v from generate_series(1, 10000000) as t(i)) -TO 'test_files/scratch/push_down_filter/t2.parquet' -STORED AS PARQUET; ----- -10000000 - -statement ok -create external table t1 stored as parquet location 'test_files/scratch/push_down_filter/t1.parquet'; - -statement ok -create external table t2 stored as parquet location 'test_files/scratch/push_down_filter/t2.parquet'; - -# The failure before https://github.com/apache/datafusion/pull/17197 was non-deterministic and random -# So we'll run the same query a couple of times just to have more certainty it's fixed -# Sorry about the spam in this slt test... - -query III rowsort -select * -from t1 -join t2 on t1.k = t2.k -where v = 1 or v = 10000000 -order by t1.k, t2.v; ----- -1 1 1 -10000000 10000000 10000000 - -query III rowsort -select * -from t1 -join t2 on t1.k = t2.k -where v = 1 or v = 10000000 -order by t1.k, t2.v; ----- -1 1 1 -10000000 10000000 10000000 - -query III rowsort -select * -from t1 -join t2 on t1.k = t2.k -where v = 1 or v = 10000000 -order by t1.k, t2.v; ----- -1 1 1 -10000000 10000000 10000000 - -query III rowsort -select * -from t1 -join t2 on t1.k = t2.k -where v = 1 or v = 10000000 -order by t1.k, t2.v; ----- -1 1 1 -10000000 10000000 10000000 - -query III rowsort -select * -from t1 -join t2 on t1.k = t2.k -where v = 1 or v = 10000000 -order by t1.k, t2.v; ----- -1 1 1 -10000000 10000000 10000000 - -# Regression test for https://github.com/apache/datafusion/issues/17512 - -query I -COPY ( - SELECT arrow_cast('2025-01-01T00:00:00Z'::timestamptz, 'Timestamp(Microsecond, Some("UTC"))') AS start_timestamp -) -TO 'test_files/scratch/push_down_filter/17512.parquet' -STORED AS PARQUET; ----- -1 - -statement ok -CREATE EXTERNAL TABLE records STORED AS PARQUET LOCATION 'test_files/scratch/push_down_filter/17512.parquet'; - -query I -SELECT 1 -FROM ( - SELECT start_timestamp - FROM records - WHERE start_timestamp <= '2025-01-01T00:00:00Z'::timestamptz -) AS t -WHERE t.start_timestamp::time < '00:00:01'::time; ----- -1 - -# Test aggregate dynamic filter pushdown -# Note: most of the test coverage lives in `datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs` -# , to compare dynamic filter content easier. Here the tests are simple end-to-end -# exercises. - -statement ok -set datafusion.explain.format = 'indent'; - -statement ok -set datafusion.explain.physical_plan_only = true; - -statement ok -set datafusion.execution.target_partitions = 2; - -statement ok -set datafusion.execution.parquet.pushdown_filters = true; - -statement ok -set datafusion.optimizer.enable_dynamic_filter_pushdown = true; - -statement ok -set datafusion.optimizer.enable_aggregate_dynamic_filter_pushdown = true; - -statement ok -create external table agg_dyn_test stored as parquet location '../core/tests/data/test_statistics_per_partition'; - -# Expect dynamic filter available inside data source -query TT -explain select max(id) from agg_dyn_test where id > 1; ----- -physical_plan -01)AggregateExec: mode=Final, gby=[], aggr=[max(agg_dyn_test.id)] -02)--CoalescePartitionsExec -03)----AggregateExec: mode=Partial, gby=[], aggr=[max(agg_dyn_test.id)] -04)------DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-01/j5fUeSDQo22oPyPU.parquet, WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-02/j5fUeSDQo22oPyPU.parquet], [WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-03/j5fUeSDQo22oPyPU.parquet, WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-04/j5fUeSDQo22oPyPU.parquet]]}, projection=[id], file_type=parquet, predicate=id@0 > 1 AND DynamicFilter [ empty ], pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1, required_guarantees=[] - -query I -select max(id) from agg_dyn_test where id > 1; ----- -4 - -# Expect dynamic filter available inside data source -query TT -explain select max(id) from agg_dyn_test where (id+1) > 1; ----- -physical_plan -01)AggregateExec: mode=Final, gby=[], aggr=[max(agg_dyn_test.id)] -02)--CoalescePartitionsExec -03)----AggregateExec: mode=Partial, gby=[], aggr=[max(agg_dyn_test.id)] -04)------DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-01/j5fUeSDQo22oPyPU.parquet, WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-02/j5fUeSDQo22oPyPU.parquet], [WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-03/j5fUeSDQo22oPyPU.parquet, WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-04/j5fUeSDQo22oPyPU.parquet]]}, projection=[id], file_type=parquet, predicate=CAST(id@0 AS Int64) + 1 > 1 AND DynamicFilter [ empty ] - -# Expect dynamic filter available inside data source -query TT -explain select max(id), min(id) from agg_dyn_test where id < 10; ----- -physical_plan -01)AggregateExec: mode=Final, gby=[], aggr=[max(agg_dyn_test.id), min(agg_dyn_test.id)] -02)--CoalescePartitionsExec -03)----AggregateExec: mode=Partial, gby=[], aggr=[max(agg_dyn_test.id), min(agg_dyn_test.id)] -04)------DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-01/j5fUeSDQo22oPyPU.parquet, WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-02/j5fUeSDQo22oPyPU.parquet], [WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-03/j5fUeSDQo22oPyPU.parquet, WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-04/j5fUeSDQo22oPyPU.parquet]]}, projection=[id], file_type=parquet, predicate=id@0 < 10 AND DynamicFilter [ empty ], pruning_predicate=id_null_count@1 != row_count@2 AND id_min@0 < 10, required_guarantees=[] - -# Dynamic filter should not be available for grouping sets -query TT -explain select max(id) from agg_dyn_test where id < 10 -group by grouping sets ((), (id)) ----- -physical_plan -01)ProjectionExec: expr=[max(agg_dyn_test.id)@2 as max(agg_dyn_test.id)] -02)--AggregateExec: mode=FinalPartitioned, gby=[id@0 as id, __grouping_id@1 as __grouping_id], aggr=[max(agg_dyn_test.id)] -03)----RepartitionExec: partitioning=Hash([id@0, __grouping_id@1], 2), input_partitions=2 -04)------AggregateExec: mode=Partial, gby=[(NULL as id), (id@0 as id)], aggr=[max(agg_dyn_test.id)] -05)--------DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-01/j5fUeSDQo22oPyPU.parquet, WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-02/j5fUeSDQo22oPyPU.parquet], [WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-03/j5fUeSDQo22oPyPU.parquet, WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-04/j5fUeSDQo22oPyPU.parquet]]}, projection=[id], file_type=parquet, predicate=id@0 < 10, pruning_predicate=id_null_count@1 != row_count@2 AND id_min@0 < 10, required_guarantees=[] - -statement ok -drop table agg_dyn_test; diff --git a/datafusion/sqllogictest/test_files/push_down_filter_outer_joins.slt b/datafusion/sqllogictest/test_files/push_down_filter_outer_joins.slt new file mode 100644 index 000000000000..2e5f7c317fd4 --- /dev/null +++ b/datafusion/sqllogictest/test_files/push_down_filter_outer_joins.slt @@ -0,0 +1,264 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Test push down filter + +# check LEFT/RIGHT joins with filter pushdown to both relations (when possible) + +statement ok +create table t1(k int, v int); + +statement ok +create table t2(k int, v int); + +statement ok +insert into t1 values + (1, 10), + (2, 20), + (3, 30), + (null, 40), + (50, null), + (null, null); + +statement ok +insert into t2 values + (1, 11), + (2, 21), + (2, 22), + (null, 41), + (51, null), + (null, null); + +statement ok +set datafusion.explain.physical_plan_only = false; + +statement ok +set datafusion.explain.logical_plan_only = true; + + +# left join + filter on join key -> pushed +query TT +explain select * from t1 left join t2 on t1.k = t2.k where t1.k > 1; +---- +logical_plan +01)Left Join: t1.k = t2.k +02)--Filter: t1.k > Int32(1) +03)----TableScan: t1 projection=[k, v] +04)--Filter: t2.k > Int32(1) +05)----TableScan: t2 projection=[k, v] + +query IIII rowsort +select * from t1 left join t2 on t1.k = t2.k where t1.k > 1; +---- +2 20 2 21 +2 20 2 22 +3 30 NULL NULL +50 NULL NULL NULL + +# left join + filter on another column -> not pushed +query TT +explain select * from t1 left join t2 on t1.k = t2.k where t1.v > 1; +---- +logical_plan +01)Left Join: t1.k = t2.k +02)--Filter: t1.v > Int32(1) +03)----TableScan: t1 projection=[k, v] +04)--TableScan: t2 projection=[k, v] + +query IIII rowsort +select * from t1 left join t2 on t1.k = t2.k where t1.v > 1; +---- +1 10 1 11 +2 20 2 21 +2 20 2 22 +3 30 NULL NULL +NULL 40 NULL NULL + +# left join + or + filter on another column -> not pushed +query TT +explain select * from t1 left join t2 on t1.k = t2.k where t1.k > 3 or t1.v > 20; +---- +logical_plan +01)Left Join: t1.k = t2.k +02)--Filter: t1.k > Int32(3) OR t1.v > Int32(20) +03)----TableScan: t1 projection=[k, v] +04)--TableScan: t2 projection=[k, v] + +query IIII rowsort +select * from t1 left join t2 on t1.k = t2.k where t1.k > 3 or t1.v > 20; +---- +3 30 NULL NULL +50 NULL NULL NULL +NULL 40 NULL NULL + + +# right join + filter on join key -> pushed +query TT +explain select * from t1 right join t2 on t1.k = t2.k where t1.k > 1; +---- +logical_plan +01)Inner Join: t1.k = t2.k +02)--Filter: t1.k > Int32(1) +03)----TableScan: t1 projection=[k, v] +04)--Filter: t2.k > Int32(1) +05)----TableScan: t2 projection=[k, v] + +query IIII rowsort +select * from t1 right join t2 on t1.k = t2.k where t1.k > 1; +---- +2 20 2 21 +2 20 2 22 + +# right join + filter on another column -> not pushed +query TT +explain select * from t1 right join t2 on t1.k = t2.k where t1.v > 1; +---- +logical_plan +01)Inner Join: t1.k = t2.k +02)--Filter: t1.v > Int32(1) +03)----TableScan: t1 projection=[k, v] +04)--TableScan: t2 projection=[k, v] + +query IIII rowsort +select * from t1 right join t2 on t1.k = t2.k where t1.v > 1; +---- +1 10 1 11 +2 20 2 21 +2 20 2 22 + +# right join + or + filter on another column -> not pushed +query TT +explain select * from t1 right join t2 on t1.k = t2.k where t1.k > 3 or t1.v > 20; +---- +logical_plan +01)Inner Join: t1.k = t2.k +02)--Filter: t1.k > Int32(3) OR t1.v > Int32(20) +03)----TableScan: t1 projection=[k, v] +04)--TableScan: t2 projection=[k, v] + +query IIII rowsort +select * from t1 right join t2 on t1.k = t2.k where t1.k > 3 or t1.v > 20; +---- + + +# left anti join + filter on join key -> pushed +query TT +explain select * from t1 left anti join t2 on t1.k = t2.k where t1.k > 1; +---- +logical_plan +01)LeftAnti Join: t1.k = t2.k +02)--Filter: t1.k > Int32(1) +03)----TableScan: t1 projection=[k, v] +04)--Filter: t2.k > Int32(1) +05)----TableScan: t2 projection=[k] + +query II rowsort +select * from t1 left anti join t2 on t1.k = t2.k where t1.k > 1; +---- +3 30 +50 NULL + +# left anti join + filter on another column -> not pushed +query TT +explain select * from t1 left anti join t2 on t1.k = t2.k where t1.v > 1; +---- +logical_plan +01)LeftAnti Join: t1.k = t2.k +02)--Filter: t1.v > Int32(1) +03)----TableScan: t1 projection=[k, v] +04)--TableScan: t2 projection=[k] + +query II rowsort +select * from t1 left anti join t2 on t1.k = t2.k where t1.v > 1; +---- +3 30 +NULL 40 + +# left anti join + or + filter on another column -> not pushed +query TT +explain select * from t1 left anti join t2 on t1.k = t2.k where t1.k > 3 or t1.v > 20; +---- +logical_plan +01)LeftAnti Join: t1.k = t2.k +02)--Filter: t1.k > Int32(3) OR t1.v > Int32(20) +03)----TableScan: t1 projection=[k, v] +04)--TableScan: t2 projection=[k] + +query II rowsort +select * from t1 left anti join t2 on t1.k = t2.k where t1.k > 3 or t1.v > 20; +---- +3 30 +50 NULL +NULL 40 + + +# right anti join + filter on join key -> pushed +query TT +explain select * from t1 right anti join t2 on t1.k = t2.k where t2.k > 1; +---- +logical_plan +01)RightAnti Join: t1.k = t2.k +02)--Filter: t1.k > Int32(1) +03)----TableScan: t1 projection=[k] +04)--Filter: t2.k > Int32(1) +05)----TableScan: t2 projection=[k, v] + +query II rowsort +select * from t1 right anti join t2 on t1.k = t2.k where t2.k > 1; +---- +51 NULL + +# right anti join + filter on another column -> not pushed +query TT +explain select * from t1 right anti join t2 on t1.k = t2.k where t2.v > 1; +---- +logical_plan +01)RightAnti Join: t1.k = t2.k +02)--TableScan: t1 projection=[k] +03)--Filter: t2.v > Int32(1) +04)----TableScan: t2 projection=[k, v] + +query II rowsort +select * from t1 right anti join t2 on t1.k = t2.k where t2.v > 1; +---- +NULL 41 + +# right anti join + or + filter on another column -> not pushed +query TT +explain select * from t1 right anti join t2 on t1.k = t2.k where t2.k > 3 or t2.v > 20; +---- +logical_plan +01)RightAnti Join: t1.k = t2.k +02)--TableScan: t1 projection=[k] +03)--Filter: t2.k > Int32(3) OR t2.v > Int32(20) +04)----TableScan: t2 projection=[k, v] + +query II rowsort +select * from t1 right anti join t2 on t1.k = t2.k where t2.k > 3 or t2.v > 20; +---- +51 NULL +NULL 41 + + +statement ok +set datafusion.explain.logical_plan_only = false; + +statement ok +drop table t1; + +statement ok +drop table t2; diff --git a/datafusion/sqllogictest/test_files/push_down_filter_parquet.slt b/datafusion/sqllogictest/test_files/push_down_filter_parquet.slt new file mode 100644 index 000000000000..e1c83c8c330d --- /dev/null +++ b/datafusion/sqllogictest/test_files/push_down_filter_parquet.slt @@ -0,0 +1,188 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Test push down filter + +statement ok +set datafusion.explain.physical_plan_only = true; + +# Test push down filter with limit for parquet +statement ok +set datafusion.execution.parquet.pushdown_filters = true; + +# this one is also required to make DF skip second file due to "sufficient" amount of rows +statement ok +set datafusion.execution.collect_statistics = true; + +# Create a table as a data source +statement ok +CREATE TABLE src_table ( + part_key INT, + value INT +) AS VALUES(1, 0), (1, 1), (1, 100), (2, 0), (2, 2), (2, 2), (2, 100), (3, 4), (3, 5), (3, 6); + + +# There will be more than 2 records filtered from the table to check that `limit 1` actually applied. +# Setup 3 files, i.e., as many as there are partitions: + +# File 1: +query I +COPY (SELECT * FROM src_table where part_key = 1) +TO 'test_files/scratch/push_down_filter_parquet/test_filter_with_limit/part-0.parquet' +STORED AS PARQUET; +---- +3 + +# File 2: +query I +COPY (SELECT * FROM src_table where part_key = 2) +TO 'test_files/scratch/push_down_filter_parquet/test_filter_with_limit/part-1.parquet' +STORED AS PARQUET; +---- +4 + +# File 3: +query I +COPY (SELECT * FROM src_table where part_key = 3) +TO 'test_files/scratch/push_down_filter_parquet/test_filter_with_limit/part-2.parquet' +STORED AS PARQUET; +---- +3 + +statement ok +CREATE EXTERNAL TABLE test_filter_with_limit +( + part_key INT, + value INT +) +STORED AS PARQUET +LOCATION 'test_files/scratch/push_down_filter_parquet/test_filter_with_limit/'; + +query TT +explain select * from test_filter_with_limit where value = 2 limit 1; +---- +physical_plan +01)CoalescePartitionsExec: fetch=1 +02)--DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter_parquet/test_filter_with_limit/part-0.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter_parquet/test_filter_with_limit/part-1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter_parquet/test_filter_with_limit/part-2.parquet]]}, projection=[part_key, value], limit=1, file_type=parquet, predicate=value@1 = 2, pruning_predicate=value_null_count@2 != row_count@3 AND value_min@0 <= 2 AND 2 <= value_max@1, required_guarantees=[value in (2)] + +query II +select * from test_filter_with_limit where value = 2 limit 1; +---- +2 2 + + +# Tear down test_filter_with_limit table: +statement ok +DROP TABLE test_filter_with_limit; + +# Tear down src_table table: +statement ok +DROP TABLE src_table; + + +query I +COPY (VALUES (1), (2), (3), (4), (5), (6), (7), (8), (9), (10)) +TO 'test_files/scratch/push_down_filter_parquet/t.parquet' +STORED AS PARQUET; +---- +10 + +statement ok +CREATE EXTERNAL TABLE t +( + a INT +) +STORED AS PARQUET +LOCATION 'test_files/scratch/push_down_filter_parquet/t.parquet'; + + +# The predicate should not have a column cast when the value is a valid i32 +query TT +explain select a from t where a = '100'; +---- +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter_parquet/t.parquet]]}, projection=[a], file_type=parquet, predicate=a@0 = 100, pruning_predicate=a_null_count@2 != row_count@3 AND a_min@0 <= 100 AND 100 <= a_max@1, required_guarantees=[a in (100)] + +# The predicate should not have a column cast when the value is a valid i32 +query TT +explain select a from t where a != '100'; +---- +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter_parquet/t.parquet]]}, projection=[a], file_type=parquet, predicate=a@0 != 100, pruning_predicate=a_null_count@2 != row_count@3 AND (a_min@0 != 100 OR 100 != a_max@1), required_guarantees=[a not in (100)] + +# The predicate should still have the column cast when the value is a NOT valid i32 +query TT +explain select a from t where a = '99999999999'; +---- +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter_parquet/t.parquet]]}, projection=[a], file_type=parquet, predicate=CAST(a@0 AS Utf8) = 99999999999 + +# The predicate should still have the column cast when the value is a NOT valid i32 +query TT +explain select a from t where a = '99.99'; +---- +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter_parquet/t.parquet]]}, projection=[a], file_type=parquet, predicate=CAST(a@0 AS Utf8) = 99.99 + +# The predicate should still have the column cast when the value is a NOT valid i32 +query TT +explain select a from t where a = ''; +---- +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter_parquet/t.parquet]]}, projection=[a], file_type=parquet, predicate=CAST(a@0 AS Utf8) = + +# The predicate should not have a column cast when the operator is = or != and the literal can be round-trip casted without losing information. +query TT +explain select a from t where cast(a as string) = '100'; +---- +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter_parquet/t.parquet]]}, projection=[a], file_type=parquet, predicate=a@0 = 100, pruning_predicate=a_null_count@2 != row_count@3 AND a_min@0 <= 100 AND 100 <= a_max@1, required_guarantees=[a in (100)] + +# The predicate should still have the column cast when the literal alters its string representation after round-trip casting (leading zero lost). +query TT +explain select a from t where CAST(a AS string) = '0123'; +---- +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter_parquet/t.parquet]]}, projection=[a], file_type=parquet, predicate=CAST(a@0 AS Utf8View) = 0123 + + +# Test dynamic filter pushdown with swapped join inputs (issue #17196) +# Create tables with different sizes to force join input swapping +statement ok +copy (select i as k from generate_series(1, 100) t(i)) to 'test_files/scratch/push_down_filter_parquet/small_table.parquet'; + +statement ok +copy (select i as k, i as v from generate_series(1, 1000) t(i)) to 'test_files/scratch/push_down_filter_parquet/large_table.parquet'; + +statement ok +create external table small_table stored as parquet location 'test_files/scratch/push_down_filter_parquet/small_table.parquet'; + +statement ok +create external table large_table stored as parquet location 'test_files/scratch/push_down_filter_parquet/large_table.parquet'; + +# Test that dynamic filter is applied to the correct table after join input swapping +# The small_table should be the build side, large_table should be the probe side with dynamic filter +query TT +explain select * from small_table join large_table on small_table.k = large_table.k where large_table.v >= 50; +---- +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(k@0, k@0)] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter_parquet/small_table.parquet]]}, projection=[k], file_type=parquet +03)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter_parquet/large_table.parquet]]}, projection=[k, v], file_type=parquet, predicate=v@1 >= 50 AND DynamicFilter [ empty ], pruning_predicate=v_null_count@1 != row_count@2 AND v_max@0 >= 50, required_guarantees=[] + +statement ok +drop table small_table; + +statement ok +drop table large_table; + +statement ok +drop table t; diff --git a/datafusion/sqllogictest/test_files/push_down_filter_regression.slt b/datafusion/sqllogictest/test_files/push_down_filter_regression.slt new file mode 100644 index 000000000000..ca4a30fa96c3 --- /dev/null +++ b/datafusion/sqllogictest/test_files/push_down_filter_regression.slt @@ -0,0 +1,200 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Test push down filter + +# Regression test for https://github.com/apache/datafusion/issues/17188 +query I +COPY (select i as k from generate_series(1, 10000000) as t(i)) +TO 'test_files/scratch/push_down_filter_regression/t1.parquet' +STORED AS PARQUET; +---- +10000000 + +query I +COPY (select i as k, i as v from generate_series(1, 10000000) as t(i)) +TO 'test_files/scratch/push_down_filter_regression/t2.parquet' +STORED AS PARQUET; +---- +10000000 + +statement ok +create external table t1 stored as parquet location 'test_files/scratch/push_down_filter_regression/t1.parquet'; + +statement ok +create external table t2 stored as parquet location 'test_files/scratch/push_down_filter_regression/t2.parquet'; + +# The failure before https://github.com/apache/datafusion/pull/17197 was non-deterministic and random +# So we'll run the same query a couple of times just to have more certainty it's fixed +# Sorry about the spam in this slt test... + +query III rowsort +select * +from t1 +join t2 on t1.k = t2.k +where v = 1 or v = 10000000 +order by t1.k, t2.v; +---- +1 1 1 +10000000 10000000 10000000 + +query III rowsort +select * +from t1 +join t2 on t1.k = t2.k +where v = 1 or v = 10000000 +order by t1.k, t2.v; +---- +1 1 1 +10000000 10000000 10000000 + +query III rowsort +select * +from t1 +join t2 on t1.k = t2.k +where v = 1 or v = 10000000 +order by t1.k, t2.v; +---- +1 1 1 +10000000 10000000 10000000 + +query III rowsort +select * +from t1 +join t2 on t1.k = t2.k +where v = 1 or v = 10000000 +order by t1.k, t2.v; +---- +1 1 1 +10000000 10000000 10000000 + +query III rowsort +select * +from t1 +join t2 on t1.k = t2.k +where v = 1 or v = 10000000 +order by t1.k, t2.v; +---- +1 1 1 +10000000 10000000 10000000 + +# Regression test for https://github.com/apache/datafusion/issues/17512 + +query I +COPY ( + SELECT arrow_cast('2025-01-01T00:00:00Z'::timestamptz, 'Timestamp(Microsecond, Some("UTC"))') AS start_timestamp +) +TO 'test_files/scratch/push_down_filter_regression/17512.parquet' +STORED AS PARQUET; +---- +1 + +statement ok +CREATE EXTERNAL TABLE records STORED AS PARQUET LOCATION 'test_files/scratch/push_down_filter_regression/17512.parquet'; + +query I +SELECT 1 +FROM ( + SELECT start_timestamp + FROM records + WHERE start_timestamp <= '2025-01-01T00:00:00Z'::timestamptz +) AS t +WHERE t.start_timestamp::time < '00:00:01'::time; +---- +1 + +# Test aggregate dynamic filter pushdown +# Note: most of the test coverage lives in `datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs` +# , to compare dynamic filter content easier. Here the tests are simple end-to-end +# exercises. + +statement ok +set datafusion.explain.format = 'indent'; + +statement ok +set datafusion.explain.physical_plan_only = true; + +statement ok +set datafusion.execution.target_partitions = 2; + +statement ok +set datafusion.execution.parquet.pushdown_filters = true; + +statement ok +set datafusion.optimizer.enable_dynamic_filter_pushdown = true; + +statement ok +set datafusion.optimizer.enable_aggregate_dynamic_filter_pushdown = true; + +statement ok +create external table agg_dyn_test stored as parquet location '../core/tests/data/test_statistics_per_partition'; + +# Expect dynamic filter available inside data source +query TT +explain select max(id) from agg_dyn_test where id > 1; +---- +physical_plan +01)AggregateExec: mode=Final, gby=[], aggr=[max(agg_dyn_test.id)] +02)--CoalescePartitionsExec +03)----AggregateExec: mode=Partial, gby=[], aggr=[max(agg_dyn_test.id)] +04)------DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-01/j5fUeSDQo22oPyPU.parquet, WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-02/j5fUeSDQo22oPyPU.parquet], [WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-03/j5fUeSDQo22oPyPU.parquet, WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-04/j5fUeSDQo22oPyPU.parquet]]}, projection=[id], file_type=parquet, predicate=id@0 > 1 AND DynamicFilter [ empty ], pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1, required_guarantees=[] + +query I +select max(id) from agg_dyn_test where id > 1; +---- +4 + +# Expect dynamic filter available inside data source +query TT +explain select max(id) from agg_dyn_test where (id+1) > 1; +---- +physical_plan +01)AggregateExec: mode=Final, gby=[], aggr=[max(agg_dyn_test.id)] +02)--CoalescePartitionsExec +03)----AggregateExec: mode=Partial, gby=[], aggr=[max(agg_dyn_test.id)] +04)------DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-01/j5fUeSDQo22oPyPU.parquet, WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-02/j5fUeSDQo22oPyPU.parquet], [WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-03/j5fUeSDQo22oPyPU.parquet, WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-04/j5fUeSDQo22oPyPU.parquet]]}, projection=[id], file_type=parquet, predicate=CAST(id@0 AS Int64) + 1 > 1 AND DynamicFilter [ empty ] + +# Expect dynamic filter available inside data source +query TT +explain select max(id), min(id) from agg_dyn_test where id < 10; +---- +physical_plan +01)AggregateExec: mode=Final, gby=[], aggr=[max(agg_dyn_test.id), min(agg_dyn_test.id)] +02)--CoalescePartitionsExec +03)----AggregateExec: mode=Partial, gby=[], aggr=[max(agg_dyn_test.id), min(agg_dyn_test.id)] +04)------DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-01/j5fUeSDQo22oPyPU.parquet, WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-02/j5fUeSDQo22oPyPU.parquet], [WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-03/j5fUeSDQo22oPyPU.parquet, WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-04/j5fUeSDQo22oPyPU.parquet]]}, projection=[id], file_type=parquet, predicate=id@0 < 10 AND DynamicFilter [ empty ], pruning_predicate=id_null_count@1 != row_count@2 AND id_min@0 < 10, required_guarantees=[] + +# Dynamic filter should not be available for grouping sets +query TT +explain select max(id) from agg_dyn_test where id < 10 +group by grouping sets ((), (id)) +---- +physical_plan +01)ProjectionExec: expr=[max(agg_dyn_test.id)@2 as max(agg_dyn_test.id)] +02)--AggregateExec: mode=FinalPartitioned, gby=[id@0 as id, __grouping_id@1 as __grouping_id], aggr=[max(agg_dyn_test.id)] +03)----RepartitionExec: partitioning=Hash([id@0, __grouping_id@1], 2), input_partitions=2 +04)------AggregateExec: mode=Partial, gby=[(NULL as id), (id@0 as id)], aggr=[max(agg_dyn_test.id)] +05)--------DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-01/j5fUeSDQo22oPyPU.parquet, WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-02/j5fUeSDQo22oPyPU.parquet], [WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-03/j5fUeSDQo22oPyPU.parquet, WORKSPACE_ROOT/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-04/j5fUeSDQo22oPyPU.parquet]]}, projection=[id], file_type=parquet, predicate=id@0 < 10, pruning_predicate=id_null_count@1 != row_count@2 AND id_min@0 < 10, required_guarantees=[] + +statement ok +drop table agg_dyn_test; + +statement ok +drop table t1; + +statement ok +drop table t2; diff --git a/datafusion/sqllogictest/test_files/push_down_filter_unnest.slt b/datafusion/sqllogictest/test_files/push_down_filter_unnest.slt new file mode 100644 index 000000000000..58fe24e2e2cc --- /dev/null +++ b/datafusion/sqllogictest/test_files/push_down_filter_unnest.slt @@ -0,0 +1,148 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Test push down filter + +statement ok +set datafusion.explain.physical_plan_only = true; + +statement ok +CREATE TABLE IF NOT EXISTS v AS VALUES(1,[1,2,3]),(2,[3,4,5]); + +query I +select uc2 from (select unnest(column2) as uc2, column1 from v) where column1 = 2; +---- +3 +4 +5 + +# test push down filter for unnest with filter on non-unnest column +# filter plan is pushed down into projection plan +query TT +explain select uc2 from (select unnest(column2) as uc2, column1 from v) where column1 = 2; +---- +physical_plan +01)ProjectionExec: expr=[__unnest_placeholder(v.column2,depth=1)@0 as uc2] +02)--UnnestExec +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +04)------ProjectionExec: expr=[column2@0 as __unnest_placeholder(v.column2)] +05)--------FilterExec: column1@0 = 2, projection=[column2@1] +06)----------DataSourceExec: partitions=1, partition_sizes=[1] + +query I +select uc2 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3; +---- +4 +5 + +# test push down filter for unnest with filter on unnest column +query TT +explain select uc2 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3; +---- +physical_plan +01)ProjectionExec: expr=[__unnest_placeholder(v.column2,depth=1)@0 as uc2] +02)--FilterExec: __unnest_placeholder(v.column2,depth=1)@0 > 3 +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +04)------UnnestExec +05)--------ProjectionExec: expr=[column2@0 as __unnest_placeholder(v.column2)] +06)----------DataSourceExec: partitions=1, partition_sizes=[1] + +query II +select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 AND column1 = 2; +---- +4 2 +5 2 + +# Could push the filter (column1 = 2) down below unnest +query TT +explain select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 AND column1 = 2; +---- +physical_plan +01)ProjectionExec: expr=[__unnest_placeholder(v.column2,depth=1)@0 as uc2, column1@1 as column1] +02)--FilterExec: __unnest_placeholder(v.column2,depth=1)@0 > 3 +03)----UnnestExec +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------ProjectionExec: expr=[column2@1 as __unnest_placeholder(v.column2), column1@0 as column1] +06)----------FilterExec: column1@0 = 2 +07)------------DataSourceExec: partitions=1, partition_sizes=[1] + +query II +select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 OR column1 = 2; +---- +3 2 +4 2 +5 2 + +# only non-unnest filter in AND clause could be pushed down +query TT +explain select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 OR column1 = 2; +---- +physical_plan +01)ProjectionExec: expr=[__unnest_placeholder(v.column2,depth=1)@0 as uc2, column1@1 as column1] +02)--FilterExec: __unnest_placeholder(v.column2,depth=1)@0 > 3 OR column1@1 = 2 +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +04)------UnnestExec +05)--------ProjectionExec: expr=[column2@1 as __unnest_placeholder(v.column2), column1@0 as column1] +06)----------DataSourceExec: partitions=1, partition_sizes=[1] + +statement ok +drop table v; + +# test with unnest struct, should not push down filter +statement ok +CREATE TABLE d AS VALUES(1,[named_struct('a', 1, 'b', 2)]),(2,[named_struct('a', 3, 'b', 4), named_struct('a', 5, 'b', 6)]); + +query I? +select * from (select column1, unnest(column2) as o from d) where o['a'] = 1; +---- +1 {a: 1, b: 2} + +query TT +explain select * from (select column1, unnest(column2) as o from d) where o['a'] = 1; +---- +physical_plan +01)ProjectionExec: expr=[column1@0 as column1, __unnest_placeholder(d.column2,depth=1)@1 as o] +02)--FilterExec: __datafusion_extracted_1@0 = 1, projection=[column1@1, __unnest_placeholder(d.column2,depth=1)@2] +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +04)------ProjectionExec: expr=[get_field(__unnest_placeholder(d.column2,depth=1)@1, a) as __datafusion_extracted_1, column1@0 as column1, __unnest_placeholder(d.column2,depth=1)@1 as __unnest_placeholder(d.column2,depth=1)] +05)--------UnnestExec +06)----------ProjectionExec: expr=[column1@0 as column1, column2@1 as __unnest_placeholder(d.column2)] +07)------------DataSourceExec: partitions=1, partition_sizes=[1] + +statement ok +drop table d; + +statement ok +CREATE TABLE d AS VALUES (named_struct('a', 1, 'b', 2)), (named_struct('a', 3, 'b', 4)), (named_struct('a', 5, 'b', 6)); + +query II +select * from (select unnest(column1) from d) where "__unnest_placeholder(d.column1).b" > 5; +---- +5 6 + +query TT +explain select * from (select unnest(column1) from d) where "__unnest_placeholder(d.column1).b" > 5; +---- +physical_plan +01)FilterExec: __unnest_placeholder(d.column1).b@1 > 5 +02)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +03)----UnnestExec +04)------ProjectionExec: expr=[column1@0 as __unnest_placeholder(d.column1)] +05)--------DataSourceExec: partitions=1, partition_sizes=[1] + +statement ok +drop table d; diff --git a/datafusion/sqllogictest/test_files/regexp/regexp_like.slt b/datafusion/sqllogictest/test_files/regexp/regexp_like.slt index 6f2d5a873c1b..2b304c8de1a3 100644 --- a/datafusion/sqllogictest/test_files/regexp/regexp_like.slt +++ b/datafusion/sqllogictest/test_files/regexp/regexp_like.slt @@ -334,5 +334,10 @@ true true false false false false +query TT +select * from regexp_test where regexp_like('f', regexp_replace((('v\r') like ('f_*sP6H1*')), '339629555', '-1459539013')); +---- + + statement ok drop table if exists dict_table; diff --git a/datafusion/sqllogictest/test_files/run_end_encoded.slt b/datafusion/sqllogictest/test_files/run_end_encoded.slt new file mode 100644 index 000000000000..1f0a9b4eb3fd --- /dev/null +++ b/datafusion/sqllogictest/test_files/run_end_encoded.slt @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tests for Run-End Encoded (REE) array support in aggregations +# This tests that REE arrays can be used as GROUP BY keys (requires proper hashing support) + +# Create a table with REE-encoded sensor IDs using arrow_cast +# First create primitive arrays, then cast to REE in a second step +statement ok +CREATE TABLE sensor_readings AS +WITH raw_data AS ( + SELECT * FROM ( + VALUES + ('sensor_A', 22), + ('sensor_A', 23), + ('sensor_B', 20), + ('sensor_A', 24) + ) AS t(sensor_id, temperature) +) +SELECT + arrow_cast(sensor_id, 'RunEndEncoded("run_ends": non-null Int32, "values": Utf8)') AS sensor_id, + temperature +FROM raw_data; + +# Test basic aggregation with REE column as GROUP BY key +query ?RI rowsort +SELECT + sensor_id, + AVG(temperature) AS avg_temp, + COUNT(*) AS reading_count +FROM sensor_readings +GROUP BY sensor_id; +---- +sensor_A 23 3 +sensor_B 20 1 + +# Test DISTINCT with REE column +query ? rowsort +SELECT DISTINCT sensor_id +FROM sensor_readings; +---- +sensor_A +sensor_B diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 7be7de5a4def..681540a29d37 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -765,11 +765,11 @@ select nanvl(null, 64); ---- NULL -# nanvl scalar nulls #1 +# nanvl scalar nulls #1 - x is not NaN, so return x even if y is NULL query R rowsort select nanvl(2, null); ---- -NULL +2 # nanvl scalar nulls #2 query R rowsort @@ -923,7 +923,7 @@ select round(a), round(b), round(c) from small_floats; # round with too large # max Int32 is 2147483647 -query error Arrow error: Cast error: Can't cast value 2147483648 to type Int32 +query error round decimal_places 2147483648 is out of supported i32 range select round(3.14, 2147483648); # with array @@ -931,11 +931,12 @@ query error Arrow error: Cast error: Can't cast value 2147483649 to type Int32 select round(column1, column2) from values (3.14, 2), (3.14, 3), (3.14, 2147483649); # round decimal should not cast to float +# scale reduces to match decimal_places query TR select arrow_typeof(round('173975140545.855'::decimal(38,10), 2)), round('173975140545.855'::decimal(38,10), 2); ---- -Decimal128(38, 10) 173975140545.86 +Decimal128(38, 2) 173975140545.86 # round decimal ties away from zero query RRRR @@ -951,15 +952,74 @@ query TR select arrow_typeof(round('12345.55'::decimal(10,2), -1)), round('12345.55'::decimal(10,2), -1); ---- -Decimal128(10, 2) 12350 +Decimal128(10, 0) 12350 + +# round decimal scale 0 negative places (carry can require extra precision) +query TR +select arrow_typeof(round('99'::decimal(2,0), -1)), + round('99'::decimal(2,0), -1); +---- +Decimal128(3, 0) 100 # round decimal256 keeps decimals query TR select arrow_typeof(round('1234.5678'::decimal(50,4), 2)), round('1234.5678'::decimal(50,4), 2); ---- -Decimal256(50, 4) 1234.57 +Decimal256(50, 2) 1234.57 + +# round decimal with carry-over (reduce scale) +# Scale reduces from 1 to 0, allowing extra digit for carry-over +query TRRR +select arrow_typeof(round('999.9'::decimal(4,1))), + round('999.9'::decimal(4,1)), + round('-999.9'::decimal(4,1)), + round('99.99'::decimal(4,2)); +---- +Decimal128(4, 0) 1000 -1000 100 + +# round decimal with carry-over and non-literal decimal_places (increase precision) +# Scale can't be reduced when decimal_places isn't a constant, so precision increases. +query TR +select arrow_typeof(round(val, dp)), round(val, dp) +from (values (cast('999.9' as decimal(4,1)), 0)) as t(val, dp); +---- +Decimal128(5, 1) 1000 + +# round decimal at max precision now works (scale reduction handles overflow) +query TR +select arrow_typeof(round('9999999999999999999999999999999999999.9'::decimal(38,1))), + round('9999999999999999999999999999999999999.9'::decimal(38,1)); +---- +Decimal128(38, 0) 10000000000000000000000000000000000000 + +# round decimal at max precision with non-literal decimal_places can overflow +query error Decimal overflow: rounded value exceeds precision 38 +select round(val, dp) +from (values (cast('9999999999999999999999999999999999999.9' as decimal(38,1)), 0)) as t(val, dp); + +# round decimal with negative scale +query TRRR +select arrow_typeof(round(cast(500 as decimal(10,-2)), -3)), + round(cast(500 as decimal(10,-2)), -3), + round(cast(400 as decimal(10,-2)), -3), + round(cast(-500 as decimal(10,-2)), -3); +---- +Decimal128(10, -3) 1000 0 -1000 + +# round decimal with negative scale and carry-over +query TR +select arrow_typeof(round(cast(999999999900 as decimal(10,-2)), -3)), + round(cast(999999999900 as decimal(10,-2)), -3); +---- +Decimal128(10, -3) 1000000000000 +# round decimal with very small decimal_places (i32::MIN) should not error +query TR +select arrow_typeof(round('123.45'::decimal(5,2), -2147483648)), + round('123.45'::decimal(5,2), -2147483648); +---- +Decimal128(5, 0) 0 ## signum @@ -1165,7 +1225,7 @@ from small_floats; ---- 0.447 0.4 0.447 0.707 0.7 0.707 -0.837 0.8 0.837 +0.836 0.8 0.836 1 1 1 ## bitwise and @@ -1311,6 +1371,14 @@ select a << b, c << d, e << f from signed_integers; 33554432 123 10485760 NULL NULL NULL +## bitwise operations should reject non-integer types + +query error DataFusion error: Error during planning: Cannot infer common type for bitwise operation Float32 & Float32 +select arrow_cast(1, 'Float32') & arrow_cast(2, 'Float32'); + +query error DataFusion error: Error during planning: Cannot infer common type for bitwise operation Date32 & Date32 +select arrow_cast(1, 'Date32') & arrow_cast(2, 'Date32'); + statement ok drop table unsigned_integers; @@ -1993,10 +2061,10 @@ query TT EXPLAIN SELECT letter, letter = LEFT(letter2, 1) FROM simple_string; ---- logical_plan -01)Projection: simple_string.letter, simple_string.letter = CAST(left(simple_string.letter2, Int64(1)) AS Utf8View) +01)Projection: simple_string.letter, simple_string.letter = left(simple_string.letter2, Int64(1)) 02)--TableScan: simple_string projection=[letter, letter2] physical_plan -01)ProjectionExec: expr=[letter@0 as letter, letter@0 = CAST(left(letter2@1, 1) AS Utf8View) as simple_string.letter = left(simple_string.letter2,Int64(1))] +01)ProjectionExec: expr=[letter@0 as letter, letter@0 = left(letter2@1, 1) as simple_string.letter = left(simple_string.letter2,Int64(1))] 02)--DataSourceExec: partitions=1, partition_sizes=[1] query TB @@ -2010,8 +2078,8 @@ D false # test string_temporal_coercion query BBBBBBBBBB select - arrow_cast(to_timestamp('2020-01-01 01:01:11.1234567890Z'), 'Timestamp(Second, None)') == '2020-01-01T01:01:11', - arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Second, None)') == arrow_cast('2020-01-02T01:01:11', 'LargeUtf8'), + arrow_cast(to_timestamp('2020-01-01 01:01:11.1234567890Z'), 'Timestamp(s)') == '2020-01-01T01:01:11', + arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(s)') == arrow_cast('2020-01-02T01:01:11', 'LargeUtf8'), arrow_cast(to_timestamp('2020-01-03 01:01:11.1234567890Z'), 'Time32(Second)') == '01:01:11', arrow_cast(to_timestamp('2020-01-04 01:01:11.1234567890Z'), 'Time32(Second)') == arrow_cast('01:01:11', 'LargeUtf8'), arrow_cast(to_timestamp('2020-01-05 01:01:11.1234567890Z'), 'Time64(Microsecond)') == '01:01:11.123456', @@ -2069,7 +2137,7 @@ select position('' in '') ---- 1 -query error DataFusion error: Error during planning: Internal error: Expect TypeSignatureClass::Native\(LogicalType\(Native\(String\), String\)\) but received NativeType::Int64, DataType: Int64 +query error DataFusion error: Error during planning: Function 'strpos' requires TypeSignatureClass::Native\(LogicalType\(Native\(String\), String\)\), but received Int64 \(DataType: Int64\) select position(1 in 1) query I diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index 490df4b72d17..553ccb74dedb 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -820,7 +820,7 @@ SELECT ALL c1 FROM aggregate_simple order by c1 0.00005 0.00005 -# select distinct +# SELECT DISTINCT query RRB rowsort SELECT DISTINCT * FROM aggregate_simple ---- @@ -830,6 +830,31 @@ SELECT DISTINCT * FROM aggregate_simple 0.00004 0.000000000004 false 0.00005 0.000000000005 true +# select ALL (inverse of distinct) +query RRB rowsort +SELECT ALL * FROM aggregate_simple; +---- +0.00001 0.000000000001 true +0.00002 0.000000000002 false +0.00002 0.000000000002 false +0.00003 0.000000000003 true +0.00003 0.000000000003 true +0.00003 0.000000000003 true +0.00004 0.000000000004 false +0.00004 0.000000000004 false +0.00004 0.000000000004 false +0.00004 0.000000000004 false +0.00005 0.000000000005 true +0.00005 0.000000000005 true +0.00005 0.000000000005 true +0.00005 0.000000000005 true +0.00005 0.000000000005 true + + +# select distinct all ( +query error DataFusion error: SQL error: ParserError\("Cannot specify DISTINCT then ALL at Line: 1, Column: 8"\) +SELECT DISTINCT ALL * FROM aggregate_simple + # select distinct with projection and order by query R SELECT DISTINCT c1 FROM aggregate_simple order by c1 @@ -1926,3 +1951,12 @@ select "current_time" is not null from t_with_current_time; true false true + +# https://github.com/apache/datafusion/issues/20215 +statement count 0 +CREATE TABLE t0; + +query I +SELECT COUNT(*) FROM t0 AS tt0 WHERE (4==(3/0)); +---- +0 diff --git a/datafusion/sqllogictest/test_files/set_variable.slt b/datafusion/sqllogictest/test_files/set_variable.slt index c444128b18f4..7be353f0573c 100644 --- a/datafusion/sqllogictest/test_files/set_variable.slt +++ b/datafusion/sqllogictest/test_files/set_variable.slt @@ -447,3 +447,6 @@ datafusion.runtime.max_temp_directory_size datafusion.runtime.memory_limit datafusion.runtime.metadata_cache_limit datafusion.runtime.temp_directory + +statement error DataFusion error: Error during planning: Unsupported value Null +SET datafusion.runtime.memory_limit = NULL diff --git a/datafusion/sqllogictest/test_files/simplify_expr.slt b/datafusion/sqllogictest/test_files/simplify_expr.slt index d8c25ab25e8e..99fc9900ef61 100644 --- a/datafusion/sqllogictest/test_files/simplify_expr.slt +++ b/datafusion/sqllogictest/test_files/simplify_expr.slt @@ -113,3 +113,21 @@ logical_plan physical_plan 01)ProjectionExec: expr=[[{x:100}] as a] 02)--PlaceholderRowExec + +# Simplify expr = L1 AND expr != L2 to expr = L1 when L1 != L2 +query TT +EXPLAIN SELECT + v = 1 AND v != 0 as opt1, + v = 2 AND v != 2 as noopt1, + v != 3 AND v = 4 as opt2, + v != 5 AND v = 5 as noopt2 +FROM (VALUES (0), (1), (2)) t(v) +---- +logical_plan +01)Projection: t.v = Int64(1) AS opt1, t.v = Int64(2) AND t.v != Int64(2) AS noopt1, t.v = Int64(4) AS opt2, t.v != Int64(5) AND t.v = Int64(5) AS noopt2 +02)--SubqueryAlias: t +03)----Projection: column1 AS v +04)------Values: (Int64(0)), (Int64(1)), (Int64(2)) +physical_plan +01)ProjectionExec: expr=[column1@0 = 1 as opt1, column1@0 = 2 AND column1@0 != 2 as noopt1, column1@0 = 4 as opt2, column1@0 != 5 AND column1@0 = 5 as noopt2] +02)--DataSourceExec: partitions=1, partition_sizes=[1] diff --git a/datafusion/sqllogictest/test_files/sort_pushdown.slt b/datafusion/sqllogictest/test_files/sort_pushdown.slt index 58d9915a24be..99f26b66d458 100644 --- a/datafusion/sqllogictest/test_files/sort_pushdown.slt +++ b/datafusion/sqllogictest/test_files/sort_pushdown.slt @@ -851,7 +851,749 @@ LIMIT 3; 5 4 2 -3 +# Test 3.7: Aggregate ORDER BY expression should keep SortExec +# Source pattern declared on parquet scan: [x ASC, y ASC]. +# Requested pattern in ORDER BY: [x ASC, CAST(y AS BIGINT) % 2 ASC]. +# Example for x=1 input y order 1,2,3 gives bucket order 1,0,1, which does not +# match requested bucket ASC order. SortExec is required above AggregateExec. +statement ok +SET datafusion.execution.target_partitions = 1; + +statement ok +CREATE TABLE agg_expr_data(x INT, y INT, v INT) AS VALUES +(1, 1, 10), +(1, 2, 20), +(1, 3, 30), +(2, 1, 40), +(2, 2, 50), +(2, 3, 60); + +query I +COPY (SELECT * FROM agg_expr_data ORDER BY x, y) +TO 'test_files/scratch/sort_pushdown/agg_expr_sorted.parquet'; +---- +6 + +statement ok +CREATE EXTERNAL TABLE agg_expr_parquet(x INT, y INT, v INT) +STORED AS PARQUET +LOCATION 'test_files/scratch/sort_pushdown/agg_expr_sorted.parquet' +WITH ORDER (x ASC, y ASC); + +query TT +EXPLAIN SELECT + x, + CAST(y AS BIGINT) % 2, + SUM(v) +FROM agg_expr_parquet +GROUP BY x, CAST(y AS BIGINT) % 2 +ORDER BY x, CAST(y AS BIGINT) % 2; +---- +logical_plan +01)Sort: agg_expr_parquet.x ASC NULLS LAST, agg_expr_parquet.y % Int64(2) ASC NULLS LAST +02)--Aggregate: groupBy=[[agg_expr_parquet.x, CAST(agg_expr_parquet.y AS Int64) % Int64(2)]], aggr=[[sum(CAST(agg_expr_parquet.v AS Int64))]] +03)----TableScan: agg_expr_parquet projection=[x, y, v] +physical_plan +01)SortExec: expr=[x@0 ASC NULLS LAST, agg_expr_parquet.y % Int64(2)@1 ASC NULLS LAST], preserve_partitioning=[false] +02)--AggregateExec: mode=Single, gby=[x@0 as x, CAST(y@1 AS Int64) % 2 as agg_expr_parquet.y % Int64(2)], aggr=[sum(agg_expr_parquet.v)], ordering_mode=PartiallySorted([0]) +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/agg_expr_sorted.parquet]]}, projection=[x, y, v], output_ordering=[x@0 ASC NULLS LAST, y@1 ASC NULLS LAST], file_type=parquet + +# Expected output pattern from ORDER BY [x, bucket]: +# rows grouped by x, and within each x bucket appears as 0 then 1. +query III +SELECT + x, + CAST(y AS BIGINT) % 2, + SUM(v) +FROM agg_expr_parquet +GROUP BY x, CAST(y AS BIGINT) % 2 +ORDER BY x, CAST(y AS BIGINT) % 2; +---- +1 0 20 +1 1 40 +2 0 50 +2 1 100 + +# Test 3.8: Aggregate ORDER BY monotonic expression can push down (no SortExec) +query TT +EXPLAIN SELECT + x, + CAST(y AS BIGINT), + SUM(v) +FROM agg_expr_parquet +GROUP BY x, CAST(y AS BIGINT) +ORDER BY x, CAST(y AS BIGINT); +---- +logical_plan +01)Sort: agg_expr_parquet.x ASC NULLS LAST, agg_expr_parquet.y ASC NULLS LAST +02)--Aggregate: groupBy=[[agg_expr_parquet.x, CAST(agg_expr_parquet.y AS Int64)]], aggr=[[sum(CAST(agg_expr_parquet.v AS Int64))]] +03)----TableScan: agg_expr_parquet projection=[x, y, v] +physical_plan +01)AggregateExec: mode=Single, gby=[x@0 as x, CAST(y@1 AS Int64) as agg_expr_parquet.y], aggr=[sum(agg_expr_parquet.v)], ordering_mode=Sorted +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/agg_expr_sorted.parquet]]}, projection=[x, y, v], output_ordering=[x@0 ASC NULLS LAST, y@1 ASC NULLS LAST], file_type=parquet + +query III +SELECT + x, + CAST(y AS BIGINT), + SUM(v) +FROM agg_expr_parquet +GROUP BY x, CAST(y AS BIGINT) +ORDER BY x, CAST(y AS BIGINT); +---- +1 1 10 +1 2 20 +1 3 30 +2 1 40 +2 2 50 +2 3 60 + +# Test 3.9: Aggregate ORDER BY aggregate output should keep SortExec +query TT +EXPLAIN SELECT x, SUM(v) +FROM agg_expr_parquet +GROUP BY x +ORDER BY SUM(v); +---- +logical_plan +01)Sort: sum(agg_expr_parquet.v) ASC NULLS LAST +02)--Aggregate: groupBy=[[agg_expr_parquet.x]], aggr=[[sum(CAST(agg_expr_parquet.v AS Int64))]] +03)----TableScan: agg_expr_parquet projection=[x, v] +physical_plan +01)SortExec: expr=[sum(agg_expr_parquet.v)@1 ASC NULLS LAST], preserve_partitioning=[false] +02)--AggregateExec: mode=Single, gby=[x@0 as x], aggr=[sum(agg_expr_parquet.v)], ordering_mode=Sorted +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/agg_expr_sorted.parquet]]}, projection=[x, v], output_ordering=[x@0 ASC NULLS LAST], file_type=parquet + +query II +SELECT x, SUM(v) +FROM agg_expr_parquet +GROUP BY x +ORDER BY SUM(v); +---- +1 60 +2 150 + +# Test 3.10: Aggregate with non-preserved input order should keep SortExec +# v is not part of the order by +query TT +EXPLAIN SELECT v, SUM(y) +FROM agg_expr_parquet +GROUP BY v +ORDER BY v; +---- +logical_plan +01)Sort: agg_expr_parquet.v ASC NULLS LAST +02)--Aggregate: groupBy=[[agg_expr_parquet.v]], aggr=[[sum(CAST(agg_expr_parquet.y AS Int64))]] +03)----TableScan: agg_expr_parquet projection=[y, v] +physical_plan +01)SortExec: expr=[v@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--AggregateExec: mode=Single, gby=[v@1 as v], aggr=[sum(agg_expr_parquet.y)] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/agg_expr_sorted.parquet]]}, projection=[y, v], file_type=parquet + +query II +SELECT v, SUM(y) +FROM agg_expr_parquet +GROUP BY v +ORDER BY v; +---- +10 1 +20 2 +30 3 +40 1 +50 2 +60 3 + +# Test 3.11: Aggregate ORDER BY non-column expression (unsatisfied) keeps SortExec +# (though note in theory DataFusion could figure out that data sorted by x will also be sorted by x+1) +query TT +EXPLAIN SELECT x, SUM(v) +FROM agg_expr_parquet +GROUP BY x +ORDER BY x + 1 DESC; +---- +logical_plan +01)Sort: CAST(agg_expr_parquet.x AS Int64) + Int64(1) DESC NULLS FIRST +02)--Aggregate: groupBy=[[agg_expr_parquet.x]], aggr=[[sum(CAST(agg_expr_parquet.v AS Int64))]] +03)----TableScan: agg_expr_parquet projection=[x, v] +physical_plan +01)SortExec: expr=[CAST(x@0 AS Int64) + 1 DESC], preserve_partitioning=[false] +02)--AggregateExec: mode=Single, gby=[x@0 as x], aggr=[sum(agg_expr_parquet.v)], ordering_mode=Sorted +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/agg_expr_sorted.parquet]]}, projection=[x, v], output_ordering=[x@0 ASC NULLS LAST], file_type=parquet + +query II +SELECT x, SUM(v) +FROM agg_expr_parquet +GROUP BY x +ORDER BY x + 1 DESC; +---- +2 150 +1 60 + +# Test 3.12: Aggregate ORDER BY non-column expression (unsatisfied) keeps SortExec +# (though note in theory DataFusion could figure out that data sorted by x will also be sorted by x+1) +query TT +EXPLAIN SELECT x, SUM(v) +FROM agg_expr_parquet +GROUP BY x +ORDER BY 2 * x ASC; +---- +logical_plan +01)Sort: Int64(2) * CAST(agg_expr_parquet.x AS Int64) ASC NULLS LAST +02)--Aggregate: groupBy=[[agg_expr_parquet.x]], aggr=[[sum(CAST(agg_expr_parquet.v AS Int64))]] +03)----TableScan: agg_expr_parquet projection=[x, v] +physical_plan +01)SortExec: expr=[2 * CAST(x@0 AS Int64) ASC NULLS LAST], preserve_partitioning=[false] +02)--AggregateExec: mode=Single, gby=[x@0 as x], aggr=[sum(agg_expr_parquet.v)], ordering_mode=Sorted +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/agg_expr_sorted.parquet]]}, projection=[x, v], output_ordering=[x@0 ASC NULLS LAST], file_type=parquet + +query II +SELECT x, SUM(v) +FROM agg_expr_parquet +GROUP BY x +ORDER BY 2 * x ASC; +---- +1 60 +2 150 + +# Test 4: Reversed filesystem order with inferred ordering +# Create 3 parquet files with non-overlapping id ranges, named so filesystem +# order is OPPOSITE to data order. Each file is internally sorted by id ASC. +# Force target_partitions=1 so all files end up in one file group, which is +# where the inter-file ordering bug manifests. +# Without inter-file validation, the optimizer would incorrectly trust the +# inferred ordering and remove SortExec. + +# Save current target_partitions and set to 1 to force single file group +statement ok +SET datafusion.execution.target_partitions = 1; + +statement ok +CREATE TABLE reversed_high(id INT, value INT) AS VALUES (7, 700), (8, 800), (9, 900); + +statement ok +CREATE TABLE reversed_mid(id INT, value INT) AS VALUES (4, 400), (5, 500), (6, 600); + +statement ok +CREATE TABLE reversed_low(id INT, value INT) AS VALUES (1, 100), (2, 200), (3, 300); + +query I +COPY (SELECT * FROM reversed_high ORDER BY id ASC) +TO 'test_files/scratch/sort_pushdown/reversed/a_high.parquet'; +---- +3 + +query I +COPY (SELECT * FROM reversed_mid ORDER BY id ASC) +TO 'test_files/scratch/sort_pushdown/reversed/b_mid.parquet'; +---- +3 + +query I +COPY (SELECT * FROM reversed_low ORDER BY id ASC) +TO 'test_files/scratch/sort_pushdown/reversed/c_low.parquet'; +---- +3 + +# External table with NO "WITH ORDER" — relies on inferred ordering from parquet metadata +statement ok +CREATE EXTERNAL TABLE reversed_parquet(id INT, value INT) +STORED AS PARQUET +LOCATION 'test_files/scratch/sort_pushdown/reversed/'; + +# Test 4.1: SortExec must be present because files are not in inter-file order +query TT +EXPLAIN SELECT * FROM reversed_parquet ORDER BY id ASC; +---- +logical_plan +01)Sort: reversed_parquet.id ASC NULLS LAST +02)--TableScan: reversed_parquet projection=[id, value] +physical_plan +01)SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/reversed/a_high.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/reversed/b_mid.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/reversed/c_low.parquet]]}, projection=[id, value], file_type=parquet + +# Test 4.2: Results must be correct +query II +SELECT * FROM reversed_parquet ORDER BY id ASC; +---- +1 100 +2 200 +3 300 +4 400 +5 500 +6 600 +7 700 +8 800 +9 900 + +# Test 5: Overlapping files with inferred ordering +# Create files with overlapping id ranges + +statement ok +CREATE TABLE overlap_x(id INT, value INT) AS VALUES (1, 100), (3, 300), (5, 500); + +statement ok +CREATE TABLE overlap_y(id INT, value INT) AS VALUES (2, 200), (4, 400), (6, 600); + +query I +COPY (SELECT * FROM overlap_x ORDER BY id ASC) +TO 'test_files/scratch/sort_pushdown/overlap/file_x.parquet'; +---- +3 + +query I +COPY (SELECT * FROM overlap_y ORDER BY id ASC) +TO 'test_files/scratch/sort_pushdown/overlap/file_y.parquet'; +---- +3 + +statement ok +CREATE EXTERNAL TABLE overlap_parquet(id INT, value INT) +STORED AS PARQUET +LOCATION 'test_files/scratch/sort_pushdown/overlap/'; + +# Test 5.1: SortExec must be present because files have overlapping ranges +query TT +EXPLAIN SELECT * FROM overlap_parquet ORDER BY id ASC; +---- +logical_plan +01)Sort: overlap_parquet.id ASC NULLS LAST +02)--TableScan: overlap_parquet projection=[id, value] +physical_plan +01)SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/overlap/file_x.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/overlap/file_y.parquet]]}, projection=[id, value], file_type=parquet + +# Test 5.2: Results must be correct +query II +SELECT * FROM overlap_parquet ORDER BY id ASC; +---- +1 100 +2 200 +3 300 +4 400 +5 500 +6 600 + +# Test 6: WITH ORDER + reversed filesystem order +# Same file setup as Test 4 but explicitly declaring ordering via WITH ORDER. +# Even with WITH ORDER, the optimizer should detect that inter-file order is wrong +# and keep SortExec. + +statement ok +CREATE EXTERNAL TABLE reversed_with_order_parquet(id INT, value INT) +STORED AS PARQUET +LOCATION 'test_files/scratch/sort_pushdown/reversed/' +WITH ORDER (id ASC); + +# Test 6.1: SortExec must be present despite WITH ORDER +query TT +EXPLAIN SELECT * FROM reversed_with_order_parquet ORDER BY id ASC; +---- +logical_plan +01)Sort: reversed_with_order_parquet.id ASC NULLS LAST +02)--TableScan: reversed_with_order_parquet projection=[id, value] +physical_plan +01)SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/reversed/a_high.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/reversed/b_mid.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/reversed/c_low.parquet]]}, projection=[id, value], file_type=parquet + +# Test 6.2: Results must be correct +query II +SELECT * FROM reversed_with_order_parquet ORDER BY id ASC; +---- +1 100 +2 200 +3 300 +4 400 +5 500 +6 600 +7 700 +8 800 +9 900 + +# Test 7: Correctly ordered multi-file single group (positive case) +# Files are in CORRECT inter-file order within a single group. +# The validation should PASS and SortExec should be eliminated. + +statement ok +CREATE TABLE correct_low(id INT, value INT) AS VALUES (1, 100), (2, 200), (3, 300); + +statement ok +CREATE TABLE correct_mid(id INT, value INT) AS VALUES (4, 400), (5, 500), (6, 600); + +statement ok +CREATE TABLE correct_high(id INT, value INT) AS VALUES (7, 700), (8, 800), (9, 900); + +query I +COPY (SELECT * FROM correct_low ORDER BY id ASC) +TO 'test_files/scratch/sort_pushdown/correct/a_low.parquet'; +---- +3 + +query I +COPY (SELECT * FROM correct_mid ORDER BY id ASC) +TO 'test_files/scratch/sort_pushdown/correct/b_mid.parquet'; +---- +3 + +query I +COPY (SELECT * FROM correct_high ORDER BY id ASC) +TO 'test_files/scratch/sort_pushdown/correct/c_high.parquet'; +---- +3 + +statement ok +CREATE EXTERNAL TABLE correct_parquet(id INT, value INT) +STORED AS PARQUET +LOCATION 'test_files/scratch/sort_pushdown/correct/' +WITH ORDER (id ASC); + +# Test 7.1: SortExec should be ELIMINATED — files are in correct inter-file order +query TT +EXPLAIN SELECT * FROM correct_parquet ORDER BY id ASC; +---- +logical_plan +01)Sort: correct_parquet.id ASC NULLS LAST +02)--TableScan: correct_parquet projection=[id, value] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/correct/a_low.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/correct/b_mid.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/correct/c_high.parquet]]}, projection=[id, value], output_ordering=[id@0 ASC NULLS LAST], file_type=parquet + +# Test 7.2: Results must be correct +query II +SELECT * FROM correct_parquet ORDER BY id ASC; +---- +1 100 +2 200 +3 300 +4 400 +5 500 +6 600 +7 700 +8 800 +9 900 + +# Test 7.3: DESC query on correctly ordered ASC files should still use SortExec +# Note: reverse_row_groups=true reverses the file list in the plan display +query TT +EXPLAIN SELECT * FROM correct_parquet ORDER BY id DESC; +---- +logical_plan +01)Sort: correct_parquet.id DESC NULLS FIRST +02)--TableScan: correct_parquet projection=[id, value] +physical_plan +01)SortExec: expr=[id@0 DESC], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/correct/c_high.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/correct/b_mid.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/correct/a_low.parquet]]}, projection=[id, value], file_type=parquet, reverse_row_groups=true + +query II +SELECT * FROM correct_parquet ORDER BY id DESC; +---- +9 900 +8 800 +7 700 +6 600 +5 500 +4 400 +3 300 +2 200 +1 100 + +# Test 8: DESC ordering with files in wrong inter-file DESC order +# Create files internally sorted by id DESC, but named so filesystem order +# is WRONG for DESC ordering (low values first in filesystem order). + +statement ok +CREATE TABLE desc_low(id INT, value INT) AS VALUES (3, 300), (2, 200), (1, 100); + +statement ok +CREATE TABLE desc_high(id INT, value INT) AS VALUES (9, 900), (8, 800), (7, 700); + +query I +COPY (SELECT * FROM desc_low ORDER BY id DESC) +TO 'test_files/scratch/sort_pushdown/desc_reversed/a_low.parquet'; +---- +3 + +query I +COPY (SELECT * FROM desc_high ORDER BY id DESC) +TO 'test_files/scratch/sort_pushdown/desc_reversed/b_high.parquet'; +---- +3 + +statement ok +CREATE EXTERNAL TABLE desc_reversed_parquet(id INT, value INT) +STORED AS PARQUET +LOCATION 'test_files/scratch/sort_pushdown/desc_reversed/' +WITH ORDER (id DESC); + +# Test 8.1: SortExec must be present — files are in wrong inter-file DESC order +# (a_low has 1-3, b_high has 7-9; for DESC, b_high should come first) +query TT +EXPLAIN SELECT * FROM desc_reversed_parquet ORDER BY id DESC; +---- +logical_plan +01)Sort: desc_reversed_parquet.id DESC NULLS FIRST +02)--TableScan: desc_reversed_parquet projection=[id, value] +physical_plan +01)SortExec: expr=[id@0 DESC], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/desc_reversed/a_low.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/desc_reversed/b_high.parquet]]}, projection=[id, value], file_type=parquet + +# Test 8.2: Results must be correct +query II +SELECT * FROM desc_reversed_parquet ORDER BY id DESC; +---- +9 900 +8 800 +7 700 +3 300 +2 200 +1 100 + +# Test 9: Multi-column sort key validation +# Files have (category, id) ordering. Files share a boundary value on category='B' +# so column-level min/max statistics overlap on the primary key column. +# The validation conservatively rejects this because column-level stats can't +# precisely represent row-level boundaries for multi-column keys. + +statement ok +CREATE TABLE multi_col_a(category VARCHAR, id INT, value INT) AS VALUES +('A', 1, 10), ('A', 2, 20), ('B', 1, 30); + +statement ok +CREATE TABLE multi_col_b(category VARCHAR, id INT, value INT) AS VALUES +('B', 2, 40), ('C', 1, 50), ('C', 2, 60); + +query I +COPY (SELECT * FROM multi_col_a ORDER BY category ASC, id ASC) +TO 'test_files/scratch/sort_pushdown/multi_col/a_first.parquet'; +---- +3 + +query I +COPY (SELECT * FROM multi_col_b ORDER BY category ASC, id ASC) +TO 'test_files/scratch/sort_pushdown/multi_col/b_second.parquet'; +---- +3 + +statement ok +CREATE EXTERNAL TABLE multi_col_parquet(category VARCHAR, id INT, value INT) +STORED AS PARQUET +LOCATION 'test_files/scratch/sort_pushdown/multi_col/' +WITH ORDER (category ASC, id ASC); + +# Test 9.1: SortExec is present — validation conservatively rejects because +# column-level stats overlap on category='B' across both files +query TT +EXPLAIN SELECT * FROM multi_col_parquet ORDER BY category ASC, id ASC; +---- +logical_plan +01)Sort: multi_col_parquet.category ASC NULLS LAST, multi_col_parquet.id ASC NULLS LAST +02)--TableScan: multi_col_parquet projection=[category, id, value] +physical_plan +01)SortExec: expr=[category@0 ASC NULLS LAST, id@1 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/multi_col/a_first.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/multi_col/b_second.parquet]]}, projection=[category, id, value], file_type=parquet + +# Test 9.2: Results must be correct +query TII +SELECT * FROM multi_col_parquet ORDER BY category ASC, id ASC; +---- +A 1 10 +A 2 20 +B 1 30 +B 2 40 +C 1 50 +C 2 60 + +# Test 9.3: Multi-column sort with non-overlapping primary key across files +# When files don't overlap on the primary column, validation succeeds. + +statement ok +CREATE TABLE multi_col_x(category VARCHAR, id INT, value INT) AS VALUES +('A', 1, 10), ('A', 2, 20); + +statement ok +CREATE TABLE multi_col_y(category VARCHAR, id INT, value INT) AS VALUES +('B', 1, 30), ('B', 2, 40); + +query I +COPY (SELECT * FROM multi_col_x ORDER BY category ASC, id ASC) +TO 'test_files/scratch/sort_pushdown/multi_col_clean/x_first.parquet'; +---- +2 + +query I +COPY (SELECT * FROM multi_col_y ORDER BY category ASC, id ASC) +TO 'test_files/scratch/sort_pushdown/multi_col_clean/y_second.parquet'; +---- +2 + +statement ok +CREATE EXTERNAL TABLE multi_col_clean_parquet(category VARCHAR, id INT, value INT) +STORED AS PARQUET +LOCATION 'test_files/scratch/sort_pushdown/multi_col_clean/' +WITH ORDER (category ASC, id ASC); + +# Test 9.3a: SortExec should be eliminated — non-overlapping primary column +query TT +EXPLAIN SELECT * FROM multi_col_clean_parquet ORDER BY category ASC, id ASC; +---- +logical_plan +01)Sort: multi_col_clean_parquet.category ASC NULLS LAST, multi_col_clean_parquet.id ASC NULLS LAST +02)--TableScan: multi_col_clean_parquet projection=[category, id, value] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/multi_col_clean/x_first.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/multi_col_clean/y_second.parquet]]}, projection=[category, id, value], output_ordering=[category@0 ASC NULLS LAST, id@1 ASC NULLS LAST], file_type=parquet + +# Test 9.3b: Results must be correct +query TII +SELECT * FROM multi_col_clean_parquet ORDER BY category ASC, id ASC; +---- +A 1 10 +A 2 20 +B 1 30 +B 2 40 + +# Test 10: Correctly ordered files WITH ORDER (positive counterpart to Test 6) +# Files in correct_parquet are in correct ASC order — WITH ORDER should pass validation +# and SortExec should be eliminated. + +statement ok +CREATE EXTERNAL TABLE correct_with_order_parquet(id INT, value INT) +STORED AS PARQUET +LOCATION 'test_files/scratch/sort_pushdown/correct/' +WITH ORDER (id ASC); + +# Test 10.1: SortExec should be ELIMINATED — files are in correct order +query TT +EXPLAIN SELECT * FROM correct_with_order_parquet ORDER BY id ASC; +---- +logical_plan +01)Sort: correct_with_order_parquet.id ASC NULLS LAST +02)--TableScan: correct_with_order_parquet projection=[id, value] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/correct/a_low.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/correct/b_mid.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/correct/c_high.parquet]]}, projection=[id, value], output_ordering=[id@0 ASC NULLS LAST], file_type=parquet + +# Test 10.2: Results must be correct +query II +SELECT * FROM correct_with_order_parquet ORDER BY id ASC; +---- +1 100 +2 200 +3 300 +4 400 +5 500 +6 600 +7 700 +8 800 +9 900 + +# Test 11: Multiple file groups (target_partitions > 1) — each group has one file +# When files are spread across separate partitions (one file per group), each +# partition is trivially sorted and SortPreservingMergeExec handles the merge. + +# Restore higher target_partitions so files go into separate groups +statement ok +SET datafusion.execution.target_partitions = 4; + +statement ok +CREATE EXTERNAL TABLE multi_partition_parquet(id INT, value INT) +STORED AS PARQUET +LOCATION 'test_files/scratch/sort_pushdown/reversed/' +WITH ORDER (id ASC); + +# Test 11.1: With separate partitions, each file is trivially sorted. +# SortPreservingMergeExec merges, no SortExec needed per-partition. +query TT +EXPLAIN SELECT * FROM multi_partition_parquet ORDER BY id ASC; +---- +logical_plan +01)Sort: multi_partition_parquet.id ASC NULLS LAST +02)--TableScan: multi_partition_parquet projection=[id, value] +physical_plan +01)SortPreservingMergeExec: [id@0 ASC NULLS LAST] +02)--DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/reversed/a_high.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/reversed/b_mid.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/reversed/c_low.parquet]]}, projection=[id, value], output_ordering=[id@0 ASC NULLS LAST], file_type=parquet + +# Test 11.2: Results must be correct +query II +SELECT * FROM multi_partition_parquet ORDER BY id ASC; +---- +1 100 +2 200 +3 300 +4 400 +5 500 +6 600 +7 700 +8 800 +9 900 + +# Restore target_partitions to 1 for remaining cleanup +statement ok +SET datafusion.execution.target_partitions = 2; + # Cleanup +statement ok +DROP TABLE reversed_high; + +statement ok +DROP TABLE reversed_mid; + +statement ok +DROP TABLE reversed_low; + +statement ok +DROP TABLE reversed_parquet; + +statement ok +DROP TABLE overlap_x; + +statement ok +DROP TABLE overlap_y; + +statement ok +DROP TABLE overlap_parquet; + +statement ok +DROP TABLE reversed_with_order_parquet; + +statement ok +DROP TABLE correct_low; + +statement ok +DROP TABLE correct_mid; + +statement ok +DROP TABLE correct_high; + +statement ok +DROP TABLE correct_parquet; + +statement ok +DROP TABLE desc_low; + +statement ok +DROP TABLE desc_high; + +statement ok +DROP TABLE desc_reversed_parquet; + +statement ok +DROP TABLE multi_col_a; + +statement ok +DROP TABLE multi_col_b; + +statement ok +DROP TABLE multi_col_parquet; + +statement ok +DROP TABLE multi_col_x; + +statement ok +DROP TABLE multi_col_y; + +statement ok +DROP TABLE multi_col_clean_parquet; + +statement ok +DROP TABLE correct_with_order_parquet; + +statement ok +DROP TABLE multi_partition_parquet; + statement ok DROP TABLE timestamp_data; @@ -882,5 +1624,11 @@ DROP TABLE signed_data; statement ok DROP TABLE signed_parquet; +statement ok +DROP TABLE agg_expr_data; + +statement ok +DROP TABLE agg_expr_parquet; + statement ok SET datafusion.optimizer.enable_sort_pushdown = true; diff --git a/datafusion/sqllogictest/test_files/spark/README.md b/datafusion/sqllogictest/test_files/spark/README.md index cffd28009889..e61001c6e42e 100644 --- a/datafusion/sqllogictest/test_files/spark/README.md +++ b/datafusion/sqllogictest/test_files/spark/README.md @@ -39,6 +39,18 @@ When testing Spark functions: - Test cases should only contain `SELECT` statements with the function being tested - Add explicit casts to input values to ensure the correct data type is used (e.g., `0::INT`) - Explicit casting is necessary because DataFusion and Spark do not infer data types in the same way +- If the Spark built-in function under test behaves differently in ANSI SQL mode, please wrap your test cases like this example: + +```sql +statement ok +set datafusion.execution.enable_ansi_mode = true; + +# Functions under test +select abs((-128)::TINYINT) + +statement ok +set datafusion.execution.enable_ansi_mode = false; +``` ### Finding Test Cases diff --git a/datafusion/sqllogictest/test_files/spark/aggregate/collect.slt b/datafusion/sqllogictest/test_files/spark/aggregate/collect.slt new file mode 100644 index 000000000000..2bd80e2e1328 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/aggregate/collect.slt @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query ? +SELECT collect_list(a) FROM (VALUES (1), (2), (3)) AS t(a); +---- +[1, 2, 3] + +query ? +SELECT collect_list(a) FROM (VALUES (1), (2), (2), (3), (1)) AS t(a); +---- +[1, 2, 2, 3, 1] + +query ? +SELECT collect_list(a) FROM (VALUES (1), (NULL), (3)) AS t(a); +---- +[1, 3] + +query ? +SELECT collect_list(a) FROM (VALUES (CAST(NULL AS INT)), (NULL), (NULL)) AS t(a); +---- +[] + +query I? +SELECT g, collect_list(a) +FROM (VALUES (1, 10), (1, 20), (2, 30), (2, 30), (1, 10)) AS t(g, a) +GROUP BY g +ORDER BY g; +---- +1 [10, 20, 10] +2 [30, 30] + +query I? +SELECT g, collect_list(a) +FROM (VALUES (1, 10), (1, NULL), (2, 20), (2, NULL)) AS t(g, a) +GROUP BY g +ORDER BY g; +---- +1 [10] +2 [20] + +# we need to wrap collect_set with array_sort to have consistent outputs +query ? +SELECT array_sort(collect_set(a)) FROM (VALUES (1), (2), (3)) AS t(a); +---- +[1, 2, 3] + +query ? +SELECT array_sort(collect_set(a)) FROM (VALUES (1), (2), (2), (3), (1)) AS t(a); +---- +[1, 2, 3] + +query ? +SELECT array_sort(collect_set(a)) FROM (VALUES (1), (NULL), (3)) AS t(a); +---- +[1, 3] + +query ? +SELECT array_sort(collect_set(a)) FROM (VALUES (CAST(NULL AS INT)), (NULL), (NULL)) AS t(a); +---- +[] + +query I? +SELECT g, array_sort(collect_set(a)) +FROM (VALUES (1, 10), (1, 20), (2, 30), (2, 30), (1, 10)) AS t(g, a) +GROUP BY g +ORDER BY g; +---- +1 [10, 20] +2 [30] + +query I? +SELECT g, array_sort(collect_set(a)) +FROM (VALUES (1, 10), (1, NULL), (1, NULL), (2, 20), (2, NULL)) AS t(g, a) +GROUP BY g +ORDER BY g; +---- +1 [10] +2 [20] diff --git a/datafusion/sqllogictest/test_files/spark/array/array_contains.slt b/datafusion/sqllogictest/test_files/spark/array/array_contains.slt new file mode 100644 index 000000000000..db9ac6b122e3 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/array/array_contains.slt @@ -0,0 +1,140 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tests for Spark-compatible array_contains function. +# Spark semantics: if element is found -> true; if not found and array has nulls -> null; if not found and no nulls -> false. + +### +### Scalar tests +### + +# Element found in array +query B +SELECT array_contains(array(1, 2, 3), 2); +---- +true + +# Element not found, no nulls in array +query B +SELECT array_contains(array(1, 2, 3), 4); +---- +false + +# Element not found, array has null elements -> null +query B +SELECT array_contains(array(1, NULL, 3), 2); +---- +NULL + +# Element found, array has null elements -> true (nulls don't matter) +query B +SELECT array_contains(array(1, NULL, 3), 1); +---- +true + +# Element found at the end, array has null elements -> true +query B +SELECT array_contains(array(1, NULL, 3), 3); +---- +true + +# Null array -> null +query B +SELECT array_contains(NULL, 1); +---- +NULL + +# Null element -> null +query B +SELECT array_contains(array(1, 2, 3), NULL); +---- +NULL + +# Empty array, element not found -> false +query B +SELECT array_contains(array(), 1); +---- +false + +# Array with only nulls, element not found -> null +query B +SELECT array_contains(array(NULL, NULL), 1); +---- +NULL + +# String array, element found +query B +SELECT array_contains(array('a', 'b', 'c'), 'b'); +---- +true + +# String array, element not found, no nulls +query B +SELECT array_contains(array('a', 'b', 'c'), 'd'); +---- +false + +# String array, element not found, has null +query B +SELECT array_contains(array('a', NULL, 'c'), 'd'); +---- +NULL + +### +### Columnar tests with a table +### + +statement ok +CREATE TABLE test_arrays AS VALUES + (1, make_array(1, 2, 3), 10), + (2, make_array(4, NULL, 6), 5), + (3, make_array(7, 8, 9), 10), + (4, NULL, 1), + (5, make_array(10, NULL, NULL), 10); + +# Column needle against column array +query IBB +SELECT column1, + array_contains(column2, column3), + array_contains(column2, 10) +FROM test_arrays +ORDER BY column1; +---- +1 false false +2 NULL NULL +3 false false +4 NULL NULL +5 true true + +statement ok +DROP TABLE test_arrays; + +### +### Nested array tests +### + +# Nested array element found +query B +SELECT array_contains(array(array(1, 2), array(3, 4)), array(3, 4)); +---- +true + +# Nested array element not found, no nulls +query B +SELECT array_contains(array(array(1, 2), array(3, 4)), array(5, 6)); +---- +false diff --git a/datafusion/sqllogictest/test_files/spark/array/array_repeat.slt b/datafusion/sqllogictest/test_files/spark/array/array_repeat.slt index 544c39608f33..19181aae0fc5 100644 --- a/datafusion/sqllogictest/test_files/spark/array/array_repeat.slt +++ b/datafusion/sqllogictest/test_files/spark/array/array_repeat.slt @@ -15,13 +15,90 @@ # specific language governing permissions and limitations # under the License. -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT array_repeat('123', 2); -## PySpark 3.5.5 Result: {'array_repeat(123, 2)': ['123', '123'], 'typeof(array_repeat(123, 2))': 'array', 'typeof(123)': 'string', 'typeof(2)': 'int'} -#query -#SELECT array_repeat('123'::string, 2::int); + +query ? +SELECT array_repeat('123', 2); +---- +[123, 123] + +query ? +SELECT array_repeat('123', 0); +---- +[] + +query ? +SELECT array_repeat('123', -1); +---- +[] + +query ? +SELECT array_repeat('123', CAST('2' AS INT)); +---- +[123, 123] + +query ? +SELECT array_repeat(123, 3); +---- +[123, 123, 123] + +query ? +SELECT array_repeat('2001-09-28T01:00:00'::timestamp, 2); +---- +[2001-09-28T01:00:00, 2001-09-28T01:00:00] + +query ? +SELECT array_repeat(array_repeat('123', CAST('2' AS INT)), CAST('3' AS INT)); +---- +[[123, 123], [123, 123], [123, 123]] + +query ? +SELECT array_repeat(['123'], 2); +---- +[[123], [123]] + +query ? +SELECT array_repeat(NULL, 2); +---- +NULL + +query ? +SELECT array_repeat([NULL], 2); +---- +[[NULL], [NULL]] + +query ? +SELECT array_repeat(['123', NULL], 2); +---- +[[123, NULL], [123, NULL]] + +query ? +SELECT array_repeat('123', CAST(NULL AS INT)); +---- +NULL + +query ? +SELECT array_repeat(column1, column2) +FROM VALUES +('123', 2), +('123', 0), +('123', -1), +(NULL, 1), +('123', NULL); +---- +[123, 123] +[] +[] +NULL +NULL + + +query ? +SELECT array_repeat(column1, column2) +FROM VALUES +(['123'], 2), +([], 2), +([NULL], 2); +---- +[[123], [123]] +[[], []] +[[NULL], [NULL]] diff --git a/datafusion/sqllogictest/test_files/spark/array/shuffle.slt b/datafusion/sqllogictest/test_files/spark/array/shuffle.slt index 35aad58144c9..01d319b619da 100644 --- a/datafusion/sqllogictest/test_files/spark/array/shuffle.slt +++ b/datafusion/sqllogictest/test_files/spark/array/shuffle.slt @@ -87,6 +87,36 @@ SELECT shuffle(column1, 1) FROM test_shuffle_fixed_size; [9, NULL, 8] NULL +query ? +SELECT shuffle(['2001-09-28T01:00:00'::timestamp, '2001-08-28T01:00:00'::timestamp, '2001-07-28T01:00:00'::timestamp, '2001-06-28T01:00:00'::timestamp, '2001-05-28T01:00:00'::timestamp], 1); +---- +[2001-09-28T01:00:00, 2001-06-28T01:00:00, 2001-07-28T01:00:00, 2001-08-28T01:00:00, 2001-05-28T01:00:00] + +query ? +SELECT shuffle(shuffle([1, 20, NULL, 3, 100, NULL, 98, 99], 1), 1); +---- +[1, 99, NULL, 98, 100, NULL, 3, 20] + +query ? +SELECT shuffle([' ', NULL, 'abc'], 1); +---- +[ , NULL, abc] + +query ? +SELECT shuffle([1, 2, 3, 4], CAST('2' AS INT)); +---- +[1, 4, 2, 3] + +query ? +SELECT shuffle(['ab'], NULL); +---- +[ab] + +query ? +SELECT shuffle(shuffle([3, 3], NULL), NULL); +---- +[3, 3] + # Clean up statement ok DROP TABLE test_shuffle_list_types; diff --git a/datafusion/sqllogictest/test_files/spark/array/slice.slt b/datafusion/sqllogictest/test_files/spark/array/slice.slt new file mode 100644 index 000000000000..4aba076aba6b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/array/slice.slt @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query ? +SELECT slice([], 2, 2); +---- +[] + +query ? +SELECT slice([1, 2, 3, 4], 2, 2); +---- +[2, 3] + +query ? +SELECT slice([1, 2, 3, 4], 1, 100); +---- +[1, 2, 3, 4] + +query ? +SELECT slice([1, 2, 3, 4], -2, 2); +---- +[3, 4] + +query ? +SELECT slice([1, 2, 3, 4], 100, 2); +---- +[] + +query ? +SELECT slice([1, 2, 3, 4], -200, 2); +---- +[] + +query error DataFusion error: Execution error: Length must be non-negative, but got -2 +SELECT slice([1, 2, 3, 4], 2, -2); + +query error DataFusion error: Execution error: Length must be non-negative, but got -2 +SELECT slice([1, 2, 3, 4], -2, -2); + +query error DataFusion error: Execution error: Start index must not be zero +SELECT slice([1, 2, 3, 4], 0, -2); + +query ? +SELECT slice([NULL, NULL, NULL, NULL, NULL], 2, 2); +---- +[NULL, NULL] + +query ? +SELECT slice(arrow_cast(NULL, 'FixedSizeList(1, Int64)'), 2, 2); +---- +NULL + +query ? +SELECT slice([1, 2, 3, 4], NULL, 2); +---- +NULL + +query ? +SELECT slice([1, 2, 3, 4], 2, NULL); +---- +NULL + + +query ? +SELECT slice(column1, column2, column3) +FROM VALUES +([1, 2, 3, 4], 2, 2), +([1, 2, 3, 4], 1, 100), +([1, 2, 3, 4], -2, 2), +([], 2, 2), +([1, 2, 3, 4], 100, 2), +([1, 2, 3, 4], -200, 2), +([NULL, NULL, NULL, NULL, NULL], 2, 2), +(arrow_cast(NULL, 'FixedSizeList(1, Int64)'), 2, 2), +([1, 2, 3, 4], NULL, 2), +([1, 2, 3, 4], 2, NULL); +---- +[2, 3] +[1, 2, 3, 4] +[3, 4] +[] +[] +[] +[NULL, NULL] +NULL +NULL +NULL + +query ? +SELECT slice(['2001-09-28T01:00:00'::timestamp, '2001-08-28T01:00:00'::timestamp, '2001-07-28T01:00:00'::timestamp, '2001-06-28T01:00:00'::timestamp, '2001-05-28T01:00:00'::timestamp], 1, 3); +---- +[2001-09-28T01:00:00, 2001-08-28T01:00:00, 2001-07-28T01:00:00] + +query ? +SELECT slice(slice([1, 2, 3, 4], 1, 3), 1, 2); +---- +[1, 2] + +query ? +SELECT slice([1, 2, 3, 4], CAST('2' AS INT), 4); +---- +[2, 3, 4] diff --git a/datafusion/sqllogictest/test_files/spark/bitmap/bitmap_bit_position.slt b/datafusion/sqllogictest/test_files/spark/bitmap/bitmap_bit_position.slt new file mode 100644 index 000000000000..4af3193a5db3 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/bitmap/bitmap_bit_position.slt @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +query I +SELECT bitmap_bit_position(arrow_cast(1, 'Int8')); +---- +0 + +query I +SELECT bitmap_bit_position(arrow_cast(3, 'Int8')); +---- +2 + +query I +SELECT bitmap_bit_position(arrow_cast(7, 'Int8')); +---- +6 + +query I +SELECT bitmap_bit_position(arrow_cast(15, 'Int8')); +---- +14 + +query I +SELECT bitmap_bit_position(arrow_cast(-1, 'Int8')); +---- +1 + +query I +SELECT bitmap_bit_position(arrow_cast(256, 'Int16')); +---- +255 + +query I +SELECT bitmap_bit_position(arrow_cast(1024, 'Int16')); +---- +1023 + +query I +SELECT bitmap_bit_position(arrow_cast(-32768, 'Int16')); +---- +0 + +query I +SELECT bitmap_bit_position(arrow_cast(16384, 'Int16')); +---- +16383 + +query I +SELECT bitmap_bit_position(arrow_cast(-1, 'Int16')); +---- +1 + +query I +SELECT bitmap_bit_position(arrow_cast(65536, 'Int32')); +---- +32767 + +query I +SELECT bitmap_bit_position(arrow_cast(1048576, 'Int32')); +---- +32767 + +query I +SELECT bitmap_bit_position(arrow_cast(-2147483648, 'Int32')); +---- +0 + +query I +SELECT bitmap_bit_position(arrow_cast(1073741824, 'Int32')); +---- +32767 + +query I +SELECT bitmap_bit_position(arrow_cast(-1, 'Int32')); +---- +1 + +query I +SELECT bitmap_bit_position(arrow_cast(4294967296, 'Int64')); +---- +32767 + +query I +SELECT bitmap_bit_position(arrow_cast(-1, 'Int64')); +---- +1 + +query I +SELECT bitmap_bit_position(arrow_cast(-9223372036854775808, 'Int64')); +---- +0 + +query I +SELECT bitmap_bit_position(arrow_cast(9223372036854775807, 'Int64')); +---- +32766 diff --git a/datafusion/sqllogictest/test_files/spark/bitmap/bitmap_bucket_number.slt b/datafusion/sqllogictest/test_files/spark/bitmap/bitmap_bucket_number.slt new file mode 100644 index 000000000000..2a6e190b31ea --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/bitmap/bitmap_bucket_number.slt @@ -0,0 +1,122 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +query I +SELECT bitmap_bucket_number(arrow_cast(1, 'Int8')); +---- +1 + +query I +SELECT bitmap_bucket_number(arrow_cast(127, 'Int8')); +---- +1 + +query I +SELECT bitmap_bucket_number(arrow_cast(-1, 'Int8')); +---- +0 + +query I +SELECT bitmap_bucket_number(arrow_cast(-64, 'Int8')); +---- +0 + +query I +SELECT bitmap_bucket_number(arrow_cast(-65, 'Int8')); +---- +0 + +query I +SELECT bitmap_bucket_number(arrow_cast(1, 'Int16')); +---- +1 + +query I +SELECT bitmap_bucket_number(arrow_cast(257, 'Int16')); +---- +1 + +query I +SELECT bitmap_bucket_number(arrow_cast(32767, 'Int16')); +---- +1 + +query I +SELECT bitmap_bucket_number(arrow_cast(-1, 'Int16')); +---- +0 + +query I +SELECT bitmap_bucket_number(arrow_cast(-256, 'Int16')); +---- +0 + +query I +SELECT bitmap_bucket_number(arrow_cast(1, 'Int32')); +---- +1 + +query I +SELECT bitmap_bucket_number(arrow_cast(65537, 'Int32')); +---- +3 + +query I +SELECT bitmap_bucket_number(arrow_cast(2147483647, 'Int32')); +---- +65536 + +query I +SELECT bitmap_bucket_number(arrow_cast(-1, 'Int32')); +---- +0 + +query I +SELECT bitmap_bucket_number(arrow_cast(-65536, 'Int32')); +---- +-2 + +query I +SELECT bitmap_bucket_number(arrow_cast(1, 'Int64')); +---- +1 + +query I +SELECT bitmap_bucket_number(arrow_cast(4294967297, 'Int64')); +---- +131073 + +query I +SELECT bitmap_bucket_number(arrow_cast(9223372036854775807, 'Int64')); +---- +281474976710656 + +query I +SELECT bitmap_bucket_number(arrow_cast(-1, 'Int64')); +---- +0 + +query I +SELECT bitmap_bucket_number(arrow_cast(-4294967296, 'Int64')); +---- +-131072 + +query I +SELECT bitmap_bucket_number(arrow_cast(-9223372036854775808, 'Int64')); +---- +-281474976710656 diff --git a/datafusion/sqllogictest/test_files/spark/collection/size.slt b/datafusion/sqllogictest/test_files/spark/collection/size.slt new file mode 100644 index 000000000000..106760eebfe4 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/collection/size.slt @@ -0,0 +1,131 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT size(array(1, 2, 3)); +## PySpark 3.5.5 Result: {'size(array(1, 2, 3))': 3} + +# Basic array +query I +SELECT size(make_array(1, 2, 3)); +---- +3 + +# Nested array +query I +SELECT size(make_array(make_array(1, 2), make_array(3, 4, 5))); +---- +2 + +# LargeList tests +query I +SELECT size(arrow_cast(make_array(1, 2, 3), 'LargeList(Int32)')); +---- +3 + +query I +SELECT size(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')); +---- +5 + +# FixedSizeList tests +query I +SELECT size(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int32)')); +---- +3 + +query I +SELECT size(arrow_cast(make_array(1, 2, 3, 4), 'FixedSizeList(4, Int32)')); +---- +4 + +# Map size tests +query I +SELECT size(map(make_array('a', 'b', 'c'), make_array(1, 2, 3))); +---- +3 + +query I +SELECT size(map(make_array('a'), make_array(1))); +---- +1 + +# Empty array +query I +SELECT size(arrow_cast(make_array(), 'List(Int32)')); +---- +0 + + +# Array with NULL elements (size counts elements including NULLs) +query I +SELECT size(make_array(1, NULL, 3)); +---- +3 + +# NULL array returns -1 (Spark behavior) +query I +SELECT size(NULL::int[]); +---- +-1 + + +# Empty map +query I +SELECT size(map(arrow_cast(make_array(), 'List(Utf8)'), arrow_cast(make_array(), 'List(Int32)'))); +---- +0 + +# String array +query I +SELECT size(make_array('hello', 'world')); +---- +2 + +# Boolean array +query I +SELECT size(make_array(true, false, true)); +---- +3 + +# Float array +query I +SELECT size(make_array(1.5, 2.5, 3.5, 4.5)); +---- +4 + +# Array column tests (with NULL values) +query I +SELECT size(column1) FROM VALUES ([1]), ([1,2]), ([]), (NULL); +---- +1 +2 +0 +-1 + +# Map column tests (with NULL values) +query I +SELECT size(column1) FROM VALUES (map(['a'], [1])), (map(['a','b'], [1,2])), (NULL); +---- +1 +2 +-1 diff --git a/datafusion/sqllogictest/test_files/spark/datetime/add_months.slt b/datafusion/sqllogictest/test_files/spark/datetime/add_months.slt index cae9b21dd476..55a493ffefe2 100644 --- a/datafusion/sqllogictest/test_files/spark/datetime/add_months.slt +++ b/datafusion/sqllogictest/test_files/spark/datetime/add_months.slt @@ -15,13 +15,45 @@ # specific language governing permissions and limitations # under the License. -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT add_months('2016-08-31', 1); -## PySpark 3.5.5 Result: {'add_months(2016-08-31, 1)': datetime.date(2016, 9, 30), 'typeof(add_months(2016-08-31, 1))': 'date', 'typeof(2016-08-31)': 'string', 'typeof(1)': 'int'} -#query -#SELECT add_months('2016-08-31'::string, 1::int); +query D +SELECT add_months('2016-07-30'::date, 1::int); +---- +2016-08-30 + +query D +SELECT add_months('2016-07-30'::date, 0::int); +---- +2016-07-30 + +query D +SELECT add_months('2016-07-30'::date, 10000::int); +---- +2849-11-30 + +# Test integer overflow +# TODO: Enable with next arrow upgrade (>=58.0.0) +# query D +# SELECT add_months('2016-07-30'::date, 2147483647::int); +# ---- +# NULL + +query D +SELECT add_months('2016-07-30'::date, -5::int); +---- +2016-02-29 + +# Test with NULL values +query D +SELECT add_months(NULL::date, 1::int); +---- +NULL + +query D +SELECT add_months('2016-07-30'::date, NULL::int); +---- +NULL + +query D +SELECT add_months(NULL::date, NULL::int); +---- +NULL diff --git a/datafusion/sqllogictest/test_files/spark/datetime/date_add.slt b/datafusion/sqllogictest/test_files/spark/datetime/date_add.slt index a2ac7cf2edb1..cb407a645369 100644 --- a/datafusion/sqllogictest/test_files/spark/datetime/date_add.slt +++ b/datafusion/sqllogictest/test_files/spark/datetime/date_add.slt @@ -41,7 +41,7 @@ SELECT date_add('2016-07-30'::date, arrow_cast(1, 'Int8')); 2016-07-31 query D -SELECT date_sub('2016-07-30'::date, 0::int); +SELECT date_add('2016-07-30'::date, 0::int); ---- 2016-07-30 @@ -51,20 +51,15 @@ SELECT date_add('2016-07-30'::date, 2147483647::int)::int; -2147466637 query I -SELECT date_sub('1969-01-01'::date, 2147483647::int)::int; +SELECT date_add('1969-01-01'::date, 2147483647::int)::int; ---- -2147483284 +2147483282 query D SELECT date_add('2016-07-30'::date, 100000::int); ---- 2290-05-15 -query D -SELECT date_sub('2016-07-30'::date, 100000::int); ----- -1742-10-15 - # Test with negative day values (should subtract days) query D SELECT date_add('2016-07-30'::date, -5::int); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/date_diff.slt b/datafusion/sqllogictest/test_files/spark/datetime/date_diff.slt index c5871ab41e18..b0952d6a4351 100644 --- a/datafusion/sqllogictest/test_files/spark/datetime/date_diff.slt +++ b/datafusion/sqllogictest/test_files/spark/datetime/date_diff.slt @@ -15,18 +15,138 @@ # specific language governing permissions and limitations # under the License. -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT date_diff('2009-07-30', '2009-07-31'); -## PySpark 3.5.5 Result: {'date_diff(2009-07-30, 2009-07-31)': -1, 'typeof(date_diff(2009-07-30, 2009-07-31))': 'int', 'typeof(2009-07-30)': 'string', 'typeof(2009-07-31)': 'string'} -#query -#SELECT date_diff('2009-07-30'::string, '2009-07-31'::string); - -## Original Query: SELECT date_diff('2009-07-31', '2009-07-30'); -## PySpark 3.5.5 Result: {'date_diff(2009-07-31, 2009-07-30)': 1, 'typeof(date_diff(2009-07-31, 2009-07-30))': 'int', 'typeof(2009-07-31)': 'string', 'typeof(2009-07-30)': 'string'} -#query -#SELECT date_diff('2009-07-31'::string, '2009-07-30'::string); +# date input +query I +SELECT date_diff('2009-07-30'::date, '2009-07-31'::date); +---- +-1 + +query I +SELECT date_diff('2009-07-31'::date, '2009-07-30'::date); +---- +1 + +query I +SELECT date_diff('2009-07-31'::string, '2009-07-30'::date); +---- +1 + +query I +SELECT date_diff('2009-07-31'::timestamp, '2009-07-30'::date); +---- +1 + +# Date64 input +query I +SELECT date_diff(arrow_cast('2009-07-31', 'Date64'), arrow_cast('2009-07-30', 'Date64')); +---- +1 + +query I +SELECT date_diff(arrow_cast('2009-07-30', 'Date64'), arrow_cast('2009-07-31', 'Date64')); +---- +-1 + +# Mixed Date32 and Date64 input +query I +SELECT date_diff('2009-07-31'::date, arrow_cast('2009-07-30', 'Date64')); +---- +1 + +query I +SELECT date_diff(arrow_cast('2009-07-31', 'Date64'), '2009-07-30'::date); +---- +1 + + +# Same date returns 0 +query I +SELECT date_diff('2009-07-30'::date, '2009-07-30'::date); +---- +0 + +# Large difference +query I +SELECT date_diff('2020-01-01'::date, '1970-01-01'::date); +---- +18262 + +# timestamp input +query I +SELECT date_diff('2009-07-30 12:34:56'::timestamp, '2009-07-31 23:45:01'::timestamp); +---- +-1 + +query I +SELECT date_diff('2009-07-31 23:45:01'::timestamp, '2009-07-30 12:34:56'::timestamp); +---- +1 + +query I +SELECT date_diff('2009-07-31 23:45:01'::string, '2009-07-30 12:34:56'::timestamp); +---- +1 + +# string input +query I +SELECT date_diff('2009-07-30', '2009-07-31'); +---- +-1 + +query I +SELECT date_diff('2009-07-31', '2009-07-30'); +---- +1 + +# NULL handling +query I +SELECT date_diff(NULL::date, '2009-07-30'::date); +---- +NULL + +query I +SELECT date_diff('2009-07-31'::date, NULL::date); +---- +NULL + +query I +SELECT date_diff(NULL::date, NULL::date); +---- +NULL + +query I +SELECT date_diff(column1, column2) +FROM VALUES +('2009-07-30'::date, '2009-07-31'::date), +('2009-07-31'::date, '2009-07-30'::date), +(NULL::date, '2009-07-30'::date), +('2009-07-31'::date, NULL::date), +(NULL::date, NULL::date); +---- +-1 +1 +NULL +NULL +NULL + + +# Alias datediff +query I +SELECT datediff('2009-07-30'::date, '2009-07-31'::date); +---- +-1 + +query I +SELECT datediff(column1, column2) +FROM VALUES +('2009-07-30'::date, '2009-07-31'::date), +('2009-07-31'::date, '2009-07-30'::date), +(NULL::date, '2009-07-30'::date), +('2009-07-31'::date, NULL::date), +(NULL::date, NULL::date); +---- +-1 +1 +NULL +NULL +NULL diff --git a/datafusion/sqllogictest/test_files/spark/datetime/date_part.slt b/datafusion/sqllogictest/test_files/spark/datetime/date_part.slt index cd3271cdc7df..48216bd55169 100644 --- a/datafusion/sqllogictest/test_files/spark/datetime/date_part.slt +++ b/datafusion/sqllogictest/test_files/spark/datetime/date_part.slt @@ -15,48 +15,262 @@ # specific language governing permissions and limitations # under the License. -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT date_part('MINUTE', INTERVAL '123 23:55:59.002001' DAY TO SECOND); -## PySpark 3.5.5 Result: {"date_part(MINUTE, INTERVAL '123 23:55:59.002001' DAY TO SECOND)": 55, "typeof(date_part(MINUTE, INTERVAL '123 23:55:59.002001' DAY TO SECOND))": 'tinyint', 'typeof(MINUTE)': 'string', "typeof(INTERVAL '123 23:55:59.002001' DAY TO SECOND)": 'interval day to second'} -#query -#SELECT date_part('MINUTE'::string, INTERVAL '123 23:55:59.002001' DAY TO SECOND::interval day to second); - -## Original Query: SELECT date_part('MONTH', INTERVAL '2021-11' YEAR TO MONTH); -## PySpark 3.5.5 Result: {"date_part(MONTH, INTERVAL '2021-11' YEAR TO MONTH)": 11, "typeof(date_part(MONTH, INTERVAL '2021-11' YEAR TO MONTH))": 'tinyint', 'typeof(MONTH)': 'string', "typeof(INTERVAL '2021-11' YEAR TO MONTH)": 'interval year to month'} -#query -#SELECT date_part('MONTH'::string, INTERVAL '2021-11' YEAR TO MONTH::interval year to month); - -## Original Query: SELECT date_part('SECONDS', timestamp'2019-10-01 00:00:01.000001'); -## PySpark 3.5.5 Result: {"date_part(SECONDS, TIMESTAMP '2019-10-01 00:00:01.000001')": Decimal('1.000001'), "typeof(date_part(SECONDS, TIMESTAMP '2019-10-01 00:00:01.000001'))": 'decimal(8,6)', 'typeof(SECONDS)': 'string', "typeof(TIMESTAMP '2019-10-01 00:00:01.000001')": 'timestamp'} -#query -#SELECT date_part('SECONDS'::string, TIMESTAMP '2019-10-01 00:00:01.000001'::timestamp); - -## Original Query: SELECT date_part('YEAR', TIMESTAMP '2019-08-12 01:00:00.123456'); -## PySpark 3.5.5 Result: {"date_part(YEAR, TIMESTAMP '2019-08-12 01:00:00.123456')": 2019, "typeof(date_part(YEAR, TIMESTAMP '2019-08-12 01:00:00.123456'))": 'int', 'typeof(YEAR)': 'string', "typeof(TIMESTAMP '2019-08-12 01:00:00.123456')": 'timestamp'} -#query -#SELECT date_part('YEAR'::string, TIMESTAMP '2019-08-12 01:00:00.123456'::timestamp); - -## Original Query: SELECT date_part('days', interval 5 days 3 hours 7 minutes); -## PySpark 3.5.5 Result: {"date_part(days, INTERVAL '5 03:07' DAY TO MINUTE)": 5, "typeof(date_part(days, INTERVAL '5 03:07' DAY TO MINUTE))": 'int', 'typeof(days)': 'string', "typeof(INTERVAL '5 03:07' DAY TO MINUTE)": 'interval day to minute'} -#query -#SELECT date_part('days'::string, INTERVAL '5 03:07' DAY TO MINUTE::interval day to minute); - -## Original Query: SELECT date_part('doy', DATE'2019-08-12'); -## PySpark 3.5.5 Result: {"date_part(doy, DATE '2019-08-12')": 224, "typeof(date_part(doy, DATE '2019-08-12'))": 'int', 'typeof(doy)': 'string', "typeof(DATE '2019-08-12')": 'date'} -#query -#SELECT date_part('doy'::string, DATE '2019-08-12'::date); - -## Original Query: SELECT date_part('seconds', interval 5 hours 30 seconds 1 milliseconds 1 microseconds); -## PySpark 3.5.5 Result: {"date_part(seconds, INTERVAL '05:00:30.001001' HOUR TO SECOND)": Decimal('30.001001'), "typeof(date_part(seconds, INTERVAL '05:00:30.001001' HOUR TO SECOND))": 'decimal(8,6)', 'typeof(seconds)': 'string', "typeof(INTERVAL '05:00:30.001001' HOUR TO SECOND)": 'interval hour to second'} -#query -#SELECT date_part('seconds'::string, INTERVAL '05:00:30.001001' HOUR TO SECOND::interval hour to second); - -## Original Query: SELECT date_part('week', timestamp'2019-08-12 01:00:00.123456'); -## PySpark 3.5.5 Result: {"date_part(week, TIMESTAMP '2019-08-12 01:00:00.123456')": 33, "typeof(date_part(week, TIMESTAMP '2019-08-12 01:00:00.123456'))": 'int', 'typeof(week)': 'string', "typeof(TIMESTAMP '2019-08-12 01:00:00.123456')": 'timestamp'} -#query -#SELECT date_part('week'::string, TIMESTAMP '2019-08-12 01:00:00.123456'::timestamp); +# YEAR +query I +SELECT date_part('YEAR'::string, '2000-01-01'::date); +---- +2000 + +query I +SELECT date_part('YEARS'::string, '2000-01-01'::date); +---- +2000 + +query I +SELECT date_part('Y'::string, '2000-01-01'::date); +---- +2000 + +query I +SELECT date_part('YR'::string, '2000-01-01'::date); +---- +2000 + +query I +SELECT date_part('YRS'::string, '2000-01-01'::date); +---- +2000 + +# YEAROFWEEK +query I +SELECT date_part('YEAROFWEEK'::string, '2000-01-01'::date); +---- +1999 + +# QUARTER +query I +SELECT date_part('QUARTER'::string, '2000-01-01'::date); +---- +1 + +query I +SELECT date_part('QTR'::string, '2000-01-01'::date); +---- +1 + +# MONTH +query I +SELECT date_part('MONTH'::string, '2000-01-01'::date); +---- +1 + +query I +SELECT date_part('MON'::string, '2000-01-01'::date); +---- +1 + +query I +SELECT date_part('MONS'::string, '2000-01-01'::date); +---- +1 + +query I +SELECT date_part('MONTHS'::string, '2000-01-01'::date); +---- +1 + +# WEEK +query I +SELECT date_part('WEEK'::string, '2000-01-01'::date); +---- +52 + +query I +SELECT date_part('WEEKS'::string, '2000-01-01'::date); +---- +52 + +query I +SELECT date_part('W'::string, '2000-01-01'::date); +---- +52 + +# DAYS +query I +SELECT date_part('DAY'::string, '2000-01-01'::date); +---- +1 + +query I +SELECT date_part('D'::string, '2000-01-01'::date); +---- +1 + +query I +SELECT date_part('DAYS'::string, '2000-01-01'::date); +---- +1 + +# DAYOFWEEK +query I +SELECT date_part('DAYOFWEEK'::string, '2000-01-01'::date); +---- +7 + +query I +SELECT date_part('DOW'::string, '2000-01-01'::date); +---- +7 + +# DAYOFWEEK_ISO +query I +SELECT date_part('DAYOFWEEK_ISO'::string, '2000-01-01'::date); +---- +6 + +query I +SELECT date_part('DOW_ISO'::string, '2000-01-01'::date); +---- +6 + +# DOY +query I +SELECT date_part('DOY'::string, '2000-01-01'::date); +---- +1 + +# HOUR +query I +SELECT date_part('HOUR'::string, '2000-01-01 12:30:45'::timestamp); +---- +12 + +query I +SELECT date_part('H'::string, '2000-01-01 12:30:45'::timestamp); +---- +12 + +query I +SELECT date_part('HOURS'::string, '2000-01-01 12:30:45'::timestamp); +---- +12 + +query I +SELECT date_part('HR'::string, '2000-01-01 12:30:45'::timestamp); +---- +12 + +query I +SELECT date_part('HRS'::string, '2000-01-01 12:30:45'::timestamp); +---- +12 + +# MINUTE +query I +SELECT date_part('MINUTE'::string, '2000-01-01 12:30:45'::timestamp); +---- +30 + +query I +SELECT date_part('M'::string, '2000-01-01 12:30:45'::timestamp); +---- +30 + +query I +SELECT date_part('MIN'::string, '2000-01-01 12:30:45'::timestamp); +---- +30 + +query I +SELECT date_part('MINS'::string, '2000-01-01 12:30:45'::timestamp); +---- +30 + +query I +SELECT date_part('MINUTES'::string, '2000-01-01 12:30:45'::timestamp); +---- +30 + +# SECOND +query I +SELECT date_part('SECOND'::string, '2000-01-01 12:30:45'::timestamp); +---- +45 + +query I +SELECT date_part('S'::string, '2000-01-01 12:30:45'::timestamp); +---- +45 + +query I +SELECT date_part('SEC'::string, '2000-01-01 12:30:45'::timestamp); +---- +45 + +query I +SELECT date_part('SECONDS'::string, '2000-01-01 12:30:45'::timestamp); +---- +45 + +query I +SELECT date_part('SECS'::string, '2000-01-01 12:30:45'::timestamp); +---- +45 + +# NULL input +query I +SELECT date_part('year'::string, NULL::timestamp); +---- +NULL + +query error Internal error: First argument of `DATE_PART` must be non-null scalar Utf8 +SELECT date_part(NULL::string, '2000-01-01'::date); + +# Invalid part +query error DataFusion error: Execution error: Date part 'test' not supported +SELECT date_part('test'::string, '2000-01-01'::date); + +query I +SELECT date_part('year', column1) +FROM VALUES +('2022-03-15'::date), +('1999-12-31'::date), +('2000-01-01'::date), +(NULL::date); +---- +2022 +1999 +2000 +NULL + +query I +SELECT date_part('minutes', column1) +FROM VALUES +('2022-03-15 12:30:45'::timestamp), +('1999-12-31 12:32:45'::timestamp), +('2000-01-01 12:00:45'::timestamp), +(NULL::timestamp); +---- +30 +32 +0 +NULL + +# alias datepart +query I +SELECT datepart('YEAR'::string, '2000-01-01'::date); +---- +2000 + +query I +SELECT datepart('year', column1) +FROM VALUES +('2022-03-15'::date), +('1999-12-31'::date), +('2000-01-01'::date), +(NULL::date); +---- +2022 +1999 +2000 +NULL diff --git a/datafusion/sqllogictest/test_files/spark/datetime/date_sub.slt b/datafusion/sqllogictest/test_files/spark/datetime/date_sub.slt index cb5e77c3b4f1..bf36ebd867d1 100644 --- a/datafusion/sqllogictest/test_files/spark/datetime/date_sub.slt +++ b/datafusion/sqllogictest/test_files/spark/datetime/date_sub.slt @@ -45,6 +45,16 @@ SELECT date_sub('2016-07-30'::date, 0::int); ---- 2016-07-30 +query I +SELECT date_sub('1969-01-01'::date, 2147483647::int)::int; +---- +2147483284 + +query D +SELECT date_sub('2016-07-30'::date, 100000::int); +---- +1742-10-15 + # Test with negative day values (should add days) query D SELECT date_sub('2016-07-30'::date, -1::int); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/date_trunc.slt b/datafusion/sqllogictest/test_files/spark/datetime/date_trunc.slt index 8a15254e6795..7fc1583bb931 100644 --- a/datafusion/sqllogictest/test_files/spark/datetime/date_trunc.slt +++ b/datafusion/sqllogictest/test_files/spark/datetime/date_trunc.slt @@ -15,33 +15,150 @@ # specific language governing permissions and limitations # under the License. -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT date_trunc('DD', '2015-03-05T09:32:05.359'); -## PySpark 3.5.5 Result: {'date_trunc(DD, 2015-03-05T09:32:05.359)': datetime.datetime(2015, 3, 5, 0, 0), 'typeof(date_trunc(DD, 2015-03-05T09:32:05.359))': 'timestamp', 'typeof(DD)': 'string', 'typeof(2015-03-05T09:32:05.359)': 'string'} -#query -#SELECT date_trunc('DD'::string, '2015-03-05T09:32:05.359'::string); - -## Original Query: SELECT date_trunc('HOUR', '2015-03-05T09:32:05.359'); -## PySpark 3.5.5 Result: {'date_trunc(HOUR, 2015-03-05T09:32:05.359)': datetime.datetime(2015, 3, 5, 9, 0), 'typeof(date_trunc(HOUR, 2015-03-05T09:32:05.359))': 'timestamp', 'typeof(HOUR)': 'string', 'typeof(2015-03-05T09:32:05.359)': 'string'} -#query -#SELECT date_trunc('HOUR'::string, '2015-03-05T09:32:05.359'::string); - -## Original Query: SELECT date_trunc('MILLISECOND', '2015-03-05T09:32:05.123456'); -## PySpark 3.5.5 Result: {'date_trunc(MILLISECOND, 2015-03-05T09:32:05.123456)': datetime.datetime(2015, 3, 5, 9, 32, 5, 123000), 'typeof(date_trunc(MILLISECOND, 2015-03-05T09:32:05.123456))': 'timestamp', 'typeof(MILLISECOND)': 'string', 'typeof(2015-03-05T09:32:05.123456)': 'string'} -#query -#SELECT date_trunc('MILLISECOND'::string, '2015-03-05T09:32:05.123456'::string); - -## Original Query: SELECT date_trunc('MM', '2015-03-05T09:32:05.359'); -## PySpark 3.5.5 Result: {'date_trunc(MM, 2015-03-05T09:32:05.359)': datetime.datetime(2015, 3, 1, 0, 0), 'typeof(date_trunc(MM, 2015-03-05T09:32:05.359))': 'timestamp', 'typeof(MM)': 'string', 'typeof(2015-03-05T09:32:05.359)': 'string'} -#query -#SELECT date_trunc('MM'::string, '2015-03-05T09:32:05.359'::string); - -## Original Query: SELECT date_trunc('YEAR', '2015-03-05T09:32:05.359'); -## PySpark 3.5.5 Result: {'date_trunc(YEAR, 2015-03-05T09:32:05.359)': datetime.datetime(2015, 1, 1, 0, 0), 'typeof(date_trunc(YEAR, 2015-03-05T09:32:05.359))': 'timestamp', 'typeof(YEAR)': 'string', 'typeof(2015-03-05T09:32:05.359)': 'string'} -#query -#SELECT date_trunc('YEAR'::string, '2015-03-05T09:32:05.359'::string); +# YEAR - truncate to first date of year, time zeroed +query P +SELECT date_trunc('YEAR', '2015-03-05T09:32:05.123456'::timestamp); +---- +2015-01-01T00:00:00 + +query P +SELECT date_trunc('YYYY', '2015-03-05T09:32:05.123456'::timestamp); +---- +2015-01-01T00:00:00 + +query P +SELECT date_trunc('YY', '2015-03-05T09:32:05.123456'::timestamp); +---- +2015-01-01T00:00:00 + +# QUARTER - truncate to first date of quarter, time zeroed +query P +SELECT date_trunc('QUARTER', '2015-05-05T09:32:05.123456'::timestamp); +---- +2015-04-01T00:00:00 + +# MONTH - truncate to first date of month, time zeroed +query P +SELECT date_trunc('MONTH', '2015-03-05T09:32:05.123456'::timestamp); +---- +2015-03-01T00:00:00 + +query P +SELECT date_trunc('MM', '2015-03-05T09:32:05.123456'::timestamp); +---- +2015-03-01T00:00:00 + +query P +SELECT date_trunc('MON', '2015-03-05T09:32:05.123456'::timestamp); +---- +2015-03-01T00:00:00 + +# WEEK - truncate to Monday of the week, time zeroed +query P +SELECT date_trunc('WEEK', '2015-03-05T09:32:05.123456'::timestamp); +---- +2015-03-02T00:00:00 + +# DAY - zero out time part +query P +SELECT date_trunc('DAY', '2015-03-05T09:32:05.123456'::timestamp); +---- +2015-03-05T00:00:00 + +query P +SELECT date_trunc('DD', '2015-03-05T09:32:05.123456'::timestamp); +---- +2015-03-05T00:00:00 + +# HOUR - zero out minute and second with fraction +query P +SELECT date_trunc('HOUR', '2015-03-05T09:32:05.123456'::timestamp); +---- +2015-03-05T09:00:00 + +# MINUTE - zero out second with fraction +query P +SELECT date_trunc('MINUTE', '2015-03-05T09:32:05.123456'::timestamp); +---- +2015-03-05T09:32:00 + +# SECOND - zero out fraction +query P +SELECT date_trunc('SECOND', '2015-03-05T09:32:05.123456'::timestamp); +---- +2015-03-05T09:32:05 + +# MILLISECOND - zero out microseconds +query P +SELECT date_trunc('MILLISECOND', '2015-03-05T09:32:05.123456'::timestamp); +---- +2015-03-05T09:32:05.123 + +# MICROSECOND - everything remains +query P +SELECT date_trunc('MICROSECOND', '2015-03-05T09:32:05.123456'::timestamp); +---- +2015-03-05T09:32:05.123456 + +query P +SELECT date_trunc('YEAR', column1) +FROM VALUES +('2015-03-05T09:32:05.123456'::timestamp), +('2020-11-15T22:45:30.654321'::timestamp), +('1999-07-20T14:20:10.000001'::timestamp), +(NULL::timestamp); +---- +2015-01-01T00:00:00 +2020-01-01T00:00:00 +1999-01-01T00:00:00 +NULL + +# String input +query P +SELECT date_trunc('YEAR', '2015-03-05T09:32:05.123456'); +---- +2015-01-01T00:00:00 + +# Null handling +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: First argument of `DATE_TRUNC` must be non-null scalar Utf8 +SELECT date_trunc(NULL, '2015-03-05T09:32:05.123456'); + +query P +SELECT date_trunc('YEAR', NULL::timestamp); +---- +NULL + +# incorrect format +query error DataFusion error: Execution error: Unsupported date_trunc granularity: 'test'. Supported values are: microsecond, millisecond, second, minute, hour, day, week, month, quarter, year +SELECT date_trunc('test', '2015-03-05T09:32:05.123456'); + +# Timezone handling - Spark-compatible behavior +# Spark converts timestamps to session timezone before truncating for coarse granularities + +query P +SELECT date_trunc('DAY', arrow_cast(timestamp '2024-07-15T03:30:00', 'Timestamp(Microsecond, Some("UTC"))')); +---- +2024-07-15T00:00:00Z + +query P +SELECT date_trunc('DAY', arrow_cast(timestamp '2024-07-15T03:30:00', 'Timestamp(Microsecond, None)')); +---- +2024-07-15T00:00:00 + +statement ok +SET datafusion.execution.time_zone = 'America/New_York'; + +# This timestamp is 03:30 UTC = 23:30 EDT (previous day) on July 14 +# With session timezone, truncation happens in America/New_York timezone +query P +SELECT date_trunc('DAY', arrow_cast(timestamp '2024-07-15T03:30:00', 'Timestamp(Microsecond, Some("UTC"))')); +---- +2024-07-14T00:00:00Z + +query P +SELECT date_trunc('DAY', arrow_cast(timestamp '2024-07-15T03:30:00', 'Timestamp(Microsecond, None)')); +---- +2024-07-15T00:00:00 + +statement ok +RESET datafusion.execution.time_zone; diff --git a/datafusion/sqllogictest/test_files/spark/datetime/datediff.slt b/datafusion/sqllogictest/test_files/spark/datetime/datediff.slt deleted file mode 100644 index 223e2c313ae8..000000000000 --- a/datafusion/sqllogictest/test_files/spark/datetime/datediff.slt +++ /dev/null @@ -1,32 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT datediff('2009-07-30', '2009-07-31'); -## PySpark 3.5.5 Result: {'datediff(2009-07-30, 2009-07-31)': -1, 'typeof(datediff(2009-07-30, 2009-07-31))': 'int', 'typeof(2009-07-30)': 'string', 'typeof(2009-07-31)': 'string'} -#query -#SELECT datediff('2009-07-30'::string, '2009-07-31'::string); - -## Original Query: SELECT datediff('2009-07-31', '2009-07-30'); -## PySpark 3.5.5 Result: {'datediff(2009-07-31, 2009-07-30)': 1, 'typeof(datediff(2009-07-31, 2009-07-30))': 'int', 'typeof(2009-07-31)': 'string', 'typeof(2009-07-30)': 'string'} -#query -#SELECT datediff('2009-07-31'::string, '2009-07-30'::string); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/datepart.slt b/datafusion/sqllogictest/test_files/spark/datetime/datepart.slt deleted file mode 100644 index b2dd0089c282..000000000000 --- a/datafusion/sqllogictest/test_files/spark/datetime/datepart.slt +++ /dev/null @@ -1,62 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT datepart('MINUTE', INTERVAL '123 23:55:59.002001' DAY TO SECOND); -## PySpark 3.5.5 Result: {"datepart(MINUTE FROM INTERVAL '123 23:55:59.002001' DAY TO SECOND)": 55, "typeof(datepart(MINUTE FROM INTERVAL '123 23:55:59.002001' DAY TO SECOND))": 'tinyint', 'typeof(MINUTE)': 'string', "typeof(INTERVAL '123 23:55:59.002001' DAY TO SECOND)": 'interval day to second'} -#query -#SELECT datepart('MINUTE'::string, INTERVAL '123 23:55:59.002001' DAY TO SECOND::interval day to second); - -## Original Query: SELECT datepart('MONTH', INTERVAL '2021-11' YEAR TO MONTH); -## PySpark 3.5.5 Result: {"datepart(MONTH FROM INTERVAL '2021-11' YEAR TO MONTH)": 11, "typeof(datepart(MONTH FROM INTERVAL '2021-11' YEAR TO MONTH))": 'tinyint', 'typeof(MONTH)': 'string', "typeof(INTERVAL '2021-11' YEAR TO MONTH)": 'interval year to month'} -#query -#SELECT datepart('MONTH'::string, INTERVAL '2021-11' YEAR TO MONTH::interval year to month); - -## Original Query: SELECT datepart('SECONDS', timestamp'2019-10-01 00:00:01.000001'); -## PySpark 3.5.5 Result: {"datepart(SECONDS FROM TIMESTAMP '2019-10-01 00:00:01.000001')": Decimal('1.000001'), "typeof(datepart(SECONDS FROM TIMESTAMP '2019-10-01 00:00:01.000001'))": 'decimal(8,6)', 'typeof(SECONDS)': 'string', "typeof(TIMESTAMP '2019-10-01 00:00:01.000001')": 'timestamp'} -#query -#SELECT datepart('SECONDS'::string, TIMESTAMP '2019-10-01 00:00:01.000001'::timestamp); - -## Original Query: SELECT datepart('YEAR', TIMESTAMP '2019-08-12 01:00:00.123456'); -## PySpark 3.5.5 Result: {"datepart(YEAR FROM TIMESTAMP '2019-08-12 01:00:00.123456')": 2019, "typeof(datepart(YEAR FROM TIMESTAMP '2019-08-12 01:00:00.123456'))": 'int', 'typeof(YEAR)': 'string', "typeof(TIMESTAMP '2019-08-12 01:00:00.123456')": 'timestamp'} -#query -#SELECT datepart('YEAR'::string, TIMESTAMP '2019-08-12 01:00:00.123456'::timestamp); - -## Original Query: SELECT datepart('days', interval 5 days 3 hours 7 minutes); -## PySpark 3.5.5 Result: {"datepart(days FROM INTERVAL '5 03:07' DAY TO MINUTE)": 5, "typeof(datepart(days FROM INTERVAL '5 03:07' DAY TO MINUTE))": 'int', 'typeof(days)': 'string', "typeof(INTERVAL '5 03:07' DAY TO MINUTE)": 'interval day to minute'} -#query -#SELECT datepart('days'::string, INTERVAL '5 03:07' DAY TO MINUTE::interval day to minute); - -## Original Query: SELECT datepart('doy', DATE'2019-08-12'); -## PySpark 3.5.5 Result: {"datepart(doy FROM DATE '2019-08-12')": 224, "typeof(datepart(doy FROM DATE '2019-08-12'))": 'int', 'typeof(doy)': 'string', "typeof(DATE '2019-08-12')": 'date'} -#query -#SELECT datepart('doy'::string, DATE '2019-08-12'::date); - -## Original Query: SELECT datepart('seconds', interval 5 hours 30 seconds 1 milliseconds 1 microseconds); -## PySpark 3.5.5 Result: {"datepart(seconds FROM INTERVAL '05:00:30.001001' HOUR TO SECOND)": Decimal('30.001001'), "typeof(datepart(seconds FROM INTERVAL '05:00:30.001001' HOUR TO SECOND))": 'decimal(8,6)', 'typeof(seconds)': 'string', "typeof(INTERVAL '05:00:30.001001' HOUR TO SECOND)": 'interval hour to second'} -#query -#SELECT datepart('seconds'::string, INTERVAL '05:00:30.001001' HOUR TO SECOND::interval hour to second); - -## Original Query: SELECT datepart('week', timestamp'2019-08-12 01:00:00.123456'); -## PySpark 3.5.5 Result: {"datepart(week FROM TIMESTAMP '2019-08-12 01:00:00.123456')": 33, "typeof(datepart(week FROM TIMESTAMP '2019-08-12 01:00:00.123456'))": 'int', 'typeof(week)': 'string', "typeof(TIMESTAMP '2019-08-12 01:00:00.123456')": 'timestamp'} -#query -#SELECT datepart('week'::string, TIMESTAMP '2019-08-12 01:00:00.123456'::timestamp); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/from_utc_timestamp.slt b/datafusion/sqllogictest/test_files/spark/datetime/from_utc_timestamp.slt new file mode 100644 index 000000000000..5a39bda0a651 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/from_utc_timestamp.slt @@ -0,0 +1,156 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# String inputs +query P +SELECT from_utc_timestamp('2016-08-31'::string, 'UTC'::string); +---- +2016-08-31T00:00:00 + +query P +SELECT from_utc_timestamp('2016-08-31'::string, 'Asia/Seoul'::string); +---- +2016-08-31T09:00:00 + +query P +SELECT from_utc_timestamp('2016-08-31'::string, 'America/New_York'::string); +---- +2016-08-30T20:00:00 + +# String inputs with offsets +query P +SELECT from_utc_timestamp('2018-03-13T06:18:23+02:00'::string, 'UTC'::string); +---- +2018-03-13T04:18:23 + +query P +SELECT from_utc_timestamp('2018-03-13T06:18:23+02:00'::string, 'Asia/Seoul'::string); +---- +2018-03-13T13:18:23 + +query P +SELECT from_utc_timestamp('2018-03-13T06:18:23+02:00'::string, 'America/New_York'::string); +---- +2018-03-13T00:18:23 + +# Timestamp inputs +query P +SELECT from_utc_timestamp('2018-03-13T06:18:23+02:00'::timestamp, 'UTC'::string); +---- +2018-03-13T04:18:23 + +query P +SELECT from_utc_timestamp('2018-03-13T06:18:23+02:00'::timestamp, 'Asia/Seoul'::string); +---- +2018-03-13T13:18:23 + +query P +SELECT from_utc_timestamp('2018-03-13T06:18:23+02:00'::timestamp, 'America/New_York'::string); +---- +2018-03-13T00:18:23 + +# Null inputs +query P +SELECT from_utc_timestamp(NULL::string, 'Asia/Seoul'::string); +---- +NULL + +query P +SELECT from_utc_timestamp(NULL::timestamp, 'Asia/Seoul'::string); +---- +NULL + +query P +SELECT from_utc_timestamp('2016-08-31'::string, NULL::string); +---- +NULL + +query P +SELECT from_utc_timestamp(column1, column2) +FROM VALUES +('2016-08-31'::string, 'Asia/Seoul'::string), +('2018-03-13T06:18:23+02:00'::string, 'Asia/Seoul'::string), +('2016-08-31'::string, 'UTC'::string), +('2018-03-13T06:18:23+02:00'::string, 'UTC'::string), +('2016-08-31'::string, 'America/New_York'::string), +('2018-03-13T06:18:23+02:00'::string, 'America/New_York'::string), +(NULL::string, 'Asia/Seoul'::string), +('2016-08-31'::string, NULL::string); +---- +2016-08-31T09:00:00 +2018-03-13T13:18:23 +2016-08-31T00:00:00 +2018-03-13T04:18:23 +2016-08-30T20:00:00 +2018-03-13T00:18:23 +NULL +NULL + +query P +SELECT from_utc_timestamp(column1, column2) +FROM VALUES +('2016-08-31'::timestamp, 'Asia/Seoul'::string), +('2018-03-13T06:18:23+02:00'::timestamp, 'Asia/Seoul'::string), +('2016-08-31'::timestamp, 'UTC'::string), +('2018-03-13T06:18:23+02:00'::timestamp, 'UTC'::string), +('2016-08-31'::timestamp, 'America/New_York'::string), +('2018-03-13T06:18:23+02:00'::timestamp, 'America/New_York'::string), +(NULL::timestamp, 'Asia/Seoul'::string), +('2018-03-13T06:18:23+00:00'::timestamp, NULL::string); +---- +2016-08-31T09:00:00 +2018-03-13T13:18:23 +2016-08-31T00:00:00 +2018-03-13T04:18:23 +2016-08-30T20:00:00 +2018-03-13T00:18:23 +NULL +NULL + +query P +SELECT from_utc_timestamp(arrow_cast(column1, 'Timestamp(Microsecond, Some("Asia/Seoul"))'), column2) +FROM VALUES +('2016-08-31'::timestamp, 'Asia/Seoul'::string), +('2018-03-13T06:18:23+02:00'::timestamp, 'Asia/Seoul'::string), +('2016-08-31'::timestamp, 'UTC'::string), +('2018-03-13T06:18:23+02:00'::timestamp, 'UTC'::string), +('2016-08-31'::timestamp, 'America/New_York'::string), +('2018-03-13T06:18:23+02:00'::timestamp, 'America/New_York'::string), +(NULL::timestamp, 'Asia/Seoul'::string), +('2018-03-13T06:18:23+00:00'::timestamp, NULL::string); +---- +2016-08-31T09:00:00+09:00 +2018-03-13T13:18:23+09:00 +2016-08-31T00:00:00+09:00 +2018-03-13T04:18:23+09:00 +2016-08-30T20:00:00+09:00 +2018-03-13T00:18:23+09:00 +NULL +NULL + + +# DST edge cases +query P +SELECT from_utc_timestamp('2020-03-31T13:40:00'::timestamp, 'America/New_York'::string); +---- +2020-03-31T09:40:00 + + +query P +SELECT from_utc_timestamp('2020-11-04T14:06:40'::timestamp, 'America/New_York'::string); +---- +2020-11-04T09:06:40 diff --git a/datafusion/sqllogictest/test_files/spark/datetime/make_interval.slt b/datafusion/sqllogictest/test_files/spark/datetime/make_interval.slt index d6c5199b87b7..a796094979d9 100644 --- a/datafusion/sqllogictest/test_files/spark/datetime/make_interval.slt +++ b/datafusion/sqllogictest/test_files/spark/datetime/make_interval.slt @@ -90,21 +90,15 @@ SELECT make_interval(0, 0, 0, 0, 2147483647, 1, 0.0); ---- NULL -# Intervals being rendered as empty string, see issue: -# https://github.com/apache/datafusion/issues/17455 -# We expect something like 0.00 secs with query ? query T SELECT make_interval(0, 0, 0, 0, 0, 0, 0.0) || ''; ---- -(empty) +0 secs -# Intervals being rendered as empty string, see issue: -# https://github.com/apache/datafusion/issues/17455 -# We expect something like 0.00 secs with query ? query T SELECT make_interval() || ''; ---- -(empty) +0 secs query ? SELECT INTERVAL '1' SECOND AS iv; diff --git a/datafusion/sqllogictest/test_files/spark/datetime/time_trunc.slt b/datafusion/sqllogictest/test_files/spark/datetime/time_trunc.slt new file mode 100644 index 000000000000..35ffa483bb06 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/time_trunc.slt @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# HOUR - zero out minute and second with fraction +query D +SELECT time_trunc('HOUR', '09:32:05.123456'::time); +---- +09:00:00 + +# MINUTE - zero out second with fraction +query D +SELECT time_trunc('MINUTE', '09:32:05.123456'::time); +---- +09:32:00 + +# SECOND - zero out fraction +query D +SELECT time_trunc('SECOND', '09:32:05.123456'::time); +---- +09:32:05 + +# MILLISECOND - zero out microseconds +query D +SELECT time_trunc('MILLISECOND', '09:32:05.123456'::time); +---- +09:32:05.123 + +# MICROSECOND - everything remains +query D +SELECT time_trunc('MICROSECOND', '09:32:05.123456'::time); +---- +09:32:05.123456 + +query D +SELECT time_trunc('HOUR', column1) +FROM VALUES +('09:32:05.123456'::time), +('22:45:30.654321'::time), +('14:20:10.000001'::time), +(NULL::time); +---- +09:00:00 +22:00:00 +14:00:00 +NULL + + +# Null handling +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: First argument of `TIME_TRUNC` must be non-null scalar Utf8 +SELECT time_trunc(NULL, '09:32:05.123456'::time); + +query D +SELECT time_trunc('HOUR', NULL::time); +---- +NULL + +# incorrect format +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: The format argument of `TIME_TRUNC` must be one of: hour, minute, second, millisecond, microsecond +SELECT time_trunc('test', '09:32:05.123456'::time); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/to_utc_timestamp.slt b/datafusion/sqllogictest/test_files/spark/datetime/to_utc_timestamp.slt index 24693016be1a..086716e5bcd0 100644 --- a/datafusion/sqllogictest/test_files/spark/datetime/to_utc_timestamp.slt +++ b/datafusion/sqllogictest/test_files/spark/datetime/to_utc_timestamp.slt @@ -15,13 +15,143 @@ # specific language governing permissions and limitations # under the License. -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT to_utc_timestamp('2016-08-31', 'Asia/Seoul'); -## PySpark 3.5.5 Result: {'to_utc_timestamp(2016-08-31, Asia/Seoul)': datetime.datetime(2016, 8, 30, 15, 0), 'typeof(to_utc_timestamp(2016-08-31, Asia/Seoul))': 'timestamp', 'typeof(2016-08-31)': 'string', 'typeof(Asia/Seoul)': 'string'} -#query -#SELECT to_utc_timestamp('2016-08-31'::string, 'Asia/Seoul'::string); + +# String inputs +query P +SELECT to_utc_timestamp('2016-08-31'::string, 'UTC'::string); +---- +2016-08-31T00:00:00 + +query P +SELECT to_utc_timestamp('2016-08-31'::string, 'Asia/Seoul'::string); +---- +2016-08-30T15:00:00 + +query P +SELECT to_utc_timestamp('2016-08-31'::string, 'America/New_York'::string); +---- +2016-08-31T04:00:00 + +# String inputs with offsets +query P +SELECT to_utc_timestamp('2018-03-13T06:18:23+02:00'::string, 'UTC'::string); +---- +2018-03-13T04:18:23 + +query P +SELECT to_utc_timestamp('2018-03-13T06:18:23+02:00'::string, 'Asia/Seoul'::string); +---- +2018-03-12T19:18:23 + +query P +SELECT to_utc_timestamp('2018-03-13T06:18:23+02:00'::string, 'America/New_York'::string); +---- +2018-03-13T08:18:23 + +# Timestamp inputs +query P +SELECT to_utc_timestamp('2018-03-13T06:18:23+02:00'::timestamp, 'UTC'::string); +---- +2018-03-13T04:18:23 + +query P +SELECT to_utc_timestamp('2018-03-13T06:18:23+02:00'::timestamp, 'Asia/Seoul'::string); +---- +2018-03-12T19:18:23 + +query P +SELECT to_utc_timestamp('2018-03-13T06:18:23+02:00'::timestamp, 'America/New_York'::string); +---- +2018-03-13T08:18:23 + +# Null inputs +query P +SELECT to_utc_timestamp(NULL::string, 'Asia/Seoul'::string); +---- +NULL + +query P +SELECT to_utc_timestamp(NULL::timestamp, 'Asia/Seoul'::string); +---- +NULL + +query P +SELECT to_utc_timestamp('2016-08-31'::string, NULL::string); +---- +NULL + +query P +SELECT to_utc_timestamp(column1, column2) +FROM VALUES +('2016-08-31'::string, 'Asia/Seoul'::string), +('2018-03-13T06:18:23+02:00'::string, 'Asia/Seoul'::string), +('2016-08-31'::string, 'UTC'::string), +('2018-03-13T06:18:23+02:00'::string, 'UTC'::string), +('2016-08-31'::string, 'America/New_York'::string), +('2018-03-13T06:18:23+02:00'::string, 'America/New_York'::string), +(NULL::string, 'Asia/Seoul'::string), +('2016-08-31'::string, NULL::string); +---- +2016-08-30T15:00:00 +2018-03-12T19:18:23 +2016-08-31T00:00:00 +2018-03-13T04:18:23 +2016-08-31T04:00:00 +2018-03-13T08:18:23 +NULL +NULL + +query P +SELECT to_utc_timestamp(column1, column2) +FROM VALUES +('2016-08-31'::timestamp, 'Asia/Seoul'::string), +('2018-03-13T06:18:23+02:00'::timestamp, 'Asia/Seoul'::string), +('2016-08-31'::timestamp, 'UTC'::string), +('2018-03-13T06:18:23+02:00'::timestamp, 'UTC'::string), +('2016-08-31'::timestamp, 'America/New_York'::string), +('2018-03-13T06:18:23+02:00'::timestamp, 'America/New_York'::string), +(NULL::timestamp, 'Asia/Seoul'::string), +('2018-03-13T06:18:23+00:00'::timestamp, NULL::string); +---- +2016-08-30T15:00:00 +2018-03-12T19:18:23 +2016-08-31T00:00:00 +2018-03-13T04:18:23 +2016-08-31T04:00:00 +2018-03-13T08:18:23 +NULL +NULL + +query P +SELECT to_utc_timestamp(arrow_cast(column1, 'Timestamp(Microsecond, Some("Asia/Seoul"))'), column2) +FROM VALUES +('2016-08-31'::timestamp, 'Asia/Seoul'::string), +('2018-03-13T06:18:23+02:00'::timestamp, 'Asia/Seoul'::string), +('2016-08-31'::timestamp, 'UTC'::string), +('2018-03-13T06:18:23+02:00'::timestamp, 'UTC'::string), +('2016-08-31'::timestamp, 'America/New_York'::string), +('2018-03-13T06:18:23+02:00'::timestamp, 'America/New_York'::string), +(NULL::timestamp, 'Asia/Seoul'::string), +('2018-03-13T06:18:23+00:00'::timestamp, NULL::string); +---- +2016-08-30T15:00:00+09:00 +2018-03-12T19:18:23+09:00 +2016-08-31T00:00:00+09:00 +2018-03-13T04:18:23+09:00 +2016-08-31T04:00:00+09:00 +2018-03-13T08:18:23+09:00 +NULL +NULL + + +# DST edge cases +query P +SELECT to_utc_timestamp('2020-03-31T13:40:00'::timestamp, 'America/New_York'::string); +---- +2020-03-31T17:40:00 + + +query P +SELECT to_utc_timestamp('2020-11-04T14:06:40'::timestamp, 'America/New_York'::string); +---- +2020-11-04T19:06:40 diff --git a/datafusion/sqllogictest/test_files/spark/datetime/trunc.slt b/datafusion/sqllogictest/test_files/spark/datetime/trunc.slt index a502e2f7f7b0..aa26d7bd0ef0 100644 --- a/datafusion/sqllogictest/test_files/spark/datetime/trunc.slt +++ b/datafusion/sqllogictest/test_files/spark/datetime/trunc.slt @@ -15,28 +15,78 @@ # specific language governing permissions and limitations # under the License. -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT trunc('2009-02-12', 'MM'); -## PySpark 3.5.5 Result: {'trunc(2009-02-12, MM)': datetime.date(2009, 2, 1), 'typeof(trunc(2009-02-12, MM))': 'date', 'typeof(2009-02-12)': 'string', 'typeof(MM)': 'string'} -#query -#SELECT trunc('2009-02-12'::string, 'MM'::string); - -## Original Query: SELECT trunc('2015-10-27', 'YEAR'); -## PySpark 3.5.5 Result: {'trunc(2015-10-27, YEAR)': datetime.date(2015, 1, 1), 'typeof(trunc(2015-10-27, YEAR))': 'date', 'typeof(2015-10-27)': 'string', 'typeof(YEAR)': 'string'} -#query -#SELECT trunc('2015-10-27'::string, 'YEAR'::string); - -## Original Query: SELECT trunc('2019-08-04', 'quarter'); -## PySpark 3.5.5 Result: {'trunc(2019-08-04, quarter)': datetime.date(2019, 7, 1), 'typeof(trunc(2019-08-04, quarter))': 'date', 'typeof(2019-08-04)': 'string', 'typeof(quarter)': 'string'} -#query -#SELECT trunc('2019-08-04'::string, 'quarter'::string); - -## Original Query: SELECT trunc('2019-08-04', 'week'); -## PySpark 3.5.5 Result: {'trunc(2019-08-04, week)': datetime.date(2019, 7, 29), 'typeof(trunc(2019-08-04, week))': 'date', 'typeof(2019-08-04)': 'string', 'typeof(week)': 'string'} -#query -#SELECT trunc('2019-08-04'::string, 'week'::string); +# YEAR - truncate to first date of year +query D +SELECT trunc('2009-02-12'::date, 'YEAR'::string); +---- +2009-01-01 + +query D +SELECT trunc('2009-02-12'::date, 'YYYY'::string); +---- +2009-01-01 + +query D +SELECT trunc('2009-02-12'::date, 'YY'::string); +---- +2009-01-01 + +# QUARTER - truncate to first date of quarter +query D +SELECT trunc('2009-02-12'::date, 'QUARTER'::string); +---- +2009-01-01 + +# MONTH - truncate to first date of month +query D +SELECT trunc('2009-02-12'::date, 'MONTH'::string); +---- +2009-02-01 + +query D +SELECT trunc('2009-02-12'::date, 'MM'::string); +---- +2009-02-01 + +query D +SELECT trunc('2009-02-12'::date, 'MON'::string); +---- +2009-02-01 + +# WEEK - truncate to Monday of the week +query D +SELECT trunc('2009-02-12'::date, 'WEEK'::string); +---- +2009-02-09 + +# string input +query D +SELECT trunc('2009-02-12'::string, 'YEAR'::string); +---- +2009-01-01 + +query D +SELECT trunc(column1, 'YEAR'::string) +FROM VALUES +('2009-02-12'::date), +('2000-02-12'::date), +('2042-02-12'::date), +(NULL::date); +---- +2009-01-01 +2000-01-01 +2042-01-01 +NULL + +# Null handling +query D +SELECT trunc(NULL::date, 'YEAR'::string); +---- +NULL + +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Second argument of `TRUNC` must be non-null scalar Utf8 +SELECT trunc('2009-02-12'::date, NULL::string); + +# incorrect format +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: The format argument of `TRUNC` must be one of: year, yy, yyyy, month, mm, mon, day, week, quarter. +SELECT trunc('2009-02-12'::date, 'test'::string); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/unix.slt b/datafusion/sqllogictest/test_files/spark/datetime/unix.slt new file mode 100644 index 000000000000..d7441f487d03 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/unix.slt @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Unix Date tests + +query I +SELECT unix_date('1970-01-02'::date); +---- +1 + +query I +SELECT unix_date('1900-01-02'::date); +---- +-25566 + + +query I +SELECT unix_date(arrow_cast('1970-01-02', 'Date64')); +---- +1 + +query I +SELECT unix_date(NULL::date); +---- +NULL + +query error Function 'unix_date' requires TypeSignatureClass::Native\(LogicalType\(Native\(Date\), Date\)\), but received String \(DataType: Utf8View\) +SELECT unix_date('1970-01-02'::string); + +# Unix Micro Tests + +query I +SELECT unix_micros('1970-01-01 00:00:01Z'::timestamp); +---- +1000000 + +query I +SELECT unix_micros('1900-01-01 00:00:01Z'::timestamp); +---- +-2208988799000000 + +query I +SELECT unix_micros(arrow_cast('1970-01-01 00:00:01+02:00', 'Timestamp(Microsecond, None)')); +---- +-7199000000 + +query I +SELECT unix_micros(arrow_cast('1970-01-01 00:00:01Z', 'Timestamp(Second, None)')); +---- +1000000 + +query I +SELECT unix_micros(NULL::timestamp); +---- +NULL + +query error Function 'unix_micros' requires TypeSignatureClass::Timestamp, but received String \(DataType: Utf8View\) +SELECT unix_micros('1970-01-01 00:00:01Z'::string); + + +# Unix Millis Tests + +query I +SELECT unix_millis('1970-01-01 00:00:01Z'::timestamp); +---- +1000 + +query I +SELECT unix_millis('1900-01-01 00:00:01Z'::timestamp); +---- +-2208988799000 + +query I +SELECT unix_millis(arrow_cast('1970-01-01 00:00:01+02:00', 'Timestamp(Microsecond, None)')); +---- +-7199000 + +query I +SELECT unix_millis(arrow_cast('1970-01-01 00:00:01Z', 'Timestamp(Second, None)')); +---- +1000 + +query I +SELECT unix_millis(NULL::timestamp); +---- +NULL + +query error Function 'unix_millis' requires TypeSignatureClass::Timestamp, but received String \(DataType: Utf8View\) +SELECT unix_millis('1970-01-01 00:00:01Z'::string); + + +# Unix Seconds Tests + +query I +SELECT unix_seconds('1970-01-01 00:00:01Z'::timestamp); +---- +1 + +query I +SELECT unix_seconds('1900-01-01 00:00:01Z'::timestamp); +---- +-2208988799 + +query I +SELECT unix_seconds(arrow_cast('1970-01-01 00:00:01+02:00', 'Timestamp(Microsecond, None)')); +---- +-7199 + +query I +SELECT unix_seconds(arrow_cast('1970-01-01 00:00:01Z', 'Timestamp(Second, None)')); +---- +1 + +query I +SELECT unix_seconds(NULL::timestamp); +---- +NULL + +query error Function 'unix_seconds' requires TypeSignatureClass::Timestamp, but received String \(DataType: Utf8View\) +SELECT unix_seconds('1970-01-01 00:00:01Z'::string); diff --git a/datafusion/sqllogictest/test_files/spark/hash/crc32.slt b/datafusion/sqllogictest/test_files/spark/hash/crc32.slt index 6fbeb11fb9a3..df5588c75837 100644 --- a/datafusion/sqllogictest/test_files/spark/hash/crc32.slt +++ b/datafusion/sqllogictest/test_files/spark/hash/crc32.slt @@ -81,7 +81,7 @@ SELECT crc32(arrow_cast('Spark', 'BinaryView')); ---- 1557323817 -# Upstream arrow-rs issue: https://github.com/apache/arrow-rs/issues/8841 -# This should succeed after we receive the fix -query error Arrow error: Compute error: Internal Error: Cannot cast BinaryView to BinaryArray of expected type +query I select crc32(arrow_cast(null, 'Dictionary(Int32, Utf8)')) +---- +NULL diff --git a/datafusion/sqllogictest/test_files/spark/hash/sha2.slt b/datafusion/sqllogictest/test_files/spark/hash/sha2.slt index 7690a38773b0..07f70947fe92 100644 --- a/datafusion/sqllogictest/test_files/spark/hash/sha2.slt +++ b/datafusion/sqllogictest/test_files/spark/hash/sha2.slt @@ -75,3 +75,58 @@ SELECT sha2(expr, bit_length) FROM VALUES ('foo',0::INT), ('bar',224::INT), ('ba 967004d25de4abc1bd6a7c9a216254a5ac0733e8ad96dc9f1ea0fad9619da7c32d654ec8ad8ba2f9b5728fed6633bd91 8c6be9ed448a34883a13a13f4ead4aefa036b67dcda59020c01e57ea075ea8a4792d428f2c6fd0c09d1c49994d6c22789336e062188df29572ed07e7f9779c52 NULL + +# All string types +query T +SELECT sha2(arrow_cast('foo', 'Utf8'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length); +---- +0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db +2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae + +query T +SELECT sha2(arrow_cast('foo', 'LargeUtf8'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length); +---- +0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db +2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae + +query T +SELECT sha2(arrow_cast('foo', 'Utf8View'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length); +---- +0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db +2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae + +# All binary types +query T +SELECT sha2(arrow_cast('foo', 'Binary'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length); +---- +0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db +2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae + +query T +SELECT sha2(arrow_cast('foo', 'LargeBinary'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length); +---- +0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db +2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae + +query T +SELECT sha2(arrow_cast('foo', 'BinaryView'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length); +---- +0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db +2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae + + +# Null cases +query T +select sha2(null, 0); +---- +NULL + +query T +select sha2('a', null); +---- +NULL + +query T +select sha2('a', null::int); +---- +NULL diff --git a/datafusion/sqllogictest/test_files/spark/json/json_tuple.slt b/datafusion/sqllogictest/test_files/spark/json/json_tuple.slt new file mode 100644 index 000000000000..c0c424946709 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/json/json_tuple.slt @@ -0,0 +1,154 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tests for Spark-compatible json_tuple function +# https://spark.apache.org/docs/latest/api/sql/index.html#json_tuple +# +# Test cases derived from Spark JsonExpressionsSuite: +# https://github.com/apache/spark/blob/master/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala + +# Scalar: hive key 1 +query ? +SELECT json_tuple('{"f1":"value1","f2":"value2","f3":3,"f5":5.23}'::STRING, 'f1'::STRING, 'f2'::STRING, 'f3'::STRING, 'f4'::STRING, 'f5'::STRING); +---- +{c0: value1, c1: value2, c2: 3, c3: NULL, c4: 5.23} + +# Scalar: hive key 2 +query ? +SELECT json_tuple('{"f1":"value12","f3":"value3","f2":2,"f4":4.01}'::STRING, 'f1'::STRING, 'f2'::STRING, 'f3'::STRING, 'f4'::STRING, 'f5'::STRING); +---- +{c0: value12, c1: 2, c2: value3, c3: 4.01, c4: NULL} + +# Scalar: hive key 3 +query ? +SELECT json_tuple('{"f1":"value13","f4":"value44","f3":"value33","f2":2,"f5":5.01}'::STRING, 'f1'::STRING, 'f2'::STRING, 'f3'::STRING, 'f4'::STRING, 'f5'::STRING); +---- +{c0: value13, c1: 2, c2: value33, c3: value44, c4: 5.01} + +# Scalar: null JSON input +query ? +SELECT json_tuple(NULL::STRING, 'f1'::STRING, 'f2'::STRING, 'f3'::STRING, 'f4'::STRING, 'f5'::STRING); +---- +NULL + +# Scalar: null and empty values +query ? +SELECT json_tuple('{"f1":"","f5":null}'::STRING, 'f1'::STRING, 'f2'::STRING, 'f3'::STRING, 'f4'::STRING, 'f5'::STRING); +---- +{c0: , c1: NULL, c2: NULL, c3: NULL, c4: NULL} + +# Scalar: invalid JSON (array) +query ? +SELECT json_tuple('[invalid JSON string]'::STRING, 'f1'::STRING); +---- +NULL + +# Scalar: invalid JSON (start only) +query ? +SELECT json_tuple('{'::STRING, 'f1'::STRING); +---- +NULL + +# Scalar: invalid JSON (no closing brace) +query ? +SELECT json_tuple('{"foo":"bar"'::STRING, 'f1'::STRING); +---- +NULL + +# Scalar: invalid JSON (backslash) +query ? +SELECT json_tuple('\'::STRING, 'f1'::STRING); +---- +NULL + +# Scalar: invalid JSON (quoted string, not an object) +query ? +SELECT json_tuple('"quote'::STRING, '"quote'::STRING); +---- +NULL + +# Scalar: empty JSON object +query ? +SELECT json_tuple('{}'::STRING, 'a'::STRING); +---- +{c0: NULL} + +# Array: multi-row test +query ? +SELECT json_tuple(col, 'f1'::STRING, 'f2'::STRING) FROM (VALUES + ('{"f1":"a","f2":"b"}'::STRING), + (NULL::STRING), + ('{"f1":"c"}'::STRING), + ('invalid'::STRING) +) AS t(col); +---- +{c0: a, c1: b} +NULL +{c0: c, c1: NULL} +NULL + +# Array: SPARK-21677 null field key +query ? +SELECT json_tuple(col1, col2, col3, col4) FROM (VALUES + ('{"f1":1,"f2":2}'::STRING, 'f1'::STRING, NULL::STRING, 'f2'::STRING) +) AS t(col1, col2, col3, col4); +---- +{c0: 1, c1: NULL, c2: 2} + +# Array: SPARK-21804 repeated field +query ? +SELECT json_tuple(col1, col2, col3, col4) FROM (VALUES + ('{"f1":1,"f2":2}'::STRING, 'f1'::STRING, NULL::STRING, 'f1'::STRING) +) AS t(col1, col2, col3, col4); +---- +{c0: 1, c1: NULL, c2: 1} + +# Edge case: both json and field key are null +query ? +SELECT json_tuple(NULL::STRING, NULL::STRING); +---- +NULL + +# Edge case: empty string json and empty string key +query ? +SELECT json_tuple(''::STRING, ''::STRING); +---- +NULL + +# Edge case: mixed upper/lower case keys +query ? +SELECT json_tuple('{"Name":"Alice","name":"bob","NAME":"Charlie"}'::STRING, 'Name'::STRING, 'name'::STRING, 'NAME'::STRING); +---- +{c0: Alice, c1: bob, c2: Charlie} + +# Edge case: UTF-8 Chinese characters +query ? +SELECT json_tuple('{"姓名":"小明","城市":"台北"}'::STRING, '姓名'::STRING, '城市'::STRING); +---- +{c0: 小明, c1: 台北} + +# Edge case: UTF-8 Cyrillic characters +query ? +SELECT json_tuple('{"имя":"Иван","город":"Москва"}'::STRING, 'имя'::STRING, 'город'::STRING); +---- +{c0: Иван, c1: Москва} + +# Verify return type with arrow_typeof +query T +SELECT arrow_typeof(json_tuple('{"a":1}'::STRING, 'a'::STRING)); +---- +Struct("c0": Utf8) diff --git a/datafusion/sqllogictest/test_files/spark/map/str_to_map.slt b/datafusion/sqllogictest/test_files/spark/map/str_to_map.slt new file mode 100644 index 000000000000..30d1672aef0a --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/map/str_to_map.slt @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tests for Spark-compatible str_to_map function +# https://spark.apache.org/docs/latest/api/sql/index.html#str_to_map +# +# Test cases derived from Spark test("StringToMap"): +# https://github.com/apache/spark/blob/v4.0.0/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala#L525-L618 + +# s0: Basic test with default delimiters +query ? +SELECT str_to_map('a:1,b:2,c:3'); +---- +{a: 1, b: 2, c: 3} + +# s1: Preserve spaces in values +query ? +SELECT str_to_map('a: ,b:2'); +---- +{a: , b: 2} + +# s2: Custom key-value delimiter '=' +query ? +SELECT str_to_map('a=1,b=2,c=3', ',', '='); +---- +{a: 1, b: 2, c: 3} + +# s3: Empty string returns map with empty key and NULL value +query ? +SELECT str_to_map('', ',', '='); +---- +{: NULL} + +# s4: Custom pair delimiter '_' +query ? +SELECT str_to_map('a:1_b:2_c:3', '_', ':'); +---- +{a: 1, b: 2, c: 3} + +# s5: Single key without value returns NULL value +query ? +SELECT str_to_map('a'); +---- +{a: NULL} + +# s6: Custom delimiters '&' and '=' +query ? +SELECT str_to_map('a=1&b=2&c=3', '&', '='); +---- +{a: 1, b: 2, c: 3} + +# Duplicate keys: EXCEPTION policy (Spark 3.0+ default) +# TODO: Add LAST_WIN policy tests when spark.sql.mapKeyDedupPolicy config is supported +statement error +Duplicate map key +SELECT str_to_map('a:1,b:2,a:3'); + +# Additional tests (DataFusion-specific) + +# NULL input returns NULL +query ? +SELECT str_to_map(NULL, ',', ':'); +---- +NULL + +# Explicit 3-arg form +query ? +SELECT str_to_map('a:1,b:2,c:3', ',', ':'); +---- +{a: 1, b: 2, c: 3} + +# Missing key-value delimiter results in NULL value +query ? +SELECT str_to_map('a,b:2', ',', ':'); +---- +{a: NULL, b: 2} + +# Multi-row test +query ? +SELECT str_to_map(col) FROM (VALUES ('a:1,b:2'), ('x:9'), (NULL)) AS t(col); +---- +{a: 1, b: 2} +{x: 9} +NULL + +# Multi-row with custom delimiter +query ? +SELECT str_to_map(col, ',', '=') FROM (VALUES ('a=1,b=2'), ('x=9'), (NULL)) AS t(col); +---- +{a: 1, b: 2} +{x: 9} +NULL + +# Per-row delimiters: each row can have different delimiters +query ? +SELECT str_to_map(col1, col2, col3) FROM (VALUES ('a=1,b=2', ',', '='), ('x#9', ',', '#'), (NULL, ',', '=')) AS t(col1, col2, col3); +---- +{a: 1, b: 2} +{x: 9} +NULL \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/spark/math/abs.slt b/datafusion/sqllogictest/test_files/spark/math/abs.slt index 19ca902ea3de..94092caab985 100644 --- a/datafusion/sqllogictest/test_files/spark/math/abs.slt +++ b/datafusion/sqllogictest/test_files/spark/math/abs.slt @@ -24,71 +24,187 @@ ## Original Query: SELECT abs(-1); ## PySpark 3.5.5 Result: {'abs(-1)': 1, 'typeof(abs(-1))': 'int', 'typeof(-1)': 'int'} -# abs: signed int and NULL +# Scalar input +## Scalar input: signed int and NULL query IIIIR SELECT abs(-127::TINYINT), abs(-32767::SMALLINT), abs(-2147483647::INT), abs(-9223372036854775807::BIGINT), abs(NULL); ---- 127 32767 2147483647 9223372036854775807 NULL - -# See https://github.com/apache/datafusion/issues/18794 for operator precedence -# abs: signed int minimal values +## Scalar input: signed int minimal values +## See https://github.com/apache/datafusion/issues/18794 for operator precedence query IIII -select abs((-128)::TINYINT), abs((-32768)::SMALLINT), abs((-2147483648)::INT), abs((-9223372036854775808)::BIGINT) +select abs((-128)::TINYINT), abs((-32768)::SMALLINT), abs((-2147483648)::INT), abs((-9223372036854775808)::BIGINT); ---- -128 -32768 -2147483648 -9223372036854775808 -# abs: floats, NULL, NaN, -0, infinity, -infinity +## Scalar input: Spark ANSI mode, signed int minimal values +statement ok +set datafusion.execution.enable_ansi_mode = true; + +query error DataFusion error: Arrow error: Compute error: Int8 overflow on abs\(\-128\) +select abs((-128)::TINYINT); + +query error DataFusion error: Arrow error: Compute error: Int16 overflow on abs\(\-32768\) +select abs((-32768)::SMALLINT); + +query error DataFusion error: Arrow error: Compute error: Int32 overflow on abs\(\-2147483648\) +select abs((-2147483648)::INT); + +query error DataFusion error: Arrow error: Compute error: Int64 overflow on abs\(\-9223372036854775808\) +select abs((-9223372036854775808)::BIGINT); + +statement ok +set datafusion.execution.enable_ansi_mode = false; + +## Scalar input: float, NULL, NaN, -0, infinity, -infinity query RRRRRRRRRRRR -SELECT abs(-1.0::FLOAT), abs(0.::FLOAT), abs(-0.::FLOAT), abs(-0::FLOAT), abs(NULL::FLOAT), abs('NaN'::FLOAT), abs('inf'::FLOAT), abs('+inf'::FLOAT), abs('-inf'::FLOAT), abs('infinity'::FLOAT), abs('+infinity'::FLOAT), abs('-infinity'::FLOAT) +SELECT abs(-1.0::FLOAT), abs(0.::FLOAT), abs(-0.::FLOAT), abs(-0::FLOAT), abs(NULL::FLOAT), abs('NaN'::FLOAT), abs('inf'::FLOAT), abs('+inf'::FLOAT), abs('-inf'::FLOAT), abs('infinity'::FLOAT), abs('+infinity'::FLOAT), abs('-infinity'::FLOAT); ---- 1 0 0 0 NULL NaN Infinity Infinity Infinity Infinity Infinity Infinity -# abs: doubles, NULL, NaN, -0, infinity, -infinity +## Scalar input: double, NULL, NaN, -0, infinity, -infinity query RRRRRRRRRRRR -SELECT abs(-1.0::DOUBLE), abs(0.::DOUBLE), abs(-0.::DOUBLE), abs(-0::DOUBLE), abs(NULL::DOUBLE), abs('NaN'::DOUBLE), abs('inf'::DOUBLE), abs('+inf'::DOUBLE), abs('-inf'::DOUBLE), abs('infinity'::DOUBLE), abs('+infinity'::DOUBLE), abs('-infinity'::DOUBLE) +SELECT abs(-1.0::DOUBLE), abs(0.::DOUBLE), abs(-0.::DOUBLE), abs(-0::DOUBLE), abs(NULL::DOUBLE), abs('NaN'::DOUBLE), abs('inf'::DOUBLE), abs('+inf'::DOUBLE), abs('-inf'::DOUBLE), abs('infinity'::DOUBLE), abs('+infinity'::DOUBLE), abs('-infinity'::DOUBLE); ---- 1 0 0 0 NULL NaN Infinity Infinity Infinity Infinity Infinity Infinity -# abs: decimal128 and decimal256 -statement ok -CREATE TABLE test_nullable_decimal( - c1 DECIMAL(10, 2), /* Decimal128 */ - c2 DECIMAL(38, 10), /* Decimal128 with max precision */ - c3 DECIMAL(40, 2), /* Decimal256 */ - c4 DECIMAL(76, 10) /* Decimal256 with max precision */ - ) AS VALUES - (0, 0, 0, 0), - (NULL, NULL, NULL, NULL); +## Scalar input: decimal128 +query RRR +SELECT abs(('-99999999.99')::DECIMAL(10, 2)), abs(0::DECIMAL(10, 2)), abs(NULL::DECIMAL(10, 2)); +---- +99999999.99 0 NULL + +query RRR +SELECT abs(('-9999999999999999999999999999.9999999999')::DECIMAL(38, 10)), abs(0::DECIMAL(38, 10)), abs(NULL::DECIMAL(38, 10)); +---- +9999999999999999999999999999.9999999999 0 NULL + +## Scalar input: decimal256 +query RRR +SELECT abs(('-99999999999999999999999999999999999999.99')::DECIMAL(40, 2)), abs(0::DECIMAL(40, 2)), abs(NULL::DECIMAL(40, 2)); +---- +99999999999999999999999999999999999999.99 0 NULL + +query RRR +SELECT abs(('-999999999999999999999999999999999999999999999999999999999999999999.9999999999')::DECIMAL(76, 10)), abs(0::DECIMAL(76, 10)), abs(NULL::DECIMAL(76, 10)); +---- +999999999999999999999999999999999999999999999999999999999999999999.9999999999 0 NULL + + +# Array input +## Array input: signed int, signed int minimal values and NULL +query I +SELECT abs(a) FROM (VALUES (-127::TINYINT), ((-128)::TINYINT), (NULL)) AS t(a); +---- +127 +-128 +NULL + +query I +select abs(a) FROM (VALUES (-32767::SMALLINT), ((-32768)::SMALLINT), (NULL)) AS t(a); +---- +32767 +-32768 +NULL + +query I +select abs(a) FROM (VALUES (-2147483647::INT), ((-2147483648)::INT), (NULL)) AS t(a); +---- +2147483647 +-2147483648 +NULL query I -INSERT into test_nullable_decimal values - ( - -99999999.99, - '-9999999999999999999999999999.9999999999', - '-99999999999999999999999999999999999999.99', - '-999999999999999999999999999999999999999999999999999999999999999999.9999999999' - ), - ( - 99999999.99, - '9999999999999999999999999999.9999999999', - '99999999999999999999999999999999999999.99', - '999999999999999999999999999999999999999999999999999999999999999999.9999999999' - ) ----- -2 - -query RRRR rowsort -SELECT abs(c1), abs(c2), abs(c3), abs(c4) FROM test_nullable_decimal ----- -0 0 0 0 -99999999.99 9999999999999999999999999999.9999999999 99999999999999999999999999999999999999.99 999999999999999999999999999999999999999999999999999999999999999999.9999999999 -99999999.99 9999999999999999999999999999.9999999999 99999999999999999999999999999999999999.99 999999999999999999999999999999999999999999999999999999999999999999.9999999999 -NULL NULL NULL NULL +select abs(a) FROM (VALUES (-9223372036854775807::BIGINT), ((-9223372036854775808)::BIGINT), (NULL)) AS t(a); +---- +9223372036854775807 +-9223372036854775808 +NULL + +## Array Input: Spark ANSI mode, signed int minimal values +statement ok +set datafusion.execution.enable_ansi_mode = true; + +query error DataFusion error: Arrow error: Compute error: Int8Array overflow on abs\(\-128\) +SELECT abs(a) FROM (VALUES (-127::TINYINT), ((-128)::TINYINT)) AS t(a); + +query error DataFusion error: Arrow error: Compute error: Int16Array overflow on abs\(\-32768\) +select abs(a) FROM (VALUES (-32767::SMALLINT), ((-32768)::SMALLINT)) AS t(a); +query error DataFusion error: Arrow error: Compute error: Int32Array overflow on abs\(\-2147483648\) +select abs(a) FROM (VALUES (-2147483647::INT), ((-2147483648)::INT)) AS t(a); + +query error DataFusion error: Arrow error: Compute error: Int64Array overflow on abs\(\-9223372036854775808\) +select abs(a) FROM (VALUES (-9223372036854775807::BIGINT), ((-9223372036854775808)::BIGINT)) AS t(a); statement ok -drop table test_nullable_decimal +set datafusion.execution.enable_ansi_mode = false; + +## Array input: float, NULL, NaN, -0, infinity, -infinity +query R +SELECT abs(a) FROM (VALUES (-1.0::FLOAT), (0.::FLOAT), (-0.::FLOAT), (-0::FLOAT), (NULL::FLOAT), ('NaN'::FLOAT), ('inf'::FLOAT), ('+inf'::FLOAT), ('-inf'::FLOAT), ('infinity'::FLOAT), ('+infinity'::FLOAT), ('-infinity'::FLOAT)) AS t(a); +---- +1 +0 +0 +0 +NULL +NaN +Infinity +Infinity +Infinity +Infinity +Infinity +Infinity + + +## Array input: double, NULL, NaN, -0, infinity, -infinity +query R +SELECT abs(a) FROM (VALUES (-1.0::DOUBLE), (0.::DOUBLE), (-0.::DOUBLE), (-0::DOUBLE), (NULL::DOUBLE), ('NaN'::DOUBLE), ('inf'::DOUBLE), ('+inf'::DOUBLE), ('-inf'::DOUBLE), ('infinity'::DOUBLE), ('+infinity'::DOUBLE), ('-infinity'::DOUBLE)) AS t(a); +---- +1 +0 +0 +0 +NULL +NaN +Infinity +Infinity +Infinity +Infinity +Infinity +Infinity + +## Array input: decimal128 +query R +SELECT abs(a) FROM (VALUES (('-99999999.99')::DECIMAL(10, 2)), (0::DECIMAL(10, 2)), (NULL::DECIMAL(10, 2))) AS t(a); +---- +99999999.99 +0 +NULL + +query R +SELECT abs(a) FROM (VALUES (('-9999999999999999999999999999.9999999999')::DECIMAL(38, 10)), (0::DECIMAL(38, 10)), (NULL::DECIMAL(38, 10))) AS t(a); +---- +9999999999999999999999999999.9999999999 +0 +NULL + +## Array input: decimal256 +query R +SELECT abs(a) FROM (VALUES (('-99999999999999999999999999999999999999.99')::DECIMAL(40, 2)), (0::DECIMAL(40, 2)), (NULL::DECIMAL(40, 2))) AS t(a); +---- +99999999999999999999999999999999999999.99 +0 +NULL + +query R +SELECT abs(a) FROM (VALUES (('-999999999999999999999999999999999999999999999999999999999999999999.9999999999')::DECIMAL(76, 10)), (0::DECIMAL(76, 10)), (NULL::DECIMAL(76, 10))) AS t(a); +---- +999999999999999999999999999999999999999999999999999999999999999999.9999999999 +0 +NULL ## Original Query: SELECT abs(INTERVAL -'1-1' YEAR TO MONTH); ## PySpark 3.5.5 Result: {"abs(INTERVAL '-1-1' YEAR TO MONTH)": 13, "typeof(abs(INTERVAL '-1-1' YEAR TO MONTH))": 'interval year to month', "typeof(INTERVAL '-1-1' YEAR TO MONTH)": 'interval year to month'} diff --git a/datafusion/sqllogictest/test_files/spark/math/bin.slt b/datafusion/sqllogictest/test_files/spark/math/bin.slt index 1fa24e6cda6b..b2e2aadde44b 100644 --- a/datafusion/sqllogictest/test_files/spark/math/bin.slt +++ b/datafusion/sqllogictest/test_files/spark/math/bin.slt @@ -15,23 +15,62 @@ # specific language governing permissions and limitations # under the License. -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT bin(-13); -## PySpark 3.5.5 Result: {'bin(-13)': '1111111111111111111111111111111111111111111111111111111111110011', 'typeof(bin(-13))': 'string', 'typeof(-13)': 'int'} -#query -#SELECT bin(-13::int); - -## Original Query: SELECT bin(13); -## PySpark 3.5.5 Result: {'bin(13)': '1101', 'typeof(bin(13))': 'string', 'typeof(13)': 'int'} -#query -#SELECT bin(13::int); - -## Original Query: SELECT bin(13.3); -## PySpark 3.5.5 Result: {'bin(13.3)': '1101', 'typeof(bin(13.3))': 'string', 'typeof(13.3)': 'decimal(3,1)'} -#query -#SELECT bin(13.3::decimal(3,1)); +query T +SELECT bin(arrow_cast(NULL, 'Int8')); +---- +NULL + +query T +SELECT bin(arrow_cast(0, 'Int8')); +---- +0 + +query T +SELECT bin(arrow_cast(13, 'Int8')); +---- +1101 + +query T +SELECT bin(arrow_cast(13.36, 'Float16')); +---- +1101 + +query T +SELECT bin(13.3::decimal(3,1)); +---- +1101 + +query T +SELECT bin(arrow_cast(-13, 'Int8')); +---- +1111111111111111111111111111111111111111111111111111111111110011 + +query T +SELECT bin(arrow_cast(256, 'Int16')); +---- +100000000 + +query T +SELECT bin(arrow_cast(-32768, 'Int16')); +---- +1111111111111111111111111111111111111111111111111000000000000000 + +query T +SELECT bin(arrow_cast(-2147483648, 'Int32')); +---- +1111111111111111111111111111111110000000000000000000000000000000 + +query T +SELECT bin(arrow_cast(1073741824, 'Int32')); +---- +1000000000000000000000000000000 + +query T +SELECT bin(arrow_cast(-9223372036854775808, 'Int64')); +---- +1000000000000000000000000000000000000000000000000000000000000000 + +query T +SELECT bin(arrow_cast(9223372036854775807, 'Int64')); +---- +111111111111111111111111111111111111111111111111111111111111111 diff --git a/datafusion/sqllogictest/test_files/spark/math/hex.slt b/datafusion/sqllogictest/test_files/spark/math/hex.slt index 05c9fb3f31b2..17e9ff432890 100644 --- a/datafusion/sqllogictest/test_files/spark/math/hex.slt +++ b/datafusion/sqllogictest/test_files/spark/math/hex.slt @@ -63,3 +63,23 @@ query T SELECT hex(arrow_cast('test', 'LargeBinary')) as lar_b; ---- 74657374 + +statement ok +CREATE TABLE t_dict_binary AS +SELECT arrow_cast(column1, 'Dictionary(Int32, Binary)') as dict_col +FROM VALUES ('foo'), ('bar'), ('foo'), (NULL), ('baz'), ('bar'); + +query T +SELECT hex(dict_col) FROM t_dict_binary; +---- +666F6F +626172 +666F6F +NULL +62617A +626172 + +query T +SELECT arrow_typeof(hex(dict_col)) FROM t_dict_binary LIMIT 1; +---- +Dictionary(Int32, Utf8) diff --git a/datafusion/sqllogictest/test_files/spark/math/mod.slt b/datafusion/sqllogictest/test_files/spark/math/mod.slt index 2780b3e1053d..68c0f59f4812 100644 --- a/datafusion/sqllogictest/test_files/spark/math/mod.slt +++ b/datafusion/sqllogictest/test_files/spark/math/mod.slt @@ -144,6 +144,35 @@ SELECT MOD(10.0::decimal(3,1), 3.0::decimal(2,1)) as mod_decimal_2; ---- 1 +# Division by zero returns NULL in legacy mode (ANSI off) +query I +SELECT MOD(10::int, 0::int) as mod_div_zero_1; +---- +NULL + +query I +SELECT MOD(-7::int, 0::int) as mod_div_zero_2; +---- +NULL + +query R +SELECT MOD(10.5::float8, 0.0::float8) as mod_div_zero_float; +---- +NaN + +# Division by zero errors in ANSI mode +statement ok +set datafusion.execution.enable_ansi_mode = true; + +statement error DataFusion error: Arrow error: Divide by zero error +SELECT MOD(10::int, 0::int); + +statement error DataFusion error: Arrow error: Divide by zero error +SELECT MOD(-7::int, 0::int); + +statement ok +set datafusion.execution.enable_ansi_mode = false; + # Edge cases query I SELECT MOD(0::int, 5::int) as mod_zero_1; diff --git a/datafusion/sqllogictest/test_files/spark/math/negative.slt b/datafusion/sqllogictest/test_files/spark/math/negative.slt index aa8e558e9895..40bfaf791fe8 100644 --- a/datafusion/sqllogictest/test_files/spark/math/negative.slt +++ b/datafusion/sqllogictest/test_files/spark/math/negative.slt @@ -23,5 +23,309 @@ ## Original Query: SELECT negative(1); ## PySpark 3.5.5 Result: {'negative(1)': -1, 'typeof(negative(1))': 'int', 'typeof(1)': 'int'} -#query -#SELECT negative(1::int); + +# Test negative with integer +query I +SELECT negative(1::int); +---- +-1 + +# Test negative with positive integer +query I +SELECT negative(42::int); +---- +-42 + +# Test negative with negative integer +query I +SELECT negative(-10::int); +---- +10 + +# Test negative with zero +query I +SELECT negative(0::int); +---- +0 + +# Test negative with bigint +query I +SELECT negative(9223372036854775807::bigint); +---- +-9223372036854775807 + +# Test negative with negative bigint +query I +SELECT negative(-100::bigint); +---- +100 + +# Test negative with smallint +query I +SELECT negative(32767::smallint); +---- +-32767 + +# Test negative with float +query R +SELECT negative(3.14::float); +---- +-3.14 + +# Test negative with negative float +query R +SELECT negative(-2.5::float); +---- +2.5 + +# Test negative with double +query R +SELECT negative(3.14159265358979::double); +---- +-3.14159265358979 + +# Test negative with negative double +query R +SELECT negative(-1.5::double); +---- +1.5 + +# Test negative with decimal +query R +SELECT negative(123.456::decimal(10,3)); +---- +-123.456 + +# Test negative with negative decimal +query R +SELECT negative(-99.99::decimal(10,2)); +---- +99.99 + +# Test negative with NULL +query I +SELECT negative(NULL::int); +---- +NULL + +# Test negative with column values +statement ok +CREATE TABLE test_negative (id int, value int) AS VALUES (1, 10), (2, -20), (3, 0), (4, NULL); + +query II rowsort +SELECT id, negative(value) FROM test_negative; +---- +1 -10 +2 20 +3 0 +4 NULL + +statement ok +DROP TABLE test_negative; + +# Test negative in expressions +query I +SELECT negative(5) + 3; +---- +-2 + +# Test nested negative +query I +SELECT negative(negative(7)); +---- +7 + +# Test negative with large numbers +query R +SELECT negative(1234567890.123456::double); +---- +-1234567890.123456 + +# Test wrap-around: negative of minimum int (should wrap to same value) +# Using table to avoid constant folding overflow during optimization +statement ok +CREATE TABLE min_values_int AS VALUES (-2147483648); + +query I +SELECT negative(column1::int) FROM min_values_int; +---- +-2147483648 + +statement ok +DROP TABLE min_values_int; + +# Test wrap-around: negative of minimum bigint (should wrap to same value) +statement ok +CREATE TABLE min_values_bigint AS VALUES (-9223372036854775808); + +query I +SELECT negative(column1::bigint) FROM min_values_bigint; +---- +-9223372036854775808 + +statement ok +DROP TABLE min_values_bigint; + +# Test wrap-around: negative of minimum smallint (should wrap to same value) +statement ok +CREATE TABLE min_values_smallint AS VALUES (-32768); + +query I +SELECT negative(column1::smallint) FROM min_values_smallint; +---- +-32768 + +statement ok +DROP TABLE min_values_smallint; + +# Test wrap-around: negative of minimum tinyint (should wrap to same value) +statement ok +CREATE TABLE min_values_tinyint AS VALUES (-128); + +query I +SELECT negative(column1::tinyint) FROM min_values_tinyint; +---- +-128 + +statement ok +DROP TABLE min_values_tinyint; + +# Test overflow: negative of positive infinity (float) +query R +SELECT negative('Infinity'::float); +---- +-Infinity + +# Test overflow: negative of negative infinity (float) +query R +SELECT negative('-Infinity'::float); +---- +Infinity + +# Test overflow: negative of positive infinity (double) +query R +SELECT negative('Infinity'::double); +---- +-Infinity + +# Test overflow: negative of negative infinity (double) +query R +SELECT negative('-Infinity'::double); +---- +Infinity + +# Test overflow: negative of NaN (float) +query R +SELECT negative('NaN'::float); +---- +NaN + +# Test overflow: negative of NaN (double) +query R +SELECT negative('NaN'::double); +---- +NaN + +# Test overflow: negative of maximum float value +query R +SELECT negative(3.4028235e38::float); +---- +-340282350000000000000000000000000000000 + +# Test overflow: negative of minimum float value +query R +SELECT negative(-3.4028235e38::float); +---- +340282350000000000000000000000000000000 + +# Test overflow: negative of maximum double value +query R +SELECT negative(1.7976931348623157e308::double); +---- +-179769313486231570000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 + +# Test overflow: negative of minimum double value +query R +SELECT negative(-1.7976931348623157e308::double); +---- +179769313486231570000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 + +# Test negative with CalendarIntervalType (IntervalMonthDayNano) +# Spark make_interval creates CalendarInterval +query ? +SELECT negative(make_interval(1, 2, 3, 4, 5, 6, 7.5)); +---- +-14 mons -25 days -5 hours -6 mins -7.500000000 secs + +# Test negative with negative CalendarIntervalType +query ? +SELECT negative(make_interval(-2, -5, -1, -10, -3, -30, -15.25)); +---- +29 mons 17 days 3 hours 30 mins 15.250000000 secs + +# Test negative with CalendarInterval from table +statement ok +CREATE TABLE interval_test AS VALUES + (make_interval(1, 2, 0, 5, 0, 0, 0.0)), + (make_interval(-3, -1, 0, -2, 0, 0, 0.0)); + +query ? rowsort +SELECT negative(column1) FROM interval_test; +---- +-14 mons -5 days +37 mons 2 days + +statement ok +DROP TABLE interval_test; + +## ANSI mode tests: overflow detection +statement ok +set datafusion.execution.enable_ansi_mode = true; + +# Test ANSI mode: negative of minimum values should error (overflow) +query error DataFusion error: Execution error: Int8 overflow on negative\(\-128\) +SELECT negative((-128)::tinyint); + +query error DataFusion error: Execution error: Int16 overflow on negative\(\-32768\) +SELECT negative((-32768)::smallint); + +query error DataFusion error: Execution error: Int32 overflow on negative\(\-2147483648\) +SELECT negative((-2147483648)::int); + +query error DataFusion error: Execution error: Int64 overflow on negative\(\-9223372036854775808\) +SELECT negative((-9223372036854775808)::bigint); + +# Test ANSI mode: negative of (MIN+1) should succeed (boundary test) +query I +SELECT negative((-127)::tinyint); +---- +127 + +query I +SELECT negative((-32767)::smallint); +---- +32767 + +query I +SELECT negative((-2147483647)::int); +---- +2147483647 + +query I +SELECT negative((-9223372036854775807)::bigint); +---- +9223372036854775807 + +# Test ANSI mode: array with MIN value should error +statement ok +CREATE TABLE min_values_ansi AS VALUES (-2147483648); + +query error DataFusion error: Execution error: Int32 overflow on negative\(\-2147483648\) +SELECT negative(column1::int) FROM min_values_ansi; + +statement ok +DROP TABLE min_values_ansi; + +# Reset ANSI mode to false +statement ok +set datafusion.execution.enable_ansi_mode = false; diff --git a/datafusion/sqllogictest/test_files/spark/math/pmod.slt b/datafusion/sqllogictest/test_files/spark/math/pmod.slt index cf273c2d78f5..aa4a197ba470 100644 --- a/datafusion/sqllogictest/test_files/spark/math/pmod.slt +++ b/datafusion/sqllogictest/test_files/spark/math/pmod.slt @@ -64,8 +64,28 @@ SELECT pmod(0::int, 5::int) as pmod_zero_1; ---- 0 -statement error DataFusion error: Arrow error: Divide by zero error +query I SELECT pmod(10::int, 0::int) as pmod_zero_2; +---- +NULL + +query I +SELECT pmod(-7::int, 0::int) as pmod_zero_3; +---- +NULL + +# Division by zero errors in ANSI mode +statement ok +set datafusion.execution.enable_ansi_mode = true; + +statement error DataFusion error: Arrow error: Divide by zero error +SELECT pmod(10::int, 0::int); + +statement error DataFusion error: Arrow error: Divide by zero error +SELECT pmod(-7::int, 0::int); + +statement ok +set datafusion.execution.enable_ansi_mode = false; # PMOD tests with NULL values query I diff --git a/datafusion/sqllogictest/test_files/spark/math/unhex.slt b/datafusion/sqllogictest/test_files/spark/math/unhex.slt new file mode 100644 index 000000000000..051d8826c8a6 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/unhex.slt @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Basic hex string +query ? +SELECT unhex('537061726B2053514C'); +---- +537061726b2053514c + +query T +SELECT arrow_cast(unhex('537061726B2053514C'), 'Utf8'); +---- +Spark SQL + +# Lowercase hex +query ? +SELECT unhex('616263'); +---- +616263 + +query T +SELECT arrow_cast(unhex('616263'), 'Utf8'); +---- +abc + +# Odd length hex (left pad with 0) +query ? +SELECT unhex(a) FROM VALUES ('1A2B3'), ('1'), ('ABC'), ('123') AS t(a); +---- +01a2b3 +01 +0abc +0123 + +# Null input +query ? +SELECT unhex(NULL); +---- +NULL + +# Invalid hex characters +query ? +SELECT unhex('GGHH'); +---- +NULL + +# Empty hex string +query T +SELECT arrow_cast(unhex(''), 'Utf8'); +---- +(empty) + +# Array with mixed case +query ? +SELECT unhex(a) FROM VALUES ('4a4B4c'), ('F'), ('A'), ('AbCdEf'), ('123abc'), ('41 42'), ('00'), ('FF') AS t(a); +---- +4a4b4c +0f +0a +abcdef +123abc +NULL +00 +ff + +# LargeUtf8 type +statement ok +CREATE TABLE t_large_utf8 AS VALUES (arrow_cast('414243', 'LargeUtf8')), (NULL); + +query ? +SELECT unhex(column1) FROM t_large_utf8; +---- +414243 +NULL + +# Utf8View type +statement ok +CREATE TABLE t_utf8view AS VALUES (arrow_cast('414243', 'Utf8View')), (NULL); + +query ? +SELECT unhex(column1) FROM t_utf8view; +---- +414243 +NULL diff --git a/datafusion/sqllogictest/test_files/spark/string/base64.slt b/datafusion/sqllogictest/test_files/spark/string/base64.slt index 66edbe844215..03b488de0ee9 100644 --- a/datafusion/sqllogictest/test_files/spark/string/base64.slt +++ b/datafusion/sqllogictest/test_files/spark/string/base64.slt @@ -15,18 +15,101 @@ # specific language governing permissions and limitations # under the License. -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT base64('Spark SQL'); -## PySpark 3.5.5 Result: {'base64(Spark SQL)': 'U3BhcmsgU1FM', 'typeof(base64(Spark SQL))': 'string', 'typeof(Spark SQL)': 'string'} -#query -#SELECT base64('Spark SQL'::string); - -## Original Query: SELECT base64(x'537061726b2053514c'); -## PySpark 3.5.5 Result: {"base64(X'537061726B2053514C')": 'U3BhcmsgU1FM', "typeof(base64(X'537061726B2053514C'))": 'string', "typeof(X'537061726B2053514C')": 'binary'} -#query -#SELECT base64(X'537061726B2053514C'::binary); +query T +SELECT base64('Spark SQL'::string); +---- +U3BhcmsgU1FM + +query T +SELECT base64('Spark SQ'::string); +---- +U3BhcmsgU1E= + +query T +SELECT base64('Spark S'::string); +---- +U3BhcmsgUw== + +query T +SELECT base64('Spark SQL'::bytea); +---- +U3BhcmsgU1FM + +query T +SELECT base64(NULL::string); +---- +NULL + +query T +SELECT base64(NULL::bytea); +---- +NULL + +query T +SELECT base64(column1) +FROM VALUES +('Spark SQL'::bytea), +('Spark SQ'::bytea), +('Spark S'::bytea), +(NULL::bytea); +---- +U3BhcmsgU1FM +U3BhcmsgU1E= +U3BhcmsgUw== +NULL + +query error Function 'base64' requires TypeSignatureClass::Binary, but received Int32 \(DataType: Int32\) +SELECT base64(12::integer); + + +query T +SELECT arrow_cast(unbase64('U3BhcmsgU1FM'::string), 'Utf8'); +---- +Spark SQL + +query T +SELECT arrow_cast(unbase64('U3BhcmsgU1E='::string), 'Utf8'); +---- +Spark SQ + +query T +SELECT arrow_cast(unbase64('U3BhcmsgUw=='::string), 'Utf8'); +---- +Spark S + +query T +SELECT arrow_cast(unbase64('U3BhcmsgU1FM'::bytea), 'Utf8'); +---- +Spark SQL + +query ? +SELECT unbase64(NULL::string); +---- +NULL + +query ? +SELECT unbase64(NULL::bytea); +---- +NULL + +query T +SELECT arrow_cast(unbase64(column1), 'Utf8') +FROM VALUES +('U3BhcmsgU1FM'::string), +('U3BhcmsgU1E='::string), +('U3BhcmsgUw=='::string), +(NULL::string); +---- +Spark SQL +Spark SQ +Spark S +NULL + +query error Failed to decode value using base64 +SELECT unbase64('123'::string); + +query error Failed to decode value using base64 +SELECT unbase64('123'::bytea); + +query error Function 'unbase64' requires TypeSignatureClass::Binary, but received Int32 \(DataType: Int32\) +SELECT unbase64(12::integer); diff --git a/datafusion/sqllogictest/test_files/spark/string/concat.slt b/datafusion/sqllogictest/test_files/spark/string/concat.slt index 258cb829d7d4..97e7b57f7d06 100644 --- a/datafusion/sqllogictest/test_files/spark/string/concat.slt +++ b/datafusion/sqllogictest/test_files/spark/string/concat.slt @@ -20,6 +20,12 @@ SELECT concat('Spark', 'SQL'); ---- SparkSQL +# Test two Utf8View inputs: value and return type +query TT +SELECT concat(arrow_cast('Spark', 'Utf8View'), arrow_cast('SQL', 'Utf8View')), arrow_typeof(concat(arrow_cast('Spark', 'Utf8View'), arrow_cast('SQL', 'Utf8View'))); +---- +SparkSQL Utf8View + query T SELECT concat('Spark', 'SQL', NULL); ---- @@ -46,3 +52,21 @@ SELECT concat(a, b, c) from (select 'a' a, 'b' b, 'c' c union all select null a, ---- abc NULL + +# Test mixed types: Utf8View + Utf8 +query TT +SELECT concat(arrow_cast('hello', 'Utf8View'), ' world'), arrow_typeof(concat(arrow_cast('hello', 'Utf8View'), ' world')); +---- +hello world Utf8View + +# Test Utf8 + LargeUtf8 => return type LargeUtf8 +query TT +SELECT concat('a', arrow_cast('b', 'LargeUtf8')), arrow_typeof(concat('a', arrow_cast('b', 'LargeUtf8'))); +---- +ab LargeUtf8 + +# Test all three types mixed together +query TT +SELECT concat('a', arrow_cast('b', 'LargeUtf8'), arrow_cast('c', 'Utf8View')), arrow_typeof(concat('a', arrow_cast('b', 'LargeUtf8'), arrow_cast('c', 'Utf8View'))); +---- +abc Utf8View diff --git a/datafusion/sqllogictest/test_files/spark/string/format_string.slt b/datafusion/sqllogictest/test_files/spark/string/format_string.slt index 048863ebfbed..8ba3cfc951cd 100644 --- a/datafusion/sqllogictest/test_files/spark/string/format_string.slt +++ b/datafusion/sqllogictest/test_files/spark/string/format_string.slt @@ -931,13 +931,13 @@ Char: NULL ## NULL with timestamp format using arrow_cast query T -SELECT format_string('Hour: %tH', arrow_cast(NULL, 'Timestamp(Nanosecond, None)')); +SELECT format_string('Hour: %tH', arrow_cast(NULL, 'Timestamp(ns)')); ---- Hour: null ## NULL with timestamp format using arrow_cast query T -SELECT format_string('Month: %tB', arrow_cast(NULL, 'Timestamp(Nanosecond, None)')); +SELECT format_string('Month: %tB', arrow_cast(NULL, 'Timestamp(ns)')); ---- Month: null @@ -967,25 +967,25 @@ Month: null ## NULL with timestamp format using arrow_cast query T -SELECT format_string('Month: %tB', arrow_cast(NULL, 'Timestamp(Second, None)')); +SELECT format_string('Month: %tB', arrow_cast(NULL, 'Timestamp(s)')); ---- Month: null ## NULL with timestamp format using arrow_cast query T -SELECT format_string('Month: %tB', arrow_cast(NULL, 'Timestamp(Millisecond, None)')); +SELECT format_string('Month: %tB', arrow_cast(NULL, 'Timestamp(ms)')); ---- Month: null ## NULL with timestamp format using arrow_cast query T -SELECT format_string('Month: %tB', arrow_cast(NULL, 'Timestamp(Microsecond, None)')); +SELECT format_string('Month: %tB', arrow_cast(NULL, 'Timestamp(µs)')); ---- Month: null ## NULL with timestamp format using arrow_cast query T -SELECT format_string('Month: %tB', arrow_cast(NULL, 'Timestamp(Nanosecond, None)')); +SELECT format_string('Month: %tB', arrow_cast(NULL, 'Timestamp(ns)')); ---- Month: null @@ -1051,7 +1051,7 @@ Value: null ## NULL Timestamp with string format using arrow_cast query T -SELECT format_string('Value: %s', arrow_cast(NULL, 'Timestamp(Nanosecond, None)')); +SELECT format_string('Value: %s', arrow_cast(NULL, 'Timestamp(ns)')); ---- Value: null @@ -1717,49 +1717,49 @@ String: 52245000000000 ## TimestampSecond with time formats query T -SELECT format_string('Year: %tY', arrow_cast(1703512245, 'Timestamp(Second, None)')); +SELECT format_string('Year: %tY', arrow_cast(1703512245, 'Timestamp(s)')); ---- Year: 2023 query T -SELECT format_string('Month: %tm', arrow_cast(1703512245, 'Timestamp(Second, None)')); +SELECT format_string('Month: %tm', arrow_cast(1703512245, 'Timestamp(s)')); ---- Month: 12 query T -SELECT format_string('String: %s', arrow_cast(1703512245, 'Timestamp(Second, None)')); +SELECT format_string('String: %s', arrow_cast(1703512245, 'Timestamp(s)')); ---- String: 1703512245 query T -SELECT format_string('String: %S', arrow_cast(1703512245, 'Timestamp(Second, None)')); +SELECT format_string('String: %S', arrow_cast(1703512245, 'Timestamp(s)')); ---- String: 1703512245 ## TimestampMillisecond with time formats query T -SELECT format_string('ISO Date: %tF', arrow_cast(1703512245000, 'Timestamp(Millisecond, None)')); +SELECT format_string('ISO Date: %tF', arrow_cast(1703512245000, 'Timestamp(ms)')); ---- ISO Date: 2023-12-25 query T -SELECT format_string('String: %s', arrow_cast(1703512245000, 'Timestamp(Millisecond, None)')); +SELECT format_string('String: %s', arrow_cast(1703512245000, 'Timestamp(ms)')); ---- String: 1703512245000 ## TimestampMicrosecond with time formats query T -SELECT format_string('Date: %tD', arrow_cast(1703512245000000, 'Timestamp(Microsecond, None)')); +SELECT format_string('Date: %tD', arrow_cast(1703512245000000, 'Timestamp(µs)')); ---- Date: 12/25/23 query T -SELECT format_string('String: %s', arrow_cast(1703512245000000, 'Timestamp(Microsecond, None)')); +SELECT format_string('String: %s', arrow_cast(1703512245000000, 'Timestamp(µs)')); ---- String: 1703512245000000 query T -SELECT format_string('String: %s', arrow_cast('2020-01-02 01:01:11.1234567890Z', 'Timestamp(Nanosecond, None)')); +SELECT format_string('String: %s', arrow_cast('2020-01-02 01:01:11.1234567890Z', 'Timestamp(ns)')); ---- String: 1577926871123456789 diff --git a/datafusion/sqllogictest/test_files/spark/string/substr.slt b/datafusion/sqllogictest/test_files/spark/string/substr.slt deleted file mode 100644 index 0942bdd86a4e..000000000000 --- a/datafusion/sqllogictest/test_files/spark/string/substr.slt +++ /dev/null @@ -1,37 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT substr('Spark SQL', -3); -## PySpark 3.5.5 Result: {'substr(Spark SQL, -3, 2147483647)': 'SQL', 'typeof(substr(Spark SQL, -3, 2147483647))': 'string', 'typeof(Spark SQL)': 'string', 'typeof(-3)': 'int'} -#query -#SELECT substr('Spark SQL'::string, -3::int); - -## Original Query: SELECT substr('Spark SQL', 5); -## PySpark 3.5.5 Result: {'substr(Spark SQL, 5, 2147483647)': 'k SQL', 'typeof(substr(Spark SQL, 5, 2147483647))': 'string', 'typeof(Spark SQL)': 'string', 'typeof(5)': 'int'} -#query -#SELECT substr('Spark SQL'::string, 5::int); - -## Original Query: SELECT substr('Spark SQL', 5, 1); -## PySpark 3.5.5 Result: {'substr(Spark SQL, 5, 1)': 'k', 'typeof(substr(Spark SQL, 5, 1))': 'string', 'typeof(Spark SQL)': 'string', 'typeof(5)': 'int', 'typeof(1)': 'int'} -#query -#SELECT substr('Spark SQL'::string, 5::int, 1::int); diff --git a/datafusion/sqllogictest/test_files/spark/string/substring.slt b/datafusion/sqllogictest/test_files/spark/string/substring.slt index 847ce4b6d473..5bf2fdf2fb95 100644 --- a/datafusion/sqllogictest/test_files/spark/string/substring.slt +++ b/datafusion/sqllogictest/test_files/spark/string/substring.slt @@ -15,23 +15,189 @@ # specific language governing permissions and limitations # under the License. -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT substring('Spark SQL', -3); -## PySpark 3.5.5 Result: {'substring(Spark SQL, -3, 2147483647)': 'SQL', 'typeof(substring(Spark SQL, -3, 2147483647))': 'string', 'typeof(Spark SQL)': 'string', 'typeof(-3)': 'int'} -#query -#SELECT substring('Spark SQL'::string, -3::int); - -## Original Query: SELECT substring('Spark SQL', 5); -## PySpark 3.5.5 Result: {'substring(Spark SQL, 5, 2147483647)': 'k SQL', 'typeof(substring(Spark SQL, 5, 2147483647))': 'string', 'typeof(Spark SQL)': 'string', 'typeof(5)': 'int'} -#query -#SELECT substring('Spark SQL'::string, 5::int); - -## Original Query: SELECT substring('Spark SQL', 5, 1); -## PySpark 3.5.5 Result: {'substring(Spark SQL, 5, 1)': 'k', 'typeof(substring(Spark SQL, 5, 1))': 'string', 'typeof(Spark SQL)': 'string', 'typeof(5)': 'int', 'typeof(1)': 'int'} -#query -#SELECT substring('Spark SQL'::string, 5::int, 1::int); + +query T +SELECT substring('Spark SQL'::string, 0::int); +---- +Spark SQL + +query T +SELECT substring('Spark SQL'::string, 5::int); +---- +k SQL + +query T +SELECT substring('Spark SQL'::string, 3::int, 1::int); +---- +a + +# Test negative start +query T +SELECT substring('Spark SQL'::string, -3::int); +---- +SQL + +query T +SELECT substring('Spark SQL'::string, -3::int, 2::int); +---- +SQ + +# Test length exceeding string length +query T +SELECT substring('Spark SQL'::string, 2::int, 700::int); +---- +park SQL + +# Test start position beyond string length +query T +SELECT substring('Spark SQL'::string, 30::int); +---- +(empty) + +query T +SELECT substring('Spark SQL'::string, -30::int); +---- +Spark SQL + +# Test negative length +query T +SELECT substring('Spark SQL'::string, 3::int, -1::int); +---- +(empty) + +query T +SELECT substring('Spark SQL'::string, 3::int, 0::int); +---- +(empty) + +# Test unicode strings +query T +SELECT substring('joséésoj'::string, 5::int); +---- +ésoj + +query T +SELECT substring('joséésoj'::string, 5::int, 2::int); +---- +és + +# NULL handling +query T +SELECT substring('Spark SQL'::string, NULL::int); +---- +NULL + +query T +SELECT substring(NULL::string, 5::int); +---- +NULL + +query T +SELECT substring(NULL::string, 3::int, 1::int); +---- +NULL + +query T +SELECT substring('Spark SQL'::string, NULL::int, 1::int); +---- +NULL + +query T +SELECT substring('Spark SQL'::string, 3::int, NULL::int); +---- +NULL + +query T +SELECT substring(column1, column2) +FROM VALUES +('Spark SQL'::string, 0::int), +('Spark SQL'::string, 5::int), +('Spark SQL'::string, -3::int), +('Spark SQL'::string, 500::int), +('Spark SQL'::string, -300::int), +(NULL::string, 5::int), +('Spark SQL'::string, NULL::int); +---- +Spark SQL +k SQL +SQL +(empty) +Spark SQL +NULL +NULL + +query T +SELECT substring(column1, column2, column3) +FROM VALUES +('Spark SQL'::string, -3::int, 2::int), +('Spark SQL'::string, 3::int, 1::int), +('Spark SQL'::string, 3::int, 700::int), +('Spark SQL'::string, 3::int, -1::int), +('Spark SQL'::string, 3::int, 0::int), +('Spark SQL'::string, 300::int, 3::int), +('Spark SQL'::string, -300::int, 3::int), +(NULL::string, 3::int, 1::int), +('Spark SQL'::string, NULL::int, 1::int), +('Spark SQL'::string, 3::int, NULL::int); +---- +SQ +a +ark SQL +(empty) +(empty) +(empty) +Spa +NULL +NULL +NULL + +# alias substr + +query T +SELECT substr('Spark SQL'::string, 0::int); +---- +Spark SQL + +query T +SELECT substr(column1, column2) +FROM VALUES +('Spark SQL'::string, 0::int), +('Spark SQL'::string, 5::int), +('Spark SQL'::string, -3::int), +('Spark SQL'::string, 500::int), +('Spark SQL'::string, -300::int), +(NULL::string, 5::int), +('Spark SQL'::string, NULL::int); +---- +Spark SQL +k SQL +SQL +(empty) +Spark SQL +NULL +NULL + +query T +SELECT substr(column1, column2, column3) +FROM VALUES +('Spark SQL'::string, -3::int, 2::int), +('Spark SQL'::string, 3::int, 1::int), +('Spark SQL'::string, 3::int, 700::int), +('Spark SQL'::string, 3::int, -1::int), +('Spark SQL'::string, 3::int, 0::int), +('Spark SQL'::string, 300::int, 3::int), +('Spark SQL'::string, -300::int, 3::int), +(NULL::string, 3::int, 1::int), +('Spark SQL'::string, NULL::int, 1::int), +('Spark SQL'::string, 3::int, NULL::int); +---- +SQ +a +ark SQL +(empty) +(empty) +(empty) +Spa +NULL +NULL +NULL diff --git a/datafusion/sqllogictest/test_files/spark/string/unbase64.slt b/datafusion/sqllogictest/test_files/spark/string/unbase64.slt deleted file mode 100644 index 5cf3fbee0455..000000000000 --- a/datafusion/sqllogictest/test_files/spark/string/unbase64.slt +++ /dev/null @@ -1,27 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT unbase64('U3BhcmsgU1FM'); -## PySpark 3.5.5 Result: {'unbase64(U3BhcmsgU1FM)': bytearray(b'Spark SQL'), 'typeof(unbase64(U3BhcmsgU1FM))': 'binary', 'typeof(U3BhcmsgU1FM)': 'string'} -#query -#SELECT unbase64('U3BhcmsgU1FM'::string); diff --git a/datafusion/sqllogictest/test_files/string/string_query.slt.part b/datafusion/sqllogictest/test_files/string/string_query.slt.part index a182ba8cde11..2884c3518610 100644 --- a/datafusion/sqllogictest/test_files/string/string_query.slt.part +++ b/datafusion/sqllogictest/test_files/string/string_query.slt.part @@ -993,25 +993,27 @@ NULL NULL NULL NULL # Test FIND_IN_SET # -------------------------------------- -query IIII +query IIIIII SELECT FIND_IN_SET(ascii_1, 'a,b,c,d'), FIND_IN_SET(ascii_1, 'Andrew,Xiangpeng,Raphael'), FIND_IN_SET(unicode_1, 'a,b,c,d'), - FIND_IN_SET(unicode_1, 'datafusion📊🔥,datafusion数据融合,datafusionДатаФусион') + FIND_IN_SET(unicode_1, 'datafusion📊🔥,datafusion数据融合,datafusionДатаФусион'), + FIND_IN_SET(NULL, unicode_1), + FIND_IN_SET(unicode_1, NULL) FROM test_basic_operator; ---- -0 1 0 1 -0 2 0 2 -0 3 0 3 -0 0 0 0 -0 0 0 0 -0 0 0 0 -0 0 0 0 -0 0 0 0 -0 0 0 0 -NULL NULL NULL NULL -NULL NULL NULL NULL +0 1 0 1 NULL NULL +0 2 0 2 NULL NULL +0 3 0 3 NULL NULL +0 0 0 0 NULL NULL +0 0 0 0 NULL NULL +0 0 0 0 NULL NULL +0 0 0 0 NULL NULL +0 0 0 0 NULL NULL +0 0 0 0 NULL NULL +NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL NULL # -------------------------------------- # Test || operator diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt index d985af1104da..e20815a58c76 100644 --- a/datafusion/sqllogictest/test_files/struct.slt +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -38,9 +38,9 @@ CREATE TABLE struct_values ( s1 struct, s2 struct ) AS VALUES - (struct(1), struct(1, 'string1')), - (struct(2), struct(2, 'string2')), - (struct(3), struct(3, 'string3')) + (struct(1), struct(1 AS a, 'string1' AS b)), + (struct(2), struct(2 AS a, 'string2' AS b)), + (struct(3), struct(3 AS a, 'string3' AS b)) ; query ?? @@ -397,7 +397,8 @@ drop view complex_view; # struct with different keys r1 and r2 is not valid statement ok -create table t(a struct, b struct) as values (struct('red', 1), struct('blue', 2.3)); +create table t(a struct, b struct) as values + (struct('red' AS r1, 1 AS c), struct('blue' AS r2, 2.3 AS c)); # Expect same keys for struct type but got mismatched pair r1,c and r2,c query error @@ -408,7 +409,8 @@ drop table t; # struct with the same key statement ok -create table t(a struct, b struct) as values (struct('red', 1), struct('blue', 2.3)); +create table t(a struct, b struct) as values + (struct('red' AS r, 1 AS c), struct('blue' AS r, 2.3 AS c)); query T select arrow_typeof([a, b]) from t; @@ -442,9 +444,9 @@ CREATE TABLE struct_values ( s1 struct(a int, b varchar), s2 struct(a int, b varchar) ) AS VALUES - (row(1, 'red'), row(1, 'string1')), - (row(2, 'blue'), row(2, 'string2')), - (row(3, 'green'), row(3, 'string3')) + ({a: 1, b: 'red'}, {a: 1, b: 'string1'}), + ({a: 2, b: 'blue'}, {a: 2, b: 'string2'}), + ({a: 3, b: 'green'}, {a: 3, b: 'string3'}) ; statement ok @@ -452,8 +454,8 @@ drop table struct_values; statement ok create table t (c1 struct(r varchar, b int), c2 struct(r varchar, b float)) as values ( - row('red', 2), - row('blue', 2.3) + {r: 'red', b: 2}, + {r: 'blue', b: 2.3} ); query ?? @@ -492,9 +494,6 @@ Struct("r": Utf8, "c": Float64) statement ok drop table t; -query error DataFusion error: Optimizer rule 'simplify_expressions' failed[\s\S]*Arrow error: Cast error: Cannot cast string 'a' to value of Float64 type -create table t as values({r: 'a', c: 1}), ({c: 2.3, r: 'b'}); - ################################## ## Test Coalesce with Struct ################################## @@ -504,9 +503,9 @@ CREATE TABLE t ( s1 struct(a int, b varchar), s2 struct(a float, b varchar) ) AS VALUES - (row(1, 'red'), row(1.1, 'string1')), - (row(2, 'blue'), row(2.2, 'string2')), - (row(3, 'green'), row(33.2, 'string3')) + ({a: 1, b: 'red'}, {a: 1.1, b: 'string1'}), + ({a: 2, b: 'blue'}, {a: 2.2, b: 'string2'}), + ({a: 3, b: 'green'}, {a: 33.2, b: 'string3'}) ; query ? @@ -531,9 +530,9 @@ CREATE TABLE t ( s1 struct(a int, b varchar), s2 struct(a float, b varchar) ) AS VALUES - (row(1, 'red'), row(1.1, 'string1')), - (null, row(2.2, 'string2')), - (row(3, 'green'), row(33.2, 'string3')) + ({a: 1, b: 'red'}, {a: 1.1, b: 'string1'}), + (null, {a: 2.2, b: 'string2'}), + ({a: 3, b: 'green'}, {a: 33.2, b: 'string3'}) ; query ? @@ -553,16 +552,12 @@ Struct("a": Float32, "b": Utf8View) statement ok drop table t; -# row() with incorrect order +# row() with incorrect order - row() is positional, not name-based statement error DataFusion error: Optimizer rule 'simplify_expressions' failed[\s\S]*Arrow error: Cast error: Cannot cast string 'blue' to value of Float32 type create table t(a struct(r varchar, c int), b struct(r varchar, c float)) as values - (row('red', 1), row(2.3, 'blue')), - (row('purple', 1), row('green', 2.3)); + ({r: 'red', c: 1}, {r: 2.3, c: 'blue'}), + ({r: 'purple', c: 1}, {r: 'green', c: 2.3}); -# out of order struct literal -# TODO: This query should not fail -statement error DataFusion error: Optimizer rule 'simplify_expressions' failed[\s\S]*Arrow error: Cast error: Cannot cast string 'b' to value of Int32 type -create table t(a struct(r varchar, c int)) as values ({r: 'a', c: 1}), ({c: 2, r: 'b'}); ################################## ## Test Array of Struct @@ -573,12 +568,9 @@ select [{r: 'a', c: 1}, {r: 'b', c: 2}]; ---- [{r: a, c: 1}, {r: b, c: 2}] -# Can't create a list of struct with different field types -query error -select [{r: 'a', c: 1}, {c: 2, r: 'b'}]; statement ok -create table t(a struct(r varchar, c int), b struct(r varchar, c float)) as values (row('a', 1), row('b', 2.3)); +create table t(a struct(r varchar, c int), b struct(r varchar, c float)) as values ({r: 'a', c: 1}, {r: 'b', c: 2.3}); query T select arrow_typeof([a, b]) from t; @@ -588,27 +580,17 @@ List(Struct("r": Utf8View, "c": Float32)) statement ok drop table t; -# create table with different struct type is fine -statement ok -create table t(a struct(r varchar, c int), b struct(c float, r varchar)) as values (row('a', 1), row(2.3, 'b')); - -# create array with different struct type is not valid -query error -select arrow_typeof([a, b]) from t; - -statement ok -drop table t; statement ok -create table t(a struct(r varchar, c int, g float), b struct(r varchar, c float, g int)) as values (row('a', 1, 2.3), row('b', 2.3, 2)); +create table t(a struct(r varchar, c int, g float), b struct(r varchar, c float, g int)) as values ({r: 'a', c: 1, g: 2.3}, {r: 'b', c: 2.3, g: 2}); -# type of each column should not coerced but perserve as it is +# type of each column should not coerced but preserve as it is query T select arrow_typeof(a) from t; ---- Struct("r": Utf8View, "c": Int32, "g": Float32) -# type of each column should not coerced but perserve as it is +# type of each column should not coerced but preserve as it is query T select arrow_typeof(b) from t; ---- @@ -622,7 +604,7 @@ drop table t; # This tests accessing struct fields using the subscript notation with string literals statement ok -create table test (struct_field struct(substruct int)) as values (struct(1)); +create table test (struct_field struct(substruct int)) as values ({substruct: 1}); query ?? select * @@ -635,7 +617,7 @@ statement ok DROP TABLE test; statement ok -create table test (struct_field struct(substruct struct(subsubstruct int))) as values (struct(struct(1))); +create table test (struct_field struct(substruct struct(subsubstruct int))) as values ({substruct: {subsubstruct: 1}}); query ?? select * @@ -824,3 +806,864 @@ NULL statement ok drop table nullable_parent_test; + +# Test struct casting with field reordering - string fields +query ? +SELECT CAST({b: 'b_value', a: 'a_value'} AS STRUCT(a VARCHAR, b VARCHAR)); +---- +{a: a_value, b: b_value} + +# Test struct casting with field reordering - integer fields +query ? +SELECT CAST({b: 3, a: 4} AS STRUCT(a INT, b INT)); +---- +{a: 4, b: 3} + +# Test with type casting AND field reordering +query ? +SELECT CAST({b: 3, a: 4} AS STRUCT(a BIGINT, b INT)); +---- +{a: 4, b: 3} + +# Test casting with explicit field names +query ? +SELECT CAST({a: 1, b: 'x'} AS STRUCT(a INT, b VARCHAR)); +---- +{a: 1, b: x} + +# Test with missing field - should insert nulls +query ? +SELECT CAST({a: 1} AS STRUCT(a INT, b INT)); +---- +{a: 1, b: NULL} + +# Test with extra source field - should be ignored +query ? +SELECT CAST({a: 1, b: 2, extra: 3} AS STRUCT(a INT, b INT)); +---- +{a: 1, b: 2} + +# Test no overlap with mismatched field count - should fail because no field names match +statement error DataFusion error: (Plan error|Error during planning|This feature is not implemented): (Cannot cast struct: at least one field name must match between source and target|Cannot cast struct with 3 fields to 2 fields without name overlap|Unsupported CAST from Struct) +SELECT CAST(struct(1, 'x', 'y') AS STRUCT(a INT, b VARCHAR)); + +# Test nested struct with field reordering +query ? +SELECT CAST( + {inner: {y: 2, x: 1}} + AS STRUCT(inner STRUCT(x INT, y INT)) +); +---- +{inner: {x: 1, y: 2}} + +# Test field reordering with table data +statement ok +CREATE TABLE struct_reorder_test ( + data STRUCT(b INT, a VARCHAR) +) AS VALUES + ({b: 100, a: 'first'}), + ({b: 200, a: 'second'}), + ({b: 300, a: 'third'}) +; + +query ? +SELECT CAST(data AS STRUCT(a VARCHAR, b INT)) AS casted_data FROM struct_reorder_test ORDER BY data['b']; +---- +{a: first, b: 100} +{a: second, b: 200} +{a: third, b: 300} + +statement ok +drop table struct_reorder_test; + +# Test casting struct with multiple levels of nesting and reordering +query ? +SELECT CAST( + {level1: {z: 100, y: 'inner', x: 1}} + AS STRUCT(level1 STRUCT(x INT, y VARCHAR, z INT)) +); +---- +{level1: {x: 1, y: inner, z: 100}} + +# Test field reordering with nulls in source +query ? +SELECT CAST( + {b: NULL::INT, a: 42} + AS STRUCT(a INT, b INT) +); +---- +{a: 42, b: NULL} + +# Test casting preserves struct-level nulls +query ? +SELECT CAST(NULL::STRUCT(b INT, a INT) AS STRUCT(a INT, b INT)); +---- +NULL + +############################ +# Implicit Coercion Tests with CREATE TABLE AS VALUES +############################ + +# Test implicit coercion with same field order, different types +statement ok +create table t as values({r: 'a', c: 1}), ({r: 'b', c: 2.3}); + +query T +select arrow_typeof(column1) from t limit 1; +---- +Struct("r": Utf8, "c": Float64) + +query ? +select * from t order by column1.r; +---- +{r: a, c: 1.0} +{r: b, c: 2.3} + +statement ok +drop table t; + +# Test implicit coercion with nullable fields (same order) +statement ok +create table t as values({a: 1, b: 'x'}), ({a: 2, b: 'y'}); + +query T +select arrow_typeof(column1) from t limit 1; +---- +Struct("a": Int64, "b": Utf8) + +query ? +select * from t order by column1.a; +---- +{a: 1, b: x} +{a: 2, b: y} + +statement ok +drop table t; + +# Test implicit coercion with nested structs (same field order) +statement ok +create table t as + select {outer: {x: 1, y: 2}} as column1 + union all + select {outer: {x: 3, y: 4}}; + +query T +select arrow_typeof(column1) from t limit 1; +---- +Struct("outer": Struct("x": Int64, "y": Int64)) + +query ? +select column1 from t order by column1.outer.x; +---- +{outer: {x: 1, y: 2}} +{outer: {x: 3, y: 4}} + +statement ok +drop table t; + +# Test implicit coercion with type widening (Int32 -> Int64) +statement ok +create table t as values({id: 1, val: 100}), ({id: 2, val: 9223372036854775807}); + +query T +select arrow_typeof(column1) from t limit 1; +---- +Struct("id": Int64, "val": Int64) + +query ? +select * from t order by column1.id; +---- +{id: 1, val: 100} +{id: 2, val: 9223372036854775807} + +statement ok +drop table t; + +# Test implicit coercion with nested struct and type coercion +statement ok +create table t as + select {name: 'Alice', data: {score: 100, active: true}} as column1 + union all + select {name: 'Bob', data: {score: 200, active: false}}; + +query T +select arrow_typeof(column1) from t limit 1; +---- +Struct("name": Utf8, "data": Struct("score": Int64, "active": Boolean)) + +query ? +select column1 from t order by column1.name; +---- +{name: Alice, data: {score: 100, active: true}} +{name: Bob, data: {score: 200, active: false}} + +statement ok +drop table t; + +############################ +# Field Reordering Tests (using explicit CAST) +############################ + +# Test explicit cast with field reordering in VALUES - basic case +query ? +select CAST({c: 2.3, r: 'b'} AS STRUCT(r VARCHAR, c FLOAT)); +---- +{r: b, c: 2.3} + +# Test explicit cast with field reordering - multiple rows +query ? +select * from (values + (CAST({c: 1, r: 'a'} AS STRUCT(r VARCHAR, c FLOAT))), + (CAST({c: 2.3, r: 'b'} AS STRUCT(r VARCHAR, c FLOAT))) +) order by column1.r; +---- +{r: a, c: 1.0} +{r: b, c: 2.3} + +# Test table with explicit cast for field reordering +statement ok +create table t as select CAST({c: 1, r: 'a'} AS STRUCT(r VARCHAR, c FLOAT)) as s +union all +select CAST({c: 2.3, r: 'b'} AS STRUCT(r VARCHAR, c FLOAT)); + +query T +select arrow_typeof(s) from t limit 1; +---- +Struct("r": Utf8View, "c": Float32) + +query ? +select * from t order by s.r; +---- +{r: a, c: 1.0} +{r: b, c: 2.3} + +statement ok +drop table t; + +# Test field reordering with nullable fields using CAST +query ? +select CAST({b: NULL, a: 42} AS STRUCT(a INT, b INT)); +---- +{a: 42, b: NULL} + +# Test field reordering with nested structs using CAST +query ? +select CAST({outer: {y: 4, x: 3}} AS STRUCT(outer STRUCT(x INT, y INT))); +---- +{outer: {x: 3, y: 4}} + +# Test complex nested field reordering +query ? +select CAST( + {data: {active: false, score: 200}, name: 'Bob'} + AS STRUCT(name VARCHAR, data STRUCT(score INT, active BOOLEAN)) +); +---- +{name: Bob, data: {score: 200, active: false}} + +############################ +# Array Literal Tests with Struct Field Reordering (Implicit Coercion) +############################ + +# Test array literal with reordered struct fields - implicit coercion by name +# Field order in unified schema is determined during type coercion +query ? +select [{r: 'a', c: 1}, {c: 2.3, r: 'b'}]; +---- +[{c: 1.0, r: a}, {c: 2.3, r: b}] + +# Test array literal with same-named fields but different order +# Fields are reordered during coercion +query ? +select [{a: 1, b: 2}, {b: 3, a: 4}]; +---- +[{b: 2, a: 1}, {b: 3, a: 4}] + +# Test array literal with explicit cast to unify struct schemas with partial overlap +# Use CAST to explicitly unify schemas when fields don't match completely +query ? +select [ + CAST({a: 1, b: 2} AS STRUCT(a INT, b INT, c INT)), + CAST({b: 3, c: 4} AS STRUCT(a INT, b INT, c INT)) +]; +---- +[{a: 1, b: 2, c: NULL}, {a: NULL, b: 3, c: 4}] + +# Test NULL handling in array literals with reordered but matching fields +query ? +select [{a: NULL, b: 1}, {b: 2, a: NULL}]; +---- +[{b: 1, a: NULL}, {b: 2, a: NULL}] + +# Verify arrow_typeof for array with reordered struct fields +# The unified schema type follows the coercion order +query T +select arrow_typeof([{x: 1, y: 2}, {y: 3, x: 4}]); +---- +List(Struct("y": Int64, "x": Int64)) + +# Test array of structs with matching nested fields in different order +# Inner nested fields are also reordered during coercion +query ? +select [ + {id: 1, info: {name: 'Alice', age: 30}}, + {info: {age: 25, name: 'Bob'}, id: 2} +]; +---- +[{info: {age: 30, name: Alice}, id: 1}, {info: {age: 25, name: Bob}, id: 2}] + +# Test nested arrays with matching struct fields (different order) +query ? +select [[{x: 1, y: 2}], [{y: 4, x: 3}]]; +---- +[[{x: 1, y: 2}], [{x: 3, y: 4}]] + +# Test array literal with float type coercion across elements +query ? +select [{val: 1}, {val: 2.5}]; +---- +[{val: 1.0}, {val: 2.5}] + +############################ +# Dynamic Array Construction Tests (from Table Columns) +############################ + +# Setup test table with struct columns for dynamic array construction +statement ok +create table t_complete_overlap ( + s1 struct(x int, y int), + s2 struct(y int, x int) +) as values + ({x: 1, y: 2}, {y: 3, x: 4}), + ({x: 5, y: 6}, {y: 7, x: 8}); + +# Test 1: Complete overlap - same fields, different order +# Verify arrow_typeof for dynamically created array +query T +select arrow_typeof([s1, s2]) from t_complete_overlap limit 1; +---- +List(Struct("y": Int32, "x": Int32)) + +# Verify values are correctly mapped by name in the array +# Field order follows the second column's field order +query ? +select [s1, s2] from t_complete_overlap order by s1.x; +---- +[{y: 2, x: 1}, {y: 3, x: 4}] +[{y: 6, x: 5}, {y: 7, x: 8}] + +statement ok +drop table t_complete_overlap; + +# Test 2: Partial overlap - some shared fields between columns +# Note: Columns must have the exact same field set for array construction to work +# Test with identical field set (all fields present in both columns) +statement ok +create table t_partial_overlap ( + col_a struct(name VARCHAR, age int, active boolean), + col_b struct(age int, name VARCHAR, active boolean) +) as values + ({name: 'Alice', age: 30, active: true}, {age: 25, name: 'Bob', active: false}), + ({name: 'Charlie', age: 35, active: true}, {age: 40, name: 'Diana', active: false}); + +# Verify unified type includes all fields from both structs +query T +select arrow_typeof([col_a, col_b]) from t_partial_overlap limit 1; +---- +List(Struct("age": Int32, "name": Utf8View, "active": Boolean)) + +# Verify values are correctly mapped by name in the array +# Field order follows the second column's field order +query ? +select [col_a, col_b] from t_partial_overlap order by col_a.name; +---- +[{age: 30, name: Alice, active: true}, {age: 25, name: Bob, active: false}] +[{age: 35, name: Charlie, active: true}, {age: 40, name: Diana, active: false}] + +statement ok +drop table t_partial_overlap; + +# Test 3: Complete field set matching (no CAST needed) +# Schemas already align; confirm unified type and values +statement ok +create table t_with_cast ( + col_x struct(id int, description VARCHAR), + col_y struct(id int, description VARCHAR) +) as values + ({id: 1, description: 'First'}, {id: 10, description: 'First Value'}), + ({id: 2, description: 'Second'}, {id: 20, description: 'Second Value'}); + +# Verify type unification with all fields +query T +select arrow_typeof([col_x, col_y]) from t_with_cast limit 1; +---- +List(Struct("id": Int32, "description": Utf8View)) + +# Verify values remain aligned by name +query ? +select [col_x, col_y] from t_with_cast order by col_x.id; +---- +[{id: 1, description: First}, {id: 10, description: First Value}] +[{id: 2, description: Second}, {id: 20, description: Second Value}] + +statement ok +drop table t_with_cast; + +# Test 4: Explicit CAST for partial field overlap scenarios +# When columns have different field sets, use explicit CAST to unify schemas +query ? +select [ + CAST({id: 1} AS STRUCT(id INT, description VARCHAR)), + CAST({id: 10, description: 'Value'} AS STRUCT(id INT, description VARCHAR)) +]; +---- +[{id: 1, description: NULL}, {id: 10, description: Value}] + +# Test 5: Complex nested structs with field reordering +# Nested fields must have the exact same field set for array construction +statement ok +create table t_nested ( + col_1 struct(id int, outer struct(x int, y int)), + col_2 struct(id int, outer struct(x int, y int)) +) as values + ({id: 100, outer: {x: 1, y: 2}}, {id: 101, outer: {x: 4, y: 3}}), + ({id: 200, outer: {x: 5, y: 6}}, {id: 201, outer: {x: 8, y: 7}}); + +# Verify nested struct in unified schema +query T +select arrow_typeof([col_1, col_2]) from t_nested limit 1; +---- +List(Struct("id": Int32, "outer": Struct("x": Int32, "y": Int32))) + +# Verify nested field values are correctly mapped +query ? +select [col_1, col_2] from t_nested order by col_1.id; +---- +[{id: 100, outer: {x: 1, y: 2}}, {id: 101, outer: {x: 4, y: 3}}] +[{id: 200, outer: {x: 5, y: 6}}, {id: 201, outer: {x: 8, y: 7}}] + +statement ok +drop table t_nested; + +# Test 6: NULL handling with matching field sets +statement ok +create table t_nulls ( + col_a struct(val int, flag boolean), + col_b struct(val int, flag boolean) +) as values + ({val: 1, flag: true}, {val: 10, flag: false}), + ({val: NULL, flag: false}, {val: NULL, flag: true}); + +# Verify NULL values are preserved +query ? +select [col_a, col_b] from t_nulls order by col_a.val; +---- +[{val: 1, flag: true}, {val: 10, flag: false}] +[{val: NULL, flag: false}, {val: NULL, flag: true}] + +statement ok +drop table t_nulls; + +# Test 7: Multiple columns with complete field matching +statement ok +create table t_multi ( + col1 struct(a int, b int, c int), + col2 struct(a int, b int, c int) +) as values + ({a: 1, b: 2, c: 3}, {a: 10, b: 20, c: 30}), + ({a: 4, b: 5, c: 6}, {a: 40, b: 50, c: 60}); + +# Verify array with complete field matching +query T +select arrow_typeof([col1, col2]) from t_multi limit 1; +---- +List(Struct("a": Int32, "b": Int32, "c": Int32)) + +# Verify values are correctly unified +query ? +select [col1, col2] from t_multi order by col1.a; +---- +[{a: 1, b: 2, c: 3}, {a: 10, b: 20, c: 30}] +[{a: 4, b: 5, c: 6}, {a: 40, b: 50, c: 60}] + +statement ok +drop table t_multi; + +############################ +# Comprehensive Implicit Struct Coercion Suite +############################ + +# Test 1: VALUES clause with field reordering coerced by name into declared schema +statement ok +create table implicit_values_reorder ( + s struct(a int, b int) +) as values + ({a: 1, b: 2}), + ({b: 3, a: 4}); + +query T +select arrow_typeof(s) from implicit_values_reorder limit 1; +---- +Struct("a": Int32, "b": Int32) + +query ? +select s from implicit_values_reorder order by s.a; +---- +{a: 1, b: 2} +{a: 4, b: 3} + +statement ok +drop table implicit_values_reorder; + +# Test 2: Array literal coercion with reordered struct fields +query IIII +select + [{a: 1, b: 2}, {b: 3, a: 4}][1]['a'], + [{a: 1, b: 2}, {b: 3, a: 4}][1]['b'], + [{a: 1, b: 2}, {b: 3, a: 4}][2]['a'], + [{a: 1, b: 2}, {b: 3, a: 4}][2]['b']; +---- +1 2 4 3 + +# Test 3: Array construction from columns with reordered struct fields +statement ok +create table struct_columns_order ( + s1 struct(a int, b int), + s2 struct(b int, a int) +) as values + ({a: 1, b: 2}, {b: 3, a: 4}), + ({a: 5, b: 6}, {b: 7, a: 8}); + +query IIII +select + [s1, s2][1]['a'], + [s1, s2][1]['b'], + [s1, s2][2]['a'], + [s1, s2][2]['b'] +from struct_columns_order +order by s1['a']; +---- +1 2 4 3 +5 6 8 7 + +statement ok +drop table struct_columns_order; + +# Test 4: UNION with struct field reordering +query II +select s['a'], s['b'] +from ( + select {a: 1, b: 2} as s + union all + select {b: 3, a: 4} as s +) t +order by s['a']; +---- +1 2 +4 3 + +# Test 5: CTE with struct coercion across branches +query II +with + t1 as (select {a: 1, b: 2} as s), + t2 as (select {b: 3, a: 4} as s) +select t1.s['a'], t1.s['b'] from t1 +union all +select t2.s['a'], t2.s['b'] from t2 +order by 1; +---- +1 2 +4 3 + +# Test 6: Struct aggregation retains name-based mapping +statement ok +create table agg_structs_reorder ( + k int, + s struct(x int, y int) +) as values + (1, {x: 1, y: 2}), + (1, {y: 3, x: 4}), + (2, {x: 5, y: 6}); + +query I? +select k, array_agg(s) from agg_structs_reorder group by k order by k; +---- +1 [{x: 1, y: 2}, {x: 4, y: 3}] +2 [{x: 5, y: 6}] + +statement ok +drop table agg_structs_reorder; + +# Test 7: Nested struct coercion with reordered inner fields +query IIII +with nested as ( + select [{outer: {inner: 1, value: 2}}, {outer: {value: 3, inner: 4}}] as arr +) +select + arr[1]['outer']['inner'], + arr[1]['outer']['value'], + arr[2]['outer']['inner'], + arr[2]['outer']['value'] +from nested; +---- +1 2 4 3 + +# Test 8: Partial name overlap - currently errors (field count mismatch detected) +# This is a documented limitation: structs must have exactly same field set for coercion +query error DataFusion error: Error during planning: Inconsistent data type across values list +select column1 from (values ({a: 1, b: 2}), ({b: 3, c: 4})) order by column1['a']; + +# Negative test: mismatched struct field counts are rejected (documented limitation) +query error DataFusion error: .* +select [{a: 1}, {a: 2, b: 3}]; + +# Test 9: INSERT with name-based struct coercion into target schema +statement ok +create table target_struct_insert (s struct(a int, b int)); + +statement ok +insert into target_struct_insert values ({b: 1, a: 2}); + +query ? +select s from target_struct_insert; +---- +{a: 2, b: 1} + +statement ok +drop table target_struct_insert; + +# Test 10: CASE expression with different struct field orders +query II +select + (case when true then {a: 1, b: 2} else {b: 3, a: 4} end)['a'] as a_val, + (case when true then {a: 1, b: 2} else {b: 3, a: 4} end)['b'] as b_val; +---- +1 2 + +############################ +# JOIN Coercion Tests +############################ + +# Test: Struct coercion in JOIN ON condition +statement ok +create table t_left ( + id int, + s struct(x int, y int) +) as values + (1, {x: 1, y: 2}), + (2, {x: 3, y: 4}); + +statement ok +create table t_right ( + id int, + s struct(y int, x int) +) as values + (1, {y: 2, x: 1}), + (2, {y: 4, x: 3}); + +# JOIN on reordered struct fields - matched by name +query IIII +select t_left.id, t_left.s['x'], t_left.s['y'], t_right.id +from t_left +join t_right on t_left.s = t_right.s +order by t_left.id; +---- +1 1 2 1 +2 3 4 2 + +statement ok +drop table t_left; + +statement ok +drop table t_right; + +# Test: Struct coercion with filtered JOIN +statement ok +create table orders ( + order_id int, + customer struct(name varchar, id int) +) as values + (1, {name: 'Alice', id: 100}), + (2, {name: 'Bob', id: 101}), + (3, {name: 'Charlie', id: 102}); + +statement ok +create table customers ( + customer_id int, + info struct(id int, name varchar) +) as values + (100, {id: 100, name: 'Alice'}), + (101, {id: 101, name: 'Bob'}), + (103, {id: 103, name: 'Diana'}); + +# Join with struct field reordering - names matched, not positions +query I +select count(*) from orders +join customers on orders.customer = customers.info +where orders.order_id <= 2; +---- +2 + +statement ok +drop table orders; + +statement ok +drop table customers; + +############################ +# WHERE Predicate Coercion Tests +############################ + +# Test: Struct equality in WHERE clause with field reordering +statement ok +create table t_where ( + id int, + s struct(x int, y int) +) as values + (1, {x: 1, y: 2}), + (2, {x: 3, y: 4}), + (3, {x: 1, y: 2}); + +# WHERE clause with struct comparison - coerced by name +query I +select id from t_where +where s = {y: 2, x: 1} +order by id; +---- +1 +3 + +statement ok +drop table t_where; + +# Test: Struct IN clause with reordering +statement ok +create table t_in ( + id int, + s struct(a int, b varchar) +) as values + (1, {a: 1, b: 'x'}), + (2, {a: 2, b: 'y'}), + (3, {a: 1, b: 'x'}); + +# IN clause with reordered struct literals +query I +select id from t_in +where s in ({b: 'x', a: 1}, {b: 'y', a: 2}) +order by id; +---- +1 +2 +3 + +statement ok +drop table t_in; + +# Test: Struct BETWEEN (not supported, but documents limitation) +# Structs don't support BETWEEN, but can use comparison operators + +statement ok +create table t_between ( + id int, + s struct(val int) +) as values + (1, {val: 10}), + (2, {val: 20}), + (3, {val: 30}); + +# Comparison via field extraction works +query I +select id from t_between +where s['val'] >= 20 +order by id; +---- +2 +3 + +statement ok +drop table t_between; + +############################ +# Window Function Coercion Tests +############################ + +# Test: Struct in window function PARTITION BY +statement ok +create table t_window ( + id int, + s struct(category int, value int) +) as values + (1, {category: 1, value: 10}), + (2, {category: 1, value: 20}), + (3, {category: 2, value: 30}), + (4, {category: 2, value: 40}); + +# Window partition on struct field via extraction +query III +select + id, + s['value'], + row_number() over (partition by s['category'] order by s['value']) +from t_window +order by id; +---- +1 10 1 +2 20 2 +3 30 1 +4 40 2 + +statement ok +drop table t_window; + +# Test: Struct in window function ORDER BY with coercion +statement ok +create table t_rank ( + id int, + s struct(rank_val int, group_id int) +) as values + (1, {rank_val: 10, group_id: 1}), + (2, {rank_val: 20, group_id: 1}), + (3, {rank_val: 15, group_id: 2}); + +# Window ranking with struct field extraction +query III +select + id, + s['rank_val'], + rank() over (partition by s['group_id'] order by s['rank_val']) +from t_rank +order by id; +---- +1 10 1 +2 20 2 +3 15 1 + +statement ok +drop table t_rank; + +# Test: Aggregate function with struct coercion across window partitions +statement ok +create table t_agg_window ( + id int, + partition_id int, + s struct(amount int) +) as values + (1, 1, {amount: 100}), + (2, 1, {amount: 200}), + (3, 2, {amount: 150}); + +# Running sum via extracted struct field +query III +select + id, + partition_id, + sum(s['amount']) over (partition by partition_id order by id) +from t_agg_window +order by id; +---- +1 1 100 +2 1 300 +3 2 150 + +statement ok +drop table t_agg_window; \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index da0bfc89d584..9c7c2ddb5d85 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -430,7 +430,7 @@ SELECT t1_id, t1_name, t1_int, (select t2_id, t2_name FROM t2 WHERE t2.t2_id = t #subquery_not_allowed #In/Exist Subquery is not allowed in ORDER BY clause. -statement error DataFusion error: Invalid \(non-executable\) plan after Analyzer\ncaused by\nError during planning: In/Exist subquery can only be used in Projection, Filter, TableScan, Window functions, Aggregate and Join plan nodes, but was used in \[Sort: t1.t1_int IN \(\) ASC NULLS LAST\] +statement error DataFusion error: Invalid \(non-executable\) plan after Analyzer\ncaused by\nError during planning: In/Exist/SetComparison subquery can only be used in Projection, Filter, TableScan, Window functions, Aggregate and Join plan nodes, but was used in \[Sort: t1.t1_int IN \(\) ASC NULLS LAST\] SELECT t1_id, t1_name, t1_int FROM t1 order by t1_int in (SELECT t2_int FROM t2 WHERE t1.t1_id > t1.t1_int) #non_aggregated_correlated_scalar_subquery @@ -1469,3 +1469,198 @@ logical_plan statement count 0 drop table person; + +# Set comparison subqueries (ANY/ALL) +statement ok +create table set_cmp_t(v int) as values (1), (6), (10); + +statement ok +create table set_cmp_s(v int) as values (5), (null); + +statement ok +create table set_cmp_empty(v int); + +query I rowsort +select v from set_cmp_t where v > any(select v from set_cmp_s); +---- +10 +6 + +query I rowsort +select v from set_cmp_t where v < all(select v from set_cmp_empty); +---- +1 +10 +6 + +statement count 0 +drop table set_cmp_t; + +statement count 0 +drop table set_cmp_s; + +statement count 0 +drop table set_cmp_empty; + +query TT +explain select v from (values (1), (6), (10)) set_cmp_t(v) where v > any(select v from (values (5), (null)) set_cmp_s(v)); +---- +logical_plan +01)Projection: set_cmp_t.v +02)--Filter: __correlated_sq_1.mark OR __correlated_sq_2.mark AND NOT __correlated_sq_3.mark AND Boolean(NULL) +03)----LeftMark Join: Filter: set_cmp_t.v > __correlated_sq_3.v IS TRUE +04)------Filter: __correlated_sq_1.mark OR __correlated_sq_2.mark AND Boolean(NULL) +05)--------LeftMark Join: Filter: set_cmp_t.v > __correlated_sq_2.v IS NULL +06)----------Filter: __correlated_sq_1.mark OR Boolean(NULL) +07)------------LeftMark Join: Filter: set_cmp_t.v > __correlated_sq_1.v IS TRUE +08)--------------SubqueryAlias: set_cmp_t +09)----------------Projection: column1 AS v +10)------------------Values: (Int64(1)), (Int64(6)), (Int64(10)) +11)--------------SubqueryAlias: __correlated_sq_1 +12)----------------SubqueryAlias: set_cmp_s +13)------------------Projection: column1 AS v +14)--------------------Values: (Int64(5)), (Int64(NULL)) +15)----------SubqueryAlias: __correlated_sq_2 +16)------------SubqueryAlias: set_cmp_s +17)--------------Projection: column1 AS v +18)----------------Values: (Int64(5)), (Int64(NULL)) +19)------SubqueryAlias: __correlated_sq_3 +20)--------SubqueryAlias: set_cmp_s +21)----------Projection: column1 AS v +22)------------Values: (Int64(5)), (Int64(NULL)) + +# correlated_recursive_scalar_subquery_with_level_3_exists_subquery_referencing_level1_relation +query TT +explain select c_custkey from customer +where c_acctbal < ( + select sum(o_totalprice) from orders + where o_custkey = c_custkey + and exists ( + select * from lineitem where l_orderkey = o_orderkey + and l_extendedprice < c_acctbal + ) +) order by c_custkey; +---- +logical_plan +01)Sort: customer.c_custkey ASC NULLS LAST +02)--Projection: customer.c_custkey +03)----Inner Join: customer.c_custkey = __scalar_sq_2.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_2.sum(orders.o_totalprice) +04)------TableScan: customer projection=[c_custkey, c_acctbal] +05)------SubqueryAlias: __scalar_sq_2 +06)--------Projection: sum(orders.o_totalprice), orders.o_custkey +07)----------Aggregate: groupBy=[[orders.o_custkey]], aggr=[[sum(orders.o_totalprice)]] +08)------------Projection: orders.o_custkey, orders.o_totalprice +09)--------------LeftSemi Join: orders.o_orderkey = __correlated_sq_1.l_orderkey Filter: __correlated_sq_1.l_extendedprice < customer.c_acctbal +10)----------------TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice] +11)----------------SubqueryAlias: __correlated_sq_1 +12)------------------TableScan: lineitem projection=[l_orderkey, l_extendedprice] + +# correlated_recursive_scalar_subquery_with_level_3_in_subquery_referencing_level1_relation +query TT +explain select c_custkey from customer +where c_acctbal < ( + select sum(o_totalprice) from orders + where o_custkey = c_custkey + and o_totalprice in ( + select l_extendedprice as price from lineitem where l_orderkey = o_orderkey + and l_extendedprice < c_acctbal + ) +) order by c_custkey; +---- +logical_plan +01)Sort: customer.c_custkey ASC NULLS LAST +02)--Projection: customer.c_custkey +03)----Inner Join: customer.c_custkey = __scalar_sq_2.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_2.sum(orders.o_totalprice) +04)------TableScan: customer projection=[c_custkey, c_acctbal] +05)------SubqueryAlias: __scalar_sq_2 +06)--------Projection: sum(orders.o_totalprice), orders.o_custkey +07)----------Aggregate: groupBy=[[orders.o_custkey]], aggr=[[sum(orders.o_totalprice)]] +08)------------Projection: orders.o_custkey, orders.o_totalprice +09)--------------LeftSemi Join: orders.o_totalprice = __correlated_sq_1.price, orders.o_orderkey = __correlated_sq_1.l_orderkey Filter: __correlated_sq_1.l_extendedprice < customer.c_acctbal +10)----------------TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice] +11)----------------SubqueryAlias: __correlated_sq_1 +12)------------------Projection: lineitem.l_extendedprice AS price, lineitem.l_extendedprice, lineitem.l_orderkey +13)--------------------TableScan: lineitem projection=[l_orderkey, l_extendedprice] + +# Setup tables for recursive correlation tests +statement ok +CREATE TABLE employees ( + employee_id INTEGER, + employee_name VARCHAR, + dept_id INTEGER, + salary DECIMAL +); + +statement ok +CREATE TABLE project_assignments ( + project_id INTEGER, + employee_id INTEGER, + priority INTEGER +); + +# Provided recursive scalar subquery explain case +query TT +EXPLAIN SELECT e1.employee_name, e1.salary +FROM employees e1 +WHERE e1.salary > ( + SELECT AVG(e2.salary) + FROM employees e2 + WHERE e2.dept_id = e1.dept_id + AND e2.salary > ( + SELECT AVG(e3.salary) + FROM employees e3 + WHERE e3.dept_id = e1.dept_id + ) +); +---- +logical_plan +01)Projection: e1.employee_name, e1.salary +02)--Inner Join: e1.dept_id = __scalar_sq_1.dept_id Filter: CAST(e1.salary AS Decimal128(38, 14)) > __scalar_sq_1.avg(e2.salary) +03)----SubqueryAlias: e1 +04)------TableScan: employees projection=[employee_name, dept_id, salary] +05)----SubqueryAlias: __scalar_sq_1 +06)------Projection: avg(e2.salary), e2.dept_id +07)--------Aggregate: groupBy=[[e2.dept_id]], aggr=[[avg(e2.salary)]] +08)----------Projection: e2.dept_id, e2.salary +09)------------Inner Join: Filter: CAST(e2.salary AS Decimal128(38, 14)) > __scalar_sq_2.avg(e3.salary) AND __scalar_sq_2.dept_id = e1.dept_id +10)--------------SubqueryAlias: e2 +11)----------------TableScan: employees projection=[dept_id, salary] +12)--------------SubqueryAlias: __scalar_sq_2 +13)----------------Projection: avg(e3.salary), e3.dept_id +14)------------------Aggregate: groupBy=[[e3.dept_id]], aggr=[[avg(e3.salary)]] +15)--------------------SubqueryAlias: e3 +16)----------------------TableScan: employees projection=[dept_id, salary] + +# Check shadowing: `dept_id` should resolve to the nearest outer relation (`e2`) +# in the innermost subquery rather than the outermost +query TT +EXPLAIN SELECT e1.employee_id +FROM employees e1 +WHERE EXISTS ( + SELECT 1 + FROM employees e2 + WHERE EXISTS ( + SELECT 1 + FROM project_assignments p + WHERE p.project_id = dept_id + ) +); +---- +logical_plan +01)LeftSemi Join: +02)--SubqueryAlias: e1 +03)----TableScan: employees projection=[employee_id] +04)--SubqueryAlias: __correlated_sq_2 +05)----Projection: +06)------LeftSemi Join: e2.dept_id = __correlated_sq_1.project_id +07)--------SubqueryAlias: e2 +08)----------TableScan: employees projection=[dept_id] +09)--------SubqueryAlias: __correlated_sq_1 +10)----------SubqueryAlias: p +11)------------TableScan: project_assignments projection=[project_id] + +statement count 0 +drop table employees; + +statement count 0 +drop table project_assignments; diff --git a/datafusion/sqllogictest/test_files/table_functions.slt b/datafusion/sqllogictest/test_files/table_functions.slt index cf8a091880d3..f0e00ffc6923 100644 --- a/datafusion/sqllogictest/test_files/table_functions.slt +++ b/datafusion/sqllogictest/test_files/table_functions.slt @@ -160,17 +160,20 @@ physical_plan LazyMemoryExec: partitions=1, batch_generators=[generate_series: s # Test generate_series with invalid arguments # -query error DataFusion error: Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series +query I SELECT * FROM generate_series(5, 1) +---- -query error DataFusion error: Error during planning: Start is smaller than end, but increment is negative: Cannot generate infinite series +query I SELECT * FROM generate_series(-6, 6, -1) +---- query error DataFusion error: Error during planning: Step cannot be zero SELECT * FROM generate_series(-6, 6, 0) -query error DataFusion error: Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series +query I SELECT * FROM generate_series(6, -6, 1) +---- statement error DataFusion error: Error during planning: generate_series function requires 1 to 3 arguments @@ -298,17 +301,20 @@ physical_plan LazyMemoryExec: partitions=1, batch_generators=[range: start=1, en # Test range with invalid arguments # -query error DataFusion error: Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series +query I SELECT * FROM range(5, 1) +---- -query error DataFusion error: Error during planning: Start is smaller than end, but increment is negative: Cannot generate infinite series +query I SELECT * FROM range(-6, 6, -1) +---- query error DataFusion error: Error during planning: Step cannot be zero SELECT * FROM range(-6, 6, 0) -query error DataFusion error: Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series +query I SELECT * FROM range(6, -6, 1) +---- statement error DataFusion error: Error during planning: range function requires 1 to 3 arguments @@ -378,11 +384,13 @@ SELECT * FROM range(TIMESTAMP '2023-01-03T00:00:00', TIMESTAMP '2023-01-01T00:00 2023-01-03T00:00:00 2023-01-02T00:00:00 -query error DataFusion error: Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series +query P SELECT * FROM range(TIMESTAMP '2023-01-03T00:00:00', TIMESTAMP '2023-01-01T00:00:00', INTERVAL '1' DAY) +---- -query error DataFusion error: Error during planning: Start is smaller than end, but increment is negative: Cannot generate infinite series +query P SELECT * FROM range(TIMESTAMP '2023-01-01T00:00:00', TIMESTAMP '2023-01-02T00:00:00', INTERVAL '-1' DAY) +---- query error DataFusion error: Error during planning: range function with timestamps requires exactly 3 arguments SELECT * FROM range(TIMESTAMP '2023-01-03T00:00:00', TIMESTAMP '2023-01-01T00:00:00') @@ -489,11 +497,13 @@ query P SELECT * FROM range(DATE '1992-09-01', DATE '1992-10-01', NULL::INTERVAL) ---- -query error DataFusion error: Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series +query P SELECT * FROM range(DATE '2023-01-03', DATE '2023-01-01', INTERVAL '1' DAY) +---- -query error DataFusion error: Error during planning: Start is smaller than end, but increment is negative: Cannot generate infinite series +query P SELECT * FROM range(DATE '2023-01-01', DATE '2023-01-02', INTERVAL '-1' DAY) +---- query error DataFusion error: Error during planning: range function with dates requires exactly 3 arguments SELECT * FROM range(DATE '2023-01-01', DATE '2023-01-03') diff --git a/datafusion/sqllogictest/test_files/topk.slt b/datafusion/sqllogictest/test_files/topk.slt index aba468d21fd0..8a1fef072229 100644 --- a/datafusion/sqllogictest/test_files/topk.slt +++ b/datafusion/sqllogictest/test_files/topk.slt @@ -383,7 +383,7 @@ physical_plan 03)----ProjectionExec: expr=[__common_expr_1@0 as number_plus, number@1 as number, __common_expr_1@0 as other_number_plus, age@2 as age] 04)------ProjectionExec: expr=[CAST(number@0 AS Int64) + 1 as __common_expr_1, number@0 as number, age@1 as age] 05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true -06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, age], output_ordering=[number@0 DESC], file_type=parquet +06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, age], output_ordering=[number@0 DESC], file_type=parquet, predicate=DynamicFilter [ empty ] # Cleanup statement ok diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part index 0ee60a1e8afb..b01110b567ca 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part @@ -71,17 +71,18 @@ physical_plan 04)------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[count(alias1)] 05)--------RepartitionExec: partitioning=Hash([p_brand@0, p_type@1, p_size@2], 4), input_partitions=4 06)----------AggregateExec: mode=Partial, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[count(alias1)] -07)------------AggregateExec: mode=SinglePartitioned, gby=[p_brand@1 as p_brand, p_type@2 as p_type, p_size@3 as p_size, ps_suppkey@0 as alias1], aggr=[] -08)--------------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(ps_suppkey@0, s_suppkey@0)] -09)----------------RepartitionExec: partitioning=Hash([ps_suppkey@0], 4), input_partitions=4 -10)------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_partkey@0, p_partkey@0)], projection=[ps_suppkey@1, p_brand@3, p_type@4, p_size@5] -11)--------------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 -12)----------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey], file_type=csv, has_header=false -13)--------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 -14)----------------------FilterExec: p_brand@1 != Brand#45 AND p_type@2 NOT LIKE MEDIUM POLISHED% AND p_size@3 IN (SET) ([49, 14, 23, 45, 19, 3, 36, 9]) -15)------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -16)--------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_type, p_size], file_type=csv, has_header=false -17)----------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 -18)------------------FilterExec: s_comment@1 LIKE %Customer%Complaints%, projection=[s_suppkey@0] -19)--------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -20)----------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_comment], file_type=csv, has_header=false +07)------------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, alias1@3 as alias1], aggr=[] +08)--------------RepartitionExec: partitioning=Hash([p_brand@0, p_type@1, p_size@2, alias1@3], 4), input_partitions=4 +09)----------------AggregateExec: mode=Partial, gby=[p_brand@1 as p_brand, p_type@2 as p_type, p_size@3 as p_size, ps_suppkey@0 as alias1], aggr=[] +10)------------------HashJoinExec: mode=CollectLeft, join_type=LeftAnti, on=[(ps_suppkey@0, s_suppkey@0)] +11)--------------------CoalescePartitionsExec +12)----------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_partkey@0, p_partkey@0)], projection=[ps_suppkey@1, p_brand@3, p_type@4, p_size@5] +13)------------------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 +14)--------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey], file_type=csv, has_header=false +15)------------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 +16)--------------------------FilterExec: p_brand@1 != Brand#45 AND p_type@2 NOT LIKE MEDIUM POLISHED% AND p_size@3 IN (SET) ([49, 14, 23, 45, 19, 3, 36, 9]) +17)----------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +18)------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_type, p_size], file_type=csv, has_header=false +19)--------------------FilterExec: s_comment@1 LIKE %Customer%Complaints%, projection=[s_suppkey@0] +20)----------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +21)------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_comment], file_type=csv, has_header=false diff --git a/datafusion/sqllogictest/test_files/truncate.slt b/datafusion/sqllogictest/test_files/truncate.slt new file mode 100644 index 000000000000..ad3ccbb1a7cf --- /dev/null +++ b/datafusion/sqllogictest/test_files/truncate.slt @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Truncate Tests +########## + +statement ok +create table t1(a int, b varchar, c double, d int); + +statement ok +insert into t1 values (1, 'abc', 3.14, 4), (2, 'def', 2.71, 5); + +# Truncate all rows from table +query TT +explain truncate table t1; +---- +logical_plan +01)Dml: op=[Truncate] table=[t1] +02)--EmptyRelation: rows=0 +physical_plan_error +01)TRUNCATE operation on table 't1' +02)caused by +03)This feature is not implemented: TRUNCATE not supported for Base table + +# Test TRUNCATE with fully qualified table name +statement ok +create schema test_schema; + +statement ok +create table test_schema.t5(a int); + +query TT +explain truncate table test_schema.t5; +---- +logical_plan +01)Dml: op=[Truncate] table=[test_schema.t5] +02)--EmptyRelation: rows=0 +physical_plan_error +01)TRUNCATE operation on table 'test_schema.t5' +02)caused by +03)This feature is not implemented: TRUNCATE not supported for Base table + +# Test TRUNCATE with CASCADE option +statement error TRUNCATE with CASCADE/RESTRICT is not supported +TRUNCATE TABLE t1 CASCADE; + +# Test TRUNCATE with multiple tables +statement error TRUNCATE with multiple tables is not supported +TRUNCATE TABLE t1, t2; + +statement error TRUNCATE with PARTITION is not supported +TRUNCATE TABLE t1 PARTITION (p1); + +statement error TRUNCATE with ONLY is not supported +TRUNCATE ONLY t1; + +statement error TRUNCATE with RESTART/CONTINUE IDENTITY is not supported +TRUNCATE TABLE t1 RESTART IDENTITY; + +# Test TRUNCATE without TABLE keyword +query TT +explain truncate t1; +---- +logical_plan +01)Dml: op=[Truncate] table=[t1] +02)--EmptyRelation: rows=0 +physical_plan_error +01)TRUNCATE operation on table 't1' +02)caused by +03)This feature is not implemented: TRUNCATE not supported for Base table diff --git a/datafusion/sqllogictest/test_files/type_coercion.slt b/datafusion/sqllogictest/test_files/type_coercion.slt index e3baa8fedcf6..7039e66b38b1 100644 --- a/datafusion/sqllogictest/test_files/type_coercion.slt +++ b/datafusion/sqllogictest/test_files/type_coercion.slt @@ -254,3 +254,51 @@ DROP TABLE orders; ######################################## ## Test type coercion with UNIONs end ## ######################################## + +# https://github.com/apache/datafusion/issues/15661 +# LIKE is a string pattern matching operator and is not supported for nested types. + +statement ok +CREATE TABLE t0(v0 BIGINT, v1 STRING, v2 BOOLEAN); + +statement ok +INSERT INTO t0(v0, v2) VALUES (123, true); + +query error There isn't a common type to coerce .* in .* expression +SELECT true FROM t0 WHERE ((REGEXP_MATCH(t0.v1, t0.v1)) NOT LIKE (REGEXP_MATCH(t0.v1, t0.v1, 'jH'))); + +query error There isn't a common type to coerce .* in .* expression +SELECT true FROM t0 WHERE (REGEXP_MATCH(t0.v1, t0.v1)) NOT LIKE []; + +query error There isn't a common type to coerce .* in .* expression +SELECT true FROM t0 WHERE (REGEXP_MATCH(t0.v1, t0.v1)) LIKE []; + +query error There isn't a common type to coerce .* in .* expression +SELECT true FROM t0 WHERE (REGEXP_MATCH(t0.v1, t0.v1)) ILIKE []; + +query error There isn't a common type to coerce .* in .* expression +SELECT true FROM t0 WHERE (REGEXP_MATCH(t0.v1, t0.v1)) NOT ILIKE []; + +statement ok +DROP TABLE t0; + +############################################################# +## Test validation for functions with empty argument lists ## +############################################################# + +# https://github.com/apache/datafusion/issues/20201 + +query error does not support zero arguments +SELECT * FROM (SELECT 1) WHERE (STARTS_WITH() IS NULL); + +query error does not support zero arguments +SELECT * FROM (SELECT 1) WHERE (STARTS_WITH() IS NOT NULL); + +query error does not support zero arguments +SELECT * FROM (SELECT 'a') WHERE (STARTS_WITH() SIMILAR TO 'abc%'); + +query error does not support zero arguments +SELECT * FROM (SELECT 1) WHERE CAST(STARTS_WITH() AS STRING) = 'x'; + +query error does not support zero arguments +SELECT * FROM (SELECT 1) WHERE TRY_CAST(STARTS_WITH() AS INT) = 1; \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index b79b6d2fe5e9..d858d0ae3ea4 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -494,22 +494,25 @@ physical_plan 01)CoalescePartitionsExec: fetch=3 02)--UnionExec 03)----ProjectionExec: expr=[count(Int64(1))@0 as cnt] -04)------AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] -05)--------CoalescePartitionsExec -06)----------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] -07)------------ProjectionExec: expr=[] -08)--------------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[] -09)----------------RepartitionExec: partitioning=Hash([c1@0], 4), input_partitions=4 -10)------------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[] -11)--------------------FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434, projection=[c1@0] -12)----------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -13)------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c13], file_type=csv, has_header=true -14)----ProjectionExec: expr=[1 as cnt] -15)------PlaceholderRowExec -16)----ProjectionExec: expr=[lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as cnt] -17)------BoundedWindowAggExec: wdw=[lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Field { "lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING": nullable Int64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], mode=[Sorted] -18)--------ProjectionExec: expr=[1 as c1] -19)----------PlaceholderRowExec +04)------GlobalLimitExec: skip=0, fetch=3 +05)--------AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] +06)----------CoalescePartitionsExec +07)------------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] +08)--------------ProjectionExec: expr=[] +09)----------------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[] +10)------------------RepartitionExec: partitioning=Hash([c1@0], 4), input_partitions=4 +11)--------------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[] +12)----------------------FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434, projection=[c1@0] +13)------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +14)--------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c13], file_type=csv, has_header=true +15)----ProjectionExec: expr=[1 as cnt] +16)------GlobalLimitExec: skip=0, fetch=3 +17)--------PlaceholderRowExec +18)----ProjectionExec: expr=[lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as cnt] +19)------GlobalLimitExec: skip=0, fetch=3 +20)--------BoundedWindowAggExec: wdw=[lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Field { "lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING": nullable Int64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], mode=[Sorted] +21)----------ProjectionExec: expr=[1 as c1] +22)------------PlaceholderRowExec ######## diff --git a/datafusion/sqllogictest/test_files/unnest.slt b/datafusion/sqllogictest/test_files/unnest.slt index 352056adbf81..73aeb6c99d0d 100644 --- a/datafusion/sqllogictest/test_files/unnest.slt +++ b/datafusion/sqllogictest/test_files/unnest.slt @@ -58,6 +58,20 @@ select unnest(struct(1,2,3)); ---- 1 2 3 +## Basic unnest expression in select struct with alias (alias is ignored for struct unnest) +query III +select unnest(struct(1,2,3)) as ignored_alias; +---- +1 2 3 + +## Verify schema output for struct unnest with alias (alias is ignored) +query TTT +describe select unnest(struct(1,2,3)) as ignored_alias; +---- +__unnest_placeholder(struct(Int64(1),Int64(2),Int64(3))).c0 Int64 YES +__unnest_placeholder(struct(Int64(1),Int64(2),Int64(3))).c1 Int64 YES +__unnest_placeholder(struct(Int64(1),Int64(2),Int64(3))).c2 Int64 YES + ## Basic unnest list expression in from clause query I select * from unnest([1,2,3]); @@ -652,15 +666,15 @@ explain select unnest(unnest(unnest(column3)['c1'])), column3 from recursive_unn logical_plan 01)Projection: __unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1],depth=2) AS UNNEST(UNNEST(UNNEST(recursive_unnest_table.column3)[c1])), recursive_unnest_table.column3 02)--Unnest: lists[__unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1])|depth=2] structs[] -03)----Projection: get_field(__unnest_placeholder(recursive_unnest_table.column3,depth=1) AS UNNEST(recursive_unnest_table.column3), Utf8("c1")) AS __unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1]), recursive_unnest_table.column3 +03)----Projection: get_field(__unnest_placeholder(recursive_unnest_table.column3,depth=1), Utf8("c1")) AS __unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1]), recursive_unnest_table.column3 04)------Unnest: lists[__unnest_placeholder(recursive_unnest_table.column3)|depth=1] structs[] 05)--------Projection: recursive_unnest_table.column3 AS __unnest_placeholder(recursive_unnest_table.column3), recursive_unnest_table.column3 06)----------TableScan: recursive_unnest_table projection=[column3] physical_plan 01)ProjectionExec: expr=[__unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1],depth=2)@0 as UNNEST(UNNEST(UNNEST(recursive_unnest_table.column3)[c1])), column3@1 as column3] 02)--UnnestExec -03)----ProjectionExec: expr=[get_field(__unnest_placeholder(recursive_unnest_table.column3,depth=1)@0, c1) as __unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1]), column3@1 as column3] -04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +04)------ProjectionExec: expr=[get_field(__unnest_placeholder(recursive_unnest_table.column3,depth=1)@0, c1) as __unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1]), column3@1 as column3] 05)--------UnnestExec 06)----------ProjectionExec: expr=[column3@0 as __unnest_placeholder(recursive_unnest_table.column3), column3@0 as column3] 07)------------DataSourceExec: partitions=1, partition_sizes=[1] @@ -798,9 +812,21 @@ NULL 1 query error DataFusion error: Error during planning: Column in SELECT must be in GROUP BY or an aggregate function: While expanding wildcard, column "nested_unnest_table\.column1" must appear in the GROUP BY clause or must be part of an aggregate function, currently only "UNNEST\(nested_unnest_table\.column1\)\[c0\]" appears in the SELECT clause satisfies this requirement select unnest(column1) c1 from nested_unnest_table group by c1.c0; -# TODO: this query should work. see issue: https://github.com/apache/datafusion/issues/12794 -query error DataFusion error: Internal error: Assertion failed: struct_allowed: unnest on struct can only be applied at the root level of select expression +## Unnest struct with alias - alias is ignored (same as DuckDB behavior) +## See: https://github.com/apache/datafusion/issues/12794 +query TT? select unnest(column1) c1 from nested_unnest_table +---- +a b {c0: c} +d e {c0: f} + +## Verify schema output for struct unnest with alias (alias is ignored) +query TTT +describe select unnest(column1) c1 from nested_unnest_table; +---- +__unnest_placeholder(nested_unnest_table.column1).c0 Utf8 YES +__unnest_placeholder(nested_unnest_table.column1).c1 Utf8 YES +__unnest_placeholder(nested_unnest_table.column1).c2 Struct("c0": Utf8) YES query II??I?? select unnest(column5), * from unnest_table; diff --git a/datafusion/sqllogictest/test_files/update.slt b/datafusion/sqllogictest/test_files/update.slt index a652ae7633e4..1cd2b626e3b8 100644 --- a/datafusion/sqllogictest/test_files/update.slt +++ b/datafusion/sqllogictest/test_files/update.slt @@ -67,39 +67,48 @@ logical_plan physical_plan_error This feature is not implemented: Physical plan does not support logical expression ScalarSubquery() # set from other table -query TT +# UPDATE ... FROM is currently unsupported +# TODO fix https://github.com/apache/datafusion/issues/19950 +query error DataFusion error: This feature is not implemented: UPDATE ... FROM is not supported explain update t1 set b = t2.b, c = t2.a, d = 1 from t2 where t1.a = t2.a and t1.b > 'foo' and t2.c > 1.0; ----- -logical_plan -01)Dml: op=[Update] table=[t1] -02)--Projection: t1.a AS a, t2.b AS b, CAST(t2.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d -03)----Filter: t1.a = t2.a AND t1.b > CAST(Utf8("foo") AS Utf8View) AND t2.c > Float64(1) -04)------Cross Join: -05)--------TableScan: t1 -06)--------TableScan: t2 -physical_plan -01)CooperativeExec -02)--DmlResultExec: rows_affected=0 +# test update from other table with actual data statement ok -create table t3(a int, b varchar, c double, d int); +insert into t1 values (1, 'zoo', 2.0, 10), (2, 'qux', 3.0, 20), (3, 'bar', 4.0, 30); + +statement ok +insert into t2 values (1, 'updated_b', 5.0, 40), (2, 'updated_b2', 2.5, 50), (4, 'updated_b3', 1.5, 60); + +# UPDATE ... FROM is currently unsupported - qualifier stripping breaks source column references +# causing assignments like 'b = t2.b' to resolve to target table's 'b' instead of source table's 'b' +# TODO fix https://github.com/apache/datafusion/issues/19950 +statement error DataFusion error: This feature is not implemented: UPDATE ... FROM is not supported +update t1 set b = t2.b, c = t2.a, d = 1 from t2 where t1.a = t2.a and t1.b > 'foo' and t2.c > 1.0; # set from multiple tables, DataFusion only supports from one table -query error DataFusion error: Error during planning: Multiple tables in UPDATE SET FROM not yet supported +statement error DataFusion error: This feature is not implemented: Multiple tables in UPDATE SET FROM not yet supported explain update t1 set b = t2.b, c = t3.a, d = 1 from t2, t3 where t1.a = t2.a and t1.a = t3.a; # test table alias -query TT +# UPDATE ... FROM is currently unsupported +# TODO fix https://github.com/apache/datafusion/issues/19950 +statement error DataFusion error: This feature is not implemented: UPDATE ... FROM is not supported explain update t1 as T set b = t2.b, c = t.a, d = 1 from t2 where t.a = t2.a and t.b > 'foo' and t2.c > 1.0; ----- -logical_plan -01)Dml: op=[Update] table=[t1] -02)--Projection: t.a AS a, t2.b AS b, CAST(t.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d -03)----Filter: t.a = t2.a AND t.b > CAST(Utf8("foo") AS Utf8View) AND t2.c > Float64(1) -04)------Cross Join: -05)--------SubqueryAlias: t -06)----------TableScan: t1 -07)--------TableScan: t2 -physical_plan -01)CooperativeExec -02)--DmlResultExec: rows_affected=0 + +# test update with table alias with actual data +statement ok +delete from t1; + +statement ok +delete from t2; + +statement ok +insert into t1 values (1, 'zebra', 1.5, 5), (2, 'wolf', 2.0, 10), (3, 'apple', 3.5, 15); + +statement ok +insert into t2 values (1, 'new_val', 2.0, 100), (2, 'new_val2', 1.5, 200); + +# UPDATE ... FROM is currently unsupported +# TODO fix https://github.com/apache/datafusion/issues/19950 +statement error DataFusion error: This feature is not implemented: UPDATE ... FROM is not supported +update t1 as T set b = t2.b, c = t.a, d = 1 from t2 where t.a = t2.a and t.b > 'foo' and t2.c > 1.0; diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 8ac8724683a8..d444283aa3c3 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -3198,16 +3198,17 @@ EXPLAIN SELECT * FROM (SELECT *, ROW_NUMBER() OVER(ORDER BY a ASC) as rn1 ---- logical_plan 01)Sort: rn1 ASC NULLS LAST -02)--Sort: rn1 ASC NULLS LAST, fetch=5 -03)----Projection: annotated_data_infinite2.a0, annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.c, annotated_data_infinite2.d, row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn1 -04)------Filter: row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW < UInt64(50) +02)--Filter: rn1 < UInt64(50) +03)----Sort: rn1 ASC NULLS LAST, fetch=5 +04)------Projection: annotated_data_infinite2.a0, annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.c, annotated_data_infinite2.d, row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn1 05)--------WindowAggr: windowExpr=[[row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] 06)----------TableScan: annotated_data_infinite2 projection=[a0, a, b, c, d] physical_plan -01)ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as rn1] -02)--FilterExec: row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 < 50, fetch=5 -03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { "row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW": UInt64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] -04)------StreamingTableExec: partition_sizes=1, projection=[a0, a, b, c, d], infinite_source=true, output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST] +01)FilterExec: rn1@5 < 50 +02)--ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as rn1] +03)----GlobalLimitExec: skip=0, fetch=5 +04)------BoundedWindowAggExec: wdw=[row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { "row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW": UInt64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] +05)--------StreamingTableExec: partition_sizes=1, projection=[a0, a, b, c, d], infinite_source=true, output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST] # Top level sort is pushed down through BoundedWindowAggExec as its SUM result does already satisfy the required # global order. The existing sort is for the second-term lexicographical ordering requirement, which is being @@ -4387,9 +4388,9 @@ LIMIT 5; ---- 78 50 63 38 -3 53 +NULL 19 24 31 -14 94 +24 56 # result should be same with above, when LAG/LEAD algorithm work with pruned data. # decreasing batch size, causes data to be produced in smaller chunks at the source. @@ -4406,9 +4407,9 @@ LIMIT 5; ---- 78 50 63 38 -3 53 +NULL 19 24 31 -14 94 +24 56 statement ok set datafusion.execution.batch_size = 100; @@ -6081,3 +6082,49 @@ WHERE acctbal > ( ); ---- 1 + +# Regression test for https://github.com/apache/datafusion/issues/20194 +# Window function with CASE WHEN in ORDER BY combined with NVL filter +# should not trigger SanityCheckPlan error from equivalence normalization +# replacing literals in sort expressions with complex filter expressions. +statement ok +CREATE TABLE issue_20194_t1 ( + value_1_1 decimal(25) NULL, + value_1_2 int NULL, + value_1_3 bigint NULL +); + +statement ok +CREATE TABLE issue_20194_t2 ( + value_2_1 bigint NULL, + value_2_2 varchar(140) NULL, + value_2_3 varchar(140) NULL +); + +statement ok +INSERT INTO issue_20194_t1 (value_1_1, value_1_2, value_1_3) VALUES (6774502793, 10040029, 1120); + +statement ok +INSERT INTO issue_20194_t2 (value_2_1, value_2_2, value_2_3) VALUES (1120, '0', '0'); + +query RII +SELECT + t1.value_1_1, t1.value_1_2, + ROW_NUMBER() OVER ( + PARTITION BY t1.value_1_1, t1.value_1_2 + ORDER BY + CASE WHEN t2.value_2_2 = '0' THEN 1 ELSE 0 END ASC, + CASE WHEN t2.value_2_3 = '0' THEN 1 ELSE 0 END ASC + ) AS ord +FROM issue_20194_t1 t1 +INNER JOIN issue_20194_t2 t2 + ON t1.value_1_3 = t2.value_2_1 + AND nvl(t2.value_2_3, '0') = '0'; +---- +6774502793 10040029 1 + +statement ok +DROP TABLE issue_20194_t1; + +statement ok +DROP TABLE issue_20194_t2; diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index 8bfec86497ef..85479c344860 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -41,12 +41,12 @@ datafusion = { workspace = true, features = ["sql"] } half = { workspace = true } itertools = { workspace = true } object_store = { workspace = true } -pbjson-types = { workspace = true } +# We need to match the version in substrait, so we don't use the workspace version here +pbjson-types = { version = "0.8.0" } prost = { workspace = true } -substrait = { version = "0.62", features = ["serde"] } +substrait = { version = "=0.62.2", features = ["serde"] } url = { workspace = true } tokio = { workspace = true, features = ["fs"] } -uuid = { version = "1.19.0", features = ["v4"] } [dev-dependencies] datafusion = { workspace = true, features = ["nested_expressions", "unicode_expressions"] } diff --git a/datafusion/substrait/src/lib.rs b/datafusion/substrait/src/lib.rs index 407408aaa71b..0819fd3a592f 100644 --- a/datafusion/substrait/src/lib.rs +++ b/datafusion/substrait/src/lib.rs @@ -23,7 +23,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] -#![deny(clippy::allow_attributes)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] //! Serialize / Deserialize DataFusion Plans to [Substrait.io] diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/field_reference.rs b/datafusion/substrait/src/logical_plan/consumer/expr/field_reference.rs index c17bf9c92edc..50d93a4600a0 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/field_reference.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/field_reference.rs @@ -16,34 +16,48 @@ // under the License. use crate::logical_plan::consumer::SubstraitConsumer; -use datafusion::common::{Column, DFSchema, not_impl_err}; +use datafusion::common::{Column, DFSchema, not_impl_err, substrait_err}; use datafusion::logical_expr::Expr; +use std::sync::Arc; use substrait::proto::expression::FieldReference; use substrait::proto::expression::field_reference::ReferenceType::DirectReference; +use substrait::proto::expression::field_reference::RootType; use substrait::proto::expression::reference_segment::ReferenceType::StructField; pub async fn from_field_reference( - _consumer: &impl SubstraitConsumer, + consumer: &impl SubstraitConsumer, field_ref: &FieldReference, input_schema: &DFSchema, ) -> datafusion::common::Result { - from_substrait_field_reference(field_ref, input_schema) + from_substrait_field_reference(consumer, field_ref, input_schema) } pub(crate) fn from_substrait_field_reference( + consumer: &impl SubstraitConsumer, field_ref: &FieldReference, input_schema: &DFSchema, ) -> datafusion::common::Result { match &field_ref.reference_type { Some(DirectReference(direct)) => match &direct.reference_type.as_ref() { - Some(StructField(x)) => match &x.child.as_ref() { - Some(_) => not_impl_err!( - "Direct reference StructField with child is not supported" - ), - None => Ok(Expr::Column(Column::from( - input_schema.qualified_field(x.field as usize), - ))), - }, + Some(StructField(struct_field)) => { + if struct_field.child.is_some() { + return not_impl_err!( + "Direct reference StructField with child is not supported" + ); + } + let field_idx = struct_field.field as usize; + match &field_ref.root_type { + Some(RootType::RootReference(_)) | None => Ok(Expr::Column( + Column::from(input_schema.qualified_field(field_idx)), + )), + Some(RootType::OuterReference(outer_ref)) => { + resolve_outer_reference(consumer, outer_ref, field_idx) + } + Some(RootType::Expression(_)) => not_impl_err!( + "Expression root type in field reference is not supported" + ), + } + } _ => not_impl_err!( "Direct reference with types other than StructField is not supported" ), @@ -51,3 +65,20 @@ pub(crate) fn from_substrait_field_reference( _ => not_impl_err!("unsupported field ref type"), } } + +fn resolve_outer_reference( + consumer: &impl SubstraitConsumer, + outer_ref: &substrait::proto::expression::field_reference::OuterReference, + field_idx: usize, +) -> datafusion::common::Result { + let steps_out = outer_ref.steps_out as usize; + let Some(outer_schema) = consumer.get_outer_schema(steps_out) else { + return substrait_err!( + "OuterReference with steps_out={steps_out} \ + but no outer schema is available" + ); + }; + let (qualifier, field) = outer_schema.qualified_field(field_idx); + let col = Column::from((qualifier, field)); + Ok(Expr::OuterReferenceColumn(Arc::clone(field), col)) +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/literal.rs b/datafusion/substrait/src/logical_plan/consumer/expr/literal.rs index 112f1ea374b3..ad38b6addee0 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/literal.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/literal.rs @@ -102,6 +102,7 @@ pub(crate) fn from_substrait_literal( }, Some(LiteralType::Fp32(f)) => ScalarValue::Float32(Some(*f)), Some(LiteralType::Fp64(f)) => ScalarValue::Float64(Some(*f)), + #[expect(deprecated)] Some(LiteralType::Timestamp(t)) => { // Kept for backwards compatibility, new plans should use PrecisionTimestamp(Tz) instead #[expect(deprecated)] @@ -385,6 +386,7 @@ pub(crate) fn from_substrait_literal( use interval_day_to_second::PrecisionMode; // DF only supports millisecond precision, so for any more granular type we lose precision let milliseconds = match precision_mode { + #[expect(deprecated)] Some(PrecisionMode::Microseconds(ms)) => ms / 1000, None => { if *subseconds != 0 { diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs b/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs index 6c2bc652bb19..5d98850c72cc 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs @@ -88,6 +88,7 @@ pub async fn from_substrait_rex( consumer.consume_subquery(expr.as_ref(), input_schema).await } RexType::Nested(expr) => consumer.consume_nested(expr, input_schema).await, + #[expect(deprecated)] RexType::Enum(expr) => consumer.consume_enum(expr, input_schema).await, RexType::DynamicParameter(expr) => { consumer.consume_dynamic_parameter(expr, input_schema).await @@ -116,10 +117,7 @@ pub async fn from_substrait_extended_expr( return not_impl_err!("Type variation extensions are not supported"); } - let consumer = DefaultSubstraitConsumer { - extensions: &extensions, - state, - }; + let consumer = DefaultSubstraitConsumer::new(&extensions, state); let input_schema = DFSchemaRef::new(match &extended_expr.base_schema { Some(base_schema) => from_substrait_named_struct(&consumer, base_schema), diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs b/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs index 15fe7947a1e1..83cf8400eebf 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs @@ -16,14 +16,33 @@ // under the License. use crate::logical_plan::consumer::SubstraitConsumer; -use datafusion::common::{DFSchema, Spans, substrait_err}; -use datafusion::logical_expr::expr::{Exists, InSubquery}; -use datafusion::logical_expr::{Expr, Subquery}; +use datafusion::common::{DFSchema, Spans, substrait_datafusion_err, substrait_err}; +use datafusion::logical_expr::expr::{Exists, InSubquery, SetComparison, SetQuantifier}; +use datafusion::logical_expr::{Expr, LogicalPlan, Operator, Subquery}; use std::sync::Arc; +use substrait::proto::Rel; use substrait::proto::expression as substrait_expression; use substrait::proto::expression::subquery::SubqueryType; +use substrait::proto::expression::subquery::set_comparison::{ComparisonOp, ReductionOp}; use substrait::proto::expression::subquery::set_predicate::PredicateOp; +/// Consume a subquery relation, making the enclosing query's schema +/// available for resolving correlated column references. +/// +/// Substrait represents correlated references using `OuterReference` +/// field references with a `steps_out` depth. To resolve these, +/// the consumer maintains a stack of outer schemas. +async fn consume_subquery_rel( + consumer: &impl SubstraitConsumer, + rel: &Rel, + outer_schema: &DFSchema, +) -> datafusion::common::Result { + consumer.push_outer_schema(Arc::new(outer_schema.clone())); + let result = consumer.consume_rel(rel).await; + consumer.pop_outer_schema(); + result +} + pub async fn from_subquery( consumer: &impl SubstraitConsumer, subquery: &substrait_expression::Subquery, @@ -40,7 +59,9 @@ pub async fn from_subquery( let needle_expr = &in_predicate.needles[0]; let haystack_expr = &in_predicate.haystack; if let Some(haystack_expr) = haystack_expr { - let haystack_expr = consumer.consume_rel(haystack_expr).await?; + let haystack_expr = + consume_subquery_rel(consumer, haystack_expr, input_schema) + .await?; let outer_refs = haystack_expr.all_out_ref_exprs(); Ok(Expr::InSubquery(InSubquery { expr: Box::new( @@ -63,9 +84,12 @@ pub async fn from_subquery( } } SubqueryType::Scalar(query) => { - let plan = consumer - .consume_rel(&(query.input.clone()).unwrap_or_default()) - .await?; + let plan = consume_subquery_rel( + consumer, + &(query.input.clone()).unwrap_or_default(), + input_schema, + ) + .await?; let outer_ref_columns = plan.all_out_ref_exprs(); Ok(Expr::ScalarSubquery(Subquery { subquery: Arc::new(plan), @@ -78,9 +102,12 @@ pub async fn from_subquery( // exist PredicateOp::Exists => { let relation = &predicate.tuples; - let plan = consumer - .consume_rel(&relation.clone().unwrap_or_default()) - .await?; + let plan = consume_subquery_rel( + consumer, + &relation.clone().unwrap_or_default(), + input_schema, + ) + .await?; let outer_ref_columns = plan.all_out_ref_exprs(); Ok(Expr::Exists(Exists::new( Subquery { @@ -96,8 +123,53 @@ pub async fn from_subquery( ), } } - other_type => { - substrait_err!("Subquery type {other_type:?} not implemented") + SubqueryType::SetComparison(comparison) => { + let left = comparison.left.as_ref().ok_or_else(|| { + substrait_datafusion_err!("SetComparison requires a left expression") + })?; + let right = comparison.right.as_ref().ok_or_else(|| { + substrait_datafusion_err!("SetComparison requires a right relation") + })?; + let reduction_op = match ReductionOp::try_from(comparison.reduction_op) { + Ok(ReductionOp::Any) => SetQuantifier::Any, + Ok(ReductionOp::All) => SetQuantifier::All, + _ => { + return substrait_err!( + "Unsupported reduction op for SetComparison: {}", + comparison.reduction_op + ); + } + }; + let comparison_op = match ComparisonOp::try_from(comparison.comparison_op) + { + Ok(ComparisonOp::Eq) => Operator::Eq, + Ok(ComparisonOp::Ne) => Operator::NotEq, + Ok(ComparisonOp::Lt) => Operator::Lt, + Ok(ComparisonOp::Gt) => Operator::Gt, + Ok(ComparisonOp::Le) => Operator::LtEq, + Ok(ComparisonOp::Ge) => Operator::GtEq, + _ => { + return substrait_err!( + "Unsupported comparison op for SetComparison: {}", + comparison.comparison_op + ); + } + }; + + let left_expr = consumer.consume_expression(left, input_schema).await?; + let plan = consume_subquery_rel(consumer, right, input_schema).await?; + let outer_ref_columns = plan.all_out_ref_exprs(); + + Ok(Expr::SetComparison(SetComparison::new( + Box::new(left_expr), + Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + spans: Spans::new(), + }, + comparison_op, + reduction_op, + ))) } }, None => { diff --git a/datafusion/substrait/src/logical_plan/consumer/plan.rs b/datafusion/substrait/src/logical_plan/consumer/plan.rs index d5e10fb60401..407980c4a7f4 100644 --- a/datafusion/substrait/src/logical_plan/consumer/plan.rs +++ b/datafusion/substrait/src/logical_plan/consumer/plan.rs @@ -35,10 +35,7 @@ pub async fn from_substrait_plan( return not_impl_err!("Type variation extensions are not supported"); } - let consumer = DefaultSubstraitConsumer { - extensions: &extensions, - state, - }; + let consumer = DefaultSubstraitConsumer::new(&extensions, state); from_substrait_plan_with_consumer(&consumer, plan).await } diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/exchange_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/exchange_rel.rs index a6132e047f7d..b275e523f586 100644 --- a/datafusion/substrait/src/logical_plan/consumer/rel/exchange_rel.rs +++ b/datafusion/substrait/src/logical_plan/consumer/rel/exchange_rel.rs @@ -42,7 +42,8 @@ pub async fn from_exchange_rel( let mut partition_columns = vec![]; let input_schema = input.schema(); for field_ref in &scatter_fields.fields { - let column = from_substrait_field_reference(field_ref, input_schema)?; + let column = + from_substrait_field_reference(consumer, field_ref, input_schema)?; partition_columns.push(column); } Partitioning::Hash(partition_columns, exchange.partition_count as usize) diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/fetch_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/fetch_rel.rs index bd6d94736e26..12a8a77199b1 100644 --- a/datafusion/substrait/src/logical_plan/consumer/rel/fetch_rel.rs +++ b/datafusion/substrait/src/logical_plan/consumer/rel/fetch_rel.rs @@ -30,6 +30,7 @@ pub async fn from_fetch_rel( let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); let empty_schema = DFSchemaRef::new(DFSchema::empty()); let offset = match &fetch.offset_mode { + #[expect(deprecated)] Some(fetch_rel::OffsetMode::Offset(offset)) => Some(lit(*offset)), Some(fetch_rel::OffsetMode::OffsetExpr(expr)) => { Some(consumer.consume_expression(expr, &empty_schema).await?) @@ -37,6 +38,7 @@ pub async fn from_fetch_rel( None => None, }; let count = match &fetch.count_mode { + #[expect(deprecated)] Some(fetch_rel::CountMode::Count(count)) => { // -1 means that ALL records should be returned, equivalent to None (*count != -1).then(|| lit(*count)) diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/join_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/join_rel.rs index 3604630d6f0b..7850dbea797f 100644 --- a/datafusion/substrait/src/logical_plan/consumer/rel/join_rel.rs +++ b/datafusion/substrait/src/logical_plan/consumer/rel/join_rel.rs @@ -18,7 +18,7 @@ use crate::logical_plan::consumer::SubstraitConsumer; use datafusion::common::{Column, JoinType, NullEquality, not_impl_err, plan_err}; use datafusion::logical_expr::requalify_sides_if_needed; -use datafusion::logical_expr::utils::split_conjunction; +use datafusion::logical_expr::utils::split_conjunction_owned; use datafusion::logical_expr::{ BinaryExpr, Expr, LogicalPlan, LogicalPlanBuilder, Operator, }; @@ -56,15 +56,10 @@ pub async fn from_join_rel( // So we extract each part as follows: // - If an Eq or IsNotDistinctFrom op is encountered, add the left column, right column and is_null_equal_nulls to `join_ons` vector // - Otherwise we add the expression to join_filter (use conjunction if filter already exists) - let (join_ons, nulls_equal_nulls, join_filter) = - split_eq_and_noneq_join_predicate_with_nulls_equality(&on); + let (join_ons, null_equality, join_filter) = + split_eq_and_noneq_join_predicate_with_nulls_equality(on); let (left_cols, right_cols): (Vec<_>, Vec<_>) = itertools::multiunzip(join_ons); - let null_equality = if nulls_equal_nulls { - NullEquality::NullEqualsNull - } else { - NullEquality::NullEqualsNothing - }; left.join_detailed( right.build()?, join_type, @@ -89,49 +84,61 @@ pub async fn from_join_rel( } fn split_eq_and_noneq_join_predicate_with_nulls_equality( - filter: &Expr, -) -> (Vec<(Column, Column)>, bool, Option) { - let exprs = split_conjunction(filter); + filter: Expr, +) -> (Vec<(Column, Column)>, NullEquality, Option) { + let exprs = split_conjunction_owned(filter); - let mut accum_join_keys: Vec<(Column, Column)> = vec![]; + let mut eq_keys: Vec<(Column, Column)> = vec![]; + let mut indistinct_keys: Vec<(Column, Column)> = vec![]; let mut accum_filters: Vec = vec![]; - let mut nulls_equal_nulls = false; for expr in exprs { - #[expect(clippy::collapsible_match)] match expr { - Expr::BinaryExpr(binary_expr) => match binary_expr { - x @ (BinaryExpr { - left, - op: Operator::Eq, - right, + Expr::BinaryExpr(BinaryExpr { + left, + op: op @ (Operator::Eq | Operator::IsNotDistinctFrom), + right, + }) => match (*left, *right) { + (Expr::Column(l), Expr::Column(r)) => match op { + Operator::Eq => eq_keys.push((l, r)), + Operator::IsNotDistinctFrom => indistinct_keys.push((l, r)), + _ => unreachable!(), + }, + (left, right) => { + accum_filters.push(Expr::BinaryExpr(BinaryExpr { + left: Box::new(left), + op, + right: Box::new(right), + })); } - | BinaryExpr { - left, - op: Operator::IsNotDistinctFrom, - right, - }) => { - nulls_equal_nulls = match x.op { - Operator::Eq => false, - Operator::IsNotDistinctFrom => true, - _ => unreachable!(), - }; - - match (left.as_ref(), right.as_ref()) { - (Expr::Column(l), Expr::Column(r)) => { - accum_join_keys.push((l.clone(), r.clone())); - } - _ => accum_filters.push(expr.clone()), - } - } - _ => accum_filters.push(expr.clone()), }, - _ => accum_filters.push(expr.clone()), + _ => accum_filters.push(expr), } } + let (join_keys, null_equality) = + match (eq_keys.is_empty(), indistinct_keys.is_empty()) { + // Mixed: use eq_keys as equijoin keys, demote indistinct keys to filter + (false, false) => { + for (l, r) in indistinct_keys { + accum_filters.push(Expr::BinaryExpr(BinaryExpr { + left: Box::new(Expr::Column(l)), + op: Operator::IsNotDistinctFrom, + right: Box::new(Expr::Column(r)), + })); + } + (eq_keys, NullEquality::NullEqualsNothing) + } + // Only eq keys + (false, true) => (eq_keys, NullEquality::NullEqualsNothing), + // Only indistinct keys + (true, false) => (indistinct_keys, NullEquality::NullEqualsNull), + // No keys at all + (true, true) => (vec![], NullEquality::NullEqualsNothing), + }; + let join_filter = accum_filters.into_iter().reduce(Expr::and); - (accum_join_keys, nulls_equal_nulls, join_filter) + (join_keys, null_equality, join_filter) } fn from_substrait_jointype(join_type: i32) -> datafusion::common::Result { @@ -153,3 +160,102 @@ fn from_substrait_jointype(join_type: i32) -> datafusion::common::Result Expr { + Expr::Column(Column::from_name(name)) + } + + fn indistinct(left: Expr, right: Expr) -> Expr { + Expr::BinaryExpr(BinaryExpr { + left: Box::new(left), + op: Operator::IsNotDistinctFrom, + right: Box::new(right), + }) + } + + fn fmt_keys(keys: &[(Column, Column)]) -> String { + keys.iter() + .map(|(l, r)| format!("{l} = {r}")) + .collect::>() + .join(", ") + } + + #[test] + fn split_only_eq_keys() { + let expr = col("a").eq(col("b")); + let (keys, null_eq, filter) = + split_eq_and_noneq_join_predicate_with_nulls_equality(expr); + + assert_eq!(fmt_keys(&keys), "a = b"); + assert_eq!(null_eq, NullEquality::NullEqualsNothing); + assert!(filter.is_none()); + } + + #[test] + fn split_only_indistinct_keys() { + let expr = indistinct(col("a"), col("b")); + let (keys, null_eq, filter) = + split_eq_and_noneq_join_predicate_with_nulls_equality(expr); + + assert_eq!(fmt_keys(&keys), "a = b"); + assert_eq!(null_eq, NullEquality::NullEqualsNull); + assert!(filter.is_none()); + } + + /// Regression: mixed `equal` + `is_not_distinct_from` must demote + /// the indistinct key to the join filter so the single NullEquality + /// flag stays consistent (NullEqualsNothing for the eq keys). + #[test] + fn split_mixed_eq_and_indistinct_demotes_indistinct_to_filter() { + let expr = + indistinct(col("val_l"), col("val_r")).and(col("id_l").eq(col("id_r"))); + + let (keys, null_eq, filter) = + split_eq_and_noneq_join_predicate_with_nulls_equality(expr); + + assert_eq!(fmt_keys(&keys), "id_l = id_r"); + assert_eq!(null_eq, NullEquality::NullEqualsNothing); + assert_eq!( + filter.unwrap().to_string(), + "val_l IS NOT DISTINCT FROM val_r" + ); + } + + /// Multiple IS NOT DISTINCT FROM keys with a single Eq key should demote + /// all indistinct keys to the filter. + #[test] + fn split_mixed_multiple_indistinct_demoted() { + let expr = indistinct(col("a_l"), col("a_r")) + .and(indistinct(col("b_l"), col("b_r"))) + .and(col("id_l").eq(col("id_r"))); + + let (keys, null_eq, filter) = + split_eq_and_noneq_join_predicate_with_nulls_equality(expr); + + assert_eq!(fmt_keys(&keys), "id_l = id_r"); + assert_eq!(null_eq, NullEquality::NullEqualsNothing); + assert_eq!( + filter.unwrap().to_string(), + "a_l IS NOT DISTINCT FROM a_r AND b_l IS NOT DISTINCT FROM b_r" + ); + } + + #[test] + fn split_non_column_eq_goes_to_filter() { + let expr = Expr::Literal( + datafusion::common::ScalarValue::Utf8(Some("x".into())), + None, + ) + .eq(col("b")); + + let (keys, _, filter) = + split_eq_and_noneq_join_predicate_with_nulls_equality(expr); + + assert!(keys.is_empty()); + assert_eq!(filter.unwrap().to_string(), "Utf8(\"x\") = b"); + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/project_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/project_rel.rs index 07f9a34888fc..d216d4ecf318 100644 --- a/datafusion/substrait/src/logical_plan/consumer/rel/project_rel.rs +++ b/datafusion/substrait/src/logical_plan/consumer/rel/project_rel.rs @@ -62,20 +62,7 @@ pub async fn from_project_rel( // to transform it into a column reference window_exprs.insert(e.clone()); } - // Substrait plans are ordinal based, so they do not provide names for columns. - // Names for columns are generated by Datafusion during conversion, and for literals - // Datafusion produces names based on the literal value. It is possible to construct - // valid Substrait plans that result in duplicated names if the same literal value is - // used in multiple relations. To avoid this issue, we alias literals with unique names. - // The name tracker will ensure that two literals in the same project would have - // unique names but, it does not ensure that if a literal column exists in a previous - // project say before a join that it is deduplicated with respect to those columns. - // See: https://github.com/apache/datafusion/pull/17299 - let maybe_apply_alias = match e { - lit @ Expr::Literal(_, _) => lit.alias(uuid::Uuid::new_v4().to_string()), - _ => e, - }; - explicit_exprs.push(name_tracker.get_uniquely_named_expr(maybe_apply_alias)?); + explicit_exprs.push(name_tracker.get_uniquely_named_expr(e)?); } let input = if !window_exprs.is_empty() { diff --git a/datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs b/datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs index 4c19227a30c7..a23f1faed1eb 100644 --- a/datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs @@ -31,7 +31,7 @@ use datafusion::common::{ }; use datafusion::execution::{FunctionRegistry, SessionState}; use datafusion::logical_expr::{Expr, Extension, LogicalPlan}; -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use substrait::proto; use substrait::proto::expression as substrait_expression; use substrait::proto::expression::{ @@ -364,6 +364,26 @@ pub trait SubstraitConsumer: Send + Sync + Sized { not_impl_err!("Dynamic Parameter expression not supported") } + // Outer Schema Stack + // These methods manage a stack of outer schemas for correlated subquery support. + // When entering a subquery, the enclosing query's schema is pushed onto the stack. + // Field references with OuterReference root_type use these to resolve columns. + + /// Push an outer schema onto the stack when entering a subquery. + fn push_outer_schema(&self, _schema: Arc) {} + + /// Pop an outer schema from the stack when leaving a subquery. + fn pop_outer_schema(&self) {} + + /// Get the outer schema at the given nesting depth. + /// `steps_out = 1` is the immediately enclosing query, `steps_out = 2` + /// is two levels out, etc. Returns `None` if `steps_out` is 0 or + /// exceeds the current nesting depth (the caller should treat this as + /// an error in the Substrait plan). + fn get_outer_schema(&self, _steps_out: usize) -> Option> { + None + } + // User-Defined Functionality // The details of extension relations, and how to handle them, are fully up to users to specify. @@ -437,11 +457,16 @@ pub trait SubstraitConsumer: Send + Sync + Sized { pub struct DefaultSubstraitConsumer<'a> { pub(super) extensions: &'a Extensions, pub(super) state: &'a SessionState, + outer_schemas: RwLock>>, } impl<'a> DefaultSubstraitConsumer<'a> { pub fn new(extensions: &'a Extensions, state: &'a SessionState) -> Self { - DefaultSubstraitConsumer { extensions, state } + DefaultSubstraitConsumer { + extensions, + state, + outer_schemas: RwLock::new(Vec::new()), + } } } @@ -465,6 +490,24 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> { self.state } + fn push_outer_schema(&self, schema: Arc) { + self.outer_schemas.write().unwrap().push(schema); + } + + fn pop_outer_schema(&self) { + self.outer_schemas.write().unwrap().pop(); + } + + fn get_outer_schema(&self, steps_out: usize) -> Option> { + let schemas = self.outer_schemas.read().unwrap(); + // steps_out=1 → last element, steps_out=2 → second-to-last, etc. + // Returns None for steps_out=0 or steps_out > stack depth. + schemas + .len() + .checked_sub(steps_out) + .and_then(|idx| schemas.get(idx).cloned()) + } + async fn consume_extension_leaf( &self, rel: &ExtensionLeafRel, @@ -520,3 +563,79 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> { Ok(LogicalPlan::Extension(Extension { node: plan })) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::logical_plan::consumer::utils::tests::test_consumer; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + + fn make_schema(fields: &[(&str, DataType)]) -> Arc { + let arrow_fields: Vec = fields + .iter() + .map(|(name, dt)| Field::new(*name, dt.clone(), true)) + .collect(); + Arc::new( + DFSchema::try_from(Schema::new(arrow_fields)) + .expect("failed to create schema"), + ) + } + + #[test] + fn test_get_outer_schema_empty_stack() { + let consumer = test_consumer(); + + // No schemas pushed — any steps_out should return None + assert!(consumer.get_outer_schema(0).is_none()); + assert!(consumer.get_outer_schema(1).is_none()); + assert!(consumer.get_outer_schema(2).is_none()); + } + + #[test] + fn test_get_outer_schema_single_level() { + let consumer = test_consumer(); + + let schema_a = make_schema(&[("a", DataType::Int64)]); + consumer.push_outer_schema(Arc::clone(&schema_a)); + + // steps_out=1 returns the one pushed schema + let result = consumer.get_outer_schema(1).unwrap(); + assert_eq!(result.fields().len(), 1); + assert_eq!(result.fields()[0].name(), "a"); + + // steps_out=0 and steps_out=2 are out of range + assert!(consumer.get_outer_schema(0).is_none()); + assert!(consumer.get_outer_schema(2).is_none()); + + consumer.pop_outer_schema(); + assert!(consumer.get_outer_schema(1).is_none()); + } + + #[test] + fn test_get_outer_schema_nested() { + let consumer = test_consumer(); + + let schema_a = make_schema(&[("a", DataType::Int64)]); + let schema_b = make_schema(&[("b", DataType::Utf8)]); + + consumer.push_outer_schema(Arc::clone(&schema_a)); + consumer.push_outer_schema(Arc::clone(&schema_b)); + + // steps_out=1 returns the most recent (schema_b) + let result = consumer.get_outer_schema(1).unwrap(); + assert_eq!(result.fields()[0].name(), "b"); + + // steps_out=2 returns the grandparent (schema_a) + let result = consumer.get_outer_schema(2).unwrap(); + assert_eq!(result.fields()[0].name(), "a"); + + // steps_out=3 exceeds depth + assert!(consumer.get_outer_schema(3).is_none()); + + // Pop one level — now steps_out=1 returns schema_a + consumer.pop_outer_schema(); + let result = consumer.get_outer_schema(1).unwrap(); + assert_eq!(result.fields()[0].name(), "a"); + assert!(consumer.get_outer_schema(2).is_none()); + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/types.rs b/datafusion/substrait/src/logical_plan/consumer/types.rs index eb2cc967ca23..9ef7a0dd46b8 100644 --- a/datafusion/substrait/src/logical_plan/consumer/types.rs +++ b/datafusion/substrait/src/logical_plan/consumer/types.rs @@ -88,6 +88,7 @@ pub fn from_substrait_type( }, r#type::Kind::Fp32(_) => Ok(DataType::Float32), r#type::Kind::Fp64(_) => Ok(DataType::Float64), + #[expect(deprecated)] r#type::Kind::Timestamp(ts) => { // Kept for backwards compatibility, new plans should use PrecisionTimestamp(Tz) instead #[expect(deprecated)] diff --git a/datafusion/substrait/src/logical_plan/consumer/utils.rs b/datafusion/substrait/src/logical_plan/consumer/utils.rs index 9325926c278a..59cdf4a8fc93 100644 --- a/datafusion/substrait/src/logical_plan/consumer/utils.rs +++ b/datafusion/substrait/src/logical_plan/consumer/utils.rs @@ -23,6 +23,7 @@ use datafusion::common::{ }; use datafusion::logical_expr::expr::Sort; use datafusion::logical_expr::{Cast, Expr, ExprSchemable}; +use datafusion::sql::TableReference; use std::collections::HashSet; use std::sync::Arc; use substrait::proto::SortField; @@ -359,35 +360,71 @@ fn compatible_nullabilities( } pub(super) struct NameTracker { - seen_names: HashSet, -} - -pub(super) enum NameTrackerStatus { - NeverSeen, - SeenBefore, + /// Tracks seen schema names (from expr.schema_name()). + /// Used to detect duplicates that would fail validate_unique_names. + seen_schema_names: HashSet, + /// Tracks column names that have been seen with a qualifier. + /// Used to detect ambiguous references (qualified + unqualified with same name). + qualified_names: HashSet, + /// Tracks column names that have been seen without a qualifier. + /// Used to detect ambiguous references. + unqualified_names: HashSet, } impl NameTracker { pub(super) fn new() -> Self { NameTracker { - seen_names: HashSet::default(), + seen_schema_names: HashSet::default(), + qualified_names: HashSet::default(), + unqualified_names: HashSet::default(), } } - pub(super) fn get_unique_name( - &mut self, - name: String, - ) -> (String, NameTrackerStatus) { - match self.seen_names.insert(name.clone()) { - true => (name, NameTrackerStatus::NeverSeen), - false => { - let mut counter = 0; - loop { - let candidate_name = format!("{name}__temp__{counter}"); - if self.seen_names.insert(candidate_name.clone()) { - return (candidate_name, NameTrackerStatus::SeenBefore); - } - counter += 1; - } + + /// Check if the expression would cause a conflict either in: + /// 1. validate_unique_names (duplicate schema_name) + /// 2. DFSchema::check_names (ambiguous reference) + fn would_conflict(&self, expr: &Expr) -> bool { + let (qualifier, name) = expr.qualified_name(); + let schema_name = expr.schema_name().to_string(); + self.would_conflict_inner((qualifier, &name), &schema_name) + } + + fn would_conflict_inner( + &self, + qualified_name: (Option, &str), + schema_name: &str, + ) -> bool { + // Check for duplicate schema_name (would fail validate_unique_names) + if self.seen_schema_names.contains(schema_name) { + return true; + } + + // Check for ambiguous reference (would fail DFSchema::check_names) + // This happens when a qualified field and unqualified field have the same name + let (qualifier, name) = qualified_name; + match qualifier { + Some(_) => { + // Adding a qualified name - conflicts if unqualified version exists + self.unqualified_names.contains(name) + } + None => { + // Adding an unqualified name - conflicts if qualified version exists + self.qualified_names.contains(name) + } + } + } + + fn insert(&mut self, expr: &Expr) { + let schema_name = expr.schema_name().to_string(); + self.seen_schema_names.insert(schema_name); + + let (qualifier, name) = expr.qualified_name(); + match qualifier { + Some(_) => { + self.qualified_names.insert(name); + } + None => { + self.unqualified_names.insert(name); } } } @@ -396,10 +433,25 @@ impl NameTracker { &mut self, expr: Expr, ) -> datafusion::common::Result { - match self.get_unique_name(expr.name_for_alias()?) { - (_, NameTrackerStatus::NeverSeen) => Ok(expr), - (name, NameTrackerStatus::SeenBefore) => Ok(expr.alias(name)), + if !self.would_conflict(&expr) { + self.insert(&expr); + return Ok(expr); } + + // Name collision - need to generate a unique alias + let schema_name = expr.schema_name().to_string(); + let mut counter = 0; + let candidate_name = loop { + let candidate_name = format!("{schema_name}__temp__{counter}"); + // .alias always produces an unqualified name so check for conflicts accordingly. + if !self.would_conflict_inner((None, &candidate_name), &candidate_name) { + break candidate_name; + } + counter += 1; + }; + let candidate_expr = expr.alias(&candidate_name); + self.insert(&candidate_expr); + Ok(candidate_expr) } } @@ -469,13 +521,14 @@ pub(crate) fn from_substrait_precision( #[cfg(test)] pub(crate) mod tests { - use super::make_renamed_schema; + use super::{NameTracker, make_renamed_schema}; use crate::extensions::Extensions; use crate::logical_plan::consumer::DefaultSubstraitConsumer; use datafusion::arrow::datatypes::{DataType, Field}; use datafusion::common::DFSchema; use datafusion::error::Result; use datafusion::execution::SessionState; + use datafusion::logical_expr::{Expr, col}; use datafusion::prelude::SessionContext; use datafusion::sql::TableReference; use std::collections::HashMap; @@ -641,4 +694,123 @@ pub(crate) mod tests { ); Ok(()) } + + #[test] + fn name_tracker_unique_names_pass_through() -> Result<()> { + let mut tracker = NameTracker::new(); + + // First expression should pass through unchanged + let expr1 = col("a"); + let result1 = tracker.get_uniquely_named_expr(expr1.clone())?; + assert_eq!(result1, col("a")); + + // Different name should also pass through unchanged + let expr2 = col("b"); + let result2 = tracker.get_uniquely_named_expr(expr2)?; + assert_eq!(result2, col("b")); + + Ok(()) + } + + #[test] + fn name_tracker_duplicate_schema_name_gets_alias() -> Result<()> { + let mut tracker = NameTracker::new(); + + // First expression with name "a" + let expr1 = col("a"); + let result1 = tracker.get_uniquely_named_expr(expr1)?; + assert_eq!(result1, col("a")); + + // Second expression with same name "a" should get aliased + let expr2 = col("a"); + let result2 = tracker.get_uniquely_named_expr(expr2)?; + assert_eq!(result2, col("a").alias("a__temp__0")); + + // Third expression with same name "a" should get a different alias + let expr3 = col("a"); + let result3 = tracker.get_uniquely_named_expr(expr3)?; + assert_eq!(result3, col("a").alias("a__temp__1")); + + Ok(()) + } + + #[test] + fn name_tracker_qualified_then_unqualified_conflicts() -> Result<()> { + let mut tracker = NameTracker::new(); + + // First: qualified column "table.a" + let qualified_col = + Expr::Column(datafusion::common::Column::new(Some("table"), "a")); + let result1 = tracker.get_uniquely_named_expr(qualified_col)?; + assert_eq!( + result1, + Expr::Column(datafusion::common::Column::new(Some("table"), "a")) + ); + + // Second: unqualified column "a" - should conflict (ambiguous reference) + let unqualified_col = col("a"); + let result2 = tracker.get_uniquely_named_expr(unqualified_col)?; + // Should be aliased to avoid ambiguous reference + assert_eq!(result2, col("a").alias("a__temp__0")); + + Ok(()) + } + + #[test] + fn name_tracker_unqualified_then_qualified_conflicts() -> Result<()> { + let mut tracker = NameTracker::new(); + + // First: unqualified column "a" + let unqualified_col = col("a"); + let result1 = tracker.get_uniquely_named_expr(unqualified_col)?; + assert_eq!(result1, col("a")); + + // Second: qualified column "table.a" - should conflict (ambiguous reference) + let qualified_col = + Expr::Column(datafusion::common::Column::new(Some("table"), "a")); + let result2 = tracker.get_uniquely_named_expr(qualified_col)?; + // Should be aliased to avoid ambiguous reference + assert_eq!( + result2, + Expr::Column(datafusion::common::Column::new(Some("table"), "a")) + .alias("table.a__temp__0") + ); + + Ok(()) + } + + #[test] + fn name_tracker_different_qualifiers_no_conflict() -> Result<()> { + let mut tracker = NameTracker::new(); + + // First: qualified column "table1.a" + let col1 = Expr::Column(datafusion::common::Column::new(Some("table1"), "a")); + let result1 = tracker.get_uniquely_named_expr(col1.clone())?; + assert_eq!(result1, col1); + + // Second: qualified column "table2.a" - different qualifier, different schema_name + // so should NOT conflict + let col2 = Expr::Column(datafusion::common::Column::new(Some("table2"), "a")); + let result2 = tracker.get_uniquely_named_expr(col2.clone())?; + assert_eq!(result2, col2); + + Ok(()) + } + + #[test] + fn name_tracker_aliased_expressions() -> Result<()> { + let mut tracker = NameTracker::new(); + + // First: col("x").alias("result") + let expr1 = col("x").alias("result"); + let result1 = tracker.get_uniquely_named_expr(expr1.clone())?; + assert_eq!(result1, col("x").alias("result")); + + // Second: col("y").alias("result") - same alias name, should conflict + let expr2 = col("y").alias("result"); + let result2 = tracker.get_uniquely_named_expr(expr2)?; + assert_eq!(result2, col("y").alias("result").alias("result__temp__0")); + + Ok(()) + } } diff --git a/datafusion/substrait/src/logical_plan/producer/expr/cast.rs b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs index 53d3d3e12c4b..6eb27fc39df6 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/cast.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs @@ -35,7 +35,7 @@ pub fn from_cast( // only the untyped(a null scalar value) null literal need this special handling // since all other kind of nulls are already typed and can be handled by substrait // e.g. null:: or null:: - if matches!(lit, ScalarValue::Null) { + if *lit == ScalarValue::Null { let lit = Literal { nullable: true, type_variation_reference: DEFAULT_TYPE_VARIATION_REF, diff --git a/datafusion/substrait/src/logical_plan/producer/expr/field_reference.rs b/datafusion/substrait/src/logical_plan/producer/expr/field_reference.rs index b6af7d3bbc8e..aa34317a6e29 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/field_reference.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/field_reference.rs @@ -76,6 +76,22 @@ pub(crate) fn try_to_substrait_field_reference( } } +/// Convert an outer reference column to a Substrait field reference. +/// Outer reference columns reference columns from an outer query scope in correlated subqueries. +/// We convert them the same way as regular columns since the subquery plan will be +/// reconstructed with the proper schema context during consumption. +pub fn from_outer_reference_column( + col: &Column, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + // OuterReferenceColumn is converted similarly to a regular column reference. + // The schema provided should be the schema context in which the outer reference + // column appears. During Substrait round-trip, the consumer will reconstruct + // the outer reference based on the subquery context. + let index = schema.index_of_column(col)?; + substrait_field_ref(index) +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs index 5057564d370c..3aa8aa2b68bc 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs @@ -139,16 +139,17 @@ pub fn to_substrait_rex( } Expr::WindowFunction(expr) => producer.handle_window_function(expr, schema), Expr::InList(expr) => producer.handle_in_list(expr, schema), - Expr::Exists(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::Exists(expr) => producer.handle_exists(expr, schema), Expr::InSubquery(expr) => producer.handle_in_subquery(expr, schema), - Expr::ScalarSubquery(expr) => { - not_impl_err!("Cannot convert {expr:?} to Substrait") - } + Expr::SetComparison(expr) => producer.handle_set_comparison(expr, schema), + Expr::ScalarSubquery(expr) => producer.handle_scalar_subquery(expr, schema), #[expect(deprecated)] Expr::Wildcard { .. } => not_impl_err!("Cannot convert {expr:?} to Substrait"), Expr::GroupingSet(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), Expr::Placeholder(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), Expr::OuterReferenceColumn(_, _) => { + // OuterReferenceColumn requires tracking outer query schema context for correlated + // subqueries. This is a complex feature that is not yet implemented. not_impl_err!("Cannot convert {expr:?} to Substrait") } Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), diff --git a/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs b/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs index bd8a9d9a99b5..9f70e903a0bd 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs @@ -344,5 +344,6 @@ pub fn operator_to_name(op: Operator) -> &'static str { Operator::BitwiseXor => "bitwise_xor", Operator::BitwiseShiftRight => "bitwise_shift_right", Operator::BitwiseShiftLeft => "bitwise_shift_left", + Operator::Colon => "colon", } } diff --git a/datafusion/substrait/src/logical_plan/producer/expr/singular_or_list.rs b/datafusion/substrait/src/logical_plan/producer/expr/singular_or_list.rs index 2d53db6501a5..fd09a60d5ead 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/singular_or_list.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/singular_or_list.rs @@ -15,12 +15,11 @@ // specific language governing permissions and limitations // under the License. -use crate::logical_plan::producer::SubstraitProducer; +use crate::logical_plan::producer::{SubstraitProducer, negate}; use datafusion::common::DFSchemaRef; use datafusion::logical_expr::expr::InList; -use substrait::proto::expression::{RexType, ScalarFunction, SingularOrList}; -use substrait::proto::function_argument::ArgType; -use substrait::proto::{Expression, FunctionArgument}; +use substrait::proto::Expression; +use substrait::proto::expression::{RexType, SingularOrList}; pub fn from_in_list( producer: &mut impl SubstraitProducer, @@ -46,20 +45,7 @@ pub fn from_in_list( }; if *negated { - let function_anchor = producer.register_function("not".to_string()); - - #[expect(deprecated)] - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments: vec![FunctionArgument { - arg_type: Some(ArgType::Value(substrait_or_list)), - }], - output_type: None, - args: vec![], - options: vec![], - })), - }) + Ok(negate(producer, substrait_or_list)) } else { Ok(substrait_or_list) } diff --git a/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs b/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs index f2e6ff551223..97699c213278 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs @@ -15,13 +15,14 @@ // specific language governing permissions and limitations // under the License. -use crate::logical_plan::producer::SubstraitProducer; -use datafusion::common::DFSchemaRef; -use datafusion::logical_expr::expr::InSubquery; -use substrait::proto::expression::subquery::InPredicate; -use substrait::proto::expression::{RexType, ScalarFunction}; -use substrait::proto::function_argument::ArgType; -use substrait::proto::{Expression, FunctionArgument}; +use crate::logical_plan::producer::{SubstraitProducer, negate}; +use datafusion::common::{DFSchemaRef, substrait_err}; +use datafusion::logical_expr::expr::{Exists, InSubquery, SetComparison, SetQuantifier}; +use datafusion::logical_expr::{Operator, Subquery}; +use substrait::proto::Expression; +use substrait::proto::expression::RexType; +use substrait::proto::expression::subquery::set_comparison::{ComparisonOp, ReductionOp}; +use substrait::proto::expression::subquery::{InPredicate, Scalar, SetPredicate}; pub fn from_in_subquery( producer: &mut impl SubstraitProducer, @@ -52,21 +53,111 @@ pub fn from_in_subquery( ))), }; if *negated { - let function_anchor = producer.register_function("not".to_string()); - - #[expect(deprecated)] - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments: vec![FunctionArgument { - arg_type: Some(ArgType::Value(substrait_subquery)), - }], - output_type: None, - args: vec![], - options: vec![], - })), - }) + Ok(negate(producer, substrait_subquery)) } else { Ok(substrait_subquery) } } + +fn comparison_op_to_proto(op: &Operator) -> datafusion::common::Result { + match op { + Operator::Eq => Ok(ComparisonOp::Eq), + Operator::NotEq => Ok(ComparisonOp::Ne), + Operator::Lt => Ok(ComparisonOp::Lt), + Operator::Gt => Ok(ComparisonOp::Gt), + Operator::LtEq => Ok(ComparisonOp::Le), + Operator::GtEq => Ok(ComparisonOp::Ge), + _ => substrait_err!("Unsupported operator {op:?} for SetComparison subquery"), + } +} + +fn reduction_op_to_proto( + quantifier: &SetQuantifier, +) -> datafusion::common::Result { + match quantifier { + SetQuantifier::Any => Ok(ReductionOp::Any), + SetQuantifier::All => Ok(ReductionOp::All), + } +} + +pub fn from_set_comparison( + producer: &mut impl SubstraitProducer, + set_comparison: &SetComparison, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let comparison_op = comparison_op_to_proto(&set_comparison.op)? as i32; + let reduction_op = reduction_op_to_proto(&set_comparison.quantifier)? as i32; + let left = producer.handle_expr(set_comparison.expr.as_ref(), schema)?; + let subquery_plan = + producer.handle_plan(set_comparison.subquery.subquery.as_ref())?; + + Ok(Expression { + rex_type: Some(RexType::Subquery(Box::new( + substrait::proto::expression::Subquery { + subquery_type: Some( + substrait::proto::expression::subquery::SubqueryType::SetComparison( + Box::new(substrait::proto::expression::subquery::SetComparison { + reduction_op, + comparison_op, + left: Some(Box::new(left)), + right: Some(subquery_plan), + }), + ), + ), + }, + ))), + }) +} + +/// Convert DataFusion ScalarSubquery to Substrait Scalar subquery type +pub fn from_scalar_subquery( + producer: &mut impl SubstraitProducer, + subquery: &Subquery, + _schema: &DFSchemaRef, +) -> datafusion::common::Result { + let subquery_plan = producer.handle_plan(subquery.subquery.as_ref())?; + + Ok(Expression { + rex_type: Some(RexType::Subquery(Box::new( + substrait::proto::expression::Subquery { + subquery_type: Some( + substrait::proto::expression::subquery::SubqueryType::Scalar( + Box::new(Scalar { + input: Some(subquery_plan), + }), + ), + ), + }, + ))), + }) +} + +/// Convert DataFusion Exists expression to Substrait SetPredicate subquery type +pub fn from_exists( + producer: &mut impl SubstraitProducer, + exists: &Exists, + _schema: &DFSchemaRef, +) -> datafusion::common::Result { + let subquery_plan = producer.handle_plan(exists.subquery.subquery.as_ref())?; + + let substrait_exists = Expression { + rex_type: Some(RexType::Subquery(Box::new( + substrait::proto::expression::Subquery { + subquery_type: Some( + substrait::proto::expression::subquery::SubqueryType::SetPredicate( + Box::new(SetPredicate { + predicate_op: substrait::proto::expression::subquery::set_predicate::PredicateOp::Exists as i32, + tuples: Some(subquery_plan), + }), + ), + ), + }, + ))), + }; + + if exists.negated { + Ok(negate(producer, substrait_exists)) + } else { + Ok(substrait_exists) + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs index ffc920ffe609..51d2c0ca8e78 100644 --- a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs +++ b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs @@ -18,16 +18,20 @@ use crate::extensions::Extensions; use crate::logical_plan::producer::{ from_aggregate, from_aggregate_function, from_alias, from_between, from_binary_expr, - from_case, from_cast, from_column, from_distinct, from_empty_relation, from_filter, - from_in_list, from_in_subquery, from_join, from_like, from_limit, from_literal, - from_projection, from_repartition, from_scalar_function, from_sort, - from_subquery_alias, from_table_scan, from_try_cast, from_unary_expr, from_union, - from_values, from_window, from_window_function, to_substrait_rel, to_substrait_rex, + from_case, from_cast, from_column, from_distinct, from_empty_relation, from_exists, + from_filter, from_in_list, from_in_subquery, from_join, from_like, from_limit, + from_literal, from_projection, from_repartition, from_scalar_function, + from_scalar_subquery, from_set_comparison, from_sort, from_subquery_alias, + from_table_scan, from_try_cast, from_unary_expr, from_union, from_values, + from_window, from_window_function, to_substrait_rel, to_substrait_rex, }; use datafusion::common::{Column, DFSchemaRef, ScalarValue, substrait_err}; use datafusion::execution::SessionState; use datafusion::execution::registry::SerializerRegistry; -use datafusion::logical_expr::expr::{Alias, InList, InSubquery, WindowFunction}; +use datafusion::logical_expr::Subquery; +use datafusion::logical_expr::expr::{ + Alias, Exists, InList, InSubquery, SetComparison, WindowFunction, +}; use datafusion::logical_expr::{ Aggregate, Between, BinaryExpr, Case, Cast, Distinct, EmptyRelation, Expr, Extension, Filter, Join, Like, Limit, LogicalPlan, Projection, Repartition, Sort, SubqueryAlias, @@ -361,6 +365,29 @@ pub trait SubstraitProducer: Send + Sync + Sized { ) -> datafusion::common::Result { from_in_subquery(self, in_subquery, schema) } + + fn handle_set_comparison( + &mut self, + set_comparison: &SetComparison, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_set_comparison(self, set_comparison, schema) + } + fn handle_scalar_subquery( + &mut self, + subquery: &Subquery, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_scalar_subquery(self, subquery, schema) + } + + fn handle_exists( + &mut self, + exists: &Exists, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_exists(self, exists, schema) + } } pub struct DefaultSubstraitProducer<'a> { diff --git a/datafusion/substrait/src/logical_plan/producer/utils.rs b/datafusion/substrait/src/logical_plan/producer/utils.rs index 820c14809dd7..e8310f4acd31 100644 --- a/datafusion/substrait/src/logical_plan/producer/utils.rs +++ b/datafusion/substrait/src/logical_plan/producer/utils.rs @@ -19,8 +19,8 @@ use crate::logical_plan::producer::SubstraitProducer; use datafusion::arrow::datatypes::{DataType, Field, TimeUnit}; use datafusion::common::{DFSchemaRef, plan_err}; use datafusion::logical_expr::SortExpr; -use substrait::proto::SortField; use substrait::proto::sort_field::{SortDirection, SortKind}; +use substrait::proto::{Expression, SortField}; // Substrait wants a list of all field names, including nested fields from structs, // also from within e.g. lists and maps. However, it does not want the list and map field names @@ -85,3 +85,28 @@ pub(crate) fn to_substrait_precision(time_unit: &TimeUnit) -> i32 { TimeUnit::Nanosecond => 9, } } + +/// Wraps an expression with a `not()` function. +pub(crate) fn negate( + producer: &mut impl SubstraitProducer, + expr: Expression, +) -> Expression { + let function_anchor = producer.register_function("not".to_string()); + + #[expect(deprecated)] + Expression { + rex_type: Some(substrait::proto::expression::RexType::ScalarFunction( + substrait::proto::expression::ScalarFunction { + function_reference: function_anchor, + arguments: vec![substrait::proto::FunctionArgument { + arg_type: Some(substrait::proto::function_argument::ArgType::Value( + expr, + )), + }], + output_type: None, + args: vec![], + options: vec![], + }, + )), + } +} diff --git a/datafusion/substrait/src/physical_plan/consumer.rs b/datafusion/substrait/src/physical_plan/consumer.rs index ac0f26722513..ccaf1abec424 100644 --- a/datafusion/substrait/src/physical_plan/consumer.rs +++ b/datafusion/substrait/src/physical_plan/consumer.rs @@ -119,20 +119,14 @@ pub async fn from_substrait_rel( .unwrap(); let size = 0; - let partitioned_file = PartitionedFile { - object_meta: ObjectMeta { + let partitioned_file = + PartitionedFile::new_from_meta(ObjectMeta { last_modified: last_modified.into(), location: path.into(), size, e_tag: None, version: None, - }, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }; + }); let part_index = file.partition_index as usize; while part_index >= file_groups.len() { diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index 194098cf060e..b5d9f36620c6 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -25,6 +25,8 @@ #[cfg(test)] mod tests { use crate::utils::test::add_plan_schemas_to_ctx; + use datafusion::arrow::record_batch::RecordBatch; + use datafusion::arrow::util::pretty::pretty_format_batches; use datafusion::common::Result; use datafusion::prelude::SessionContext; use datafusion_substrait::logical_plan::consumer::from_substrait_plan; @@ -33,6 +35,34 @@ mod tests { use std::io::BufReader; use substrait::proto::Plan; + async fn execute_plan(name: &str) -> Result> { + let path = format!("tests/testdata/test_plans/{name}"); + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + let ctx = SessionContext::new(); + let plan = from_substrait_plan(&ctx.state(), &proto).await?; + ctx.execute_logical_plan(plan).await?.collect().await + } + + /// Pretty-print batches as a table with header on top and data rows sorted. + fn pretty_sorted(batches: &[RecordBatch]) -> String { + let pretty = pretty_format_batches(batches).unwrap().to_string(); + let all_lines: Vec<&str> = pretty.trim().lines().collect(); + let header = &all_lines[..3]; + let mut data: Vec<&str> = all_lines[3..all_lines.len() - 1].to_vec(); + data.sort(); + let footer = &all_lines[all_lines.len() - 1..]; + header + .iter() + .copied() + .chain(data) + .chain(footer.iter().copied()) + .collect::>() + .join("\n") + } + async fn tpch_plan_to_string(query_id: i32) -> Result { let path = format!("tests/testdata/tpch_substrait_plans/query_{query_id:02}_plan.json"); @@ -77,18 +107,18 @@ mod tests { Subquery: Aggregate: groupBy=[[]], aggr=[[min(PARTSUPP.PS_SUPPLYCOST)]] Projection: PARTSUPP.PS_SUPPLYCOST - Filter: PARTSUPP.PS_PARTKEY = PARTSUPP.PS_PARTKEY AND SUPPLIER.S_SUPPKEY = PARTSUPP.PS_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_REGIONKEY = REGION.R_REGIONKEY AND REGION.R_NAME = Utf8("EUROPE") - Cross Join: - Cross Join: - Cross Join: + Filter: outer_ref(PART.P_PARTKEY) = PARTSUPP.PS_PARTKEY AND SUPPLIER.S_SUPPKEY = PARTSUPP.PS_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_REGIONKEY = REGION.R_REGIONKEY AND REGION.R_NAME = Utf8("EUROPE") + Cross Join: + Cross Join: + Cross Join: TableScan: PARTSUPP TableScan: SUPPLIER TableScan: NATION TableScan: REGION - Cross Join: - Cross Join: - Cross Join: - Cross Join: + Cross Join: + Cross Join: + Cross Join: + Cross Join: TableScan: PART TableScan: SUPPLIER TableScan: PARTSUPP @@ -112,8 +142,8 @@ mod tests { Aggregate: groupBy=[[LINEITEM.L_ORDERKEY, ORDERS.O_ORDERDATE, ORDERS.O_SHIPPRIORITY]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]] Projection: LINEITEM.L_ORDERKEY, ORDERS.O_ORDERDATE, ORDERS.O_SHIPPRIORITY, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) Filter: CUSTOMER.C_MKTSEGMENT = Utf8("BUILDING") AND CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND LINEITEM.L_ORDERKEY = ORDERS.O_ORDERKEY AND ORDERS.O_ORDERDATE < CAST(Utf8("1995-03-15") AS Date32) AND LINEITEM.L_SHIPDATE > CAST(Utf8("1995-03-15") AS Date32) - Cross Join: - Cross Join: + Cross Join: + Cross Join: TableScan: LINEITEM TableScan: CUSTOMER TableScan: ORDERS @@ -134,7 +164,7 @@ mod tests { Projection: ORDERS.O_ORDERPRIORITY Filter: ORDERS.O_ORDERDATE >= CAST(Utf8("1993-07-01") AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8("1993-10-01") AS Date32) AND EXISTS () Subquery: - Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_ORDERKEY AND LINEITEM.L_COMMITDATE < LINEITEM.L_RECEIPTDATE + Filter: LINEITEM.L_ORDERKEY = outer_ref(ORDERS.O_ORDERKEY) AND LINEITEM.L_COMMITDATE < LINEITEM.L_RECEIPTDATE TableScan: LINEITEM TableScan: ORDERS "# @@ -153,11 +183,11 @@ mod tests { Aggregate: groupBy=[[NATION.N_NAME]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]] Projection: NATION.N_NAME, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) Filter: CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND LINEITEM.L_ORDERKEY = ORDERS.O_ORDERKEY AND LINEITEM.L_SUPPKEY = SUPPLIER.S_SUPPKEY AND CUSTOMER.C_NATIONKEY = SUPPLIER.S_NATIONKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_REGIONKEY = REGION.R_REGIONKEY AND REGION.R_NAME = Utf8("ASIA") AND ORDERS.O_ORDERDATE >= CAST(Utf8("1994-01-01") AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8("1995-01-01") AS Date32) - Cross Join: - Cross Join: - Cross Join: - Cross Join: - Cross Join: + Cross Join: + Cross Join: + Cross Join: + Cross Join: + Cross Join: TableScan: CUSTOMER TableScan: ORDERS TableScan: LINEITEM @@ -221,9 +251,9 @@ mod tests { Aggregate: groupBy=[[CUSTOMER.C_CUSTKEY, CUSTOMER.C_NAME, CUSTOMER.C_ACCTBAL, CUSTOMER.C_PHONE, NATION.N_NAME, CUSTOMER.C_ADDRESS, CUSTOMER.C_COMMENT]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]] Projection: CUSTOMER.C_CUSTKEY, CUSTOMER.C_NAME, CUSTOMER.C_ACCTBAL, CUSTOMER.C_PHONE, NATION.N_NAME, CUSTOMER.C_ADDRESS, CUSTOMER.C_COMMENT, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) Filter: CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND LINEITEM.L_ORDERKEY = ORDERS.O_ORDERKEY AND ORDERS.O_ORDERDATE >= CAST(Utf8("1993-10-01") AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8("1994-01-01") AS Date32) AND LINEITEM.L_RETURNFLAG = Utf8("R") AND CUSTOMER.C_NATIONKEY = NATION.N_NATIONKEY - Cross Join: - Cross Join: - Cross Join: + Cross Join: + Cross Join: + Cross Join: TableScan: CUSTOMER TableScan: ORDERS TableScan: LINEITEM @@ -247,16 +277,16 @@ mod tests { Aggregate: groupBy=[[]], aggr=[[sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY)]] Projection: PARTSUPP.PS_SUPPLYCOST * CAST(PARTSUPP.PS_AVAILQTY AS Decimal128(19, 0)) Filter: PARTSUPP.PS_SUPPKEY = SUPPLIER.S_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8("JAPAN") - Cross Join: - Cross Join: + Cross Join: + Cross Join: TableScan: PARTSUPP TableScan: SUPPLIER TableScan: NATION Aggregate: groupBy=[[PARTSUPP.PS_PARTKEY]], aggr=[[sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY)]] Projection: PARTSUPP.PS_PARTKEY, PARTSUPP.PS_SUPPLYCOST * CAST(PARTSUPP.PS_AVAILQTY AS Decimal128(19, 0)) Filter: PARTSUPP.PS_SUPPKEY = SUPPLIER.S_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8("JAPAN") - Cross Join: - Cross Join: + Cross Join: + Cross Join: TableScan: PARTSUPP TableScan: SUPPLIER TableScan: NATION @@ -276,7 +306,7 @@ mod tests { Aggregate: groupBy=[[LINEITEM.L_SHIPMODE]], aggr=[[sum(CASE WHEN ORDERS.O_ORDERPRIORITY = Utf8("1-URGENT") OR ORDERS.O_ORDERPRIORITY = Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END), sum(CASE WHEN ORDERS.O_ORDERPRIORITY != Utf8("1-URGENT") AND ORDERS.O_ORDERPRIORITY != Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END)]] Projection: LINEITEM.L_SHIPMODE, CASE WHEN ORDERS.O_ORDERPRIORITY = Utf8("1-URGENT") OR ORDERS.O_ORDERPRIORITY = Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END, CASE WHEN ORDERS.O_ORDERPRIORITY != Utf8("1-URGENT") AND ORDERS.O_ORDERPRIORITY != Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END Filter: ORDERS.O_ORDERKEY = LINEITEM.L_ORDERKEY AND (LINEITEM.L_SHIPMODE = CAST(Utf8("MAIL") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8("SHIP") AS Utf8)) AND LINEITEM.L_COMMITDATE < LINEITEM.L_RECEIPTDATE AND LINEITEM.L_SHIPDATE < LINEITEM.L_COMMITDATE AND LINEITEM.L_RECEIPTDATE >= CAST(Utf8("1994-01-01") AS Date32) AND LINEITEM.L_RECEIPTDATE < CAST(Utf8("1995-01-01") AS Date32) - Cross Join: + Cross Join: TableScan: ORDERS TableScan: LINEITEM "# @@ -314,7 +344,7 @@ mod tests { Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN PART.P_TYPE LIKE Utf8("PROMO%") THEN LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT ELSE Decimal128(Some(0),19,4) END), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]] Projection: CASE WHEN PART.P_TYPE LIKE CAST(Utf8("PROMO%") AS Utf8) THEN LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) ELSE Decimal128(Some(0),19,4) END, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) Filter: LINEITEM.L_PARTKEY = PART.P_PARTKEY AND LINEITEM.L_SHIPDATE >= Date32("1995-09-01") AND LINEITEM.L_SHIPDATE < CAST(Utf8("1995-10-01") AS Date32) - Cross Join: + Cross Join: TableScan: LINEITEM TableScan: PART "# @@ -345,7 +375,7 @@ mod tests { Projection: SUPPLIER.S_SUPPKEY Filter: SUPPLIER.S_COMMENT LIKE CAST(Utf8("%Customer%Complaints%") AS Utf8) TableScan: SUPPLIER - Cross Join: + Cross Join: TableScan: PARTSUPP TableScan: PART "# @@ -353,11 +383,27 @@ mod tests { Ok(()) } - #[ignore] #[tokio::test] async fn tpch_test_17() -> Result<()> { let plan_str = tpch_plan_to_string(17).await?; - assert_snapshot!(plan_str, "panics due to out of bounds field access"); + assert_snapshot!( + plan_str, + @r#" + Projection: sum(LINEITEM.L_EXTENDEDPRICE) / Decimal128(Some(70),2,1) AS AVG_YEARLY + Aggregate: groupBy=[[]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE)]] + Projection: LINEITEM.L_EXTENDEDPRICE + Filter: PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8("Brand#23") AND PART.P_CONTAINER = Utf8("MED BOX") AND LINEITEM.L_QUANTITY < () + Subquery: + Projection: Decimal128(Some(2),2,1) * avg(LINEITEM.L_QUANTITY) + Aggregate: groupBy=[[]], aggr=[[avg(LINEITEM.L_QUANTITY)]] + Projection: LINEITEM.L_QUANTITY + Filter: LINEITEM.L_PARTKEY = outer_ref(PART.P_PARTKEY) + TableScan: LINEITEM + Cross Join: + TableScan: LINEITEM + TableScan: PART + "# + ); Ok(()) } @@ -379,8 +425,8 @@ mod tests { Aggregate: groupBy=[[LINEITEM.L_ORDERKEY]], aggr=[[sum(LINEITEM.L_QUANTITY)]] Projection: LINEITEM.L_ORDERKEY, LINEITEM.L_QUANTITY TableScan: LINEITEM - Cross Join: - Cross Join: + Cross Join: + Cross Join: TableScan: CUSTOMER TableScan: ORDERS TableScan: LINEITEM @@ -397,7 +443,7 @@ mod tests { Aggregate: groupBy=[[]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS REVENUE]] Projection: LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) Filter: PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8("Brand#12") AND (PART.P_CONTAINER = CAST(Utf8("SM CASE") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("SM BOX") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("SM PACK") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("SM PKG") AS Utf8)) AND LINEITEM.L_QUANTITY >= CAST(Int32(1) AS Decimal128(15, 2)) AND LINEITEM.L_QUANTITY <= CAST(Int32(1) + Int32(10) AS Decimal128(15, 2)) AND PART.P_SIZE >= Int32(1) AND PART.P_SIZE <= Int32(5) AND (LINEITEM.L_SHIPMODE = CAST(Utf8("AIR") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8("AIR REG") AS Utf8)) AND LINEITEM.L_SHIPINSTRUCT = Utf8("DELIVER IN PERSON") OR PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8("Brand#23") AND (PART.P_CONTAINER = CAST(Utf8("MED BAG") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("MED BOX") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("MED PKG") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("MED PACK") AS Utf8)) AND LINEITEM.L_QUANTITY >= CAST(Int32(10) AS Decimal128(15, 2)) AND LINEITEM.L_QUANTITY <= CAST(Int32(10) + Int32(10) AS Decimal128(15, 2)) AND PART.P_SIZE >= Int32(1) AND PART.P_SIZE <= Int32(10) AND (LINEITEM.L_SHIPMODE = CAST(Utf8("AIR") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8("AIR REG") AS Utf8)) AND LINEITEM.L_SHIPINSTRUCT = Utf8("DELIVER IN PERSON") OR PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8("Brand#34") AND (PART.P_CONTAINER = CAST(Utf8("LG CASE") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("LG BOX") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("LG PACK") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("LG PKG") AS Utf8)) AND LINEITEM.L_QUANTITY >= CAST(Int32(20) AS Decimal128(15, 2)) AND LINEITEM.L_QUANTITY <= CAST(Int32(20) + Int32(10) AS Decimal128(15, 2)) AND PART.P_SIZE >= Int32(1) AND PART.P_SIZE <= Int32(15) AND (LINEITEM.L_SHIPMODE = CAST(Utf8("AIR") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8("AIR REG") AS Utf8)) AND LINEITEM.L_SHIPINSTRUCT = Utf8("DELIVER IN PERSON") - Cross Join: + Cross Join: TableScan: LINEITEM TableScan: PART "# @@ -425,10 +471,10 @@ mod tests { Projection: Decimal128(Some(5),2,1) * sum(LINEITEM.L_QUANTITY) Aggregate: groupBy=[[]], aggr=[[sum(LINEITEM.L_QUANTITY)]] Projection: LINEITEM.L_QUANTITY - Filter: LINEITEM.L_PARTKEY = LINEITEM.L_ORDERKEY AND LINEITEM.L_SUPPKEY = LINEITEM.L_PARTKEY AND LINEITEM.L_SHIPDATE >= CAST(Utf8("1994-01-01") AS Date32) AND LINEITEM.L_SHIPDATE < CAST(Utf8("1995-01-01") AS Date32) + Filter: LINEITEM.L_PARTKEY = outer_ref(PARTSUPP.PS_PARTKEY) AND LINEITEM.L_SUPPKEY = outer_ref(PARTSUPP.PS_SUPPKEY) AND LINEITEM.L_SHIPDATE >= CAST(Utf8("1994-01-01") AS Date32) AND LINEITEM.L_SHIPDATE < CAST(Utf8("1995-01-01") AS Date32) TableScan: LINEITEM TableScan: PARTSUPP - Cross Join: + Cross Join: TableScan: SUPPLIER TableScan: NATION "# @@ -449,14 +495,14 @@ mod tests { Projection: SUPPLIER.S_NAME Filter: SUPPLIER.S_SUPPKEY = LINEITEM.L_SUPPKEY AND ORDERS.O_ORDERKEY = LINEITEM.L_ORDERKEY AND ORDERS.O_ORDERSTATUS = Utf8("F") AND LINEITEM.L_RECEIPTDATE > LINEITEM.L_COMMITDATE AND EXISTS () AND NOT EXISTS () AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8("SAUDI ARABIA") Subquery: - Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_TAX AND LINEITEM.L_SUPPKEY != LINEITEM.L_LINESTATUS + Filter: LINEITEM.L_ORDERKEY = outer_ref(LINEITEM.L_ORDERKEY) AND LINEITEM.L_SUPPKEY != outer_ref(LINEITEM.L_SUPPKEY) TableScan: LINEITEM Subquery: - Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_TAX AND LINEITEM.L_SUPPKEY != LINEITEM.L_LINESTATUS AND LINEITEM.L_RECEIPTDATE > LINEITEM.L_COMMITDATE + Filter: LINEITEM.L_ORDERKEY = outer_ref(LINEITEM.L_ORDERKEY) AND LINEITEM.L_SUPPKEY != outer_ref(LINEITEM.L_SUPPKEY) AND LINEITEM.L_RECEIPTDATE > LINEITEM.L_COMMITDATE TableScan: LINEITEM - Cross Join: - Cross Join: - Cross Join: + Cross Join: + Cross Join: + Cross Join: TableScan: SUPPLIER TableScan: LINEITEM TableScan: ORDERS @@ -483,7 +529,7 @@ mod tests { Filter: CUSTOMER.C_ACCTBAL > Decimal128(Some(0),3,2) AND (substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("13") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("31") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("23") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("29") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("30") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("18") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("17") AS Utf8)) TableScan: CUSTOMER Subquery: - Filter: ORDERS.O_CUSTKEY = ORDERS.O_ORDERKEY + Filter: ORDERS.O_CUSTKEY = outer_ref(CUSTOMER.C_CUSTKEY) TableScan: ORDERS TableScan: CUSTOMER "# @@ -491,6 +537,52 @@ mod tests { Ok(()) } + /// Tests nested correlated subqueries where the innermost subquery + /// references the outermost query (steps_out=2). + /// + /// This tests the outer schema stack with depth > 1. + /// The plan represents: + /// ```sql + /// SELECT * FROM A + /// WHERE EXISTS ( + /// SELECT * FROM B + /// WHERE B.b1 = A.a1 -- steps_out=1 (references immediate parent) + /// AND EXISTS ( + /// SELECT * FROM C + /// WHERE C.c1 = A.a1 -- steps_out=2 (references grandparent) + /// AND C.c2 = B.b2 -- steps_out=1 (references immediate parent) + /// ) + /// ) + /// ``` + /// + #[tokio::test] + async fn test_nested_correlated_subquery() -> Result<()> { + let path = "tests/testdata/test_plans/nested_correlated_subquery.substrait.json"; + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?; + let plan = from_substrait_plan(&ctx.state(), &proto).await?; + let plan_str = format!("{plan}"); + + assert_snapshot!( + plan_str, + @r#" + Filter: EXISTS () + Subquery: + Filter: B.b1 = outer_ref(A.a1) AND EXISTS () + Subquery: + Filter: C.c1 = outer_ref(A.a1) AND C.c2 = outer_ref(B.b2) + TableScan: C + TableScan: B + TableScan: A + "# + ); + Ok(()) + } + async fn test_plan_to_string(name: &str) -> Result { let path = format!("tests/testdata/test_plans/{name}"); let proto = serde_json::from_reader::<_, Plan>(BufReader::new( @@ -651,31 +743,23 @@ mod tests { #[tokio::test] async fn test_multiple_unions() -> Result<()> { let plan_str = test_plan_to_string("multiple_unions.json").await?; - - let mut settings = insta::Settings::clone_current(); - settings.add_filter( - r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}", - "[UUID]", - ); - settings.bind(|| { - assert_snapshot!( - plan_str, - @r#" - Projection: [UUID] AS product_category, [UUID] AS product_type, product_key - Union - Projection: Utf8("people") AS [UUID], Utf8("people") AS [UUID], sales.product_key - Left Join: sales.product_key = food.@food_id - TableScan: sales - TableScan: food - Union - Projection: people.$f3, people.$f5, people.product_key0 - Left Join: people.product_key0 = food.@food_id - TableScan: people - TableScan: food - TableScan: more_products - "# + assert_snapshot!( + plan_str, + @r#" + Projection: Utf8("people") AS product_category, Utf8("people")__temp__0 AS product_type, product_key + Union + Projection: Utf8("people"), Utf8("people") AS Utf8("people")__temp__0, sales.product_key + Left Join: sales.product_key = food.@food_id + TableScan: sales + TableScan: food + Union + Projection: people.$f3, people.$f5, people.product_key0 + Left Join: people.product_key0 = food.@food_id + TableScan: people + TableScan: food + TableScan: more_products + "# ); - }); Ok(()) } @@ -708,4 +792,80 @@ mod tests { Ok(()) } + + /// Substrait join with both `equal` and `is_not_distinct_from` must demote + /// `IS NOT DISTINCT FROM` to the join filter. + #[tokio::test] + async fn test_mixed_join_equal_and_indistinct_inner_join() -> Result<()> { + let plan_str = + test_plan_to_string("mixed_join_equal_and_indistinct.json").await?; + // Eq becomes the equijoin key; IS NOT DISTINCT FROM is demoted to filter. + assert_snapshot!( + plan_str, + @r#" + Projection: left.id, left.val, left.comment, right.id AS id0, right.val AS val0, right.comment AS comment0 + Inner Join: left.id = right.id Filter: left.val IS NOT DISTINCT FROM right.val + SubqueryAlias: left + Values: (Utf8("1"), Utf8("a"), Utf8("c1")), (Utf8("2"), Utf8("b"), Utf8("c2")), (Utf8("3"), Utf8(NULL), Utf8("c3")), (Utf8("4"), Utf8(NULL), Utf8("c4")), (Utf8("5"), Utf8("e"), Utf8("c5"))... + SubqueryAlias: right + Values: (Utf8("1"), Utf8("a"), Utf8("c1")), (Utf8("2"), Utf8("b"), Utf8("c2")), (Utf8("3"), Utf8(NULL), Utf8("c3")), (Utf8("4"), Utf8(NULL), Utf8("c4")), (Utf8("5"), Utf8("e"), Utf8("c5"))... + "# + ); + + // Execute and verify actual rows, including NULL=NULL matches (ids 3,4). + let results = execute_plan("mixed_join_equal_and_indistinct.json").await?; + assert_snapshot!(pretty_sorted(&results), + @r" + +----+-----+---------+-----+------+----------+ + | id | val | comment | id0 | val0 | comment0 | + +----+-----+---------+-----+------+----------+ + | 1 | a | c1 | 1 | a | c1 | + | 2 | b | c2 | 2 | b | c2 | + | 3 | | c3 | 3 | | c3 | + | 4 | | c4 | 4 | | c4 | + | 5 | e | c5 | 5 | e | c5 | + | 6 | f | c6 | 6 | f | c6 | + +----+-----+---------+-----+------+----------+ + " + ); + + Ok(()) + } + + /// Substrait join with both `equal` and `is_not_distinct_from` must demote + /// `IS NOT DISTINCT FROM` to the join filter. + #[tokio::test] + async fn test_mixed_join_equal_and_indistinct_left_join() -> Result<()> { + let plan_str = + test_plan_to_string("mixed_join_equal_and_indistinct_left.json").await?; + assert_snapshot!( + plan_str, + @r#" + Projection: left.id, left.val, left.comment, right.id AS id0, right.val AS val0, right.comment AS comment0 + Left Join: left.id = right.id Filter: left.val IS NOT DISTINCT FROM right.val + SubqueryAlias: left + Values: (Utf8("1"), Utf8("a"), Utf8("c1")), (Utf8("2"), Utf8("b"), Utf8("c2")), (Utf8("3"), Utf8(NULL), Utf8("c3")), (Utf8("4"), Utf8(NULL), Utf8("c4")), (Utf8("5"), Utf8("e"), Utf8("c5"))... + SubqueryAlias: right + Values: (Utf8("1"), Utf8("a"), Utf8("c1")), (Utf8("2"), Utf8("b"), Utf8("c2")), (Utf8("3"), Utf8(NULL), Utf8("c3")), (Utf8("4"), Utf8(NULL), Utf8("c4")), (Utf8("5"), Utf8("e"), Utf8("c5"))... + "# + ); + + let results = execute_plan("mixed_join_equal_and_indistinct_left.json").await?; + assert_snapshot!(pretty_sorted(&results), + @r" + +----+-----+---------+-----+------+----------+ + | id | val | comment | id0 | val0 | comment0 | + +----+-----+---------+-----+------+----------+ + | 1 | a | c1 | 1 | a | c1 | + | 2 | b | c2 | 2 | b | c2 | + | 3 | | c3 | 3 | | c3 | + | 4 | | c4 | 4 | | c4 | + | 5 | e | c5 | 5 | e | c5 | + | 6 | f | c6 | 6 | f | c6 | + +----+-----+---------+-----+------+----------+ + " + ); + + Ok(()) + } } diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index 5ebacaf5336d..9de7cb8f3835 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -20,6 +20,9 @@ #[cfg(test)] mod tests { use crate::utils::test::{add_plan_schemas_to_ctx, read_json}; + use datafusion::common::test_util::format_batches; + use std::collections::HashSet; + use datafusion::common::Result; use datafusion::dataframe::DataFrame; use datafusion::prelude::SessionContext; @@ -157,28 +160,21 @@ mod tests { let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; - let mut settings = insta::Settings::clone_current(); - settings.add_filter( - r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}", - "[UUID]", + assert_snapshot!( + plan, + @r" + Projection: left.A, left.Utf8(NULL) AS C, right.D, Utf8(NULL) AS Utf8(NULL)__temp__0 AS E + Left Join: left.A = right.A + SubqueryAlias: left + Union + Projection: A.A, Utf8(NULL) + TableScan: A + Projection: B.A, CAST(B.C AS Utf8) + TableScan: B + SubqueryAlias: right + TableScan: C + " ); - settings.bind(|| { - assert_snapshot!( - plan, - @r" - Projection: left.A, left.[UUID] AS C, right.D, Utf8(NULL) AS [UUID] AS E - Left Join: left.A = right.A - SubqueryAlias: left - Union - Projection: A.A, Utf8(NULL) AS [UUID] - TableScan: A - Projection: B.A, CAST(B.C AS Utf8) - TableScan: B - SubqueryAlias: right - TableScan: C - " - ); - }); // Trigger execution to ensure plan validity DataFrame::new(ctx.state(), plan).show().await?; @@ -229,4 +225,49 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn duplicate_name_in_union() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/duplicate_name_in_union.substrait.json"); + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; + + assert_snapshot!( + plan, + @r" + Projection: foo AS col1, bar AS col2 + Union + Projection: foo, bar + Values: (Int64(100), Int64(200)) + Projection: x, foo + Values: (Int32(300), Int64(400)) + " + ); + + // Trigger execution to ensure plan validity + let results = DataFrame::new(ctx.state(), plan).collect().await?; + + assert_snapshot!( + format_batches(&results)?, + @r" + +------+------+ + | col1 | col2 | + +------+------+ + | 100 | 200 | + | 300 | 400 | + +------+------+ + ", + ); + + // also verify that the output schema has unique field names + let schema = results[0].schema(); + for batch in &results { + assert_eq!(schema, batch.schema()); + } + let field_names: HashSet<_> = schema.fields().iter().map(|f| f.name()).collect(); + assert_eq!(field_names.len(), schema.fields().len()); + + Ok(()) + } } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 98b35bf082ec..5dd4aa4e2be9 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -29,14 +29,15 @@ use std::mem::size_of_val; use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}; use datafusion::common::tree_node::Transformed; -use datafusion::common::{DFSchema, DFSchemaRef, not_impl_err, plan_err}; +use datafusion::common::{DFSchema, DFSchemaRef, Spans, not_impl_err, plan_err}; use datafusion::error::Result; use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::session_state::SessionStateBuilder; +use datafusion::logical_expr::expr::{Exists, SetComparison, SetQuantifier}; use datafusion::logical_expr::{ - EmptyRelation, Extension, InvariantLevel, LogicalPlan, PartitionEvaluator, - Repartition, UserDefinedLogicalNode, Values, Volatility, + EmptyRelation, Extension, InvariantLevel, LogicalPlan, Operator, PartitionEvaluator, + Repartition, Subquery, UserDefinedLogicalNode, Values, Volatility, }; use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST; use datafusion::prelude::*; @@ -689,6 +690,60 @@ async fn roundtrip_exists_filter() -> Result<()> { Ok(()) } +// assemble logical plan manually to ensure SetComparison expr is present (not rewrite away) +#[tokio::test] +async fn roundtrip_set_comparison_any_substrait() -> Result<()> { + let ctx = create_context().await?; + let plan = build_set_comparison_plan(&ctx, SetQuantifier::Any, Operator::Gt).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let roundtrip_plan = from_substrait_plan(&ctx.state(), &proto).await?; + assert_set_comparison_predicate(&roundtrip_plan, Operator::Gt, SetQuantifier::Any); + Ok(()) +} + +// assemble logical plan manually to ensure SetComparison expr is present (not rewrite away) +#[tokio::test] +async fn roundtrip_set_comparison_all_substrait() -> Result<()> { + let ctx = create_context().await?; + let plan = + build_set_comparison_plan(&ctx, SetQuantifier::All, Operator::NotEq).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let roundtrip_plan = from_substrait_plan(&ctx.state(), &proto).await?; + assert_set_comparison_predicate(&roundtrip_plan, Operator::NotEq, SetQuantifier::All); + Ok(()) +} + +#[tokio::test] +async fn roundtrip_scalar_subquery_substrait() -> Result<()> { + let ctx = create_context().await?; + let plan = build_scalar_subquery_projection_plan(&ctx).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + assert_root_project_has_scalar_subquery(proto.as_ref()); + let roundtrip_plan = from_substrait_plan(&ctx.state(), &proto).await?; + assert_projection_contains_scalar_subquery(&roundtrip_plan); + Ok(()) +} + +#[tokio::test] +async fn roundtrip_exists_substrait() -> Result<()> { + let ctx = create_context().await?; + let plan = build_exists_filter_plan(&ctx, false).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let roundtrip_plan = from_substrait_plan(&ctx.state(), &proto).await?; + assert_exists_predicate(&roundtrip_plan, false); + Ok(()) +} + +#[tokio::test] +async fn roundtrip_not_exists_substrait() -> Result<()> { + let ctx = create_context().await?; + let plan = build_exists_filter_plan(&ctx, true).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let roundtrip_plan = from_substrait_plan(&ctx.state(), &proto).await?; + assert_exists_predicate(&roundtrip_plan, true); + Ok(()) +} + #[tokio::test] async fn roundtrip_not_exists_filter_left_anti_join() -> Result<()> { let plan = generate_plan_from_sql( @@ -789,17 +844,50 @@ async fn roundtrip_outer_join() -> Result<()> { async fn roundtrip_self_join() -> Result<()> { // Substrait does currently NOT maintain the alias of the tables. // Instead, when we consume Substrait, we add aliases before a join that'd otherwise collide. - // This roundtrip works because we set aliases to what the Substrait consumer will generate. - roundtrip("SELECT left.a as left_a, left.b, right.a as right_a, right.c FROM data AS left JOIN data AS right ON left.a = right.a").await?; - roundtrip("SELECT left.a as left_a, left.b, right.a as right_a, right.c FROM data AS left JOIN data AS right ON left.b = right.b").await + // The improved NameTracker now adds __temp__0 suffix to handle naming conflicts. + // We verify semantic equivalence rather than exact string match. + let ctx = create_context().await?; + let sql = "SELECT left.a as left_a, left.b, right.a as right_a, right.c FROM data AS left JOIN data AS right ON left.a = right.a"; + let df = ctx.sql(sql).await?; + let plan = df.into_optimized_plan()?; + let plan2 = substrait_roundtrip(&plan, &ctx).await?; + + // Verify schemas are equivalent + assert_eq!(plan.schema(), plan2.schema()); + + // Execute to ensure plan validity + DataFrame::new(ctx.state(), plan2).show().await?; + + // Test second variant + let sql2 = "SELECT left.a as left_a, left.b, right.a as right_a, right.c FROM data AS left JOIN data AS right ON left.b = right.b"; + let df2 = ctx.sql(sql2).await?; + let plan3 = df2.into_optimized_plan()?; + let plan4 = substrait_roundtrip(&plan3, &ctx).await?; + assert_eq!(plan3.schema(), plan4.schema()); + DataFrame::new(ctx.state(), plan4).show().await?; + + Ok(()) } #[tokio::test] async fn roundtrip_self_implicit_cross_join() -> Result<()> { // Substrait does currently NOT maintain the alias of the tables. // Instead, when we consume Substrait, we add aliases before a join that'd otherwise collide. - // This roundtrip works because we set aliases to what the Substrait consumer will generate. - roundtrip("SELECT left.a left_a, left.b, right.a right_a, right.c FROM data AS left, data AS right").await + // The improved NameTracker now adds __temp__0 suffix to handle naming conflicts. + // We verify semantic equivalence rather than exact string match. + let ctx = create_context().await?; + let sql = "SELECT left.a left_a, left.b, right.a right_a, right.c FROM data AS left, data AS right"; + let df = ctx.sql(sql).await?; + let plan = df.into_optimized_plan()?; + let plan2 = substrait_roundtrip(&plan, &ctx).await?; + + // Verify schemas are equivalent + assert_eq!(plan.schema(), plan2.schema()); + + // Execute to ensure plan validity + DataFrame::new(ctx.state(), plan2).show().await?; + + Ok(()) } #[tokio::test] @@ -1353,7 +1441,7 @@ async fn roundtrip_literal_named_struct() -> Result<()> { assert_snapshot!( plan, @r#" - Projection: Struct({int_field:1,boolean_field:true,string_field:}) AS named_struct(Utf8("int_field"),Int64(1),Utf8("boolean_field"),Boolean(true),Utf8("string_field"),NULL) + Projection: CAST(Struct({c0:1,c1:true,c2:}) AS Struct("int_field": Int64, "boolean_field": Boolean, "string_field": Utf8View)) AS named_struct(Utf8("int_field"),Int64(1),Utf8("boolean_field"),Boolean(true),Utf8("string_field"),NULL) TableScan: data projection=[] "# ); @@ -1373,10 +1461,10 @@ async fn roundtrip_literal_renamed_struct() -> Result<()> { assert_snapshot!( plan, - @r" - Projection: Struct({int_field:1}) AS Struct({c0:1}) + @r#" + Projection: CAST(Struct({c0:1}) AS Struct("int_field": Int32)) TableScan: data projection=[] - " + "# ); Ok(()) } @@ -1456,16 +1544,26 @@ async fn roundtrip_values_empty_relation() -> Result<()> { async fn roundtrip_values_duplicate_column_join() -> Result<()> { // Substrait does currently NOT maintain the alias of the tables. // Instead, when we consume Substrait, we add aliases before a join that'd otherwise collide. - // This roundtrip works because we set aliases to what the Substrait consumer will generate. - roundtrip( - "SELECT left.column1 as c1, right.column1 as c2 \ + // The improved NameTracker now adds __temp__0 suffix to handle naming conflicts. + // We verify semantic equivalence rather than exact string match. + let ctx = create_context().await?; + let sql = "SELECT left.column1 as c1, right.column1 as c2 \ FROM \ (VALUES (1)) AS left \ JOIN \ (VALUES (2)) AS right \ - ON left.column1 == right.column1", - ) - .await + ON left.column1 == right.column1"; + let df = ctx.sql(sql).await?; + let plan = df.into_optimized_plan()?; + let plan2 = substrait_roundtrip(&plan, &ctx).await?; + + // Verify schemas are equivalent + assert_eq!(plan.schema(), plan2.schema()); + + // Execute to ensure plan validity + DataFrame::new(ctx.state(), plan2).show().await?; + + Ok(()) } #[tokio::test] @@ -1865,6 +1963,188 @@ async fn assert_substrait_sql(substrait_plan: Plan, sql: &str) -> Result<()> { Ok(()) } +async fn build_set_comparison_plan( + ctx: &SessionContext, + quantifier: SetQuantifier, + op: Operator, +) -> Result { + let base_scan = ctx.table("data").await?.into_unoptimized_plan(); + let subquery_scan = ctx.table("data2").await?.into_unoptimized_plan(); + let subquery_plan = LogicalPlanBuilder::from(subquery_scan) + .project(vec![col("data2.a")])? + .build()?; + let predicate = Expr::SetComparison(SetComparison::new( + Box::new(col("data.a")), + Subquery { + subquery: Arc::new(subquery_plan), + outer_ref_columns: vec![], + spans: Spans::new(), + }, + op, + quantifier, + )); + + LogicalPlanBuilder::from(base_scan) + .filter(predicate)? + .project(vec![col("data.a")])? + .build() +} + +async fn build_scalar_subquery_projection_plan( + ctx: &SessionContext, +) -> Result { + let subquery_scan = ctx.table("data2").await?.into_unoptimized_plan(); + let subquery_plan = LogicalPlanBuilder::from(subquery_scan) + .project(vec![col("a")])? + .limit(0, Some(1))? + .build()?; + + let scalar_subquery = Expr::ScalarSubquery(Subquery { + subquery: Arc::new(subquery_plan), + outer_ref_columns: vec![], + spans: Spans::new(), + }); + + let outer_empty_relation = LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: true, + schema: DFSchemaRef::new(DFSchema::empty()), + }); + + LogicalPlanBuilder::from(outer_empty_relation) + .project(vec![scalar_subquery.alias("sq")])? + .build() +} + +async fn build_exists_filter_plan( + ctx: &SessionContext, + negated: bool, +) -> Result { + let base_scan = ctx.table("data").await?.into_unoptimized_plan(); + let subquery_scan = ctx.table("data2").await?.into_unoptimized_plan(); + let subquery_plan = LogicalPlanBuilder::from(subquery_scan) + .project(vec![col("data2.a")])? + .build()?; + + let predicate = Expr::Exists(Exists::new( + Subquery { + subquery: Arc::new(subquery_plan), + outer_ref_columns: vec![], + spans: Spans::new(), + }, + negated, + )); + + LogicalPlanBuilder::from(base_scan) + .filter(predicate)? + .project(vec![col("data.a")])? + .build() +} + +fn assert_set_comparison_predicate( + plan: &LogicalPlan, + expected_op: Operator, + expected_quantifier: SetQuantifier, +) { + let predicate = match plan { + LogicalPlan::Projection(p) => match p.input.as_ref() { + LogicalPlan::Filter(filter) => &filter.predicate, + other => panic!("expected Filter inside Projection, got {other:?}"), + }, + LogicalPlan::Filter(filter) => &filter.predicate, + other => panic!("expected Filter plan, got {other:?}"), + }; + + match predicate { + Expr::SetComparison(set_comparison) => { + assert_eq!(set_comparison.op, expected_op); + assert_eq!(set_comparison.quantifier, expected_quantifier); + } + other => panic!("expected SetComparison predicate, got {other:?}"), + } +} + +fn assert_root_project_has_scalar_subquery(proto: &Plan) { + let relation = proto + .relations + .first() + .expect("expected Substrait plan to have at least one relation"); + + let root = match relation.rel_type.as_ref() { + Some(plan_rel::RelType::Root(root)) => root, + other => panic!("expected root relation, got {other:?}"), + }; + + let input = root.input.as_ref().expect("expected root input relation"); + let project = match input.rel_type.as_ref() { + Some(RelType::Project(project)) => project, + other => panic!("expected Project relation at root input, got {other:?}"), + }; + + let expr = project + .expressions + .first() + .expect("expected at least one project expression"); + let subquery = match expr.rex_type.as_ref() { + Some(substrait::proto::expression::RexType::Subquery(subquery)) => subquery, + other => panic!("expected Subquery expression, got {other:?}"), + }; + + assert!( + matches!( + subquery.subquery_type.as_ref(), + Some(substrait::proto::expression::subquery::SubqueryType::Scalar(_)) + ), + "expected scalar subquery type" + ); +} + +fn assert_projection_contains_scalar_subquery(plan: &LogicalPlan) { + let projection = match plan { + LogicalPlan::Projection(projection) => projection, + other => panic!("expected Projection plan, got {other:?}"), + }; + + let found_scalar_subquery = projection.expr.iter().any(expr_contains_scalar_subquery); + assert!( + found_scalar_subquery, + "expected Projection to contain ScalarSubquery expression" + ); +} + +fn expr_contains_scalar_subquery(expr: &Expr) -> bool { + match expr { + Expr::ScalarSubquery(_) => true, + Expr::Alias(alias) => expr_contains_scalar_subquery(alias.expr.as_ref()), + _ => false, + } +} + +fn assert_exists_predicate(plan: &LogicalPlan, expected_negated: bool) { + let predicate = match plan { + LogicalPlan::Projection(projection) => match projection.input.as_ref() { + LogicalPlan::Filter(filter) => &filter.predicate, + other => panic!("expected Filter inside Projection, got {other:?}"), + }, + LogicalPlan::Filter(filter) => &filter.predicate, + other => panic!("expected Filter plan, got {other:?}"), + }; + + if expected_negated { + match predicate { + Expr::Not(inner) => match inner.as_ref() { + Expr::Exists(exists) => assert!(!exists.negated), + other => panic!("expected Exists inside NOT, got {other:?}"), + }, + other => panic!("expected NOT EXISTS predicate, got {other:?}"), + } + } else { + match predicate { + Expr::Exists(exists) => assert!(!exists.negated), + other => panic!("expected EXISTS predicate, got {other:?}"), + } + } +} + async fn roundtrip_fill_na(sql: &str) -> Result<()> { let ctx = create_context().await?; let df = ctx.sql(sql).await?; diff --git a/datafusion/substrait/tests/cases/serialize.rs b/datafusion/substrait/tests/cases/serialize.rs index d0f951176093..2d7257fad339 100644 --- a/datafusion/substrait/tests/cases/serialize.rs +++ b/datafusion/substrait/tests/cases/serialize.rs @@ -17,7 +17,6 @@ #[cfg(test)] mod tests { - use datafusion::common::assert_contains; use datafusion::datasource::provider_as_source; use datafusion::logical_expr::LogicalPlanBuilder; use datafusion_substrait::logical_plan::consumer::from_substrait_plan; @@ -44,8 +43,18 @@ mod tests { serializer::deserialize(path).await?; // Test case 2: serializing to an existing file should fail. - let got = serializer::serialize(sql, &ctx, path).await.unwrap_err(); - assert_contains!(got.to_string(), "File exists"); + let got = serializer::serialize(sql, &ctx, path) + .await + .unwrap_err() + .to_string(); + assert!( + [ + "File exists", // unix + "os error 80" // windows + ] + .iter() + .any(|s| got.contains(s)) + ); fs::remove_file(path)?; diff --git a/datafusion/substrait/tests/testdata/test_plans/duplicate_name_in_union.substrait.json b/datafusion/substrait/tests/testdata/test_plans/duplicate_name_in_union.substrait.json new file mode 100644 index 000000000000..1da2ff613136 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/duplicate_name_in_union.substrait.json @@ -0,0 +1,171 @@ +{ + "version": { + "minorNumber": 54, + "producer": "datafusion-test" + }, + "relations": [ + { + "root": { + "input": { + "set": { + "common": { + "direct": {} + }, + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [2, 3] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": ["foo", "bar"], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "virtualTable": { + "expressions": [ + { + "fields": [ + { + "literal": { + "i64": "100" + } + }, + { + "literal": { + "i64": "200" + } + } + ] + } + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [2, 3] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": ["x", "foo"], + "struct": { + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "virtualTable": { + "expressions": [ + { + "fields": [ + { + "literal": { + "i32": 300 + } + }, + { + "literal": { + "i64": "400" + } + } + ] + } + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_UNION_ALL" + } + }, + "names": ["col1", "col2"] + } + } + ] +} diff --git a/datafusion/substrait/tests/testdata/test_plans/mixed_join_equal_and_indistinct.json b/datafusion/substrait/tests/testdata/test_plans/mixed_join_equal_and_indistinct.json new file mode 100644 index 000000000000..642256c56299 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/mixed_join_equal_and_indistinct.json @@ -0,0 +1,102 @@ +{ + "extensions": [ + { "extensionFunction": { "functionAnchor": 0, "name": "is_not_distinct_from" } }, + { "extensionFunction": { "functionAnchor": 1, "name": "equal" } }, + { "extensionFunction": { "functionAnchor": 2, "name": "and" } } + ], + "relations": [{ + "root": { + "input": { + "join": { + "common": { "direct": {} }, + "left": { + "read": { + "common": { "direct": {} }, + "baseSchema": { + "names": ["id", "val", "comment"], + "struct": { + "types": [ + { "string": { "nullability": "NULLABILITY_REQUIRED" } }, + { "string": { "nullability": "NULLABILITY_NULLABLE" } }, + { "string": { "nullability": "NULLABILITY_REQUIRED" } } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "virtualTable": { + "values": [ + { "fields": [{ "string": "1", "nullable": false }, { "string": "a", "nullable": true }, { "string": "c1", "nullable": false }] }, + { "fields": [{ "string": "2", "nullable": false }, { "string": "b", "nullable": true }, { "string": "c2", "nullable": false }] }, + { "fields": [{ "string": "3", "nullable": false }, { "null": { "string": { "nullability": "NULLABILITY_NULLABLE" } }, "nullable": true }, { "string": "c3", "nullable": false }] }, + { "fields": [{ "string": "4", "nullable": false }, { "null": { "string": { "nullability": "NULLABILITY_NULLABLE" } }, "nullable": true }, { "string": "c4", "nullable": false }] }, + { "fields": [{ "string": "5", "nullable": false }, { "string": "e", "nullable": true }, { "string": "c5", "nullable": false }] }, + { "fields": [{ "string": "6", "nullable": false }, { "string": "f", "nullable": true }, { "string": "c6", "nullable": false }] } + ] + } + } + }, + "right": { + "read": { + "common": { "direct": {} }, + "baseSchema": { + "names": ["id", "val", "comment"], + "struct": { + "types": [ + { "string": { "nullability": "NULLABILITY_REQUIRED" } }, + { "string": { "nullability": "NULLABILITY_NULLABLE" } }, + { "string": { "nullability": "NULLABILITY_REQUIRED" } } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "virtualTable": { + "values": [ + { "fields": [{ "string": "1", "nullable": false }, { "string": "a", "nullable": true }, { "string": "c1", "nullable": false }] }, + { "fields": [{ "string": "2", "nullable": false }, { "string": "b", "nullable": true }, { "string": "c2", "nullable": false }] }, + { "fields": [{ "string": "3", "nullable": false }, { "null": { "string": { "nullability": "NULLABILITY_NULLABLE" } }, "nullable": true }, { "string": "c3", "nullable": false }] }, + { "fields": [{ "string": "4", "nullable": false }, { "null": { "string": { "nullability": "NULLABILITY_NULLABLE" } }, "nullable": true }, { "string": "c4", "nullable": false }] }, + { "fields": [{ "string": "5", "nullable": false }, { "string": "e", "nullable": true }, { "string": "c5", "nullable": false }] }, + { "fields": [{ "string": "6", "nullable": false }, { "string": "f", "nullable": true }, { "string": "c6", "nullable": false }] } + ] + } + } + }, + "expression": { + "scalarFunction": { + "functionReference": 2, + "outputType": { "bool": { "nullability": "NULLABILITY_NULLABLE" } }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 0, + "outputType": { "bool": { "nullability": "NULLABILITY_NULLABLE" } }, + "arguments": [ + { "value": { "selection": { "directReference": { "structField": { "field": 1 } }, "rootReference": {} } } }, + { "value": { "selection": { "directReference": { "structField": { "field": 4 } }, "rootReference": {} } } } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { "bool": { "nullability": "NULLABILITY_NULLABLE" } }, + "arguments": [ + { "value": { "selection": { "directReference": { "structField": { "field": 0 } }, "rootReference": {} } } }, + { "value": { "selection": { "directReference": { "structField": { "field": 3 } }, "rootReference": {} } } } + ] + } + } + } + ] + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "names": ["id", "val", "comment", "id0", "val0", "comment0"] + } + }] +} diff --git a/datafusion/substrait/tests/testdata/test_plans/mixed_join_equal_and_indistinct_left.json b/datafusion/substrait/tests/testdata/test_plans/mixed_join_equal_and_indistinct_left.json new file mode 100644 index 000000000000..f16672947e1e --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/mixed_join_equal_and_indistinct_left.json @@ -0,0 +1,102 @@ +{ + "extensions": [ + { "extensionFunction": { "functionAnchor": 0, "name": "is_not_distinct_from" } }, + { "extensionFunction": { "functionAnchor": 1, "name": "equal" } }, + { "extensionFunction": { "functionAnchor": 2, "name": "and" } } + ], + "relations": [{ + "root": { + "input": { + "join": { + "common": { "direct": {} }, + "left": { + "read": { + "common": { "direct": {} }, + "baseSchema": { + "names": ["id", "val", "comment"], + "struct": { + "types": [ + { "string": { "nullability": "NULLABILITY_REQUIRED" } }, + { "string": { "nullability": "NULLABILITY_NULLABLE" } }, + { "string": { "nullability": "NULLABILITY_REQUIRED" } } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "virtualTable": { + "values": [ + { "fields": [{ "string": "1", "nullable": false }, { "string": "a", "nullable": true }, { "string": "c1", "nullable": false }] }, + { "fields": [{ "string": "2", "nullable": false }, { "string": "b", "nullable": true }, { "string": "c2", "nullable": false }] }, + { "fields": [{ "string": "3", "nullable": false }, { "null": { "string": { "nullability": "NULLABILITY_NULLABLE" } }, "nullable": true }, { "string": "c3", "nullable": false }] }, + { "fields": [{ "string": "4", "nullable": false }, { "null": { "string": { "nullability": "NULLABILITY_NULLABLE" } }, "nullable": true }, { "string": "c4", "nullable": false }] }, + { "fields": [{ "string": "5", "nullable": false }, { "string": "e", "nullable": true }, { "string": "c5", "nullable": false }] }, + { "fields": [{ "string": "6", "nullable": false }, { "string": "f", "nullable": true }, { "string": "c6", "nullable": false }] } + ] + } + } + }, + "right": { + "read": { + "common": { "direct": {} }, + "baseSchema": { + "names": ["id", "val", "comment"], + "struct": { + "types": [ + { "string": { "nullability": "NULLABILITY_REQUIRED" } }, + { "string": { "nullability": "NULLABILITY_NULLABLE" } }, + { "string": { "nullability": "NULLABILITY_REQUIRED" } } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "virtualTable": { + "values": [ + { "fields": [{ "string": "1", "nullable": false }, { "string": "a", "nullable": true }, { "string": "c1", "nullable": false }] }, + { "fields": [{ "string": "2", "nullable": false }, { "string": "b", "nullable": true }, { "string": "c2", "nullable": false }] }, + { "fields": [{ "string": "3", "nullable": false }, { "null": { "string": { "nullability": "NULLABILITY_NULLABLE" } }, "nullable": true }, { "string": "c3", "nullable": false }] }, + { "fields": [{ "string": "4", "nullable": false }, { "null": { "string": { "nullability": "NULLABILITY_NULLABLE" } }, "nullable": true }, { "string": "c4", "nullable": false }] }, + { "fields": [{ "string": "5", "nullable": false }, { "string": "e", "nullable": true }, { "string": "c5", "nullable": false }] }, + { "fields": [{ "string": "6", "nullable": false }, { "string": "f", "nullable": true }, { "string": "c6", "nullable": false }] } + ] + } + } + }, + "expression": { + "scalarFunction": { + "functionReference": 2, + "outputType": { "bool": { "nullability": "NULLABILITY_NULLABLE" } }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 0, + "outputType": { "bool": { "nullability": "NULLABILITY_NULLABLE" } }, + "arguments": [ + { "value": { "selection": { "directReference": { "structField": { "field": 1 } }, "rootReference": {} } } }, + { "value": { "selection": { "directReference": { "structField": { "field": 4 } }, "rootReference": {} } } } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { "bool": { "nullability": "NULLABILITY_NULLABLE" } }, + "arguments": [ + { "value": { "selection": { "directReference": { "structField": { "field": 0 } }, "rootReference": {} } } }, + { "value": { "selection": { "directReference": { "structField": { "field": 3 } }, "rootReference": {} } } } + ] + } + } + } + ] + } + }, + "type": "JOIN_TYPE_LEFT" + } + }, + "names": ["id", "val", "comment", "id0", "val0", "comment0"] + } + }] +} diff --git a/datafusion/substrait/tests/testdata/test_plans/nested_correlated_subquery.substrait.json b/datafusion/substrait/tests/testdata/test_plans/nested_correlated_subquery.substrait.json new file mode 100644 index 000000000000..6c565a0f94e2 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/nested_correlated_subquery.substrait.json @@ -0,0 +1,265 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "name": "and:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "equal:any_any" + } + }], + "relations": [{ + "root": { + "input": { + "filter": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": ["a1", "a2"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["A"] + } + } + }, + "condition": { + "subquery": { + "setPredicate": { + "predicateOp": "PREDICATE_OP_EXISTS", + "tuples": { + "filter": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": ["b1", "b2"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["B"] + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "outerReference": { + "stepsOut": 1 + } + } + } + }] + } + } + }, { + "value": { + "subquery": { + "setPredicate": { + "predicateOp": "PREDICATE_OP_EXISTS", + "tuples": { + "filter": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": ["c1", "c2"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["C"] + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "outerReference": { + "stepsOut": 2 + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "outerReference": { + "stepsOut": 1 + } + } + } + }] + } + } + }] + } + } + } + } + } + } + } + }] + } + } + } + } + } + } + } + } + }, + "names": ["a1", "a2"] + } + }] +} diff --git a/datafusion/substrait/tests/utils.rs b/datafusion/substrait/tests/utils.rs index 2d63980aadf0..6a6824579b4e 100644 --- a/datafusion/substrait/tests/utils.rs +++ b/datafusion/substrait/tests/utils.rs @@ -484,6 +484,7 @@ pub mod test { } RexType::DynamicParameter(_) => {} // Enum is deprecated + #[expect(deprecated)] RexType::Enum(_) => {} } Ok(()) diff --git a/datafusion/wasmtest/Cargo.toml b/datafusion/wasmtest/Cargo.toml index 16fa9790f65b..e033056f9984 100644 --- a/datafusion/wasmtest/Cargo.toml +++ b/datafusion/wasmtest/Cargo.toml @@ -47,7 +47,7 @@ chrono = { version = "0.4", features = ["wasmbind"] } # all the `std::fmt` and `std::panicking` infrastructure, so isn't great for # code size when deploying. console_error_panic_hook = { version = "0.1.1", optional = true } -datafusion = { workspace = true, features = ["parquet", "sql"] } +datafusion = { workspace = true, features = ["compression", "parquet", "sql"] } datafusion-common = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } @@ -59,11 +59,13 @@ getrandom = { version = "0.3", features = ["wasm_js"] } wasm-bindgen = "0.2.99" [dev-dependencies] +bytes = { workspace = true } +futures = { workspace = true } object_store = { workspace = true } # needs to be compiled tokio = { workspace = true } url = { workspace = true } -wasm-bindgen-test = "0.3.56" +wasm-bindgen-test = "0.3.62" [package.metadata.cargo-machete] ignored = ["chrono", "getrandom"] diff --git a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json index 98ee1a34f01e..8f175b000122 100644 --- a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json +++ b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json @@ -13,7 +13,7 @@ }, "devDependencies": { "copy-webpack-plugin": "12.0.2", - "webpack": "5.94.0", + "webpack": "5.105.0", "webpack-cli": "5.1.4", "webpack-dev-server": "5.2.1" } @@ -32,17 +32,13 @@ } }, "node_modules/@jridgewell/gen-mapping": { - "version": "0.3.5", - "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.5.tgz", - "integrity": "sha512-IzL8ZoEDIBRWEzlCcRhOaCupYyN5gdIK+Q6fbFdPDg6HqX6jpkItn7DFIpW9LQzXG6Df9sA7+OKnq0qlz/GaQg==", + "version": "0.3.13", + "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.13.tgz", + "integrity": "sha512-2kkt/7niJ6MgEPxF0bYdQ6etZaA+fQvDcLKckhy1yIQOzaoKjBBjSj63/aLVjYE3qhRt5dvM+uUyfCg6UKCBbA==", "dev": true, "dependencies": { - "@jridgewell/set-array": "^1.2.1", - "@jridgewell/sourcemap-codec": "^1.4.10", + "@jridgewell/sourcemap-codec": "^1.5.0", "@jridgewell/trace-mapping": "^0.3.24" - }, - "engines": { - "node": ">=6.0.0" } }, "node_modules/@jridgewell/resolve-uri": { @@ -54,19 +50,10 @@ "node": ">=6.0.0" } }, - "node_modules/@jridgewell/set-array": { - "version": "1.2.1", - "resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.2.1.tgz", - "integrity": "sha512-R8gLRTZeyp03ymzP/6Lil/28tGeGEzhx1q2k703KGWRAI1VdvPIXdG70VJc2pAMw3NA6JKL5hhFu1sJX0Mnn/A==", - "dev": true, - "engines": { - "node": ">=6.0.0" - } - }, "node_modules/@jridgewell/source-map": { - "version": "0.3.6", - "resolved": "https://registry.npmjs.org/@jridgewell/source-map/-/source-map-0.3.6.tgz", - "integrity": "sha512-1ZJTZebgqllO79ue2bm3rIGud/bOe0pP5BjSRCRxxYkEZS8STV7zN84UBbiYu7jy+eCKSnVIUgoWWE/tt+shMQ==", + "version": "0.3.11", + "resolved": "https://registry.npmjs.org/@jridgewell/source-map/-/source-map-0.3.11.tgz", + "integrity": "sha512-ZMp1V8ZFcPG5dIWnQLr3NSI1MiCU7UETdS/A0G8V/XWHvJv3ZsFqutJn1Y5RPmAPX6F3BiE397OqveU/9NCuIA==", "dev": true, "dependencies": { "@jridgewell/gen-mapping": "^0.3.5", @@ -74,15 +61,15 @@ } }, "node_modules/@jridgewell/sourcemap-codec": { - "version": "1.5.0", - "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.0.tgz", - "integrity": "sha512-gv3ZRaISU3fjPAgNsriBRqGWQL6quFx04YMPW/zD8XMLsU32mhCCbfbO6KZFLjvYpCZ8zyDEgqsgf+PwPaM7GQ==", + "version": "1.5.5", + "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.5.tgz", + "integrity": "sha512-cYQ9310grqxueWbl+WuIUIaiUaDcj7WOq5fVhEljNVgRfOUhY9fy2zTvfoqWsnebh8Sl70VScFbICvJnLKB0Og==", "dev": true }, "node_modules/@jridgewell/trace-mapping": { - "version": "0.3.25", - "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.25.tgz", - "integrity": "sha512-vNk6aEwybGtawWmy/PzwnGDOjCkLWSD2wqvjGGAgOAwCGWySYXfYoxt00IJkTF+8Lb57DwOb3Aa0o9CApepiYQ==", + "version": "0.3.31", + "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.31.tgz", + "integrity": "sha512-zzNR+SdQSDJzc8joaeP8QQoCQr8NuYx2dIIytl1QeBEZHJ9uW6hebsrYgbz8hJwUQao3TWCMtmfV8Nu1twOLAw==", "dev": true, "dependencies": { "@jridgewell/resolve-uri": "^3.1.0", @@ -187,10 +174,30 @@ "@types/node": "*" } }, + "node_modules/@types/eslint": { + "version": "9.6.1", + "resolved": "https://registry.npmjs.org/@types/eslint/-/eslint-9.6.1.tgz", + "integrity": "sha512-FXx2pKgId/WyYo2jXw63kk7/+TY7u7AziEJxJAnSFzHlqTAS3Ync6SvgYAN/k4/PQpnnVuzoMuVnByKK2qp0ag==", + "dev": true, + "dependencies": { + "@types/estree": "*", + "@types/json-schema": "*" + } + }, + "node_modules/@types/eslint-scope": { + "version": "3.7.7", + "resolved": "https://registry.npmjs.org/@types/eslint-scope/-/eslint-scope-3.7.7.tgz", + "integrity": "sha512-MzMFlSLBqNF2gcHWO0G1vP/YQyfvrxZ0bF+u7mzUdZ1/xK4A4sru+nraZz5i3iEIk1l1uyicaDVTB4QbbEkAYg==", + "dev": true, + "dependencies": { + "@types/eslint": "*", + "@types/estree": "*" + } + }, "node_modules/@types/estree": { - "version": "1.0.5", - "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.5.tgz", - "integrity": "sha512-/kYRxGDLWzHOB7q+wtSUQlFrtcdUccpfy+X+9iMBpHK8QLLhx2wIPYuS5DYtR9Wa/YlZAbIovy7qVdB1Aq6Lyw==", + "version": "1.0.8", + "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.8.tgz", + "integrity": "sha512-dWHzHa2WqEXI/O1E9OjrocMTKJl2mSrEolh1Iomrv6U+JuNwaHXsXx9bLu5gG7BUWFIN0skIQJQ/L1rIex4X6w==", "dev": true }, "node_modules/@types/express": { @@ -234,9 +241,9 @@ } }, "node_modules/@types/json-schema": { - "version": "7.0.13", - "resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.13.tgz", - "integrity": "sha512-RbSSoHliUbnXj3ny0CNFOoxrIDV6SUGyStHsvDqosw6CkdPV8TtWGlfecuK4ToyMEAql6pzNxgCFKanovUzlgQ==", + "version": "7.0.15", + "resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.15.tgz", + "integrity": "sha512-5+fP8P8MFNC+AyZCDxrB2pkZFPGzqQWUzpSeuuVLvm8VMcorNYavBqoFcxK8bQz4Qsbn4oUEEem4wDLfcysGHA==", "dev": true }, "node_modules/@types/mime": { @@ -333,148 +340,148 @@ } }, "node_modules/@webassemblyjs/ast": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/ast/-/ast-1.12.1.tgz", - "integrity": "sha512-EKfMUOPRRUTy5UII4qJDGPpqfwjOmZ5jeGFwid9mnoqIFK+e0vqoi1qH56JpmZSzEL53jKnNzScdmftJyG5xWg==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/ast/-/ast-1.14.1.tgz", + "integrity": "sha512-nuBEDgQfm1ccRp/8bCQrx1frohyufl4JlbMMZ4P1wpeOfDhF6FQkxZJ1b/e+PLwr6X1Nhw6OLme5usuBWYBvuQ==", "dev": true, "dependencies": { - "@webassemblyjs/helper-numbers": "1.11.6", - "@webassemblyjs/helper-wasm-bytecode": "1.11.6" + "@webassemblyjs/helper-numbers": "1.13.2", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2" } }, "node_modules/@webassemblyjs/floating-point-hex-parser": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/floating-point-hex-parser/-/floating-point-hex-parser-1.11.6.tgz", - "integrity": "sha512-ejAj9hfRJ2XMsNHk/v6Fu2dGS+i4UaXBXGemOfQ/JfQ6mdQg/WXtwleQRLLS4OvfDhv8rYnVwH27YJLMyYsxhw==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/floating-point-hex-parser/-/floating-point-hex-parser-1.13.2.tgz", + "integrity": "sha512-6oXyTOzbKxGH4steLbLNOu71Oj+C8Lg34n6CqRvqfS2O71BxY6ByfMDRhBytzknj9yGUPVJ1qIKhRlAwO1AovA==", "dev": true }, "node_modules/@webassemblyjs/helper-api-error": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-api-error/-/helper-api-error-1.11.6.tgz", - "integrity": "sha512-o0YkoP4pVu4rN8aTJgAyj9hC2Sv5UlkzCHhxqWj8butaLvnpdc2jOwh4ewE6CX0txSfLn/UYaV/pheS2Txg//Q==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-api-error/-/helper-api-error-1.13.2.tgz", + "integrity": "sha512-U56GMYxy4ZQCbDZd6JuvvNV/WFildOjsaWD3Tzzvmw/mas3cXzRJPMjP83JqEsgSbyrmaGjBfDtV7KDXV9UzFQ==", "dev": true }, "node_modules/@webassemblyjs/helper-buffer": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-buffer/-/helper-buffer-1.12.1.tgz", - "integrity": "sha512-nzJwQw99DNDKr9BVCOZcLuJJUlqkJh+kVzVl6Fmq/tI5ZtEyWT1KZMyOXltXLZJmDtvLCDgwsyrkohEtopTXCw==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-buffer/-/helper-buffer-1.14.1.tgz", + "integrity": "sha512-jyH7wtcHiKssDtFPRB+iQdxlDf96m0E39yb0k5uJVhFGleZFoNw1c4aeIcVUPPbXUVJ94wwnMOAqUHyzoEPVMA==", "dev": true }, "node_modules/@webassemblyjs/helper-numbers": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-numbers/-/helper-numbers-1.11.6.tgz", - "integrity": "sha512-vUIhZ8LZoIWHBohiEObxVm6hwP034jwmc9kuq5GdHZH0wiLVLIPcMCdpJzG4C11cHoQ25TFIQj9kaVADVX7N3g==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-numbers/-/helper-numbers-1.13.2.tgz", + "integrity": "sha512-FE8aCmS5Q6eQYcV3gI35O4J789wlQA+7JrqTTpJqn5emA4U2hvwJmvFRC0HODS+3Ye6WioDklgd6scJ3+PLnEA==", "dev": true, "dependencies": { - "@webassemblyjs/floating-point-hex-parser": "1.11.6", - "@webassemblyjs/helper-api-error": "1.11.6", + "@webassemblyjs/floating-point-hex-parser": "1.13.2", + "@webassemblyjs/helper-api-error": "1.13.2", "@xtuc/long": "4.2.2" } }, "node_modules/@webassemblyjs/helper-wasm-bytecode": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-bytecode/-/helper-wasm-bytecode-1.11.6.tgz", - "integrity": "sha512-sFFHKwcmBprO9e7Icf0+gddyWYDViL8bpPjJJl0WHxCdETktXdmtWLGVzoHbqUcY4Be1LkNfwTmXOJUFZYSJdA==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-bytecode/-/helper-wasm-bytecode-1.13.2.tgz", + "integrity": "sha512-3QbLKy93F0EAIXLh0ogEVR6rOubA9AoZ+WRYhNbFyuB70j3dRdwH9g+qXhLAO0kiYGlg3TxDV+I4rQTr/YNXkA==", "dev": true }, "node_modules/@webassemblyjs/helper-wasm-section": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-section/-/helper-wasm-section-1.12.1.tgz", - "integrity": "sha512-Jif4vfB6FJlUlSbgEMHUyk1j234GTNG9dBJ4XJdOySoj518Xj0oGsNi59cUQF4RRMS9ouBUxDDdyBVfPTypa5g==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-section/-/helper-wasm-section-1.14.1.tgz", + "integrity": "sha512-ds5mXEqTJ6oxRoqjhWDU83OgzAYjwsCV8Lo/N+oRsNDmx/ZDpqalmrtgOMkHwxsG0iI//3BwWAErYRHtgn0dZw==", "dev": true, "dependencies": { - "@webassemblyjs/ast": "1.12.1", - "@webassemblyjs/helper-buffer": "1.12.1", - "@webassemblyjs/helper-wasm-bytecode": "1.11.6", - "@webassemblyjs/wasm-gen": "1.12.1" + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-buffer": "1.14.1", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2", + "@webassemblyjs/wasm-gen": "1.14.1" } }, "node_modules/@webassemblyjs/ieee754": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/ieee754/-/ieee754-1.11.6.tgz", - "integrity": "sha512-LM4p2csPNvbij6U1f19v6WR56QZ8JcHg3QIJTlSwzFcmx6WSORicYj6I63f9yU1kEUtrpG+kjkiIAkevHpDXrg==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/ieee754/-/ieee754-1.13.2.tgz", + "integrity": "sha512-4LtOzh58S/5lX4ITKxnAK2USuNEvpdVV9AlgGQb8rJDHaLeHciwG4zlGr0j/SNWlr7x3vO1lDEsuePvtcDNCkw==", "dev": true, "dependencies": { "@xtuc/ieee754": "^1.2.0" } }, "node_modules/@webassemblyjs/leb128": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/leb128/-/leb128-1.11.6.tgz", - "integrity": "sha512-m7a0FhE67DQXgouf1tbN5XQcdWoNgaAuoULHIfGFIEVKA6tu/edls6XnIlkmS6FrXAquJRPni3ZZKjw6FSPjPQ==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/leb128/-/leb128-1.13.2.tgz", + "integrity": "sha512-Lde1oNoIdzVzdkNEAWZ1dZ5orIbff80YPdHx20mrHwHrVNNTjNr8E3xz9BdpcGqRQbAEa+fkrCb+fRFTl/6sQw==", "dev": true, "dependencies": { "@xtuc/long": "4.2.2" } }, "node_modules/@webassemblyjs/utf8": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/utf8/-/utf8-1.11.6.tgz", - "integrity": "sha512-vtXf2wTQ3+up9Zsg8sa2yWiQpzSsMyXj0qViVP6xKGCUT8p8YJ6HqI7l5eCnWx1T/FYdsv07HQs2wTFbbof/RA==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/utf8/-/utf8-1.13.2.tgz", + "integrity": "sha512-3NQWGjKTASY1xV5m7Hr0iPeXD9+RDobLll3T9d2AO+g3my8xy5peVyjSag4I50mR1bBSN/Ct12lo+R9tJk0NZQ==", "dev": true }, "node_modules/@webassemblyjs/wasm-edit": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-edit/-/wasm-edit-1.12.1.tgz", - "integrity": "sha512-1DuwbVvADvS5mGnXbE+c9NfA8QRcZ6iKquqjjmR10k6o+zzsRVesil54DKexiowcFCPdr/Q0qaMgB01+SQ1u6g==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-edit/-/wasm-edit-1.14.1.tgz", + "integrity": "sha512-RNJUIQH/J8iA/1NzlE4N7KtyZNHi3w7at7hDjvRNm5rcUXa00z1vRz3glZoULfJ5mpvYhLybmVcwcjGrC1pRrQ==", "dev": true, "dependencies": { - "@webassemblyjs/ast": "1.12.1", - "@webassemblyjs/helper-buffer": "1.12.1", - "@webassemblyjs/helper-wasm-bytecode": "1.11.6", - "@webassemblyjs/helper-wasm-section": "1.12.1", - "@webassemblyjs/wasm-gen": "1.12.1", - "@webassemblyjs/wasm-opt": "1.12.1", - "@webassemblyjs/wasm-parser": "1.12.1", - "@webassemblyjs/wast-printer": "1.12.1" + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-buffer": "1.14.1", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2", + "@webassemblyjs/helper-wasm-section": "1.14.1", + "@webassemblyjs/wasm-gen": "1.14.1", + "@webassemblyjs/wasm-opt": "1.14.1", + "@webassemblyjs/wasm-parser": "1.14.1", + "@webassemblyjs/wast-printer": "1.14.1" } }, "node_modules/@webassemblyjs/wasm-gen": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-gen/-/wasm-gen-1.12.1.tgz", - "integrity": "sha512-TDq4Ojh9fcohAw6OIMXqiIcTq5KUXTGRkVxbSo1hQnSy6lAM5GSdfwWeSxpAo0YzgsgF182E/U0mDNhuA0tW7w==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-gen/-/wasm-gen-1.14.1.tgz", + "integrity": "sha512-AmomSIjP8ZbfGQhumkNvgC33AY7qtMCXnN6bL2u2Js4gVCg8fp735aEiMSBbDR7UQIj90n4wKAFUSEd0QN2Ukg==", "dev": true, "dependencies": { - "@webassemblyjs/ast": "1.12.1", - "@webassemblyjs/helper-wasm-bytecode": "1.11.6", - "@webassemblyjs/ieee754": "1.11.6", - "@webassemblyjs/leb128": "1.11.6", - "@webassemblyjs/utf8": "1.11.6" + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2", + "@webassemblyjs/ieee754": "1.13.2", + "@webassemblyjs/leb128": "1.13.2", + "@webassemblyjs/utf8": "1.13.2" } }, "node_modules/@webassemblyjs/wasm-opt": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-opt/-/wasm-opt-1.12.1.tgz", - "integrity": "sha512-Jg99j/2gG2iaz3hijw857AVYekZe2SAskcqlWIZXjji5WStnOpVoat3gQfT/Q5tb2djnCjBtMocY/Su1GfxPBg==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-opt/-/wasm-opt-1.14.1.tgz", + "integrity": "sha512-PTcKLUNvBqnY2U6E5bdOQcSM+oVP/PmrDY9NzowJjislEjwP/C4an2303MCVS2Mg9d3AJpIGdUFIQQWbPds0Sw==", "dev": true, "dependencies": { - "@webassemblyjs/ast": "1.12.1", - "@webassemblyjs/helper-buffer": "1.12.1", - "@webassemblyjs/wasm-gen": "1.12.1", - "@webassemblyjs/wasm-parser": "1.12.1" + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-buffer": "1.14.1", + "@webassemblyjs/wasm-gen": "1.14.1", + "@webassemblyjs/wasm-parser": "1.14.1" } }, "node_modules/@webassemblyjs/wasm-parser": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-parser/-/wasm-parser-1.12.1.tgz", - "integrity": "sha512-xikIi7c2FHXysxXe3COrVUPSheuBtpcfhbpFj4gmu7KRLYOzANztwUU0IbsqvMqzuNK2+glRGWCEqZo1WCLyAQ==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-parser/-/wasm-parser-1.14.1.tgz", + "integrity": "sha512-JLBl+KZ0R5qB7mCnud/yyX08jWFw5MsoalJ1pQ4EdFlgj9VdXKGuENGsiCIjegI1W7p91rUlcB/LB5yRJKNTcQ==", "dev": true, "dependencies": { - "@webassemblyjs/ast": "1.12.1", - "@webassemblyjs/helper-api-error": "1.11.6", - "@webassemblyjs/helper-wasm-bytecode": "1.11.6", - "@webassemblyjs/ieee754": "1.11.6", - "@webassemblyjs/leb128": "1.11.6", - "@webassemblyjs/utf8": "1.11.6" + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-api-error": "1.13.2", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2", + "@webassemblyjs/ieee754": "1.13.2", + "@webassemblyjs/leb128": "1.13.2", + "@webassemblyjs/utf8": "1.13.2" } }, "node_modules/@webassemblyjs/wast-printer": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wast-printer/-/wast-printer-1.12.1.tgz", - "integrity": "sha512-+X4WAlOisVWQMikjbcvY2e0rwPsKQ9F688lksZhBcPycBBuii3O7m8FACbDMWDojpAqvjIncrG8J0XHKyQfVeA==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wast-printer/-/wast-printer-1.14.1.tgz", + "integrity": "sha512-kPSSXE6De1XOR820C90RIo2ogvZG+c3KiHzqUoO/F34Y2shGzesfqv7o57xrxovZJH/MetF5UjroJ/R/3isoiw==", "dev": true, "dependencies": { - "@webassemblyjs/ast": "1.12.1", + "@webassemblyjs/ast": "1.14.1", "@xtuc/long": "4.2.2" } }, @@ -548,9 +555,9 @@ } }, "node_modules/acorn": { - "version": "8.12.1", - "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.12.1.tgz", - "integrity": "sha512-tcpGyI9zbizT9JbV6oYE477V6mTlXvvi0T0G3SNIYE2apm/G5huBa1+K89VGeovbg+jycCrfhl3ADxErOuO6Jg==", + "version": "8.15.0", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz", + "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", "dev": true, "bin": { "acorn": "bin/acorn" @@ -559,25 +566,28 @@ "node": ">=0.4.0" } }, - "node_modules/acorn-import-attributes": { - "version": "1.9.5", - "resolved": "https://registry.npmjs.org/acorn-import-attributes/-/acorn-import-attributes-1.9.5.tgz", - "integrity": "sha512-n02Vykv5uA3eHGM/Z2dQrcD56kL8TyDb2p1+0P83PClMnC/nc+anbQRhIOWnSq4Ke/KvDPrY3C9hDtC/A3eHnQ==", + "node_modules/acorn-import-phases": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/acorn-import-phases/-/acorn-import-phases-1.0.4.tgz", + "integrity": "sha512-wKmbr/DDiIXzEOiWrTTUcDm24kQ2vGfZQvM2fwg2vXqR5uW6aapr7ObPtj1th32b9u90/Pf4AItvdTh42fBmVQ==", "dev": true, + "engines": { + "node": ">=10.13.0" + }, "peerDependencies": { - "acorn": "^8" + "acorn": "^8.14.0" } }, "node_modules/ajv": { - "version": "6.12.6", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", - "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz", + "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", "dev": true, "dependencies": { - "fast-deep-equal": "^3.1.1", - "fast-json-stable-stringify": "^2.0.0", - "json-schema-traverse": "^0.4.1", - "uri-js": "^4.2.2" + "fast-deep-equal": "^3.1.3", + "fast-uri": "^3.0.1", + "json-schema-traverse": "^1.0.0", + "require-from-string": "^2.0.2" }, "funding": { "type": "github", @@ -601,35 +611,16 @@ } } }, - "node_modules/ajv-formats/node_modules/ajv": { - "version": "8.12.0", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.12.0.tgz", - "integrity": "sha512-sRu1kpcO9yLtYxBKvqfTeh9KzZEwO3STyX1HT+4CaDzC6HpTGYhIhPIzj9XuKU7KYDwnaeh5hcOwjy1QuJzBPA==", + "node_modules/ajv-keywords": { + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-5.1.0.tgz", + "integrity": "sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw==", "dev": true, "dependencies": { - "fast-deep-equal": "^3.1.1", - "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2", - "uri-js": "^4.2.2" + "fast-deep-equal": "^3.1.3" }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/epoberezkin" - } - }, - "node_modules/ajv-formats/node_modules/json-schema-traverse": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", - "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", - "dev": true - }, - "node_modules/ajv-keywords": { - "version": "3.5.2", - "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-3.5.2.tgz", - "integrity": "sha512-5p6WTN0DdTGVQk6VjcEju19IgaHudalcfabD7yhDGeA6bcQnmL+CpveLJq/3hvfwd1aof6L386Ougkx6RfyMIQ==", - "dev": true, "peerDependencies": { - "ajv": "^6.9.1" + "ajv": "^8.8.2" } }, "node_modules/ansi-html-community": { @@ -665,6 +656,15 @@ "dev": true, "license": "MIT" }, + "node_modules/baseline-browser-mapping": { + "version": "2.9.19", + "resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.9.19.tgz", + "integrity": "sha512-ipDqC8FrAl/76p2SSWKSI+H9tFwm7vYqXQrItCuiVPt26Km0jS+NzSsBWAaBusvSbQcfJG+JitdMm+wZAgTYqg==", + "dev": true, + "bin": { + "baseline-browser-mapping": "dist/cli.js" + } + }, "node_modules/batch": { "version": "0.6.1", "resolved": "https://registry.npmjs.org/batch/-/batch-0.6.1.tgz", @@ -753,9 +753,9 @@ } }, "node_modules/browserslist": { - "version": "4.21.11", - "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.21.11.tgz", - "integrity": "sha512-xn1UXOKUz7DjdGlg9RrUr0GGiWzI97UQJnugHtH0OLDfJB7jMgoIkYvRIEO1l9EeEERVqeqLYOcFBW9ldjypbQ==", + "version": "4.28.1", + "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.28.1.tgz", + "integrity": "sha512-ZC5Bd0LgJXgwGqUknZY/vkUQ04r8NXnJZ3yYi4vDmSiZmC/pdSN0NbNRPxZpbtO4uAfDUAFffO8IZoM3Gj8IkA==", "dev": true, "funding": [ { @@ -772,10 +772,11 @@ } ], "dependencies": { - "caniuse-lite": "^1.0.30001538", - "electron-to-chromium": "^1.4.526", - "node-releases": "^2.0.13", - "update-browserslist-db": "^1.0.13" + "baseline-browser-mapping": "^2.9.0", + "caniuse-lite": "^1.0.30001759", + "electron-to-chromium": "^1.5.263", + "node-releases": "^2.0.27", + "update-browserslist-db": "^1.2.0" }, "bin": { "browserslist": "cli.js" @@ -847,9 +848,9 @@ } }, "node_modules/caniuse-lite": { - "version": "1.0.30001538", - "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001538.tgz", - "integrity": "sha512-HWJnhnID+0YMtGlzcp3T9drmBJUVDchPJ08tpUGFLs9CYlwWPH2uLgpHn8fND5pCgXVtnGS3H4QR9XLMHVNkHw==", + "version": "1.0.30001768", + "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001768.tgz", + "integrity": "sha512-qY3aDRZC5nWPgHUgIB84WL+nySuo19wk0VJpp/XI9T34lrvkyhRvNVOFJOp2kxClQhiFBu+TaUSudf6oa3vkSA==", "dev": true, "funding": [ { @@ -1092,36 +1093,6 @@ "webpack": "^5.1.0" } }, - "node_modules/copy-webpack-plugin/node_modules/ajv": { - "version": "8.17.1", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz", - "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", - "dev": true, - "license": "MIT", - "dependencies": { - "fast-deep-equal": "^3.1.3", - "fast-uri": "^3.0.1", - "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/epoberezkin" - } - }, - "node_modules/copy-webpack-plugin/node_modules/ajv-keywords": { - "version": "5.1.0", - "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-5.1.0.tgz", - "integrity": "sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw==", - "dev": true, - "license": "MIT", - "dependencies": { - "fast-deep-equal": "^3.1.3" - }, - "peerDependencies": { - "ajv": "^8.8.2" - } - }, "node_modules/copy-webpack-plugin/node_modules/glob-parent": { "version": "6.0.2", "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", @@ -1135,33 +1106,6 @@ "node": ">=10.13.0" } }, - "node_modules/copy-webpack-plugin/node_modules/json-schema-traverse": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", - "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", - "dev": true, - "license": "MIT" - }, - "node_modules/copy-webpack-plugin/node_modules/schema-utils": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.3.0.tgz", - "integrity": "sha512-Gf9qqc58SpCA/xdziiHz35F4GNIWYWZrEshUc/G/r5BnLph6xpKuLeoJoQuj5WfBIx/eQLf+hmVPYHaxJu7V2g==", - "dev": true, - "license": "MIT", - "dependencies": { - "@types/json-schema": "^7.0.9", - "ajv": "^8.9.0", - "ajv-formats": "^2.1.1", - "ajv-keywords": "^5.1.0" - }, - "engines": { - "node": ">= 10.13.0" - }, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/webpack" - } - }, "node_modules/core-util-is": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/core-util-is/-/core-util-is-1.0.2.tgz", @@ -1307,9 +1251,9 @@ "license": "MIT" }, "node_modules/electron-to-chromium": { - "version": "1.4.528", - "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.4.528.tgz", - "integrity": "sha512-UdREXMXzLkREF4jA8t89FQjA8WHI6ssP38PMY4/4KhXFQbtImnghh4GkCgrtiZwLKUKVD2iTVXvDVQjfomEQuA==", + "version": "1.5.286", + "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.5.286.tgz", + "integrity": "sha512-9tfDXhJ4RKFNerfjdCcZfufu49vg620741MNs26a9+bhLThdB+plgMeou98CAaHu/WATj2iHOOHTp1hWtABj2A==", "dev": true }, "node_modules/encodeurl": { @@ -1323,13 +1267,13 @@ } }, "node_modules/enhanced-resolve": { - "version": "5.17.1", - "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.17.1.tgz", - "integrity": "sha512-LMHl3dXhTcfv8gM4kEzIUeTQ+7fpdA0l2tUf34BddXPkz2A5xJ5L/Pchd5BL6rdccM9QGvu0sWZzK1Z1t4wwyg==", + "version": "5.19.0", + "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.19.0.tgz", + "integrity": "sha512-phv3E1Xl4tQOShqSte26C7Fl84EwUdZsyOuSSk9qtAGyyQs2s3jJzComh+Abf4g187lUUAvH+H26omrqia2aGg==", "dev": true, "dependencies": { "graceful-fs": "^4.2.4", - "tapable": "^2.2.0" + "tapable": "^2.3.0" }, "engines": { "node": ">=10.13.0" @@ -1368,9 +1312,9 @@ } }, "node_modules/es-module-lexer": { - "version": "1.3.1", - "resolved": "https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-1.3.1.tgz", - "integrity": "sha512-JUFAyicQV9mXc3YRxPnDlrfBKpqt6hUYzz9/boprUJHs4e4KVr3XwOF70doO6gwXUor6EWZJAyWAfKki84t20Q==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-2.0.0.tgz", + "integrity": "sha512-5POEcUuZybH7IdmGsD8wlf0AI55wMecM9rVBTI/qEAy2c1kTOm3DjFYjrBdI2K3BaJjJYfYFeRtM0t9ssnRuxw==", "dev": true }, "node_modules/es-object-atoms": { @@ -1387,9 +1331,9 @@ } }, "node_modules/escalade": { - "version": "3.1.1", - "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.1.1.tgz", - "integrity": "sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw==", + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz", + "integrity": "sha512-WUj2qlxaQtO4g6Pq5c29GTcWGDyd8itL8zTlipgECz3JesAiiOKotd8JU6otB3PACgG6xkJUyVhboMS+bje/jA==", "dev": true, "engines": { "node": ">=6" @@ -1604,16 +1548,10 @@ "node": ">=8.6.0" } }, - "node_modules/fast-json-stable-stringify": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/fast-json-stable-stringify/-/fast-json-stable-stringify-2.1.0.tgz", - "integrity": "sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw==", - "dev": true - }, "node_modules/fast-uri": { - "version": "3.0.6", - "resolved": "https://registry.npmjs.org/fast-uri/-/fast-uri-3.0.6.tgz", - "integrity": "sha512-Atfo14OibSv5wAp4VWNsFYE1AchQRTv9cBGWET4pZWHzYshFSS9NQI6I57rdKn9croWVMbYFbLhJ+yJvmZIIHw==", + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/fast-uri/-/fast-uri-3.1.0.tgz", + "integrity": "sha512-iPeeDKJSWf4IEOasVVrknXpaBV0IApz/gp7S2bb7Z4Lljbl2MGJRqInZiUrQwV16cpzw/D3S5j5Julj/gT52AA==", "dev": true, "funding": [ { @@ -1624,8 +1562,7 @@ "type": "opencollective", "url": "https://opencollective.com/fastify" } - ], - "license": "BSD-3-Clause" + ] }, "node_modules/fastest-levenshtein": { "version": "1.0.16", @@ -2304,9 +2241,9 @@ "dev": true }, "node_modules/json-schema-traverse": { - "version": "0.4.1", - "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", - "integrity": "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==", + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", + "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", "dev": true }, "node_modules/kind-of": { @@ -2330,12 +2267,16 @@ } }, "node_modules/loader-runner": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/loader-runner/-/loader-runner-4.3.0.tgz", - "integrity": "sha512-3R/1M+yS3j5ou80Me59j7F9IMs4PXs3VqRrm0TU3AbKPxlmpoY1TNscJV/oGJXo8qCatFGTfDbY6W6ipGOYXfg==", + "version": "4.3.1", + "resolved": "https://registry.npmjs.org/loader-runner/-/loader-runner-4.3.1.tgz", + "integrity": "sha512-IWqP2SCPhyVFTBtRcgMHdzlf9ul25NwaFx4wCEH/KjAXuuHY4yNjvPXsBokp8jCB936PyWRaPKUNh8NvylLp2Q==", "dev": true, "engines": { "node": ">=6.11.5" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" } }, "node_modules/locate-path": { @@ -2619,9 +2560,9 @@ } }, "node_modules/node-releases": { - "version": "2.0.13", - "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.13.tgz", - "integrity": "sha512-uYr7J37ae/ORWdZeQ1xxMJe3NtdmqMC/JZK+geofDrkLUApKRHPd18/TxtBOJ4A0/+uUIliorNrfYV6s1b02eQ==", + "version": "2.0.27", + "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.27.tgz", + "integrity": "sha512-nmh3lCkYZ3grZvqcCH+fjmQ7X+H0OeZgP40OierEaAptX4XofMh5kwNbWh7lBduUzCcV/8kZ+NDLCwm2iorIlA==", "dev": true }, "node_modules/normalize-path": { @@ -2801,9 +2742,9 @@ } }, "node_modules/picocolors": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.0.0.tgz", - "integrity": "sha512-1fygroTLlHu66zi26VoTDv8yRgm0Fccecssto+MhsZ0D/DGW2sm8E8AjW7NU5VVTRt5GxbeZ5qBuJr+HyLYkjQ==", + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", + "integrity": "sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==", "dev": true }, "node_modules/picomatch": { @@ -2860,15 +2801,6 @@ "node": ">= 0.10" } }, - "node_modules/punycode": { - "version": "2.3.0", - "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.0.tgz", - "integrity": "sha512-rRV+zQD8tVFys26lAGR9WUuS4iUAngJScM+ZRSKtvl5tKeZ2t5bvdNFdNHBW9FWR4guGHlgmsZ1G7BSm2wTbuA==", - "dev": true, - "engines": { - "node": ">=6" - } - }, "node_modules/qs": { "version": "6.13.0", "resolved": "https://registry.npmjs.org/qs/-/qs-6.13.0.tgz", @@ -3106,14 +3038,15 @@ "license": "MIT" }, "node_modules/schema-utils": { - "version": "3.3.0", - "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-3.3.0.tgz", - "integrity": "sha512-pN/yOAvcC+5rQ5nERGuwrjLlYvLTbCibnZ1I7B1LaiAz9BRBlE9GMgE/eqV30P7aJQUf7Ddimy/RsbYO/GrVGg==", + "version": "4.3.3", + "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.3.3.tgz", + "integrity": "sha512-eflK8wEtyOE6+hsaRVPxvUKYCpRgzLqDTb8krvAsRIwOGlHoSgYLgBXoubGgLd2fT41/OUYdb48v4k4WWHQurA==", "dev": true, "dependencies": { - "@types/json-schema": "^7.0.8", - "ajv": "^6.12.5", - "ajv-keywords": "^3.5.2" + "@types/json-schema": "^7.0.9", + "ajv": "^8.9.0", + "ajv-formats": "^2.1.1", + "ajv-keywords": "^5.1.0" }, "engines": { "node": ">= 10.13.0" @@ -3558,22 +3491,26 @@ } }, "node_modules/tapable": { - "version": "2.2.1", - "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.2.1.tgz", - "integrity": "sha512-GNzQvQTOIP6RyTfE2Qxb8ZVlNmw0n88vp1szwWRimP02mnTsx3Wtn5qRdqY9w2XduFNUgvOwhNnQsjwCp+kqaQ==", + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.3.0.tgz", + "integrity": "sha512-g9ljZiwki/LfxmQADO3dEY1CbpmXT5Hm2fJ+QaGKwSXUylMybePR7/67YW7jOrrvjEgL1Fmz5kzyAjWVWLlucg==", "dev": true, "engines": { "node": ">=6" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" } }, "node_modules/terser": { - "version": "5.31.6", - "resolved": "https://registry.npmjs.org/terser/-/terser-5.31.6.tgz", - "integrity": "sha512-PQ4DAriWzKj+qgehQ7LK5bQqCFNMmlhjR2PFFLuqGCpuCAauxemVBWwWOxo3UIwWQx8+Pr61Df++r76wDmkQBg==", + "version": "5.46.0", + "resolved": "https://registry.npmjs.org/terser/-/terser-5.46.0.tgz", + "integrity": "sha512-jTwoImyr/QbOWFFso3YoU3ik0jBBDJ6JTOQiy/J2YxVJdZCc+5u7skhNwiOR3FQIygFqVUPHl7qbbxtjW2K3Qg==", "dev": true, "dependencies": { "@jridgewell/source-map": "^0.3.3", - "acorn": "^8.8.2", + "acorn": "^8.15.0", "commander": "^2.20.0", "source-map-support": "~0.5.20" }, @@ -3585,16 +3522,16 @@ } }, "node_modules/terser-webpack-plugin": { - "version": "5.3.10", - "resolved": "https://registry.npmjs.org/terser-webpack-plugin/-/terser-webpack-plugin-5.3.10.tgz", - "integrity": "sha512-BKFPWlPDndPs+NGGCr1U59t0XScL5317Y0UReNrHaw9/FwhPENlq6bfgs+4yPfyP51vqC1bQ4rp1EfXW5ZSH9w==", + "version": "5.3.16", + "resolved": "https://registry.npmjs.org/terser-webpack-plugin/-/terser-webpack-plugin-5.3.16.tgz", + "integrity": "sha512-h9oBFCWrq78NyWWVcSwZarJkZ01c2AyGrzs1crmHZO3QUg9D61Wu4NPjBy69n7JqylFF5y+CsUZYmYEIZ3mR+Q==", "dev": true, "dependencies": { - "@jridgewell/trace-mapping": "^0.3.20", + "@jridgewell/trace-mapping": "^0.3.25", "jest-worker": "^27.4.5", - "schema-utils": "^3.1.1", - "serialize-javascript": "^6.0.1", - "terser": "^5.26.0" + "schema-utils": "^4.3.0", + "serialize-javascript": "^6.0.2", + "terser": "^5.31.1" }, "engines": { "node": ">= 10.13.0" @@ -3691,9 +3628,9 @@ } }, "node_modules/update-browserslist-db": { - "version": "1.0.13", - "resolved": "https://registry.npmjs.org/update-browserslist-db/-/update-browserslist-db-1.0.13.tgz", - "integrity": "sha512-xebP81SNcPuNpPP3uzeW1NYXxI3rxyJzF3pD6sH4jE7o/IX+WtSpwnVU+qIsDPyk0d3hmFQ7mjqc6AtV604hbg==", + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/update-browserslist-db/-/update-browserslist-db-1.2.3.tgz", + "integrity": "sha512-Js0m9cx+qOgDxo0eMiFGEueWztz+d4+M3rGlmKPT+T4IS/jP4ylw3Nwpu6cpTTP8R1MAC1kF4VbdLt3ARf209w==", "dev": true, "funding": [ { @@ -3710,8 +3647,8 @@ } ], "dependencies": { - "escalade": "^3.1.1", - "picocolors": "^1.0.0" + "escalade": "^3.2.0", + "picocolors": "^1.1.1" }, "bin": { "update-browserslist-db": "cli.js" @@ -3720,15 +3657,6 @@ "browserslist": ">= 4.21.0" } }, - "node_modules/uri-js": { - "version": "4.4.1", - "resolved": "https://registry.npmjs.org/uri-js/-/uri-js-4.4.1.tgz", - "integrity": "sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg==", - "dev": true, - "dependencies": { - "punycode": "^2.1.0" - } - }, "node_modules/util-deprecate": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", @@ -3764,9 +3692,9 @@ } }, "node_modules/watchpack": { - "version": "2.4.2", - "resolved": "https://registry.npmjs.org/watchpack/-/watchpack-2.4.2.tgz", - "integrity": "sha512-TnbFSbcOCcDgjZ4piURLCbJ3nJhznVh9kw6F6iokjiFPl8ONxe9A6nMDVXDiNbrSfLILs6vB07F7wLBrwPYzJw==", + "version": "2.5.1", + "resolved": "https://registry.npmjs.org/watchpack/-/watchpack-2.5.1.tgz", + "integrity": "sha512-Zn5uXdcFNIA1+1Ei5McRd+iRzfhENPCe7LeABkJtNulSxjma+l7ltNx55BWZkRlwRnpOgHqxnjyaDgJnNXnqzg==", "dev": true, "dependencies": { "glob-to-regexp": "^0.4.1", @@ -3786,34 +3714,36 @@ } }, "node_modules/webpack": { - "version": "5.94.0", - "resolved": "https://registry.npmjs.org/webpack/-/webpack-5.94.0.tgz", - "integrity": "sha512-KcsGn50VT+06JH/iunZJedYGUJS5FGjow8wb9c0v5n1Om8O1g4L6LjtfxwlXIATopoQu+vOXXa7gYisWxCoPyg==", - "dev": true, - "dependencies": { - "@types/estree": "^1.0.5", - "@webassemblyjs/ast": "^1.12.1", - "@webassemblyjs/wasm-edit": "^1.12.1", - "@webassemblyjs/wasm-parser": "^1.12.1", - "acorn": "^8.7.1", - "acorn-import-attributes": "^1.9.5", - "browserslist": "^4.21.10", + "version": "5.105.0", + "resolved": "https://registry.npmjs.org/webpack/-/webpack-5.105.0.tgz", + "integrity": "sha512-gX/dMkRQc7QOMzgTe6KsYFM7DxeIONQSui1s0n/0xht36HvrgbxtM1xBlgx596NbpHuQU8P7QpKwrZYwUX48nw==", + "dev": true, + "dependencies": { + "@types/eslint-scope": "^3.7.7", + "@types/estree": "^1.0.8", + "@types/json-schema": "^7.0.15", + "@webassemblyjs/ast": "^1.14.1", + "@webassemblyjs/wasm-edit": "^1.14.1", + "@webassemblyjs/wasm-parser": "^1.14.1", + "acorn": "^8.15.0", + "acorn-import-phases": "^1.0.3", + "browserslist": "^4.28.1", "chrome-trace-event": "^1.0.2", - "enhanced-resolve": "^5.17.1", - "es-module-lexer": "^1.2.1", + "enhanced-resolve": "^5.19.0", + "es-module-lexer": "^2.0.0", "eslint-scope": "5.1.1", "events": "^3.2.0", "glob-to-regexp": "^0.4.1", "graceful-fs": "^4.2.11", "json-parse-even-better-errors": "^2.3.1", - "loader-runner": "^4.2.0", + "loader-runner": "^4.3.1", "mime-types": "^2.1.27", "neo-async": "^2.6.2", - "schema-utils": "^3.2.0", - "tapable": "^2.1.1", - "terser-webpack-plugin": "^5.3.10", - "watchpack": "^2.4.1", - "webpack-sources": "^3.2.3" + "schema-utils": "^4.3.3", + "tapable": "^2.3.0", + "terser-webpack-plugin": "^5.3.16", + "watchpack": "^2.5.1", + "webpack-sources": "^3.3.3" }, "bin": { "webpack": "bin/webpack.js" @@ -3915,63 +3845,6 @@ } } }, - "node_modules/webpack-dev-middleware/node_modules/ajv": { - "version": "8.17.1", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz", - "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", - "dev": true, - "license": "MIT", - "dependencies": { - "fast-deep-equal": "^3.1.3", - "fast-uri": "^3.0.1", - "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/epoberezkin" - } - }, - "node_modules/webpack-dev-middleware/node_modules/ajv-keywords": { - "version": "5.1.0", - "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-5.1.0.tgz", - "integrity": "sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw==", - "dev": true, - "license": "MIT", - "dependencies": { - "fast-deep-equal": "^3.1.3" - }, - "peerDependencies": { - "ajv": "^8.8.2" - } - }, - "node_modules/webpack-dev-middleware/node_modules/json-schema-traverse": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", - "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", - "dev": true, - "license": "MIT" - }, - "node_modules/webpack-dev-middleware/node_modules/schema-utils": { - "version": "4.3.2", - "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.3.2.tgz", - "integrity": "sha512-Gn/JaSk/Mt9gYubxTtSn/QCV4em9mpAPiR1rqy/Ocu19u/G9J5WWdNoUT4SiV6mFC3y6cxyFcFwdzPM3FgxGAQ==", - "dev": true, - "license": "MIT", - "dependencies": { - "@types/json-schema": "^7.0.9", - "ajv": "^8.9.0", - "ajv-formats": "^2.1.1", - "ajv-keywords": "^5.1.0" - }, - "engines": { - "node": ">= 10.13.0" - }, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/webpack" - } - }, "node_modules/webpack-dev-server": { "version": "5.2.1", "resolved": "https://registry.npmjs.org/webpack-dev-server/-/webpack-dev-server-5.2.1.tgz", @@ -4030,59 +3903,6 @@ } } }, - "node_modules/webpack-dev-server/node_modules/ajv": { - "version": "8.12.0", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.12.0.tgz", - "integrity": "sha512-sRu1kpcO9yLtYxBKvqfTeh9KzZEwO3STyX1HT+4CaDzC6HpTGYhIhPIzj9XuKU7KYDwnaeh5hcOwjy1QuJzBPA==", - "dev": true, - "dependencies": { - "fast-deep-equal": "^3.1.1", - "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2", - "uri-js": "^4.2.2" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/epoberezkin" - } - }, - "node_modules/webpack-dev-server/node_modules/ajv-keywords": { - "version": "5.1.0", - "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-5.1.0.tgz", - "integrity": "sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw==", - "dev": true, - "dependencies": { - "fast-deep-equal": "^3.1.3" - }, - "peerDependencies": { - "ajv": "^8.8.2" - } - }, - "node_modules/webpack-dev-server/node_modules/json-schema-traverse": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", - "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", - "dev": true - }, - "node_modules/webpack-dev-server/node_modules/schema-utils": { - "version": "4.2.0", - "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.2.0.tgz", - "integrity": "sha512-L0jRsrPpjdckP3oPug3/VxNKt2trR8TcabrM6FOAAlvC/9Phcmm+cuAgTlxBqdBR1WJx7Naj9WHw+aOmheSVbw==", - "dev": true, - "dependencies": { - "@types/json-schema": "^7.0.9", - "ajv": "^8.9.0", - "ajv-formats": "^2.1.1", - "ajv-keywords": "^5.1.0" - }, - "engines": { - "node": ">= 12.13.0" - }, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/webpack" - } - }, "node_modules/webpack-merge": { "version": "5.9.0", "resolved": "https://registry.npmjs.org/webpack-merge/-/webpack-merge-5.9.0.tgz", @@ -4096,10 +3916,10 @@ "node": ">=10.0.0" } }, - "node_modules/webpack/node_modules/webpack-sources": { - "version": "3.2.3", - "resolved": "https://registry.npmjs.org/webpack-sources/-/webpack-sources-3.2.3.tgz", - "integrity": "sha512-/DyMEOrDgLKKIG0fmvtz+4dUX/3Ghozwgm6iPp8KRhvn+eQf9+Q7GWxVNMk3+uCPWfdXYC4ExGBckIXdFEfH1w==", + "node_modules/webpack-sources": { + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/webpack-sources/-/webpack-sources-3.3.3.tgz", + "integrity": "sha512-yd1RBzSGanHkitROoPFd6qsrxt+oFhg/129YzheDGqeustzX0vTZJZsSsQjVQC4yzBQ56K55XU8gaNCtIzOnTg==", "dev": true, "engines": { "node": ">=10.13.0" @@ -4180,13 +4000,12 @@ "dev": true }, "@jridgewell/gen-mapping": { - "version": "0.3.5", - "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.5.tgz", - "integrity": "sha512-IzL8ZoEDIBRWEzlCcRhOaCupYyN5gdIK+Q6fbFdPDg6HqX6jpkItn7DFIpW9LQzXG6Df9sA7+OKnq0qlz/GaQg==", + "version": "0.3.13", + "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.13.tgz", + "integrity": "sha512-2kkt/7niJ6MgEPxF0bYdQ6etZaA+fQvDcLKckhy1yIQOzaoKjBBjSj63/aLVjYE3qhRt5dvM+uUyfCg6UKCBbA==", "dev": true, "requires": { - "@jridgewell/set-array": "^1.2.1", - "@jridgewell/sourcemap-codec": "^1.4.10", + "@jridgewell/sourcemap-codec": "^1.5.0", "@jridgewell/trace-mapping": "^0.3.24" } }, @@ -4196,16 +4015,10 @@ "integrity": "sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw==", "dev": true }, - "@jridgewell/set-array": { - "version": "1.2.1", - "resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.2.1.tgz", - "integrity": "sha512-R8gLRTZeyp03ymzP/6Lil/28tGeGEzhx1q2k703KGWRAI1VdvPIXdG70VJc2pAMw3NA6JKL5hhFu1sJX0Mnn/A==", - "dev": true - }, "@jridgewell/source-map": { - "version": "0.3.6", - "resolved": "https://registry.npmjs.org/@jridgewell/source-map/-/source-map-0.3.6.tgz", - "integrity": "sha512-1ZJTZebgqllO79ue2bm3rIGud/bOe0pP5BjSRCRxxYkEZS8STV7zN84UBbiYu7jy+eCKSnVIUgoWWE/tt+shMQ==", + "version": "0.3.11", + "resolved": "https://registry.npmjs.org/@jridgewell/source-map/-/source-map-0.3.11.tgz", + "integrity": "sha512-ZMp1V8ZFcPG5dIWnQLr3NSI1MiCU7UETdS/A0G8V/XWHvJv3ZsFqutJn1Y5RPmAPX6F3BiE397OqveU/9NCuIA==", "dev": true, "requires": { "@jridgewell/gen-mapping": "^0.3.5", @@ -4213,15 +4026,15 @@ } }, "@jridgewell/sourcemap-codec": { - "version": "1.5.0", - "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.0.tgz", - "integrity": "sha512-gv3ZRaISU3fjPAgNsriBRqGWQL6quFx04YMPW/zD8XMLsU32mhCCbfbO6KZFLjvYpCZ8zyDEgqsgf+PwPaM7GQ==", + "version": "1.5.5", + "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.5.tgz", + "integrity": "sha512-cYQ9310grqxueWbl+WuIUIaiUaDcj7WOq5fVhEljNVgRfOUhY9fy2zTvfoqWsnebh8Sl70VScFbICvJnLKB0Og==", "dev": true }, "@jridgewell/trace-mapping": { - "version": "0.3.25", - "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.25.tgz", - "integrity": "sha512-vNk6aEwybGtawWmy/PzwnGDOjCkLWSD2wqvjGGAgOAwCGWySYXfYoxt00IJkTF+8Lb57DwOb3Aa0o9CApepiYQ==", + "version": "0.3.31", + "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.31.tgz", + "integrity": "sha512-zzNR+SdQSDJzc8joaeP8QQoCQr8NuYx2dIIytl1QeBEZHJ9uW6hebsrYgbz8hJwUQao3TWCMtmfV8Nu1twOLAw==", "dev": true, "requires": { "@jridgewell/resolve-uri": "^3.1.0", @@ -4304,10 +4117,30 @@ "@types/node": "*" } }, + "@types/eslint": { + "version": "9.6.1", + "resolved": "https://registry.npmjs.org/@types/eslint/-/eslint-9.6.1.tgz", + "integrity": "sha512-FXx2pKgId/WyYo2jXw63kk7/+TY7u7AziEJxJAnSFzHlqTAS3Ync6SvgYAN/k4/PQpnnVuzoMuVnByKK2qp0ag==", + "dev": true, + "requires": { + "@types/estree": "*", + "@types/json-schema": "*" + } + }, + "@types/eslint-scope": { + "version": "3.7.7", + "resolved": "https://registry.npmjs.org/@types/eslint-scope/-/eslint-scope-3.7.7.tgz", + "integrity": "sha512-MzMFlSLBqNF2gcHWO0G1vP/YQyfvrxZ0bF+u7mzUdZ1/xK4A4sru+nraZz5i3iEIk1l1uyicaDVTB4QbbEkAYg==", + "dev": true, + "requires": { + "@types/eslint": "*", + "@types/estree": "*" + } + }, "@types/estree": { - "version": "1.0.5", - "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.5.tgz", - "integrity": "sha512-/kYRxGDLWzHOB7q+wtSUQlFrtcdUccpfy+X+9iMBpHK8QLLhx2wIPYuS5DYtR9Wa/YlZAbIovy7qVdB1Aq6Lyw==", + "version": "1.0.8", + "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.8.tgz", + "integrity": "sha512-dWHzHa2WqEXI/O1E9OjrocMTKJl2mSrEolh1Iomrv6U+JuNwaHXsXx9bLu5gG7BUWFIN0skIQJQ/L1rIex4X6w==", "dev": true }, "@types/express": { @@ -4350,9 +4183,9 @@ } }, "@types/json-schema": { - "version": "7.0.13", - "resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.13.tgz", - "integrity": "sha512-RbSSoHliUbnXj3ny0CNFOoxrIDV6SUGyStHsvDqosw6CkdPV8TtWGlfecuK4ToyMEAql6pzNxgCFKanovUzlgQ==", + "version": "7.0.15", + "resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.15.tgz", + "integrity": "sha512-5+fP8P8MFNC+AyZCDxrB2pkZFPGzqQWUzpSeuuVLvm8VMcorNYavBqoFcxK8bQz4Qsbn4oUEEem4wDLfcysGHA==", "dev": true }, "@types/mime": { @@ -4443,148 +4276,148 @@ } }, "@webassemblyjs/ast": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/ast/-/ast-1.12.1.tgz", - "integrity": "sha512-EKfMUOPRRUTy5UII4qJDGPpqfwjOmZ5jeGFwid9mnoqIFK+e0vqoi1qH56JpmZSzEL53jKnNzScdmftJyG5xWg==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/ast/-/ast-1.14.1.tgz", + "integrity": "sha512-nuBEDgQfm1ccRp/8bCQrx1frohyufl4JlbMMZ4P1wpeOfDhF6FQkxZJ1b/e+PLwr6X1Nhw6OLme5usuBWYBvuQ==", "dev": true, "requires": { - "@webassemblyjs/helper-numbers": "1.11.6", - "@webassemblyjs/helper-wasm-bytecode": "1.11.6" + "@webassemblyjs/helper-numbers": "1.13.2", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2" } }, "@webassemblyjs/floating-point-hex-parser": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/floating-point-hex-parser/-/floating-point-hex-parser-1.11.6.tgz", - "integrity": "sha512-ejAj9hfRJ2XMsNHk/v6Fu2dGS+i4UaXBXGemOfQ/JfQ6mdQg/WXtwleQRLLS4OvfDhv8rYnVwH27YJLMyYsxhw==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/floating-point-hex-parser/-/floating-point-hex-parser-1.13.2.tgz", + "integrity": "sha512-6oXyTOzbKxGH4steLbLNOu71Oj+C8Lg34n6CqRvqfS2O71BxY6ByfMDRhBytzknj9yGUPVJ1qIKhRlAwO1AovA==", "dev": true }, "@webassemblyjs/helper-api-error": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-api-error/-/helper-api-error-1.11.6.tgz", - "integrity": "sha512-o0YkoP4pVu4rN8aTJgAyj9hC2Sv5UlkzCHhxqWj8butaLvnpdc2jOwh4ewE6CX0txSfLn/UYaV/pheS2Txg//Q==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-api-error/-/helper-api-error-1.13.2.tgz", + "integrity": "sha512-U56GMYxy4ZQCbDZd6JuvvNV/WFildOjsaWD3Tzzvmw/mas3cXzRJPMjP83JqEsgSbyrmaGjBfDtV7KDXV9UzFQ==", "dev": true }, "@webassemblyjs/helper-buffer": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-buffer/-/helper-buffer-1.12.1.tgz", - "integrity": "sha512-nzJwQw99DNDKr9BVCOZcLuJJUlqkJh+kVzVl6Fmq/tI5ZtEyWT1KZMyOXltXLZJmDtvLCDgwsyrkohEtopTXCw==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-buffer/-/helper-buffer-1.14.1.tgz", + "integrity": "sha512-jyH7wtcHiKssDtFPRB+iQdxlDf96m0E39yb0k5uJVhFGleZFoNw1c4aeIcVUPPbXUVJ94wwnMOAqUHyzoEPVMA==", "dev": true }, "@webassemblyjs/helper-numbers": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-numbers/-/helper-numbers-1.11.6.tgz", - "integrity": "sha512-vUIhZ8LZoIWHBohiEObxVm6hwP034jwmc9kuq5GdHZH0wiLVLIPcMCdpJzG4C11cHoQ25TFIQj9kaVADVX7N3g==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-numbers/-/helper-numbers-1.13.2.tgz", + "integrity": "sha512-FE8aCmS5Q6eQYcV3gI35O4J789wlQA+7JrqTTpJqn5emA4U2hvwJmvFRC0HODS+3Ye6WioDklgd6scJ3+PLnEA==", "dev": true, "requires": { - "@webassemblyjs/floating-point-hex-parser": "1.11.6", - "@webassemblyjs/helper-api-error": "1.11.6", + "@webassemblyjs/floating-point-hex-parser": "1.13.2", + "@webassemblyjs/helper-api-error": "1.13.2", "@xtuc/long": "4.2.2" } }, "@webassemblyjs/helper-wasm-bytecode": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-bytecode/-/helper-wasm-bytecode-1.11.6.tgz", - "integrity": "sha512-sFFHKwcmBprO9e7Icf0+gddyWYDViL8bpPjJJl0WHxCdETktXdmtWLGVzoHbqUcY4Be1LkNfwTmXOJUFZYSJdA==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-bytecode/-/helper-wasm-bytecode-1.13.2.tgz", + "integrity": "sha512-3QbLKy93F0EAIXLh0ogEVR6rOubA9AoZ+WRYhNbFyuB70j3dRdwH9g+qXhLAO0kiYGlg3TxDV+I4rQTr/YNXkA==", "dev": true }, "@webassemblyjs/helper-wasm-section": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-section/-/helper-wasm-section-1.12.1.tgz", - "integrity": "sha512-Jif4vfB6FJlUlSbgEMHUyk1j234GTNG9dBJ4XJdOySoj518Xj0oGsNi59cUQF4RRMS9ouBUxDDdyBVfPTypa5g==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-section/-/helper-wasm-section-1.14.1.tgz", + "integrity": "sha512-ds5mXEqTJ6oxRoqjhWDU83OgzAYjwsCV8Lo/N+oRsNDmx/ZDpqalmrtgOMkHwxsG0iI//3BwWAErYRHtgn0dZw==", "dev": true, "requires": { - "@webassemblyjs/ast": "1.12.1", - "@webassemblyjs/helper-buffer": "1.12.1", - "@webassemblyjs/helper-wasm-bytecode": "1.11.6", - "@webassemblyjs/wasm-gen": "1.12.1" + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-buffer": "1.14.1", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2", + "@webassemblyjs/wasm-gen": "1.14.1" } }, "@webassemblyjs/ieee754": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/ieee754/-/ieee754-1.11.6.tgz", - "integrity": "sha512-LM4p2csPNvbij6U1f19v6WR56QZ8JcHg3QIJTlSwzFcmx6WSORicYj6I63f9yU1kEUtrpG+kjkiIAkevHpDXrg==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/ieee754/-/ieee754-1.13.2.tgz", + "integrity": "sha512-4LtOzh58S/5lX4ITKxnAK2USuNEvpdVV9AlgGQb8rJDHaLeHciwG4zlGr0j/SNWlr7x3vO1lDEsuePvtcDNCkw==", "dev": true, "requires": { "@xtuc/ieee754": "^1.2.0" } }, "@webassemblyjs/leb128": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/leb128/-/leb128-1.11.6.tgz", - "integrity": "sha512-m7a0FhE67DQXgouf1tbN5XQcdWoNgaAuoULHIfGFIEVKA6tu/edls6XnIlkmS6FrXAquJRPni3ZZKjw6FSPjPQ==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/leb128/-/leb128-1.13.2.tgz", + "integrity": "sha512-Lde1oNoIdzVzdkNEAWZ1dZ5orIbff80YPdHx20mrHwHrVNNTjNr8E3xz9BdpcGqRQbAEa+fkrCb+fRFTl/6sQw==", "dev": true, "requires": { "@xtuc/long": "4.2.2" } }, "@webassemblyjs/utf8": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/utf8/-/utf8-1.11.6.tgz", - "integrity": "sha512-vtXf2wTQ3+up9Zsg8sa2yWiQpzSsMyXj0qViVP6xKGCUT8p8YJ6HqI7l5eCnWx1T/FYdsv07HQs2wTFbbof/RA==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/utf8/-/utf8-1.13.2.tgz", + "integrity": "sha512-3NQWGjKTASY1xV5m7Hr0iPeXD9+RDobLll3T9d2AO+g3my8xy5peVyjSag4I50mR1bBSN/Ct12lo+R9tJk0NZQ==", "dev": true }, "@webassemblyjs/wasm-edit": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-edit/-/wasm-edit-1.12.1.tgz", - "integrity": "sha512-1DuwbVvADvS5mGnXbE+c9NfA8QRcZ6iKquqjjmR10k6o+zzsRVesil54DKexiowcFCPdr/Q0qaMgB01+SQ1u6g==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-edit/-/wasm-edit-1.14.1.tgz", + "integrity": "sha512-RNJUIQH/J8iA/1NzlE4N7KtyZNHi3w7at7hDjvRNm5rcUXa00z1vRz3glZoULfJ5mpvYhLybmVcwcjGrC1pRrQ==", "dev": true, "requires": { - "@webassemblyjs/ast": "1.12.1", - "@webassemblyjs/helper-buffer": "1.12.1", - "@webassemblyjs/helper-wasm-bytecode": "1.11.6", - "@webassemblyjs/helper-wasm-section": "1.12.1", - "@webassemblyjs/wasm-gen": "1.12.1", - "@webassemblyjs/wasm-opt": "1.12.1", - "@webassemblyjs/wasm-parser": "1.12.1", - "@webassemblyjs/wast-printer": "1.12.1" + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-buffer": "1.14.1", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2", + "@webassemblyjs/helper-wasm-section": "1.14.1", + "@webassemblyjs/wasm-gen": "1.14.1", + "@webassemblyjs/wasm-opt": "1.14.1", + "@webassemblyjs/wasm-parser": "1.14.1", + "@webassemblyjs/wast-printer": "1.14.1" } }, "@webassemblyjs/wasm-gen": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-gen/-/wasm-gen-1.12.1.tgz", - "integrity": "sha512-TDq4Ojh9fcohAw6OIMXqiIcTq5KUXTGRkVxbSo1hQnSy6lAM5GSdfwWeSxpAo0YzgsgF182E/U0mDNhuA0tW7w==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-gen/-/wasm-gen-1.14.1.tgz", + "integrity": "sha512-AmomSIjP8ZbfGQhumkNvgC33AY7qtMCXnN6bL2u2Js4gVCg8fp735aEiMSBbDR7UQIj90n4wKAFUSEd0QN2Ukg==", "dev": true, "requires": { - "@webassemblyjs/ast": "1.12.1", - "@webassemblyjs/helper-wasm-bytecode": "1.11.6", - "@webassemblyjs/ieee754": "1.11.6", - "@webassemblyjs/leb128": "1.11.6", - "@webassemblyjs/utf8": "1.11.6" + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2", + "@webassemblyjs/ieee754": "1.13.2", + "@webassemblyjs/leb128": "1.13.2", + "@webassemblyjs/utf8": "1.13.2" } }, "@webassemblyjs/wasm-opt": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-opt/-/wasm-opt-1.12.1.tgz", - "integrity": "sha512-Jg99j/2gG2iaz3hijw857AVYekZe2SAskcqlWIZXjji5WStnOpVoat3gQfT/Q5tb2djnCjBtMocY/Su1GfxPBg==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-opt/-/wasm-opt-1.14.1.tgz", + "integrity": "sha512-PTcKLUNvBqnY2U6E5bdOQcSM+oVP/PmrDY9NzowJjislEjwP/C4an2303MCVS2Mg9d3AJpIGdUFIQQWbPds0Sw==", "dev": true, "requires": { - "@webassemblyjs/ast": "1.12.1", - "@webassemblyjs/helper-buffer": "1.12.1", - "@webassemblyjs/wasm-gen": "1.12.1", - "@webassemblyjs/wasm-parser": "1.12.1" + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-buffer": "1.14.1", + "@webassemblyjs/wasm-gen": "1.14.1", + "@webassemblyjs/wasm-parser": "1.14.1" } }, "@webassemblyjs/wasm-parser": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-parser/-/wasm-parser-1.12.1.tgz", - "integrity": "sha512-xikIi7c2FHXysxXe3COrVUPSheuBtpcfhbpFj4gmu7KRLYOzANztwUU0IbsqvMqzuNK2+glRGWCEqZo1WCLyAQ==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-parser/-/wasm-parser-1.14.1.tgz", + "integrity": "sha512-JLBl+KZ0R5qB7mCnud/yyX08jWFw5MsoalJ1pQ4EdFlgj9VdXKGuENGsiCIjegI1W7p91rUlcB/LB5yRJKNTcQ==", "dev": true, "requires": { - "@webassemblyjs/ast": "1.12.1", - "@webassemblyjs/helper-api-error": "1.11.6", - "@webassemblyjs/helper-wasm-bytecode": "1.11.6", - "@webassemblyjs/ieee754": "1.11.6", - "@webassemblyjs/leb128": "1.11.6", - "@webassemblyjs/utf8": "1.11.6" + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-api-error": "1.13.2", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2", + "@webassemblyjs/ieee754": "1.13.2", + "@webassemblyjs/leb128": "1.13.2", + "@webassemblyjs/utf8": "1.13.2" } }, "@webassemblyjs/wast-printer": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wast-printer/-/wast-printer-1.12.1.tgz", - "integrity": "sha512-+X4WAlOisVWQMikjbcvY2e0rwPsKQ9F688lksZhBcPycBBuii3O7m8FACbDMWDojpAqvjIncrG8J0XHKyQfVeA==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wast-printer/-/wast-printer-1.14.1.tgz", + "integrity": "sha512-kPSSXE6De1XOR820C90RIo2ogvZG+c3KiHzqUoO/F34Y2shGzesfqv7o57xrxovZJH/MetF5UjroJ/R/3isoiw==", "dev": true, "requires": { - "@webassemblyjs/ast": "1.12.1", + "@webassemblyjs/ast": "1.14.1", "@xtuc/long": "4.2.2" } }, @@ -4632,28 +4465,28 @@ } }, "acorn": { - "version": "8.12.1", - "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.12.1.tgz", - "integrity": "sha512-tcpGyI9zbizT9JbV6oYE477V6mTlXvvi0T0G3SNIYE2apm/G5huBa1+K89VGeovbg+jycCrfhl3ADxErOuO6Jg==", + "version": "8.15.0", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz", + "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", "dev": true }, - "acorn-import-attributes": { - "version": "1.9.5", - "resolved": "https://registry.npmjs.org/acorn-import-attributes/-/acorn-import-attributes-1.9.5.tgz", - "integrity": "sha512-n02Vykv5uA3eHGM/Z2dQrcD56kL8TyDb2p1+0P83PClMnC/nc+anbQRhIOWnSq4Ke/KvDPrY3C9hDtC/A3eHnQ==", + "acorn-import-phases": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/acorn-import-phases/-/acorn-import-phases-1.0.4.tgz", + "integrity": "sha512-wKmbr/DDiIXzEOiWrTTUcDm24kQ2vGfZQvM2fwg2vXqR5uW6aapr7ObPtj1th32b9u90/Pf4AItvdTh42fBmVQ==", "dev": true, "requires": {} }, "ajv": { - "version": "6.12.6", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", - "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz", + "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", "dev": true, "requires": { - "fast-deep-equal": "^3.1.1", - "fast-json-stable-stringify": "^2.0.0", - "json-schema-traverse": "^0.4.1", - "uri-js": "^4.2.2" + "fast-deep-equal": "^3.1.3", + "fast-uri": "^3.0.1", + "json-schema-traverse": "^1.0.0", + "require-from-string": "^2.0.2" } }, "ajv-formats": { @@ -4663,34 +4496,16 @@ "dev": true, "requires": { "ajv": "^8.0.0" - }, - "dependencies": { - "ajv": { - "version": "8.12.0", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.12.0.tgz", - "integrity": "sha512-sRu1kpcO9yLtYxBKvqfTeh9KzZEwO3STyX1HT+4CaDzC6HpTGYhIhPIzj9XuKU7KYDwnaeh5hcOwjy1QuJzBPA==", - "dev": true, - "requires": { - "fast-deep-equal": "^3.1.1", - "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2", - "uri-js": "^4.2.2" - } - }, - "json-schema-traverse": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", - "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", - "dev": true - } } }, "ajv-keywords": { - "version": "3.5.2", - "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-3.5.2.tgz", - "integrity": "sha512-5p6WTN0DdTGVQk6VjcEju19IgaHudalcfabD7yhDGeA6bcQnmL+CpveLJq/3hvfwd1aof6L386Ougkx6RfyMIQ==", + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-5.1.0.tgz", + "integrity": "sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw==", "dev": true, - "requires": {} + "requires": { + "fast-deep-equal": "^3.1.3" + } }, "ansi-html-community": { "version": "0.0.8", @@ -4714,6 +4529,12 @@ "integrity": "sha512-PCVAQswWemu6UdxsDFFX/+gVeYqKAod3D3UVm91jHwynguOwAvYPhx8nNlM++NqRcK6CxxpUafjmhIdKiHibqg==", "dev": true }, + "baseline-browser-mapping": { + "version": "2.9.19", + "resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.9.19.tgz", + "integrity": "sha512-ipDqC8FrAl/76p2SSWKSI+H9tFwm7vYqXQrItCuiVPt26Km0jS+NzSsBWAaBusvSbQcfJG+JitdMm+wZAgTYqg==", + "dev": true + }, "batch": { "version": "0.6.1", "resolved": "https://registry.npmjs.org/batch/-/batch-0.6.1.tgz", @@ -4783,15 +4604,16 @@ } }, "browserslist": { - "version": "4.21.11", - "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.21.11.tgz", - "integrity": "sha512-xn1UXOKUz7DjdGlg9RrUr0GGiWzI97UQJnugHtH0OLDfJB7jMgoIkYvRIEO1l9EeEERVqeqLYOcFBW9ldjypbQ==", + "version": "4.28.1", + "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.28.1.tgz", + "integrity": "sha512-ZC5Bd0LgJXgwGqUknZY/vkUQ04r8NXnJZ3yYi4vDmSiZmC/pdSN0NbNRPxZpbtO4uAfDUAFffO8IZoM3Gj8IkA==", "dev": true, "requires": { - "caniuse-lite": "^1.0.30001538", - "electron-to-chromium": "^1.4.526", - "node-releases": "^2.0.13", - "update-browserslist-db": "^1.0.13" + "baseline-browser-mapping": "^2.9.0", + "caniuse-lite": "^1.0.30001759", + "electron-to-chromium": "^1.5.263", + "node-releases": "^2.0.27", + "update-browserslist-db": "^1.2.0" } }, "buffer-from": { @@ -4836,9 +4658,9 @@ } }, "caniuse-lite": { - "version": "1.0.30001538", - "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001538.tgz", - "integrity": "sha512-HWJnhnID+0YMtGlzcp3T9drmBJUVDchPJ08tpUGFLs9CYlwWPH2uLgpHn8fND5pCgXVtnGS3H4QR9XLMHVNkHw==", + "version": "1.0.30001768", + "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001768.tgz", + "integrity": "sha512-qY3aDRZC5nWPgHUgIB84WL+nySuo19wk0VJpp/XI9T34lrvkyhRvNVOFJOp2kxClQhiFBu+TaUSudf6oa3vkSA==", "dev": true }, "chokidar": { @@ -4991,27 +4813,6 @@ "serialize-javascript": "^6.0.2" }, "dependencies": { - "ajv": { - "version": "8.17.1", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz", - "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", - "dev": true, - "requires": { - "fast-deep-equal": "^3.1.3", - "fast-uri": "^3.0.1", - "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2" - } - }, - "ajv-keywords": { - "version": "5.1.0", - "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-5.1.0.tgz", - "integrity": "sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw==", - "dev": true, - "requires": { - "fast-deep-equal": "^3.1.3" - } - }, "glob-parent": { "version": "6.0.2", "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", @@ -5020,24 +4821,6 @@ "requires": { "is-glob": "^4.0.3" } - }, - "json-schema-traverse": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", - "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", - "dev": true - }, - "schema-utils": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.3.0.tgz", - "integrity": "sha512-Gf9qqc58SpCA/xdziiHz35F4GNIWYWZrEshUc/G/r5BnLph6xpKuLeoJoQuj5WfBIx/eQLf+hmVPYHaxJu7V2g==", - "dev": true, - "requires": { - "@types/json-schema": "^7.0.9", - "ajv": "^8.9.0", - "ajv-formats": "^2.1.1", - "ajv-keywords": "^5.1.0" - } } } }, @@ -5145,9 +4928,9 @@ "dev": true }, "electron-to-chromium": { - "version": "1.4.528", - "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.4.528.tgz", - "integrity": "sha512-UdREXMXzLkREF4jA8t89FQjA8WHI6ssP38PMY4/4KhXFQbtImnghh4GkCgrtiZwLKUKVD2iTVXvDVQjfomEQuA==", + "version": "1.5.286", + "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.5.286.tgz", + "integrity": "sha512-9tfDXhJ4RKFNerfjdCcZfufu49vg620741MNs26a9+bhLThdB+plgMeou98CAaHu/WATj2iHOOHTp1hWtABj2A==", "dev": true }, "encodeurl": { @@ -5157,13 +4940,13 @@ "dev": true }, "enhanced-resolve": { - "version": "5.17.1", - "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.17.1.tgz", - "integrity": "sha512-LMHl3dXhTcfv8gM4kEzIUeTQ+7fpdA0l2tUf34BddXPkz2A5xJ5L/Pchd5BL6rdccM9QGvu0sWZzK1Z1t4wwyg==", + "version": "5.19.0", + "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.19.0.tgz", + "integrity": "sha512-phv3E1Xl4tQOShqSte26C7Fl84EwUdZsyOuSSk9qtAGyyQs2s3jJzComh+Abf4g187lUUAvH+H26omrqia2aGg==", "dev": true, "requires": { "graceful-fs": "^4.2.4", - "tapable": "^2.2.0" + "tapable": "^2.3.0" } }, "envinfo": { @@ -5185,9 +4968,9 @@ "dev": true }, "es-module-lexer": { - "version": "1.3.1", - "resolved": "https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-1.3.1.tgz", - "integrity": "sha512-JUFAyicQV9mXc3YRxPnDlrfBKpqt6hUYzz9/boprUJHs4e4KVr3XwOF70doO6gwXUor6EWZJAyWAfKki84t20Q==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-2.0.0.tgz", + "integrity": "sha512-5POEcUuZybH7IdmGsD8wlf0AI55wMecM9rVBTI/qEAy2c1kTOm3DjFYjrBdI2K3BaJjJYfYFeRtM0t9ssnRuxw==", "dev": true }, "es-object-atoms": { @@ -5200,9 +4983,9 @@ } }, "escalade": { - "version": "3.1.1", - "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.1.1.tgz", - "integrity": "sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw==", + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz", + "integrity": "sha512-WUj2qlxaQtO4g6Pq5c29GTcWGDyd8itL8zTlipgECz3JesAiiOKotd8JU6otB3PACgG6xkJUyVhboMS+bje/jA==", "dev": true }, "escape-html": { @@ -5358,16 +5141,10 @@ "micromatch": "^4.0.8" } }, - "fast-json-stable-stringify": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/fast-json-stable-stringify/-/fast-json-stable-stringify-2.1.0.tgz", - "integrity": "sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw==", - "dev": true - }, "fast-uri": { - "version": "3.0.6", - "resolved": "https://registry.npmjs.org/fast-uri/-/fast-uri-3.0.6.tgz", - "integrity": "sha512-Atfo14OibSv5wAp4VWNsFYE1AchQRTv9cBGWET4pZWHzYshFSS9NQI6I57rdKn9croWVMbYFbLhJ+yJvmZIIHw==", + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/fast-uri/-/fast-uri-3.1.0.tgz", + "integrity": "sha512-iPeeDKJSWf4IEOasVVrknXpaBV0IApz/gp7S2bb7Z4Lljbl2MGJRqInZiUrQwV16cpzw/D3S5j5Julj/gT52AA==", "dev": true }, "fastest-levenshtein": { @@ -5831,9 +5608,9 @@ "dev": true }, "json-schema-traverse": { - "version": "0.4.1", - "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", - "integrity": "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==", + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", + "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", "dev": true }, "kind-of": { @@ -5853,9 +5630,9 @@ } }, "loader-runner": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/loader-runner/-/loader-runner-4.3.0.tgz", - "integrity": "sha512-3R/1M+yS3j5ou80Me59j7F9IMs4PXs3VqRrm0TU3AbKPxlmpoY1TNscJV/oGJXo8qCatFGTfDbY6W6ipGOYXfg==", + "version": "4.3.1", + "resolved": "https://registry.npmjs.org/loader-runner/-/loader-runner-4.3.1.tgz", + "integrity": "sha512-IWqP2SCPhyVFTBtRcgMHdzlf9ul25NwaFx4wCEH/KjAXuuHY4yNjvPXsBokp8jCB936PyWRaPKUNh8NvylLp2Q==", "dev": true }, "locate-path": { @@ -6035,9 +5812,9 @@ "dev": true }, "node-releases": { - "version": "2.0.13", - "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.13.tgz", - "integrity": "sha512-uYr7J37ae/ORWdZeQ1xxMJe3NtdmqMC/JZK+geofDrkLUApKRHPd18/TxtBOJ4A0/+uUIliorNrfYV6s1b02eQ==", + "version": "2.0.27", + "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.27.tgz", + "integrity": "sha512-nmh3lCkYZ3grZvqcCH+fjmQ7X+H0OeZgP40OierEaAptX4XofMh5kwNbWh7lBduUzCcV/8kZ+NDLCwm2iorIlA==", "dev": true }, "normalize-path": { @@ -6159,9 +5936,9 @@ "dev": true }, "picocolors": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.0.0.tgz", - "integrity": "sha512-1fygroTLlHu66zi26VoTDv8yRgm0Fccecssto+MhsZ0D/DGW2sm8E8AjW7NU5VVTRt5GxbeZ5qBuJr+HyLYkjQ==", + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", + "integrity": "sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==", "dev": true }, "picomatch": { @@ -6203,12 +5980,6 @@ } } }, - "punycode": { - "version": "2.3.0", - "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.0.tgz", - "integrity": "sha512-rRV+zQD8tVFys26lAGR9WUuS4iUAngJScM+ZRSKtvl5tKeZ2t5bvdNFdNHBW9FWR4guGHlgmsZ1G7BSm2wTbuA==", - "dev": true - }, "qs": { "version": "6.13.0", "resolved": "https://registry.npmjs.org/qs/-/qs-6.13.0.tgz", @@ -6362,14 +6133,15 @@ "dev": true }, "schema-utils": { - "version": "3.3.0", - "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-3.3.0.tgz", - "integrity": "sha512-pN/yOAvcC+5rQ5nERGuwrjLlYvLTbCibnZ1I7B1LaiAz9BRBlE9GMgE/eqV30P7aJQUf7Ddimy/RsbYO/GrVGg==", + "version": "4.3.3", + "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.3.3.tgz", + "integrity": "sha512-eflK8wEtyOE6+hsaRVPxvUKYCpRgzLqDTb8krvAsRIwOGlHoSgYLgBXoubGgLd2fT41/OUYdb48v4k4WWHQurA==", "dev": true, "requires": { - "@types/json-schema": "^7.0.8", - "ajv": "^6.12.5", - "ajv-keywords": "^3.5.2" + "@types/json-schema": "^7.0.9", + "ajv": "^8.9.0", + "ajv-formats": "^2.1.1", + "ajv-keywords": "^5.1.0" } }, "select-hose": { @@ -6705,34 +6477,34 @@ "dev": true }, "tapable": { - "version": "2.2.1", - "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.2.1.tgz", - "integrity": "sha512-GNzQvQTOIP6RyTfE2Qxb8ZVlNmw0n88vp1szwWRimP02mnTsx3Wtn5qRdqY9w2XduFNUgvOwhNnQsjwCp+kqaQ==", + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.3.0.tgz", + "integrity": "sha512-g9ljZiwki/LfxmQADO3dEY1CbpmXT5Hm2fJ+QaGKwSXUylMybePR7/67YW7jOrrvjEgL1Fmz5kzyAjWVWLlucg==", "dev": true }, "terser": { - "version": "5.31.6", - "resolved": "https://registry.npmjs.org/terser/-/terser-5.31.6.tgz", - "integrity": "sha512-PQ4DAriWzKj+qgehQ7LK5bQqCFNMmlhjR2PFFLuqGCpuCAauxemVBWwWOxo3UIwWQx8+Pr61Df++r76wDmkQBg==", + "version": "5.46.0", + "resolved": "https://registry.npmjs.org/terser/-/terser-5.46.0.tgz", + "integrity": "sha512-jTwoImyr/QbOWFFso3YoU3ik0jBBDJ6JTOQiy/J2YxVJdZCc+5u7skhNwiOR3FQIygFqVUPHl7qbbxtjW2K3Qg==", "dev": true, "requires": { "@jridgewell/source-map": "^0.3.3", - "acorn": "^8.8.2", + "acorn": "^8.15.0", "commander": "^2.20.0", "source-map-support": "~0.5.20" } }, "terser-webpack-plugin": { - "version": "5.3.10", - "resolved": "https://registry.npmjs.org/terser-webpack-plugin/-/terser-webpack-plugin-5.3.10.tgz", - "integrity": "sha512-BKFPWlPDndPs+NGGCr1U59t0XScL5317Y0UReNrHaw9/FwhPENlq6bfgs+4yPfyP51vqC1bQ4rp1EfXW5ZSH9w==", + "version": "5.3.16", + "resolved": "https://registry.npmjs.org/terser-webpack-plugin/-/terser-webpack-plugin-5.3.16.tgz", + "integrity": "sha512-h9oBFCWrq78NyWWVcSwZarJkZ01c2AyGrzs1crmHZO3QUg9D61Wu4NPjBy69n7JqylFF5y+CsUZYmYEIZ3mR+Q==", "dev": true, "requires": { - "@jridgewell/trace-mapping": "^0.3.20", + "@jridgewell/trace-mapping": "^0.3.25", "jest-worker": "^27.4.5", - "schema-utils": "^3.1.1", - "serialize-javascript": "^6.0.1", - "terser": "^5.26.0" + "schema-utils": "^4.3.0", + "serialize-javascript": "^6.0.2", + "terser": "^5.31.1" } }, "thunky": { @@ -6785,22 +6557,13 @@ "dev": true }, "update-browserslist-db": { - "version": "1.0.13", - "resolved": "https://registry.npmjs.org/update-browserslist-db/-/update-browserslist-db-1.0.13.tgz", - "integrity": "sha512-xebP81SNcPuNpPP3uzeW1NYXxI3rxyJzF3pD6sH4jE7o/IX+WtSpwnVU+qIsDPyk0d3hmFQ7mjqc6AtV604hbg==", - "dev": true, - "requires": { - "escalade": "^3.1.1", - "picocolors": "^1.0.0" - } - }, - "uri-js": { - "version": "4.4.1", - "resolved": "https://registry.npmjs.org/uri-js/-/uri-js-4.4.1.tgz", - "integrity": "sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg==", + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/update-browserslist-db/-/update-browserslist-db-1.2.3.tgz", + "integrity": "sha512-Js0m9cx+qOgDxo0eMiFGEueWztz+d4+M3rGlmKPT+T4IS/jP4ylw3Nwpu6cpTTP8R1MAC1kF4VbdLt3ARf209w==", "dev": true, "requires": { - "punycode": "^2.1.0" + "escalade": "^3.2.0", + "picocolors": "^1.1.1" } }, "util-deprecate": { @@ -6828,9 +6591,9 @@ "dev": true }, "watchpack": { - "version": "2.4.2", - "resolved": "https://registry.npmjs.org/watchpack/-/watchpack-2.4.2.tgz", - "integrity": "sha512-TnbFSbcOCcDgjZ4piURLCbJ3nJhznVh9kw6F6iokjiFPl8ONxe9A6nMDVXDiNbrSfLILs6vB07F7wLBrwPYzJw==", + "version": "2.5.1", + "resolved": "https://registry.npmjs.org/watchpack/-/watchpack-2.5.1.tgz", + "integrity": "sha512-Zn5uXdcFNIA1+1Ei5McRd+iRzfhENPCe7LeABkJtNulSxjma+l7ltNx55BWZkRlwRnpOgHqxnjyaDgJnNXnqzg==", "dev": true, "requires": { "glob-to-regexp": "^0.4.1", @@ -6847,42 +6610,36 @@ } }, "webpack": { - "version": "5.94.0", - "resolved": "https://registry.npmjs.org/webpack/-/webpack-5.94.0.tgz", - "integrity": "sha512-KcsGn50VT+06JH/iunZJedYGUJS5FGjow8wb9c0v5n1Om8O1g4L6LjtfxwlXIATopoQu+vOXXa7gYisWxCoPyg==", - "dev": true, - "requires": { - "@types/estree": "^1.0.5", - "@webassemblyjs/ast": "^1.12.1", - "@webassemblyjs/wasm-edit": "^1.12.1", - "@webassemblyjs/wasm-parser": "^1.12.1", - "acorn": "^8.7.1", - "acorn-import-attributes": "^1.9.5", - "browserslist": "^4.21.10", + "version": "5.105.0", + "resolved": "https://registry.npmjs.org/webpack/-/webpack-5.105.0.tgz", + "integrity": "sha512-gX/dMkRQc7QOMzgTe6KsYFM7DxeIONQSui1s0n/0xht36HvrgbxtM1xBlgx596NbpHuQU8P7QpKwrZYwUX48nw==", + "dev": true, + "requires": { + "@types/eslint-scope": "^3.7.7", + "@types/estree": "^1.0.8", + "@types/json-schema": "^7.0.15", + "@webassemblyjs/ast": "^1.14.1", + "@webassemblyjs/wasm-edit": "^1.14.1", + "@webassemblyjs/wasm-parser": "^1.14.1", + "acorn": "^8.15.0", + "acorn-import-phases": "^1.0.3", + "browserslist": "^4.28.1", "chrome-trace-event": "^1.0.2", - "enhanced-resolve": "^5.17.1", - "es-module-lexer": "^1.2.1", + "enhanced-resolve": "^5.19.0", + "es-module-lexer": "^2.0.0", "eslint-scope": "5.1.1", "events": "^3.2.0", "glob-to-regexp": "^0.4.1", "graceful-fs": "^4.2.11", "json-parse-even-better-errors": "^2.3.1", - "loader-runner": "^4.2.0", + "loader-runner": "^4.3.1", "mime-types": "^2.1.27", "neo-async": "^2.6.2", - "schema-utils": "^3.2.0", - "tapable": "^2.1.1", - "terser-webpack-plugin": "^5.3.10", - "watchpack": "^2.4.1", - "webpack-sources": "^3.2.3" - }, - "dependencies": { - "webpack-sources": { - "version": "3.2.3", - "resolved": "https://registry.npmjs.org/webpack-sources/-/webpack-sources-3.2.3.tgz", - "integrity": "sha512-/DyMEOrDgLKKIG0fmvtz+4dUX/3Ghozwgm6iPp8KRhvn+eQf9+Q7GWxVNMk3+uCPWfdXYC4ExGBckIXdFEfH1w==", - "dev": true - } + "schema-utils": "^4.3.3", + "tapable": "^2.3.0", + "terser-webpack-plugin": "^5.3.16", + "watchpack": "^2.5.1", + "webpack-sources": "^3.3.3" } }, "webpack-cli": { @@ -6926,47 +6683,6 @@ "on-finished": "^2.4.1", "range-parser": "^1.2.1", "schema-utils": "^4.0.0" - }, - "dependencies": { - "ajv": { - "version": "8.17.1", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz", - "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", - "dev": true, - "requires": { - "fast-deep-equal": "^3.1.3", - "fast-uri": "^3.0.1", - "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2" - } - }, - "ajv-keywords": { - "version": "5.1.0", - "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-5.1.0.tgz", - "integrity": "sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw==", - "dev": true, - "requires": { - "fast-deep-equal": "^3.1.3" - } - }, - "json-schema-traverse": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", - "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", - "dev": true - }, - "schema-utils": { - "version": "4.3.2", - "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.3.2.tgz", - "integrity": "sha512-Gn/JaSk/Mt9gYubxTtSn/QCV4em9mpAPiR1rqy/Ocu19u/G9J5WWdNoUT4SiV6mFC3y6cxyFcFwdzPM3FgxGAQ==", - "dev": true, - "requires": { - "@types/json-schema": "^7.0.9", - "ajv": "^8.9.0", - "ajv-formats": "^2.1.1", - "ajv-keywords": "^5.1.0" - } - } } }, "webpack-dev-server": { @@ -7003,47 +6719,6 @@ "spdy": "^4.0.2", "webpack-dev-middleware": "^7.4.2", "ws": "^8.18.0" - }, - "dependencies": { - "ajv": { - "version": "8.12.0", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.12.0.tgz", - "integrity": "sha512-sRu1kpcO9yLtYxBKvqfTeh9KzZEwO3STyX1HT+4CaDzC6HpTGYhIhPIzj9XuKU7KYDwnaeh5hcOwjy1QuJzBPA==", - "dev": true, - "requires": { - "fast-deep-equal": "^3.1.1", - "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2", - "uri-js": "^4.2.2" - } - }, - "ajv-keywords": { - "version": "5.1.0", - "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-5.1.0.tgz", - "integrity": "sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw==", - "dev": true, - "requires": { - "fast-deep-equal": "^3.1.3" - } - }, - "json-schema-traverse": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", - "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", - "dev": true - }, - "schema-utils": { - "version": "4.2.0", - "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.2.0.tgz", - "integrity": "sha512-L0jRsrPpjdckP3oPug3/VxNKt2trR8TcabrM6FOAAlvC/9Phcmm+cuAgTlxBqdBR1WJx7Naj9WHw+aOmheSVbw==", - "dev": true, - "requires": { - "@types/json-schema": "^7.0.9", - "ajv": "^8.9.0", - "ajv-formats": "^2.1.1", - "ajv-keywords": "^5.1.0" - } - } } }, "webpack-merge": { @@ -7056,6 +6731,12 @@ "wildcard": "^2.0.0" } }, + "webpack-sources": { + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/webpack-sources/-/webpack-sources-3.3.3.tgz", + "integrity": "sha512-yd1RBzSGanHkitROoPFd6qsrxt+oFhg/129YzheDGqeustzX0vTZJZsSsQjVQC4yzBQ56K55XU8gaNCtIzOnTg==", + "dev": true + }, "websocket-driver": { "version": "0.7.4", "resolved": "https://registry.npmjs.org/websocket-driver/-/websocket-driver-0.7.4.tgz", diff --git a/datafusion/wasmtest/datafusion-wasm-app/package.json b/datafusion/wasmtest/datafusion-wasm-app/package.json index b46993de77d9..aecc5b689554 100644 --- a/datafusion/wasmtest/datafusion-wasm-app/package.json +++ b/datafusion/wasmtest/datafusion-wasm-app/package.json @@ -27,7 +27,7 @@ "datafusion-wasmtest": "../pkg" }, "devDependencies": { - "webpack": "5.94.0", + "webpack": "5.105.0", "webpack-cli": "5.1.4", "webpack-dev-server": "5.2.1", "copy-webpack-plugin": "12.0.2" diff --git a/datafusion/wasmtest/src/lib.rs b/datafusion/wasmtest/src/lib.rs index b20e6c24ffea..f545ccf19306 100644 --- a/datafusion/wasmtest/src/lib.rs +++ b/datafusion/wasmtest/src/lib.rs @@ -24,14 +24,12 @@ extern crate wasm_bindgen; -use datafusion_common::{DFSchema, ScalarValue}; -use datafusion_expr::execution_props::ExecutionProps; +use datafusion_common::ScalarValue; use datafusion_expr::lit; use datafusion_expr::simplify::SimplifyContext; use datafusion_optimizer::simplify_expressions::ExprSimplifier; use datafusion_sql::sqlparser::dialect::GenericDialect; use datafusion_sql::sqlparser::parser::Parser; -use std::sync::Arc; use wasm_bindgen::prelude::*; pub fn set_panic_hook() { // When the `console_error_panic_hook` feature is enabled, we can call the @@ -63,10 +61,7 @@ pub fn basic_exprs() { log(&format!("Expr: {expr:?}")); // Simplify Expr (using datafusion-phys-expr and datafusion-optimizer) - let schema = Arc::new(DFSchema::empty()); - let execution_props = ExecutionProps::new(); - let simplifier = - ExprSimplifier::new(SimplifyContext::new(&execution_props).with_schema(schema)); + let simplifier = ExprSimplifier::new(SimplifyContext::default()); let simplified_expr = simplifier.simplify(expr).unwrap(); log(&format!("Simplified Expr: {simplified_expr:?}")); } @@ -82,7 +77,10 @@ pub fn basic_parse() { #[cfg(test)] mod test { - use super::*; + use std::sync::Arc; + + use bytes::Bytes; + use datafusion::datasource::file_format::file_compression_type::FileCompressionType; use datafusion::{ arrow::{ array::{ArrayRef, Int32Array, RecordBatch, StringArray}, @@ -90,8 +88,9 @@ mod test { }, datasource::MemTable, execution::context::SessionContext, + prelude::CsvReadOptions, }; - use datafusion_common::test_util::batches_to_string; + use datafusion_common::{DataFusionError, test_util::batches_to_string}; use datafusion_execution::{ config::SessionConfig, disk_manager::{DiskManagerBuilder, DiskManagerMode}, @@ -99,17 +98,18 @@ mod test { }; use datafusion_physical_plan::collect; use datafusion_sql::parser::DFParser; - use object_store::{ObjectStore, memory::InMemory, path::Path}; + use futures::{StreamExt, TryStreamExt, stream}; + use object_store::{ObjectStoreExt, PutPayload, memory::InMemory, path::Path}; use url::Url; use wasm_bindgen_test::wasm_bindgen_test; wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); + #[cfg(target_arch = "wasm32")] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] - #[cfg_attr(not(target_arch = "wasm32"), allow(dead_code))] fn datafusion_test() { - basic_exprs(); - basic_parse(); + super::basic_exprs(); + super::basic_parse(); } fn get_ctx() -> Arc { @@ -262,4 +262,55 @@ mod test { +----+-------+" ); } + + #[wasm_bindgen_test(unsupported = tokio::test)] + async fn test_csv_read_xz_compressed() { + let csv_data = "id,value\n1,a\n2,b\n3,c\n"; + let input = Bytes::from(csv_data.as_bytes().to_vec()); + let input_stream = + stream::iter(vec![Ok::(input)]).boxed(); + + let compressed_stream = FileCompressionType::XZ + .convert_to_compress_stream(input_stream) + .unwrap(); + let compressed_data: Vec = compressed_stream.try_collect().await.unwrap(); + + let store = InMemory::new(); + let path = Path::from("data.csv.xz"); + store + .put(&path, PutPayload::from_iter(compressed_data)) + .await + .unwrap(); + + let url = Url::parse("memory://").unwrap(); + let ctx = SessionContext::new(); + ctx.register_object_store(&url, Arc::new(store)); + + let csv_options = CsvReadOptions::new() + .has_header(true) + .file_compression_type(FileCompressionType::XZ) + .file_extension("csv.xz"); + ctx.register_csv("compressed", "memory:///data.csv.xz", csv_options) + .await + .unwrap(); + + let result = ctx + .sql("SELECT * FROM compressed") + .await + .unwrap() + .collect() + .await + .unwrap(); + + assert_eq!( + batches_to_string(&result), + "+----+-------+\n\ + | id | value |\n\ + +----+-------+\n\ + | 1 | a |\n\ + | 2 | b |\n\ + | 3 | c |\n\ + +----+-------+" + ); + } } diff --git a/dev/changelog/52.0.0.md b/dev/changelog/52.0.0.md new file mode 100644 index 000000000000..4536fd5a0690 --- /dev/null +++ b/dev/changelog/52.0.0.md @@ -0,0 +1,745 @@ + + +# Apache DataFusion 52.0.0 Changelog + +This release consists of 549 commits from 121 contributors. See credits at the end of this changelog for more information. + +See the [upgrade guide](https://datafusion.apache.org/library-user-guide/upgrading.html) for information on how to upgrade from previous versions. + +**Breaking changes:** + +- Force `FileSource` to be constructed with a `Schema` [#18386](https://github.com/apache/datafusion/pull/18386) (adriangb) +- Support Arrow IPC Stream Files [#18457](https://github.com/apache/datafusion/pull/18457) (corasaurus-hex) +- Change default of `AggregateUDFImpl::supports_null_handling_clause` to `false` [#18441](https://github.com/apache/datafusion/pull/18441) (Jefffrey) +- [Minor] Remove RawTableAllocExt [#18748](https://github.com/apache/datafusion/pull/18748) (Dandandan) +- Change `CacheAccessor::remove` to take `&self` rather than `&mut self` [#18726](https://github.com/apache/datafusion/pull/18726) (alchemist51) +- Move statistics handling into FileScanConfig [#18721](https://github.com/apache/datafusion/pull/18721) (adriangb) +- chore: remove `pyarrow` feature [#18528](https://github.com/apache/datafusion/pull/18528) (timsaucer) +- Limit visibility of internal impl functions in function crates [#18877](https://github.com/apache/datafusion/pull/18877) (Jefffrey) +- FFI: return underlying trait type when converting from FFI structs [#18672](https://github.com/apache/datafusion/pull/18672) (timsaucer) +- Refactor crypto functions code [#18664](https://github.com/apache/datafusion/pull/18664) (Jefffrey) +- move projection handling into FileSource [#18627](https://github.com/apache/datafusion/pull/18627) (adriangb) +- Add PhysicalOptimizerRule::optimize_plan to allow passing more context into optimizer rules [#18739](https://github.com/apache/datafusion/pull/18739) (adriangb) +- Optimize planning / stop cloning Strings / Fields so much (2-3% faster planning time) [#18415](https://github.com/apache/datafusion/pull/18415) (alamb) +- Adds memory-bound DefaultListFilesCache [#18855](https://github.com/apache/datafusion/pull/18855) (BlakeOrth) +- Allow Logical expression ScalarVariable to represent an extension type or metadata [#18243](https://github.com/apache/datafusion/pull/18243) (batmnnn) +- feat: Implement the `statistics_cache` function [#19054](https://github.com/apache/datafusion/pull/19054) (nuno-faria) +- Move `newlines_in_values` from `FileScanConfig` to `CsvSource` [#19313](https://github.com/apache/datafusion/pull/19313) (adriangb) +- Remove SchemaAdapter [#19345](https://github.com/apache/datafusion/pull/19345) (adriangb) +- feat: hash partitioning satisfies subset [#19304](https://github.com/apache/datafusion/pull/19304) (gene-bordegaray) +- feat: update FFI TableProvider and ExecutionPlan to use FFI Session and TaskContext [#19281](https://github.com/apache/datafusion/pull/19281) (timsaucer) +- Allow logical optimizer to be run without evaluating now() & refactor SimplifyInfo [#19505](https://github.com/apache/datafusion/pull/19505) (adriangb) +- Make default ListingFilesCache table scoped [#19616](https://github.com/apache/datafusion/pull/19616) (jizezhang) + +**Performance related:** + +- Normalize partitioned and flat object listing [#18146](https://github.com/apache/datafusion/pull/18146) (BlakeOrth) +- perf: Improve NLJ for very small right side case [#17562](https://github.com/apache/datafusion/pull/17562) (2010YOUY01) +- Consolidate `EliminateNestedUnion` and `EliminateOneUnion` optimizer rules' [#18678](https://github.com/apache/datafusion/pull/18678) (alamb) +- perf: improve performance of `vectorized_equal_to` for `PrimitiveGroupValueBuilder` in multi group by aggregation [#17977](https://github.com/apache/datafusion/pull/17977) (rluvaton) +- optimizer: Support dynamic filter in `MIN/MAX` aggregates [#18644](https://github.com/apache/datafusion/pull/18644) (2010YOUY01) +- perf: use `new_repeated` when converting scalar to an array [#19018](https://github.com/apache/datafusion/pull/19018) (rluvaton) +- perf: optimize CASE WHEN lookup table (2.5-22.5 times faster) [#18183](https://github.com/apache/datafusion/pull/18183) (rluvaton) +- add specialized InList implementations for common scalar types [#18832](https://github.com/apache/datafusion/pull/18832) (adriangb) +- Add hashing microbenchmark `with_hashes` [#19373](https://github.com/apache/datafusion/pull/19373) (alamb) +- Optimize muti-column grouping with StringView/ByteView (option 2) - 25% faster [#19413](https://github.com/apache/datafusion/pull/19413) (alamb) +- Optimize hashing for StringView and ByteView (15-70% faster) [#19374](https://github.com/apache/datafusion/pull/19374) (alamb) +- perf: Improve performance of `to_hex` (> 2x) [#19503](https://github.com/apache/datafusion/pull/19503) (andygrove) +- perf: improve performance of string repeat [#19502](https://github.com/apache/datafusion/pull/19502) (andygrove) +- perf: Optimize `starts_with` and `ends_with` for scalar arguments [#19516](https://github.com/apache/datafusion/pull/19516) (andygrove) +- perf: improve performance of string replace [#19530](https://github.com/apache/datafusion/pull/19530) (viirya) +- perf: improve performance of levenshtein by reusing cache buffer [#19532](https://github.com/apache/datafusion/pull/19532) (viirya) +- perf: improve performance of translate by reusing buffers [#19533](https://github.com/apache/datafusion/pull/19533) (viirya) +- perf: Optimize `contains` for scalar search arg [#19529](https://github.com/apache/datafusion/pull/19529) (andygrove) +- perf: improve performance of lpad/rpad by reusing buffers [#19558](https://github.com/apache/datafusion/pull/19558) (viirya) +- perf: optimize regexp_count to avoid String allocation when start position is provided [#19553](https://github.com/apache/datafusion/pull/19553) (viirya) +- perf: Improve performance of `md5` [#19568](https://github.com/apache/datafusion/pull/19568) (andygrove) +- perf: optimize strpos by eliminating double iteration for UTF-8 [#19572](https://github.com/apache/datafusion/pull/19572) (viirya) +- perf: optimize factorial function performance [#19575](https://github.com/apache/datafusion/pull/19575) (getChan) +- perf: Improve performance of ltrim, rtrim, btrim [#19551](https://github.com/apache/datafusion/pull/19551) (andygrove) +- perf: optimize `HashTableLookupExpr::evaluate` [#19602](https://github.com/apache/datafusion/pull/19602) (UBarney) +- perf: Improve performance of `split_part` [#19570](https://github.com/apache/datafusion/pull/19570) (andygrove) +- Optimize `Nullstate` / accumulators [#19625](https://github.com/apache/datafusion/pull/19625) (Dandandan) + +**Implemented enhancements:** + +- feat: Enhance `array_slice` functionality to support `ListView` and `LargeListView` types [#18432](https://github.com/apache/datafusion/pull/18432) (Weijun-H) +- feat: support complex expr for prepared statement argument [#18383](https://github.com/apache/datafusion/pull/18383) (chenkovsky) +- feat: Implement `SessionState::create_logical_expr_from_sql_expr` [#18423](https://github.com/apache/datafusion/pull/18423) (petern48) +- feat: added clippy::needless_pass_by_value lint rule to datafusion/expr [#18532](https://github.com/apache/datafusion/pull/18532) (Gohlub) +- feat: support nested key for get_field [#18394](https://github.com/apache/datafusion/pull/18394) (chenkovsky) +- feat: Add `ansi` enable parameter for execution config [#18635](https://github.com/apache/datafusion/pull/18635) (comphead) +- feat: Add evaluate_to_arrays function [#18446](https://github.com/apache/datafusion/pull/18446) (EmilyMatt) +- feat: support named variables & defaults for `CREATE FUNCTION` [#18450](https://github.com/apache/datafusion/pull/18450) (r1b) +- feat: Add new() constructor for CachedParquetFileReader [#18575](https://github.com/apache/datafusion/pull/18575) (petern48) +- feat: support decimal for math functions: power [#18032](https://github.com/apache/datafusion/pull/18032) (theirix) +- feat: selectivity metrics (for Explain Analyze) in Hash Join [#18488](https://github.com/apache/datafusion/pull/18488) (feniljain) +- feat: Handle edge case with `corr` with single row and `NaN` [#18677](https://github.com/apache/datafusion/pull/18677) (comphead) +- feat: support spark csc [#18642](https://github.com/apache/datafusion/pull/18642) (psvri) +- feat: support spark sec [#18728](https://github.com/apache/datafusion/pull/18728) (psvri) +- feat(parquet): Implement `scan_efficiency_ratio` metric for parquet reading [#18577](https://github.com/apache/datafusion/pull/18577) (petern48) +- feat: Enhance map handling to support NULL map values [#18531](https://github.com/apache/datafusion/pull/18531) (Weijun-H) +- feat: add RESET statement for configuration variabless [#18408](https://github.com/apache/datafusion/pull/18408) (Weijun-H) +- feat: add human-readable formatting to EXPLAIN ANALYZE metrics #18689 [#18734](https://github.com/apache/datafusion/pull/18734) (T2MIX) +- feat: support Spark-compatible `abs` math function part 1 - non-ANSI mode [#18205](https://github.com/apache/datafusion/pull/18205) (hsiang-c) +- feat: Support Show runtime settings [#18564](https://github.com/apache/datafusion/pull/18564) (Weijun-H) +- feat(small): Support `` marker in `sqllogictest` for non-deterministic expected parts [#18857](https://github.com/apache/datafusion/pull/18857) (2010YOUY01) +- feat: allow custom caching via logical node [#18688](https://github.com/apache/datafusion/pull/18688) (jizezhang) +- feat: add `array_slice` benchmark [#18879](https://github.com/apache/datafusion/pull/18879) (dqkqd) +- feat: Support recursive queries with a distinct 'UNION' [#18254](https://github.com/apache/datafusion/pull/18254) (Tpt) +- feat: Makes error macros hygienic [#18995](https://github.com/apache/datafusion/pull/18995) (Tpt) +- feat: Add builder API for CreateExternalTable to reduce verbosity [#19066](https://github.com/apache/datafusion/pull/19066) (AryanBagade) +- feat(spark): Implement Spark functions `url_encode`, `url_decode` and `try_url_decode` [#17399](https://github.com/apache/datafusion/pull/17399) (anhvdq) +- feat: Move DefaultMetadataCache into its own module [#19125](https://github.com/apache/datafusion/pull/19125) (AryanBagade) +- feat: Add `remove_optimizer_rule` to `SessionContext` [#19209](https://github.com/apache/datafusion/pull/19209) (nuno-faria) +- feat: integrate batch coalescer with repartition exec [#19002](https://github.com/apache/datafusion/pull/19002) (jizezhang) +- feat: Preserve File Partitioning From File Scans [#19124](https://github.com/apache/datafusion/pull/19124) (gene-bordegaray) +- feat: Add constant column extraction and rewriting for projections in ParquetOpener [#19136](https://github.com/apache/datafusion/pull/19136) (Weijun-H) +- feat: Support sliding window queries for MedianAccumulator by implementing `retract_batch` [#19278](https://github.com/apache/datafusion/pull/19278) (petern48) +- feat: add compression level configuration for JSON/CSV writers [#18954](https://github.com/apache/datafusion/pull/18954) (Smotrov) +- feat(spark): implement Spark `try_sum` function [#18569](https://github.com/apache/datafusion/pull/18569) (davidlghellin) +- feat: Support log for Decimal32 and Decimal64 [#18999](https://github.com/apache/datafusion/pull/18999) (Mark1626) +- feat(proto): Add protobuf serialization for HashExpr [#19379](https://github.com/apache/datafusion/pull/19379) (adriangb) +- feat: Add decimal support for round [#19384](https://github.com/apache/datafusion/pull/19384) (kumarUjjawal) +- Support nested field access in `get_field` with multiple path arguments [#19389](https://github.com/apache/datafusion/pull/19389) (adriangb) +- feat: fix matching for named parameters with non-lowercase signatures [#19378](https://github.com/apache/datafusion/pull/19378) (bubulalabu) +- feat: Add per-expression evaluation timing metrics to ProjectionExec [#19447](https://github.com/apache/datafusion/pull/19447) (2010YOUY01) +- feat: Improve sort memory resilience [#19494](https://github.com/apache/datafusion/pull/19494) (EmilyMatt) +- feat: Add DELETE/UPDATE hooks to TableProvider trait and to MemTable implementation [#19142](https://github.com/apache/datafusion/pull/19142) (ethan-tyler) +- feat: implement partition_statistics for WindowAggExec [#18534](https://github.com/apache/datafusion/pull/18534) (0xPoe) +- feat: integrate batch coalescer with async fn exec [#19342](https://github.com/apache/datafusion/pull/19342) (feniljain) +- feat: output statistics for constant columns in projections [#19419](https://github.com/apache/datafusion/pull/19419) (shashidhar-bm) +- feat: `to_time` function [#19540](https://github.com/apache/datafusion/pull/19540) (kumarUjjawal) +- feat: Implement Spark functions hour, minute, second [#19512](https://github.com/apache/datafusion/pull/19512) (andygrove) +- feat: plan-time SQL expression simplifying [#19311](https://github.com/apache/datafusion/pull/19311) (theirix) +- feat: Implement Spark function `space` [#19610](https://github.com/apache/datafusion/pull/19610) (kazantsev-maksim) +- feat: Implement `partition_statistics` API for `SortMergeJoinExec` [#19567](https://github.com/apache/datafusion/pull/19567) (kumarUjjawal) +- feat: add list_files_cache table function for `datafusion-cli` [#19388](https://github.com/apache/datafusion/pull/19388) (jizezhang) +- feat: implement metrics for AsyncFuncExec [#19626](https://github.com/apache/datafusion/pull/19626) (feniljain) +- feat: split BatchPartitioner::try_new into hash and round-robin constructors [#19668](https://github.com/apache/datafusion/pull/19668) (mohit7705) +- feat: add Time type support to date_trunc function [#19640](https://github.com/apache/datafusion/pull/19640) (kumarUjjawal) +- feat: Allow log with non-integer base on decimals [#19372](https://github.com/apache/datafusion/pull/19372) (Yuvraj-cyborg) + +**Fixed bugs:** + +- fix: Eliminate consecutive repartitions [#18521](https://github.com/apache/datafusion/pull/18521) (gene-bordegaray) +- fix: `with_param_values` on `LogicalPlan::EmptyRelation` returns incorrect schema [#18286](https://github.com/apache/datafusion/pull/18286) (dqkqd) +- fix: Nested arrays should not get a field in lookup [#18745](https://github.com/apache/datafusion/pull/18745) (EmilyMatt) +- fix: update schema's data type for `LogicalPlan::Values` after placeholder substitution [#18740](https://github.com/apache/datafusion/pull/18740) (dqkqd) +- fix: Pick correct columns in Sort Merge Equijoin [#18772](https://github.com/apache/datafusion/pull/18772) (tglanz) +- fix: remove `WorkTableExec` special case in `reset_plan_states` [#18803](https://github.com/apache/datafusion/pull/18803) (geoffreyclaude) +- fix: display the failed sqllogictest file and query that failed in case of a panic [#18785](https://github.com/apache/datafusion/pull/18785) (rluvaton) +- fix: preserve byte-size statistics in AggregateExec [#18885](https://github.com/apache/datafusion/pull/18885) (Tamar-Posen) +- fix: Track elapsed_compute metric for CSV scans [#18901](https://github.com/apache/datafusion/pull/18901) (Nithurshen) +- fix: Implement Substrait consumer support for like_match, like_imatch, and negated variants [#18929](https://github.com/apache/datafusion/pull/18929) (Nithurshen) +- fix: Initialize CsvOptions::double_quote from proto_opts.double_quote [#18967](https://github.com/apache/datafusion/pull/18967) (martin-g) +- fix: `rstest` is a DEV dependency [#19014](https://github.com/apache/datafusion/pull/19014) (crepererum) +- fix: partition pruning stats pruning when multiple values are present [#18923](https://github.com/apache/datafusion/pull/18923) (Mark1626) +- fix: deprecate data_type_and_nullable and simplify API usage [#18869](https://github.com/apache/datafusion/pull/18869) (BipulLamsal) +- fix: pre-warm listing file statistics cache during listing table creation [#18971](https://github.com/apache/datafusion/pull/18971) (bharath-techie) +- fix: log metadata differences when comparing physical and logical schema [#19070](https://github.com/apache/datafusion/pull/19070) (erratic-pattern) +- fix: fix panic when lo is greater than hi [#19099](https://github.com/apache/datafusion/pull/19099) (tshauck) +- fix: escape underscores when simplifying `starts_with` [#19077](https://github.com/apache/datafusion/pull/19077) (willemv) +- fix: custom nullability for length (#19175) [#19182](https://github.com/apache/datafusion/pull/19182) (skushagra) +- fix: inverted null_percent logic in in_list benchmark [#19204](https://github.com/apache/datafusion/pull/19204) (geoffreyclaude) +- fix: Ensure column names do not change with `expand_views_at_output` [#19019](https://github.com/apache/datafusion/pull/19019) (nuno-faria) +- fix: bitmap_count should report nullability correctly [#19195](https://github.com/apache/datafusion/pull/19195) (harshitsaini17) +- fix: bit_count function to report nullability correctly [#19197](https://github.com/apache/datafusion/pull/19197) (harshitsaini17) +- fix: derive custom nullability for spark `bit_shift` [#19222](https://github.com/apache/datafusion/pull/19222) (kumarUjjawal) +- fix: spark elt custom nullability [#19207](https://github.com/apache/datafusion/pull/19207) (EeshanBembi) +- fix: `array_remove`/`array_remove_n`/`array_remove_all` not using the same nullability as the input [#19259](https://github.com/apache/datafusion/pull/19259) (rluvaton) +- fix: typo in sql/ddl [#19276](https://github.com/apache/datafusion/pull/19276) (mag1c1an1) +- fix: flaky cache test [#19140](https://github.com/apache/datafusion/pull/19140) (xonx4l) +- fix: Add custom nullability for Spark ILIKE function [#19206](https://github.com/apache/datafusion/pull/19206) (Eshaan-byte) +- fix: derive custom nullability for spark `map_from_arrays` [#19275](https://github.com/apache/datafusion/pull/19275) (kumarUjjawal) +- fix: derive custom nullability for spark map_from_entries [#19274](https://github.com/apache/datafusion/pull/19274) (kumarUjjawal) +- fix: derive custom nullable for spark `make_dt_interval` [#19236](https://github.com/apache/datafusion/pull/19236) (kumarUjjawal) +- fix: derive custome nullable for the spark last_day [#19232](https://github.com/apache/datafusion/pull/19232) (kumarUjjawal) +- fix: derive custom nullable for spark `date_sub` [#19225](https://github.com/apache/datafusion/pull/19225) (kumarUjjawal) +- fix: Fix a few minor issues with join metrics [#19283](https://github.com/apache/datafusion/pull/19283) (linhr) +- fix: derive nullability for spark `bit_get` [#19220](https://github.com/apache/datafusion/pull/19220) (kumarUjjawal) +- fix: pow() with integer base and negative float exponent returns error [#19303](https://github.com/apache/datafusion/pull/19303) (adriangb) +- fix(concat): correct nullability inference (nullable only if all arguments nullable) [#19189](https://github.com/apache/datafusion/pull/19189) (ujjwaltwri) +- fix: Added nullable return from date_add(#19151) [#19229](https://github.com/apache/datafusion/pull/19229) (manishkr) +- fix: spark sha1 nullability reporting [#19242](https://github.com/apache/datafusion/pull/19242) (shashidhar-bm) +- fix: derive custom nullability for the spark `next_day` [#19253](https://github.com/apache/datafusion/pull/19253) (kumarUjjawal) +- fix: preserve ListFilesCache TTL when not set in config [#19401](https://github.com/apache/datafusion/pull/19401) (shashidhar-bm) +- fix: projection for `CooperativeExec` and `CoalesceBatchesExec` [#19400](https://github.com/apache/datafusion/pull/19400) (haohuaijin) +- fix: spark crc32 custom nullability [#19271](https://github.com/apache/datafusion/pull/19271) (watanaberin) +- fix: Fix skip aggregate test to cover regression [#19461](https://github.com/apache/datafusion/pull/19461) (kumarUjjawal) +- fix: [19450]Added flush for tokio file(substrait) write [#19456](https://github.com/apache/datafusion/pull/19456) (manishkr) +- fix: csv schema_infer_max_records set to 0 return null datatype [#19432](https://github.com/apache/datafusion/pull/19432) (haohuaijin) +- fix: Add custom nullability for Spark LIKE function [#19218](https://github.com/apache/datafusion/pull/19218) (KaranPradhan266) +- fix: implement custom nullability for spark abs function [#19395](https://github.com/apache/datafusion/pull/19395) (batmnnn) +- fix: custom nullability for format_string (#19173) [#19190](https://github.com/apache/datafusion/pull/19190) (skushagra) +- fix: Implement `reset_state` for `LazyMemoryExec` [#19362](https://github.com/apache/datafusion/pull/19362) (nuno-faria) +- fix: CteWorkTable: properly apply TableProvider::scan projection argument [#18993](https://github.com/apache/datafusion/pull/18993) (Tpt) +- fix: Median() integer overflow [#19509](https://github.com/apache/datafusion/pull/19509) (kumarUjjawal) +- fix: Reverse row selection should respect the row group index [#19557](https://github.com/apache/datafusion/pull/19557) (zhuqi-lucas) +- fix: emit empty RecordBatch for empty file writes [#19370](https://github.com/apache/datafusion/pull/19370) (nlimpid) +- fix: handle invalid byte ranges in calculate_range for single-line files [#19607](https://github.com/apache/datafusion/pull/19607) (vigimite) +- fix: NULL handling in arrow_intersect and arrow_union [#19415](https://github.com/apache/datafusion/pull/19415) (feniljain) +- fix(doc): close #19393, make upgrading guide match v51 api [#19648](https://github.com/apache/datafusion/pull/19648) (mag1c1an1) +- fix(spark): Use wrapping addition/subtraction in `SparkDateAdd` and `SparkDateSub` [#19377](https://github.com/apache/datafusion/pull/19377) (mzabaluev) +- fix(functions): Make translate function postgres compatible [#19630](https://github.com/apache/datafusion/pull/19630) (devanshu0987) +- fix: Return Int for Date - Date instead of duration [#19563](https://github.com/apache/datafusion/pull/19563) (kumarUjjawal) +- fix: DynamicFilterPhysicalExpr violates Hash/Eq contract [#19659](https://github.com/apache/datafusion/pull/19659) (kumarUjjawal) + +**Documentation updates:** + +- [main] Update version to 51.0.0, add Changelog (#18551) [#18565](https://github.com/apache/datafusion/pull/18565) (alamb) +- refactor: include metric output_batches into BaselineMetrics [#18491](https://github.com/apache/datafusion/pull/18491) (nmbr7) +- chore(deps): bump maturin from 1.9.6 to 1.10.0 in /docs [#18590](https://github.com/apache/datafusion/pull/18590) (dependabot[bot]) +- Update release download links on download page [#18550](https://github.com/apache/datafusion/pull/18550) (alamb) +- docs: fix rustup cmd for adding rust-analyzer [#18605](https://github.com/apache/datafusion/pull/18605) (Jefffrey) +- Enforce explicit opt-in for `WITHIN GROUP` syntax in aggregate UDAFs [#18607](https://github.com/apache/datafusion/pull/18607) (kosiew) +- docs: fix broken catalog example links [#18765](https://github.com/apache/datafusion/pull/18765) (nlimpid) +- doc: Add documentation for error handling [#18762](https://github.com/apache/datafusion/pull/18762) (2010YOUY01) +- docs: Fix the examples for char_length() and character_length() [#18808](https://github.com/apache/datafusion/pull/18808) (martin-g) +- chore: Support 'untake' for unassigning github issues [#18637](https://github.com/apache/datafusion/pull/18637) (petern48) +- chore: Add filtered pending PRs link to main page [#18854](https://github.com/apache/datafusion/pull/18854) (comphead) +- Docs: Enhance contributor guide with testing section [#18852](https://github.com/apache/datafusion/pull/18852) (alamb) +- Docs: Enhance testing documentation with examples and links [#18851](https://github.com/apache/datafusion/pull/18851) (alamb) +- chore(deps): bump maturin from 1.10.0 to 1.10.2 in /docs [#18905](https://github.com/apache/datafusion/pull/18905) (dependabot[bot]) +- Update links in documentation to point at new example locations [#18931](https://github.com/apache/datafusion/pull/18931) (alamb) +- Add Kubeflow Trainer to known users [#18935](https://github.com/apache/datafusion/pull/18935) (andreyvelich) +- Add PGO documentation section to crate configuration [#18959](https://github.com/apache/datafusion/pull/18959) (jatinkumarsingh) +- Add upgrade guide for PhysicalOptimizerRule::optimize_plan [#19030](https://github.com/apache/datafusion/pull/19030) (adriangb) +- doc: add `FilterExec` metrics to `user-guide/metrics.md` [#19043](https://github.com/apache/datafusion/pull/19043) (2010YOUY01) +- Add `force_filter_selections` to restore `pushdown_filters` behavior prior to parquet 57.1.0 upgrade [#19003](https://github.com/apache/datafusion/pull/19003) (alamb) +- Implement FFI task context and task context provider [#18918](https://github.com/apache/datafusion/pull/18918) (timsaucer) +- Minor: fix link errors in docs [#19088](https://github.com/apache/datafusion/pull/19088) (alamb) +- Cut `Parquet` over to PhysicalExprAdapter, remove `SchemaAdapter` [#18998](https://github.com/apache/datafusion/pull/18998) (adriangb) +- Update Committer / PMC list [#19105](https://github.com/apache/datafusion/pull/19105) (alamb) +- Revert adding PhysicalOptimizerRule::optimize_plan [#19186](https://github.com/apache/datafusion/pull/19186) (adriangb) +- Push down InList or hash table references from HashJoinExec depending on the size of the build side [#18393](https://github.com/apache/datafusion/pull/18393) (adriangb) +- Move partition handling out of PhysicalExprAdapter [#19128](https://github.com/apache/datafusion/pull/19128) (adriangb) +- Push down projection expressions into ParquetOpener [#19111](https://github.com/apache/datafusion/pull/19111) (adriangb) +- Track column sizes in Statistics; propagate through projections [#19113](https://github.com/apache/datafusion/pull/19113) (adriangb) +- Improve ProjectionExpr documentation and comments [#19263](https://github.com/apache/datafusion/pull/19263) (alamb) +- Update README occording to the new examples (#18529) [#19257](https://github.com/apache/datafusion/pull/19257) (cj-zhukov) +- Add make_time function [#19183](https://github.com/apache/datafusion/pull/19183) (Omega359) +- Update to_date udf function to support a consistent set of argument types [#19134](https://github.com/apache/datafusion/pull/19134) (Omega359) +- Add library user guide for extending SQL syntax [#19265](https://github.com/apache/datafusion/pull/19265) (geoffreyclaude) +- Add runtime config options for `list_files_cache_limit` and `list_files_cache_ttl` [#19108](https://github.com/apache/datafusion/pull/19108) (delamarch3) +- Minor: clean up titles and links n extending operators and optimizer pages [#19317](https://github.com/apache/datafusion/pull/19317) (alamb) +- Establish the high level API for sort pushdown and the optimizer rule and support reverse files and row groups [#19064](https://github.com/apache/datafusion/pull/19064) (zhuqi-lucas) +- Add Decimal support to Ceil and Floor [#18979](https://github.com/apache/datafusion/pull/18979) (kumarUjjawal) +- doc: add example for cache factory [#19139](https://github.com/apache/datafusion/pull/19139) (jizezhang) +- chore(deps): bump sphinx-reredirects from 1.0.0 to 1.1.0 in /docs [#19455](https://github.com/apache/datafusion/pull/19455) (dependabot[bot]) +- Add:arrow_metadata() UDF [#19435](https://github.com/apache/datafusion/pull/19435) (xonx4l) +- Update date_bin to support Time32 and Time64 data types [#19341](https://github.com/apache/datafusion/pull/19341) (Omega359) +- Update `to_unixtime` udf function to support a consistent set of argument types [#19442](https://github.com/apache/datafusion/pull/19442) (kumarUjjawal) +- docs: Improve config tables' readability [#19522](https://github.com/apache/datafusion/pull/19522) (nuno-faria) +- Introduce `TypeSignatureClass::Any` [#19485](https://github.com/apache/datafusion/pull/19485) (Jefffrey) +- Enables DefaultListFilesCache by default [#19366](https://github.com/apache/datafusion/pull/19366) (BlakeOrth) +- Fix typo in contributor guide architecture section [#19613](https://github.com/apache/datafusion/pull/19613) (cdegroc) +- docs: fix typos in PartitionEvaluator trait documentation [#19631](https://github.com/apache/datafusion/pull/19631) (SolariSystems) +- Respect execution timezone in to_timestamp and related functions [#19078](https://github.com/apache/datafusion/pull/19078) (Omega359) +- perfect hash join [#19411](https://github.com/apache/datafusion/pull/19411) (UBarney) + +**Other:** + +- chore(deps): bump taiki-e/install-action from 2.62.46 to 2.62.47 [#18508](https://github.com/apache/datafusion/pull/18508) (dependabot[bot]) +- Consolidate builtin functions examples (#18142) [#18523](https://github.com/apache/datafusion/pull/18523) (cj-zhukov) +- refactor: update cmp and nested data in binary operator [#18256](https://github.com/apache/datafusion/pull/18256) (sunng87) +- Fix: topk_aggregate benchmark failing [#18502](https://github.com/apache/datafusion/pull/18502) (randyli) +- refactor: Add `assert_or_internal_err!` macro for more ergonomic internal invariant checks [#18511](https://github.com/apache/datafusion/pull/18511) (2010YOUY01) +- chore: enforce clippy lint needless_pass_by_value to datafusion-physical-optimizer [#18555](https://github.com/apache/datafusion/pull/18555) (foskey51) +- chore: enforce clippy lint needless_pass_by_value for datafusion-sql [#18554](https://github.com/apache/datafusion/pull/18554) (foskey51) +- chore: enforce clippy lint needless_pass_by_value to physical-expr-common [#18556](https://github.com/apache/datafusion/pull/18556) (foskey51) +- chore: Enforce lint rule `clippy::needless_pass_by_value` to `datafusion-physical-expr` [#18557](https://github.com/apache/datafusion/pull/18557) (corasaurus-hex) +- Fix out-of-bounds access in SLT runner [#18562](https://github.com/apache/datafusion/pull/18562) (theirix) +- Make array_reverse faster for List and FixedSizeList [#18500](https://github.com/apache/datafusion/pull/18500) (vegarsti) +- Consolidate custom data source examples (#18142) [#18553](https://github.com/apache/datafusion/pull/18553) (cj-zhukov) +- chore(deps): bump taiki-e/install-action from 2.62.47 to 2.62.49 [#18581](https://github.com/apache/datafusion/pull/18581) (dependabot[bot]) +- chore: Remove unused `tokio` dependency and clippy [#18598](https://github.com/apache/datafusion/pull/18598) (comphead) +- minor: enforce `clippy::needless_pass_by_value` for crates that don't require code changes. [#18586](https://github.com/apache/datafusion/pull/18586) (2010YOUY01) +- refactor: merge CoalesceAsyncExecInput into CoalesceBatches [#18540](https://github.com/apache/datafusion/pull/18540) (Tim-53) +- Enhance the help message for invalid command in datafusion-cli [#18603](https://github.com/apache/datafusion/pull/18603) (klion26) +- Update Release README.md with latest process [#18549](https://github.com/apache/datafusion/pull/18549) (alamb) +- Add timezone to date_trunc fast path [#18596](https://github.com/apache/datafusion/pull/18596) (hareshkh) +- Coalesce batches inside FilterExec [#18604](https://github.com/apache/datafusion/pull/18604) (Dandandan) +- Fix misleading boolean 'null' interval tests [#18620](https://github.com/apache/datafusion/pull/18620) (pepijnve) +- Clarify tests for `Interval::and`, `Interval::not`, and add `Interval::or` tests [#18621](https://github.com/apache/datafusion/pull/18621) (pepijnve) +- bugfix: correct regression on TableType for into_view [#18617](https://github.com/apache/datafusion/pull/18617) (timsaucer) +- Separating Benchmarks for physical sorted union over large columns in SQL planner based on Datatype [#18599](https://github.com/apache/datafusion/pull/18599) (logan-keede) +- Add RunEndEncoded type coercion [#18561](https://github.com/apache/datafusion/pull/18561) (vegarsti) +- minor: refactor with `assert_or_internal_err!()` in `datafusion/sql` [#18614](https://github.com/apache/datafusion/pull/18614) (2010YOUY01) +- chore: ASF tracking process on `.asf.yaml` [#18636](https://github.com/apache/datafusion/pull/18636) (comphead) +- Refactor bit aggregate functions signature [#18593](https://github.com/apache/datafusion/pull/18593) (Jefffrey) +- chore(deps): bump taiki-e/install-action from 2.62.49 to 2.62.50 [#18645](https://github.com/apache/datafusion/pull/18645) (dependabot[bot]) +- bugfix: select_columns should validate column names [#18623](https://github.com/apache/datafusion/pull/18623) (timsaucer) +- Consolidate data io examples (#18142) [#18591](https://github.com/apache/datafusion/pull/18591) (cj-zhukov) +- Correct implementations of `NullableInterval::and` and `NullableInterval::or`. [#18625](https://github.com/apache/datafusion/pull/18625) (pepijnve) +- chore: ASF tracking process on `.asf.yaml` [#18652](https://github.com/apache/datafusion/pull/18652) (comphead) +- Refactor Spark bitshift signature [#18649](https://github.com/apache/datafusion/pull/18649) (Jefffrey) +- chore(deps): bump crate-ci/typos from 1.39.0 to 1.39.1 [#18667](https://github.com/apache/datafusion/pull/18667) (dependabot[bot]) +- Update docs for aggregate repartition test [#18650](https://github.com/apache/datafusion/pull/18650) (xanderbailey) +- chore: Enforce lint rule `clippy::needless_pass_by_value` to `datafusion-catalog` [#18638](https://github.com/apache/datafusion/pull/18638) (Standing-Man) +- [main] Update Changelog (#18592) [#18616](https://github.com/apache/datafusion/pull/18616) (alamb) +- Refactor distinct aggregate implementations to use common buffer [#18348](https://github.com/apache/datafusion/pull/18348) (Jefffrey) +- chore: enforce lint rule `clippy::needless_pass_by_value` to `datafusion-datasource-avro` [#18641](https://github.com/apache/datafusion/pull/18641) (Standing-Man) +- Refactor Spark expm1 signature [#18655](https://github.com/apache/datafusion/pull/18655) (Jefffrey) +- chore(core): Enforce lint rule `clippy::needless_pass_by_value` to `datafusion-core` [#18640](https://github.com/apache/datafusion/pull/18640) (Standing-Man) +- Refactor substr signature [#18653](https://github.com/apache/datafusion/pull/18653) (Jefffrey) +- minor: Use allow->expect to explicitly suppress Clippy lint checks [#18686](https://github.com/apache/datafusion/pull/18686) (2010YOUY01) +- chore(deps): bump taiki-e/install-action from 2.62.50 to 2.62.51 [#18693](https://github.com/apache/datafusion/pull/18693) (dependabot[bot]) +- chore(deps): bump crate-ci/typos from 1.39.1 to 1.39.2 [#18694](https://github.com/apache/datafusion/pull/18694) (dependabot[bot]) +- Remove FilterExec from CoalesceBatches optimization rule, add fetch support [#18630](https://github.com/apache/datafusion/pull/18630) (Dandandan) +- minor: refactor with `assert_or_internal_err!()` in `datafusion/datasource` [#18697](https://github.com/apache/datafusion/pull/18697) (kumarUjjawal) +- chore: Enforce lint rule `clippy::needless_pass_by_value` to datafusion-datasource [#18682](https://github.com/apache/datafusion/pull/18682) (AryanBagade) +- [main] Update changelog for 51.0.0 RC2 [#18710](https://github.com/apache/datafusion/pull/18710) (alamb) +- Refactor Spark crc32/sha1 signatures [#18662](https://github.com/apache/datafusion/pull/18662) (Jefffrey) +- CI: try free up space in `Rust / cargo test (amd64)` action [#18709](https://github.com/apache/datafusion/pull/18709) (Jefffrey) +- chore: enforce clippy lint needless_pass_by_value to datafusion-proto [#18715](https://github.com/apache/datafusion/pull/18715) (foskey51) +- chore: enforce clippy lint needless_pass_by_value to datafusion-spark [#18714](https://github.com/apache/datafusion/pull/18714) (foskey51) +- minor: refactor with `assert_or_internal_err!()` in `datafusion/optimizer` [#18699](https://github.com/apache/datafusion/pull/18699) (kumarUjjawal) +- minor: refactor with `assert_or_internal_err!()` in `datafusion/functions` [#18700](https://github.com/apache/datafusion/pull/18700) (kumarUjjawal) +- minor: refactor with `assert_or_internal_err!()` in `datafusion/expr-common` [#18702](https://github.com/apache/datafusion/pull/18702) (kumarUjjawal) +- minor: refactor with `assert_or_internal_err!()` in `datafusion/functions-aggregate` [#18716](https://github.com/apache/datafusion/pull/18716) (kumarUjjawal) +- chore: enforce clippy lint needless_pass_by_value to datafusion-execution [#18723](https://github.com/apache/datafusion/pull/18723) (foskey51) +- minor: refactor with `assert_or_internal_err!()` in `datafusion/functions-nested` [#18724](https://github.com/apache/datafusion/pull/18724) (kumarUjjawal) +- chore: enforce clippy lint needless_pass_by_value to datafusion-substrait [#18703](https://github.com/apache/datafusion/pull/18703) (foskey51) +- chore: Refactor with assert_or_internal_err!() in datafusion/spark. [#18674](https://github.com/apache/datafusion/pull/18674) (codetyri0n) +- Minor: Add docs to release/README.md about rate limits [#18704](https://github.com/apache/datafusion/pull/18704) (alamb) +- Consolidate query planning examples (#18142) [#18690](https://github.com/apache/datafusion/pull/18690) (cj-zhukov) +- minor: refactor with `assert_or_internal_err!()` in `datafusion/physical-expr-common` [#18735](https://github.com/apache/datafusion/pull/18735) (kumarUjjawal) +- minor: refactor with `assert_or_internal_err!()` in `datafusion/physical-expr` [#18736](https://github.com/apache/datafusion/pull/18736) (kumarUjjawal) +- Consolidate ArrowFileSource and ArrowStreamFileSource [#18720](https://github.com/apache/datafusion/pull/18720) (adriangb) +- minor: refactor with `assert_or_internal_err!()` in `datafusion/physical-optimizer` [#18732](https://github.com/apache/datafusion/pull/18732) (kumarUjjawal) +- refactor: reduce duplication in make_udf_function macro [#18733](https://github.com/apache/datafusion/pull/18733) (shashidhar-bm) +- minor: refactor with `assert_or_internal_err!()` in `datafusion/physical-plan` [#18730](https://github.com/apache/datafusion/pull/18730) (kumarUjjawal) +- chore: enforce clippy lint needless_pass_by_value to datafusion-functions-aggregate-common [#18741](https://github.com/apache/datafusion/pull/18741) (foskey51) +- Optimize NullState::build [#18737](https://github.com/apache/datafusion/pull/18737) (Dandandan) +- chore: enforce clippy lint needless_pass_by_value to datafusion-datasource-parquet [#18695](https://github.com/apache/datafusion/pull/18695) (foskey51) +- minor: refactor with `assert_or_internal_err!()` in `datafusion/expr` [#18731](https://github.com/apache/datafusion/pull/18731) (kumarUjjawal) +- minor: Fix an example in the `PruningPredicate` documentation [#18742](https://github.com/apache/datafusion/pull/18742) (2010YOUY01) +- chore(deps): bump indicatif from 0.18.2 to 0.18.3 [#18756](https://github.com/apache/datafusion/pull/18756) (dependabot[bot]) +- Fix map_query_sql benchmark duplicate key error [#18427](https://github.com/apache/datafusion/pull/18427) (atheendre130505) +- minor: enforce lint rule clippy::needless_pass_by_value to datafusion-ffi [#18764](https://github.com/apache/datafusion/pull/18764) (Standing-Man) +- Rename boolean `Interval` constants to match `NullableInterval` [#18654](https://github.com/apache/datafusion/pull/18654) (pepijnve) +- chore(deps): bump bytes from 1.10.1 to 1.11.0 [#18755](https://github.com/apache/datafusion/pull/18755) (dependabot[bot]) +- CI: Fix `main` branch CI test failure [#18792](https://github.com/apache/datafusion/pull/18792) (2010YOUY01) +- chore: Enforce 'clippy::needless_pass_by_value' to datafusion-expr-common [#18775](https://github.com/apache/datafusion/pull/18775) (petern48) +- chore: Finish refactor with `assert_or_internal_err!()` [#18790](https://github.com/apache/datafusion/pull/18790) (2010YOUY01) +- Switch from xz2 to liblzma to reduce duplicate dependencies [#17509](https://github.com/apache/datafusion/pull/17509) (timsaucer) +- chore(deps): bump taiki-e/install-action from 2.62.51 to 2.62.53 [#18796](https://github.com/apache/datafusion/pull/18796) (dependabot[bot]) +- chore(deps): bump actions/checkout from 5.0.0 to 5.0.1 [#18797](https://github.com/apache/datafusion/pull/18797) (dependabot[bot]) +- Misc improvements to ProjectionExprs [#18719](https://github.com/apache/datafusion/pull/18719) (adriangb) +- Fix incorrect link for sql_query.rs example in README [#18807](https://github.com/apache/datafusion/pull/18807) (kondamudikarthik) +- Adds prefix filtering for table URLs [#18780](https://github.com/apache/datafusion/pull/18780) (BlakeOrth) +- Refactor InListExpr to support structs by re-using existing hashing infrastructure [#18449](https://github.com/apache/datafusion/pull/18449) (adriangb) +- chore: Add script to protect RC branches during the release [#18660](https://github.com/apache/datafusion/pull/18660) (comphead) +- Prevent overflow and panics when casting DATE to TIMESTAMP by validating bounds [#18761](https://github.com/apache/datafusion/pull/18761) (kosiew) +- chore(deps): bump taiki-e/install-action from 2.62.53 to 2.62.54 [#18815](https://github.com/apache/datafusion/pull/18815) (dependabot[bot]) +- CI : Enforce clippy: :needless_pass_by_value rule to datafusion-functions-aggregate [#18805](https://github.com/apache/datafusion/pull/18805) (codetyri0n) +- Consolidate sql operations examples (#18142) [#18743](https://github.com/apache/datafusion/pull/18743) (cj-zhukov) +- Move `GuaranteeRewriter` to datafusion_expr [#18821](https://github.com/apache/datafusion/pull/18821) (pepijnve) +- Refactor state management in `HashJoinExec` and use CASE expressions for more precise filters [#18451](https://github.com/apache/datafusion/pull/18451) (adriangb) +- Refactor avg & sum signatures away from user defined [#18769](https://github.com/apache/datafusion/pull/18769) (Jefffrey) +- Hash UnionArrays [#18718](https://github.com/apache/datafusion/pull/18718) (friendlymatthew) +- CI: add clippy::needless_pass_by_value rule to datafusion-functions-window crate [#18838](https://github.com/apache/datafusion/pull/18838) (codetyri0n) +- Add field to DynamicPhysicalExpr to indicate when the filter is complete or updated [#18799](https://github.com/apache/datafusion/pull/18799) (LiaCastaneda) +- #17801 Improve nullability reporting of case expressions [#17813](https://github.com/apache/datafusion/pull/17813) (pepijnve) +- Consolidate execution monitoring examples (#18142) [#18846](https://github.com/apache/datafusion/pull/18846) (cj-zhukov) +- Implement CatalogProviderList in FFI [#18657](https://github.com/apache/datafusion/pull/18657) (timsaucer) +- Removed incorrect union check in enforce_sorting and updated tests [#18661](https://github.com/apache/datafusion/pull/18661) (gene-bordegaray) +- chore(deps): bump actions/checkout from 5.0.1 to 6.0.0 [#18865](https://github.com/apache/datafusion/pull/18865) (dependabot[bot]) +- Remove unnecessary bit counting code from spark `bit_count` [#18841](https://github.com/apache/datafusion/pull/18841) (pepijnve) +- Fix async_udf batch size behaviour [#18819](https://github.com/apache/datafusion/pull/18819) (shivbhatia10) +- Fix Partial AggregateExec correctness issue dropping rows [#18712](https://github.com/apache/datafusion/pull/18712) (xanderbailey) +- chore: Add missing boolean tests to `bit_count` Spark function [#18871](https://github.com/apache/datafusion/pull/18871) (comphead) +- Consolidate proto examples (#18142) [#18861](https://github.com/apache/datafusion/pull/18861) (cj-zhukov) +- Use logical null count in `case_when_with_expr` [#18872](https://github.com/apache/datafusion/pull/18872) (pepijnve) +- chore: enforce `clippy::needless_pass_by_value` to `datafusion-physical-plan` [#18864](https://github.com/apache/datafusion/pull/18864) (2010YOUY01) +- Refactor spark `bit_get()` signature away from user defined [#18836](https://github.com/apache/datafusion/pull/18836) (Jefffrey) +- minor: enforce lint rule clippy::needless_pass_by_value to datafusion-functions [#18768](https://github.com/apache/datafusion/pull/18768) (Standing-Man) +- chore: enforce clippy lint needless_pass_by_value to datafusion-functions-nested [#18839](https://github.com/apache/datafusion/pull/18839) (foskey51) +- chore: fix CI on main [#18876](https://github.com/apache/datafusion/pull/18876) (Jefffrey) +- chore: update Repartition DisplayAs to indicate maintained sort order [#18673](https://github.com/apache/datafusion/pull/18673) (ruchirK) +- implement sum for durations [#18853](https://github.com/apache/datafusion/pull/18853) (logan-keede) +- Consolidate dataframe examples (#18142) [#18862](https://github.com/apache/datafusion/pull/18862) (cj-zhukov) +- Avoid the need to rewrite expressions when evaluating logical case nullability [#18849](https://github.com/apache/datafusion/pull/18849) (pepijnve) +- Avoid skew in Roundrobin repartition [#18880](https://github.com/apache/datafusion/pull/18880) (Dandandan) +- Add benchmark for array_has/array_has_all/array_has_any [#18729](https://github.com/apache/datafusion/pull/18729) (zhuqi-lucas) +- chore(deps): bump taiki-e/install-action from 2.62.54 to 2.62.56 [#18899](https://github.com/apache/datafusion/pull/18899) (dependabot[bot]) +- chore(deps): bump indicatif from 0.18.0 to 0.18.3 [#18897](https://github.com/apache/datafusion/pull/18897) (dependabot[bot]) +- chore(deps): bump tokio-util from 0.7.16 to 0.7.17 [#18898](https://github.com/apache/datafusion/pull/18898) (dependabot[bot]) +- Support Non-Literal Expressions in Substrait VirtualTable Values and Improve Round-Trip Robustness [#18866](https://github.com/apache/datafusion/pull/18866) (kosiew) +- chore(deps): bump indexmap from 2.12.0 to 2.12.1 [#18895](https://github.com/apache/datafusion/pull/18895) (dependabot[bot]) +- chore(deps): bump aws-config from 1.8.7 to 1.8.11 [#18896](https://github.com/apache/datafusion/pull/18896) (dependabot[bot]) +- chore(deps): bump flate2 from 1.1.4 to 1.1.5 [#18900](https://github.com/apache/datafusion/pull/18900) (dependabot[bot]) +- Add iter() method to `Extensions` [#18887](https://github.com/apache/datafusion/pull/18887) (gabotechs) +- chore: Enforce `clippy::needless_pass_by_value` globally across the workspace [#18904](https://github.com/apache/datafusion/pull/18904) (2010YOUY01) +- Consolidate external dependency examples (#18142) [#18747](https://github.com/apache/datafusion/pull/18747) (cj-zhukov) +- Optimize planning for projected nested union [#18713](https://github.com/apache/datafusion/pull/18713) (logan-keede) +- chore(deps): bump taiki-e/install-action from 2.62.56 to 2.62.57 [#18927](https://github.com/apache/datafusion/pull/18927) (dependabot[bot]) +- chore(deps): bump actions/setup-python from 6.0.0 to 6.1.0 [#18925](https://github.com/apache/datafusion/pull/18925) (dependabot[bot]) +- Fix `map` function alias handling in SQL planner [#18914](https://github.com/apache/datafusion/pull/18914) (friendlymatthew) +- minor: add builder setting `NdJsonReadOptions::schema_infer_max_records` [#18920](https://github.com/apache/datafusion/pull/18920) (Jefffrey) +- Implement Substrait Support for `GROUPING SET CUBE` [#18798](https://github.com/apache/datafusion/pull/18798) (kosiew) +- chore: unify common dependencies as workspace dependencies [#18665](https://github.com/apache/datafusion/pull/18665) (Jefffrey) +- Fix bug where binary types were incorrectly being casted for coercible signatures [#18750](https://github.com/apache/datafusion/pull/18750) (Jefffrey) +- Refactor approx_median signature & support f16 [#18647](https://github.com/apache/datafusion/pull/18647) (Jefffrey) +- Refactor `to_local_time()` signature away from user_defined [#18707](https://github.com/apache/datafusion/pull/18707) (Jefffrey) +- chore(deps-dev): bump node-forge from 1.3.1 to 1.3.2 in /datafusion/wasmtest/datafusion-wasm-app [#18958](https://github.com/apache/datafusion/pull/18958) (dependabot[bot]) +- Support LikeMatch, ILikeMatch, NotLikeMatch, NotILikeMatch operators in protobuf serialization [#18961](https://github.com/apache/datafusion/pull/18961) (zhuqi-lucas) +- chore: cargo fmt to fix CI [#18969](https://github.com/apache/datafusion/pull/18969) (Jefffrey) +- chore(deps): bump Swatinem/rust-cache from 2.8.1 to 2.8.2 [#18963](https://github.com/apache/datafusion/pull/18963) (dependabot[bot]) +- chore(deps): bump taiki-e/install-action from 2.62.57 to 2.62.58 [#18964](https://github.com/apache/datafusion/pull/18964) (dependabot[bot]) +- chore(deps): bump crate-ci/typos from 1.39.2 to 1.40.0 [#18965](https://github.com/apache/datafusion/pull/18965) (dependabot[bot]) +- [Minor] Refactor `traverse_chain` macro to function [#18951](https://github.com/apache/datafusion/pull/18951) (Dandandan) +- Enable clippy::allow_attributes lint for datafusion-catalog [#18973](https://github.com/apache/datafusion/pull/18973) (chakkk309) +- chore: update group of crates to rust 2024 edition [#18915](https://github.com/apache/datafusion/pull/18915) (timsaucer) +- chore(deps): bump taiki-e/install-action from 2.62.58 to 2.62.59 [#18978](https://github.com/apache/datafusion/pull/18978) (dependabot[bot]) +- Simplify percentile_cont for 0/1 percentiles [#18837](https://github.com/apache/datafusion/pull/18837) (kumarUjjawal) +- chore: enforce clippy::allow_attributes for functions-\* crates [#18986](https://github.com/apache/datafusion/pull/18986) (carlosahs) +- chore: enforce clippy::allow_attributes for common crates [#18988](https://github.com/apache/datafusion/pull/18988) (chakkk309) +- Fix predicate_rows_pruned & predicate_rows_matched metrics [#18980](https://github.com/apache/datafusion/pull/18980) (xudong963) +- Allocate a buffer of the correct length for ScalarValue::FixedSizeBinary in ScalarValue::to_array_of_size [#18903](https://github.com/apache/datafusion/pull/18903) (tobixdev) +- Fix error planning aggregates with duplicated names in select list [#18831](https://github.com/apache/datafusion/pull/18831) (tshauck) +- chore: remove `deny`s of `needless_pass_by_value` in `lib.rs` files [#18996](https://github.com/apache/datafusion/pull/18996) (Jefffrey) +- Add Explicit Error Handling for Unsupported SQL `FETCH` Clause in Planner and CLI [#18691](https://github.com/apache/datafusion/pull/18691) (kosiew) +- chore(deps): bump criterion from 0.7.0 to 0.8.0 [#19009](https://github.com/apache/datafusion/pull/19009) (dependabot[bot]) +- chore(deps): bump syn from 2.0.108 to 2.0.111 [#19011](https://github.com/apache/datafusion/pull/19011) (dependabot[bot]) +- chore(deps): bump taiki-e/install-action from 2.62.59 to 2.62.60 [#19012](https://github.com/apache/datafusion/pull/19012) (dependabot[bot]) +- chore: remove redundant clone code [#18997](https://github.com/apache/datafusion/pull/18997) (Smith-Cruise) +- Update to `arrow`, `parquet` to `57.1.0` [#18820](https://github.com/apache/datafusion/pull/18820) (alamb) +- deny on allow_attributes lint in physical-plan [#18983](https://github.com/apache/datafusion/pull/18983) (YuraLitvinov) +- Add additional test coverage of multi-value PartitionPruningStats [#19021](https://github.com/apache/datafusion/pull/19021) (alamb) +- Fix tpch benchmark harness [#19033](https://github.com/apache/datafusion/pull/19033) (alamb) +- Fix data for tpch_csv and tpch_csv10 [#19034](https://github.com/apache/datafusion/pull/19034) (alamb) +- chore: update group of 3 crates to rust 2024 edition [#19001](https://github.com/apache/datafusion/pull/19001) (timsaucer) +- chore(deps-dev): bump express from 4.21.2 to 4.22.1 in /datafusion/wasmtest/datafusion-wasm-app [#19040](https://github.com/apache/datafusion/pull/19040) (dependabot[bot]) +- Allow repartitioning on files with ranges [#18948](https://github.com/apache/datafusion/pull/18948) (Samyak2) +- Support simplify not for physical expr [#18970](https://github.com/apache/datafusion/pull/18970) (xudong963) +- dev: Add typos check to the local `dev/rust_lint.sh` [#17863](https://github.com/apache/datafusion/pull/17863) (2010YOUY01) +- Implement FFI_PhysicalExpr and the structs it needs to support it. [#18916](https://github.com/apache/datafusion/pull/18916) (timsaucer) +- chore(deps): bump actions/setup-node from 6.0.0 to 6.1.0 [#19063](https://github.com/apache/datafusion/pull/19063) (dependabot[bot]) +- chore(deps): bump taiki-e/install-action from 2.62.60 to 2.62.61 [#19062](https://github.com/apache/datafusion/pull/19062) (dependabot[bot]) +- chore(deps): bump actions/stale from 10.1.0 to 10.1.1 [#19061](https://github.com/apache/datafusion/pull/19061) (dependabot[bot]) +- chore: merge make_array and spark array [#19006](https://github.com/apache/datafusion/pull/19006) (jizezhang) +- chore(deps): bump actions/checkout from 6.0.0 to 6.0.1 [#19060](https://github.com/apache/datafusion/pull/19060) (dependabot[bot]) +- Add documentation example for `PartitionPruningStatistics` [#19020](https://github.com/apache/datafusion/pull/19020) (alamb) +- chore: upgrade expr and execution crates to rust 2024 edition [#19047](https://github.com/apache/datafusion/pull/19047) (timsaucer) +- refactor: Refactor spark make_interval signature away from user defined [#19027](https://github.com/apache/datafusion/pull/19027) (kumarUjjawal) +- Fix: Align sort_merge_join filter output with join schema to fix right-anti panic [#18800](https://github.com/apache/datafusion/pull/18800) (kumarUjjawal) +- Support Substrait Round-Trip of `EmptyRelation` Including `produce_one_row` Semantics [#18842](https://github.com/apache/datafusion/pull/18842) (kosiew) +- chore(deps): bump taiki-e/install-action from 2.62.61 to 2.62.62 [#19081](https://github.com/apache/datafusion/pull/19081) (dependabot[bot]) +- chore: enforce clippy::allow_attributes for datasource crates [#19068](https://github.com/apache/datafusion/pull/19068) (chakkk309) +- common: Add hashing support for REE arrays [#18981](https://github.com/apache/datafusion/pull/18981) (brancz) +- Use `tpchgen-cli` to generate tpch data in bench.sh [#19035](https://github.com/apache/datafusion/pull/19035) (alamb) +- Update aggregate probe to be locked only if skipping aggregation [#18766](https://github.com/apache/datafusion/pull/18766) (hareshkh) +- Fix function doc CI check [#19093](https://github.com/apache/datafusion/pull/19093) (alamb) +- Fix Schema Duplication Errors in Self‑Referential INTERSECT/EXCEPT by Requalifying Input Sides [#18814](https://github.com/apache/datafusion/pull/18814) (kosiew) +- run cargo fmt to fix after #18998 [#19102](https://github.com/apache/datafusion/pull/19102) (adriangb) +- bench: set test_util as required feature for aggregate_vectorized [#19101](https://github.com/apache/datafusion/pull/19101) (rluvaton) +- use ProjectionExprs:project_statistics in FileScanConfig [#19094](https://github.com/apache/datafusion/pull/19094) (adriangb) +- Temporarily ignore test_cache_with_ttl_and_lru test [#19115](https://github.com/apache/datafusion/pull/19115) (alamb) +- refactor: move human readable display utilities to `datafusion-common` crate [#19080](https://github.com/apache/datafusion/pull/19080) (2010YOUY01) +- Always remove unecessary software from github runners for all jobs (fix intermittent out of space on runners) [#19122](https://github.com/apache/datafusion/pull/19122) (alamb) +- [datafusion-spark]: Refactor make_dt_interval's signature away from user defined [#19083](https://github.com/apache/datafusion/pull/19083) (codetyri0n) +- fix deprecation notes with incorrect versions from #13083 [#19135](https://github.com/apache/datafusion/pull/19135) (adriangb) +- Run the examples in the new format [#18946](https://github.com/apache/datafusion/pull/18946) (cj-zhukov) +- Add constant expression evaluator to physical expression simplifier [#19130](https://github.com/apache/datafusion/pull/19130) (adriangb) +- Fix shuffle function to report nullability correctly [#19184](https://github.com/apache/datafusion/pull/19184) (harshitsaini17) +- chore: enforce clippy::allow_attributes for physical crates [#19185](https://github.com/apache/datafusion/pull/19185) (carlosahs) +- Update 5 crates to rust 2024 edition [#19091](https://github.com/apache/datafusion/pull/19091) (timsaucer) +- Coalesce batches inside hash join, reuse indices buffer [#18972](https://github.com/apache/datafusion/pull/18972) (Dandandan) +- slt test coverage for `CASE` exprs with constant value lookup tables [#19143](https://github.com/apache/datafusion/pull/19143) (alamb) +- Fix fmt after logical conflict [#19208](https://github.com/apache/datafusion/pull/19208) (alamb) +- chore: Add TPCDS benchmarks [#19138](https://github.com/apache/datafusion/pull/19138) (comphead) +- Arc partition values in TableSchema [#19137](https://github.com/apache/datafusion/pull/19137) (adriangb) +- Add sorted data benchmark. [#19042](https://github.com/apache/datafusion/pull/19042) (zhuqi-lucas) +- Refactor PhysicalExprSimplfier to &self instead of &mut self [#19212](https://github.com/apache/datafusion/pull/19212) (adriangb) +- chore(deps): bump uuid from 1.18.1 to 1.19.0 [#19199](https://github.com/apache/datafusion/pull/19199) (dependabot[bot]) +- chore(deps): bump async-compression from 0.4.34 to 0.4.35 [#19201](https://github.com/apache/datafusion/pull/19201) (dependabot[bot]) +- chore(deps): bump taiki-e/install-action from 2.62.62 to 2.62.63 [#19198](https://github.com/apache/datafusion/pull/19198) (dependabot[bot]) +- chore(deps): bump tracing-subscriber from 0.3.20 to 0.3.22 [#19200](https://github.com/apache/datafusion/pull/19200) (dependabot[bot]) +- chore(deps): bump wasm-bindgen-test from 0.3.55 to 0.3.56 [#19202](https://github.com/apache/datafusion/pull/19202) (dependabot[bot]) +- bench: add dedicated Utf8View benchmarks for InList [#19211](https://github.com/apache/datafusion/pull/19211) (geoffreyclaude) +- Fix PruningPredicate interaction with DynamicFilterPhysicalExpr that references partition columns [#19129](https://github.com/apache/datafusion/pull/19129) (adriangb) +- Implement physical and logical codecs in FFI [#19079](https://github.com/apache/datafusion/pull/19079) (timsaucer) +- refactor: Refactor spark width bucket signature away from user defined [#19065](https://github.com/apache/datafusion/pull/19065) (kumarUjjawal) +- Sort Merge Join: Reduce batch concatenation, use `BatchCoalescer`, new benchmarks (TPC-H Q21 SMJ up to ~4000x faster) [#18875](https://github.com/apache/datafusion/pull/18875) (mbutrovich) +- Add relation planner extension support to customize SQL planning [#17843](https://github.com/apache/datafusion/pull/17843) (geoffreyclaude) +- Add additional tests for InListExpr [#19050](https://github.com/apache/datafusion/pull/19050) (adriangb) +- chore(deps): bump taiki-e/install-action from 2.62.63 to 2.62.64 [#19226](https://github.com/apache/datafusion/pull/19226) (dependabot[bot]) +- Use strum in the examples (#19126) [#19205](https://github.com/apache/datafusion/pull/19205) (cj-zhukov) +- [Proto]: Serialization support for `AsyncFuncExec` [#19118](https://github.com/apache/datafusion/pull/19118) (mach-kernel) +- chore: add test case for decimal overflow [#19255](https://github.com/apache/datafusion/pull/19255) (Jefffrey) +- chore(deps): bump taiki-e/install-action from 2.62.64 to 2.62.65 [#19251](https://github.com/apache/datafusion/pull/19251) (dependabot[bot]) +- chore: update 6 crates to rust edition 2024 [#19196](https://github.com/apache/datafusion/pull/19196) (timsaucer) +- Implement FFI_Session [#19223](https://github.com/apache/datafusion/pull/19223) (timsaucer) +- Feat: Add an option for fast tests by gating slow tests to extended_tests feature [#19237](https://github.com/apache/datafusion/pull/19237) (Yuvraj-cyborg) +- chore: enforce clippy::allow_attributes for 7 crates [#19133](https://github.com/apache/datafusion/pull/19133) (chakkk309) +- dev: Add CI doc prettier check to local `rust_lint.sh` [#19254](https://github.com/apache/datafusion/pull/19254) (2010YOUY01) +- bug: Eliminate dead round-robin insertion in enforce distribution [#19132](https://github.com/apache/datafusion/pull/19132) (gene-bordegaray) +- Automatically download tpcds benchmark data to the right place [#19244](https://github.com/apache/datafusion/pull/19244) (alamb) +- [datafusion-spark]: Refactor hex's signature away from user_defined [#19235](https://github.com/apache/datafusion/pull/19235) (codetyri0n) +- fix : correct nullability propagation for spark.bitwise_not [#19224](https://github.com/apache/datafusion/pull/19224) (shifluxxc) +- added custom nullability for char [#19268](https://github.com/apache/datafusion/pull/19268) (skushagra) +- replace HashTableLookupExpr with lit(true) in proto serialization [#19300](https://github.com/apache/datafusion/pull/19300) (adriangb) +- chore: fix return_field_from_args doc [#19307](https://github.com/apache/datafusion/pull/19307) (xumingming) +- chore: enforce clippy::allow_attributes for spark,sql,sustrait [#19309](https://github.com/apache/datafusion/pull/19309) (kumarUjjawal) +- Simplify make_date & fix null handling [#19296](https://github.com/apache/datafusion/pull/19296) (Jefffrey) +- Allow base64 encoding of fixedsizebinary arrays [#18950](https://github.com/apache/datafusion/pull/18950) (maxburke) +- chore: update 11 crates to Rust 2024 edition [#19258](https://github.com/apache/datafusion/pull/19258) (timsaucer) +- Minor: remove unnecessary unit tests for fixed size binary [#19318](https://github.com/apache/datafusion/pull/19318) (alamb) +- Populate partition column statistics for PartitionedFile [#19284](https://github.com/apache/datafusion/pull/19284) (adriangb) +- refactor: move metrics module to `datafusion-common` crate [#19247](https://github.com/apache/datafusion/pull/19247) (2010YOUY01) +- chore(deps): bump taiki-e/install-action from 2.62.65 to 2.62.67 [#19295](https://github.com/apache/datafusion/pull/19295) (dependabot[bot]) +- chore(deps): bump ctor from 0.6.1 to 0.6.3 [#19328](https://github.com/apache/datafusion/pull/19328) (dependabot[bot]) +- Refactor `power()` signature away from user defined [#18968](https://github.com/apache/datafusion/pull/18968) (Jefffrey) +- chore: enforce `clippy::allow_attributes` for optimizer and macros [#19310](https://github.com/apache/datafusion/pull/19310) (kumarUjjawal) +- chore(deps): bump taiki-e/install-action from 2.62.67 to 2.63.3 [#19349](https://github.com/apache/datafusion/pull/19349) (dependabot[bot]) +- chore(deps): bump clap from 4.5.50 to 4.5.53 [#19326](https://github.com/apache/datafusion/pull/19326) (dependabot[bot]) +- chore(deps): bump insta from 1.43.2 to 1.44.3 [#19327](https://github.com/apache/datafusion/pull/19327) (dependabot[bot]) +- remove repartition exec from coalesce batches optimizer [#19239](https://github.com/apache/datafusion/pull/19239) (jizezhang) +- minor: cleanup unnecessary config in `decimal.slt` [#19352](https://github.com/apache/datafusion/pull/19352) (Jefffrey) +- Fix panic for `GROUPING SETS(())` and handle empty-grouping aggregates [#19252](https://github.com/apache/datafusion/pull/19252) (kosiew) +- Update datafusion-core crate to Rust 2024 edition [#19332](https://github.com/apache/datafusion/pull/19332) (timsaucer) +- Update 4 crates to rust 2024 edition [#19357](https://github.com/apache/datafusion/pull/19357) (timsaucer) +- preserve Field metadata in first_value/last_value [#19335](https://github.com/apache/datafusion/pull/19335) (adriangb) +- Fix flaky SpillPool channel test by synchronizing reader and writer tasks [#19110](https://github.com/apache/datafusion/pull/19110) (kosiew) +- [minor] Upgrade rust version [#19363](https://github.com/apache/datafusion/pull/19363) (Dandandan) +- Minor: fix cargo fmt [#19368](https://github.com/apache/datafusion/pull/19368) (zhuqi-lucas) +- chore: enforce clippy::allow_attributes for proto, pruning, session [#19350](https://github.com/apache/datafusion/pull/19350) (kumarUjjawal) +- Update remaining crates to rust 2024 edition [#19361](https://github.com/apache/datafusion/pull/19361) (timsaucer) +- Minor: Make `ProjectionExpr::new` easier to use with constants [#19343](https://github.com/apache/datafusion/pull/19343) (alamb) +- Feat: DefaultListFilesCache prefix-aware for partition pruning optimization [#19298](https://github.com/apache/datafusion/pull/19298) (Yuvraj-cyborg) +- Extend in_list benchmark coverage [#19376](https://github.com/apache/datafusion/pull/19376) (geoffreyclaude) +- [datafusion-cli] Implement average LIST duration for object store profiling [#19127](https://github.com/apache/datafusion/pull/19127) (peterxcli) +- chore(deps): bump taiki-e/install-action from 2.63.3 to 2.64.0 [#19382](https://github.com/apache/datafusion/pull/19382) (dependabot[bot]) +- update insta snapshots [#19381](https://github.com/apache/datafusion/pull/19381) (kosiew) +- Fix regression for negative-scale decimal128 in log [#19315](https://github.com/apache/datafusion/pull/19315) (shifluxxc) +- Fix input handling for encoding functions & various refactors [#18754](https://github.com/apache/datafusion/pull/18754) (Jefffrey) +- Fix ORDER BY positional reference regression with aliased aggregates [#19412](https://github.com/apache/datafusion/pull/19412) (adriangb) +- Implement disk spilling for all grouping ordering modes in GroupedHashAggregateStream [#19287](https://github.com/apache/datafusion/pull/19287) (pepijnve) +- refactor: add ParquetOpenerBuilder to reduce test code duplication [#19405](https://github.com/apache/datafusion/pull/19405) (shashidhar-bm) +- bench: add `range_and_generate_series` [#19428](https://github.com/apache/datafusion/pull/19428) (rluvaton) +- chore: use extend instead of manual loop in multi group by [#19429](https://github.com/apache/datafusion/pull/19429) (rluvaton) +- chore(deps): bump taiki-e/install-action from 2.64.0 to 2.64.2 [#19399](https://github.com/apache/datafusion/pull/19399) (dependabot[bot]) +- Add recursive protection on planner's `create_physical_expr` [#19299](https://github.com/apache/datafusion/pull/19299) (rgehan) +- chore(deps): bump aws-config from 1.8.11 to 1.8.12 [#19453](https://github.com/apache/datafusion/pull/19453) (dependabot[bot]) +- chore(deps): bump log from 0.4.28 to 0.4.29 [#19452](https://github.com/apache/datafusion/pull/19452) (dependabot[bot]) +- chore(deps): bump taiki-e/install-action from 2.64.2 to 2.65.1 [#19451](https://github.com/apache/datafusion/pull/19451) (dependabot[bot]) +- chore(deps): bump insta from 1.44.3 to 1.45.0 [#19454](https://github.com/apache/datafusion/pull/19454) (dependabot[bot]) +- added support for negative scale for log decimal32/64 and power [#19409](https://github.com/apache/datafusion/pull/19409) (shifluxxc) +- Remove core dependency from ffi [#19422](https://github.com/apache/datafusion/pull/19422) (timsaucer) +- bench: increase in_list benchmark coverage [#19443](https://github.com/apache/datafusion/pull/19443) (geoffreyclaude) +- Use SortMergeJoinExec name consistently in physical plan outputs [#19246](https://github.com/apache/datafusion/pull/19246) (xavlee) +- Fix panic during spill to disk in clickbench query [#19421](https://github.com/apache/datafusion/pull/19421) (alamb) +- Optimize memory footprint of view arrays from `ScalarValue::to_array_of_size` [#19441](https://github.com/apache/datafusion/pull/19441) (Jefffrey) +- minor: refactoring of some `ScalarValue` code [#19439](https://github.com/apache/datafusion/pull/19439) (Jefffrey) +- Refactor Spark crc32 & sha1 to remove unnecessary scalar argument check [#19466](https://github.com/apache/datafusion/pull/19466) (Jefffrey) +- Add link to arrow-rs ticket in comments [#19479](https://github.com/apache/datafusion/pull/19479) (alamb) +- chore(deps): bump taiki-e/install-action from 2.65.1 to 2.65.2 [#19474](https://github.com/apache/datafusion/pull/19474) (dependabot[bot]) +- Improve plan_to_sql handling of empty projections with dialect-specific SELECT list support [#19221](https://github.com/apache/datafusion/pull/19221) (kosiew) +- examples: replace sql_dialect with custom_sql_parser example [#19383](https://github.com/apache/datafusion/pull/19383) (geoffreyclaude) +- Replace custom merge operator with arrow-rs implementation [#19424](https://github.com/apache/datafusion/pull/19424) (pepijnve) +- Implement nested recursive CTEs [#18956](https://github.com/apache/datafusion/pull/18956) (Tpt) +- Add: PI upper/lower bound f16 constants to ScalarValue [#19497](https://github.com/apache/datafusion/pull/19497) (xonx4l) +- chore: enforce clippy::allow_attributes for datafusion-ffi crate [#19480](https://github.com/apache/datafusion/pull/19480) (chakkk309) +- Add CI check to ensure examples are documented in README [#19371](https://github.com/apache/datafusion/pull/19371) (cj-zhukov) +- fix : snapshot to the modern multiline format [#19517](https://github.com/apache/datafusion/pull/19517) (Nachiket-Roy) +- chore(deps): bump taiki-e/install-action from 2.65.2 to 2.65.3 [#19499](https://github.com/apache/datafusion/pull/19499) (dependabot[bot]) +- docs : clarify unused test utility [#19508](https://github.com/apache/datafusion/pull/19508) (Nachiket-Roy) +- Date / time / interval arithmetic improvements [#19460](https://github.com/apache/datafusion/pull/19460) (Omega359) +- Preserve ORDER BY in Unparser for projection -> order by pattern [#19483](https://github.com/apache/datafusion/pull/19483) (adriangb) +- Redesign the try_reverse_output to support more cases [#19446](https://github.com/apache/datafusion/pull/19446) (zhuqi-lucas) +- refactor: Spark `ascii` signature away from `user_defined` [#19513](https://github.com/apache/datafusion/pull/19513) (kumarUjjawal) +- Fix: SparkAscii nullability to depend on input nullability [#19531](https://github.com/apache/datafusion/pull/19531) (Yuvraj-cyborg) +- chore(deps): bump tracing from 0.1.41 to 0.1.43 [#19543](https://github.com/apache/datafusion/pull/19543) (dependabot[bot]) +- chore(deps): bump substrait from 0.62.0 to 0.62.2 [#19542](https://github.com/apache/datafusion/pull/19542) (dependabot[bot]) +- chore(deps): bump taiki-e/install-action from 2.65.3 to 2.65.6 [#19541](https://github.com/apache/datafusion/pull/19541) (dependabot[bot]) +- minor: run all examples by default [#19506](https://github.com/apache/datafusion/pull/19506) (theirix) +- Refactor TopKHashTable to use HashTable API [#19464](https://github.com/apache/datafusion/pull/19464) (Dandandan) +- Revert Spark Elt nullability change [#19510](https://github.com/apache/datafusion/pull/19510) (Jefffrey) +- minor: implement more arms for `get_data_types()` for `NativeType` [#19449](https://github.com/apache/datafusion/pull/19449) (Jefffrey) +- Upgrade hashbrown to 0.16 [#19554](https://github.com/apache/datafusion/pull/19554) (Dandandan) +- minor : add crypto function benchmark [#19539](https://github.com/apache/datafusion/pull/19539) (getChan) +- chore(deps): bump taiki-e/install-action from 2.65.6 to 2.65.8 [#19559](https://github.com/apache/datafusion/pull/19559) (dependabot[bot]) +- bugfix: preserve schema metadata for record batch in FFI [#19293](https://github.com/apache/datafusion/pull/19293) (timsaucer) +- refactor: extract the data generate out of aggregate_topk benchmark [#19523](https://github.com/apache/datafusion/pull/19523) (haohuaijin) +- Compute Dynamic Filters only when a consumer supports them [#19546](https://github.com/apache/datafusion/pull/19546) (LiaCastaneda) +- Various refactors to string functions [#19402](https://github.com/apache/datafusion/pull/19402) (Jefffrey) +- Implement `partition_statistics` API for `NestedLoopJoinExec` [#19468](https://github.com/apache/datafusion/pull/19468) (kumarUjjawal) +- Replace deprecated structopt with clap in datafusion-benchmarks [#19492](https://github.com/apache/datafusion/pull/19492) (Yuvraj-cyborg) +- Refactor duplicate code in `type_coercion/functions.rs` [#19518](https://github.com/apache/datafusion/pull/19518) (Jefffrey) +- chore(deps): bump taiki-e/install-action from 2.65.8 to 2.65.10 [#19578](https://github.com/apache/datafusion/pull/19578) (dependabot[bot]) +- perf: Improve performance of hex encoding in spark functions [#19586](https://github.com/apache/datafusion/pull/19586) (shashidhar-bm) +- Add left function benchmark [#19600](https://github.com/apache/datafusion/pull/19600) (viirya) +- chore: Add TPCDS benchmark comparison for PR [#19552](https://github.com/apache/datafusion/pull/19552) (comphead) +- chore(deps): bump taiki-e/install-action from 2.65.10 to 2.65.11 [#19601](https://github.com/apache/datafusion/pull/19601) (dependabot[bot]) +- chore: bump testcontainers-modules to 0.14 and remove testcontainers dep [#19620](https://github.com/apache/datafusion/pull/19620) (Jefffrey) +- Validate parquet writer version [#19515](https://github.com/apache/datafusion/pull/19515) (AlyAbdelmoneim) +- chore(deps): bump insta from 1.45.0 to 1.46.0 [#19643](https://github.com/apache/datafusion/pull/19643) (dependabot[bot]) +- chore(deps): bump taiki-e/install-action from 2.65.11 to 2.65.13 [#19646](https://github.com/apache/datafusion/pull/19646) (dependabot[bot]) +- chore(deps): bump tracing from 0.1.43 to 0.1.44 [#19644](https://github.com/apache/datafusion/pull/19644) (dependabot[bot]) +- chore(deps): bump syn from 2.0.111 to 2.0.113 [#19645](https://github.com/apache/datafusion/pull/19645) (dependabot[bot]) +- Refactor `percentile_cont` to clarify support input types [#19611](https://github.com/apache/datafusion/pull/19611) (Jefffrey) +- Add a protection to release candidate branch 52 [#19660](https://github.com/apache/datafusion/pull/19660) (xudong963) +- Downgrade aws-smithy-runtime, update `rust_decimal`, ignore RUSTSEC-2026-0001 to get clean CI [#19657](https://github.com/apache/datafusion/pull/19657) (alamb) +- Update dependencies [#19667](https://github.com/apache/datafusion/pull/19667) (alamb) +- Refactor PartitionedFile: add ordering field and new_from_meta constructor [#19596](https://github.com/apache/datafusion/pull/19596) (adriangb) +- Remove coalesce batches rule and deprecate CoalesceBatchesExec [#19622](https://github.com/apache/datafusion/pull/19622) (feniljain) +- Perf: Optimize `substring_index` via single-byte fast path and direct indexing [#19590](https://github.com/apache/datafusion/pull/19590) (lyne7-sc) +- refactor: Use `Signature::coercible` for isnan/iszero [#19604](https://github.com/apache/datafusion/pull/19604) (kumarUjjawal) +- Parquet: Push down supported list predicates (array_has/any/all) during decoding [#19545](https://github.com/apache/datafusion/pull/19545) (kosiew) +- Remove dependency on `rust_decimal`, remove ignore of `RUSTSEC-2026-0001` [#19666](https://github.com/apache/datafusion/pull/19666) (alamb) +- Store example data directly inside the datafusion-examples (#19141) [#19319](https://github.com/apache/datafusion/pull/19319) (cj-zhukov) +- minor: More comments to `ParquetOpener::open()` [#19677](https://github.com/apache/datafusion/pull/19677) (2010YOUY01) +- Feat: Allow pow with negative & non-integer exponent on decimals [#19369](https://github.com/apache/datafusion/pull/19369) (Yuvraj-cyborg) +- chore(deps): bump taiki-e/install-action from 2.65.13 to 2.65.15 [#19676](https://github.com/apache/datafusion/pull/19676) (dependabot[bot]) +- Refactor cache APIs to support ordering information [#19597](https://github.com/apache/datafusion/pull/19597) (adriangb) +- Record sort order when writing Parquet with WITH ORDER [#19595](https://github.com/apache/datafusion/pull/19595) (adriangb) +- implement var distinct [#19706](https://github.com/apache/datafusion/pull/19706) (thinh2) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 67 dependabot[bot] + 38 Andrew Lamb + 36 Jeffrey Vo + 35 Kumar Ujjawal + 34 Adrian Garcia Badaracco + 22 Tim Saucer + 19 Yongting You + 13 Sergey Zhukov + 11 Pepijn Van Eeckhoudt + 11 kosiew + 10 Daniël Heres + 10 Dhanush + 10 Oleks V + 8 Geoffrey Claude + 8 Raz Luvaton + 7 Andy Grove + 7 Liang-Chi Hsieh + 7 Qi Zhu + 6 Peter Nguyen + 6 Shashidhar B M + 5 Alan Tang + 5 Alex Huang + 5 Bruce Ritchie + 5 Gene Bordegaray + 5 Nuno Faria + 5 Sriram Sundar + 4 Blake Orth + 4 Thomas Tanon + 4 Yuvraj + 4 theirix + 3 Aryan Bagade + 3 Chakkk + 3 Emily Matheys + 3 Huaijin + 3 Khanh Duong + 3 Kushagra S + 3 Vedic Chawla + 3 feniljain + 3 harshit saini + 3 jizezhang + 3 shifluxxc + 3 xonx + 3 xudong.w + 2 Carlos Hurtado + 2 Chen Chongchen + 2 Cora Sutton + 2 Haresh Khanna + 2 Lía Adriana + 2 Manish Kumar + 2 Martin Grigorov + 2 Matthew Kim + 2 Namgung Chan + 2 Nimalan + 2 Nithurshen + 2 Rosai + 2 Shubham Yadav + 2 Trent Hauck + 2 Vegard Stikbakke + 2 Vrishabh + 2 Xander + 2 chakkk309 + 2 mag1c1an1 + 2 nlimpid + 2 yqrz + 1 Adam Curtis + 1 Aly Abdelmoneim + 1 Andrey Velichkevich + 1 Arpit Bandejiya + 1 Bharathwaj G + 1 Bipul Lamsal + 1 Clement de Groc + 1 Congxian Qiu + 1 David López + 1 David Stancu + 1 Devanshu + 1 Dongpo Liu + 1 EeshanBembi + 1 Eshaan Gupta + 1 Ethan Urbanski + 1 Frederic Branczyk + 1 Gabriel + 1 Gohlub + 1 Heran Lin + 1 James Xu + 1 Jatin Kumar singh + 1 Karan Pradhan + 1 Karthik Kondamudi + 1 Kazantsev Maksim + 1 Marco Neumann + 1 Matt Butrovich + 1 Max Burke + 1 Michele Vigilante + 1 Mikhail Zabaluev + 1 Mohit rao + 1 Ning Sun + 1 Peter Lee + 1 Quoc Anh + 1 Ram + 1 Randy + 1 Renan GEHAN + 1 Ruchir Khaitan + 1 Samyak Sarnayak + 1 Shiv Bhatia + 1 Smith Cruise + 1 Smotrov Oleksii + 1 Solari Systems + 1 Suhail + 1 T2MIX + 1 Tal Glanzman + 1 Tamar + 1 Tim-53 + 1 Tobias Schwarzinger + 1 Ujjwal Kumar Tiwari + 1 Willem Verstraeten + 1 YuraLitvinov + 1 bubulalabu + 1 delamarch3 + 1 hsiang-c + 1 r1b + 1 rin + 1 xavlee +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. diff --git a/dev/changelog/52.1.0.md b/dev/changelog/52.1.0.md new file mode 100644 index 000000000000..97a1435c41a4 --- /dev/null +++ b/dev/changelog/52.1.0.md @@ -0,0 +1,46 @@ + + +# Apache DataFusion 52.1.0 Changelog + +This release consists of 3 commits from 3 contributors. See credits at the end of this changelog for more information. + +See the [upgrade guide](https://datafusion.apache.org/library-user-guide/upgrading.html) for information on how to upgrade from previous versions. + +**Documentation updates:** + +- [branch-52] Fix Internal error: Assertion failed: !self.finished: LimitedBatchCoalescer (#19785) [#19836](https://github.com/apache/datafusion/pull/19836) (alamb) + +**Other:** + +- [branch-52] fix: expose `ListFilesEntry` [#19818](https://github.com/apache/datafusion/pull/19818) (lonless9) +- [branch 52] Fix grouping set subset satisfaction [#19855](https://github.com/apache/datafusion/pull/19855) (gabotechs) +- Add BatchAdapter to simplify using PhysicalExprAdapter / Projector [#19877](https://github.com/apache/datafusion/pull/19877) (alamb) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 1 Andrew Lamb + 1 Gabriel + 1 XL Liang +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. diff --git a/dev/changelog/53.0.0.md b/dev/changelog/53.0.0.md new file mode 100644 index 000000000000..11820f3caad7 --- /dev/null +++ b/dev/changelog/53.0.0.md @@ -0,0 +1,640 @@ + + +# Apache DataFusion 53.0.0 Changelog + +This release consists of 475 commits from 114 contributors. See credits at the end of this changelog for more information. + +See the [upgrade guide](https://datafusion.apache.org/library-user-guide/upgrading.html) for information on how to upgrade from previous versions. + +**Breaking changes:** + +- Allow logical optimizer to be run without evaluating now() & refactor SimplifyInfo [#19505](https://github.com/apache/datafusion/pull/19505) (adriangb) +- Make default ListingFilesCache table scoped [#19616](https://github.com/apache/datafusion/pull/19616) (jizezhang) +- chore(deps): Update sqlparser to 0.60 [#19672](https://github.com/apache/datafusion/pull/19672) (Standing-Man) +- Do not require mut in memory reservation methods [#19759](https://github.com/apache/datafusion/pull/19759) (gabotechs) +- refactor: make PhysicalExprAdatperFactory::create fallible [#20017](https://github.com/apache/datafusion/pull/20017) (niebayes) +- Add `ScalarValue::RunEndEncoded` variant [#19895](https://github.com/apache/datafusion/pull/19895) (Jefffrey) +- minor: remove unused crypto functions & narrow public API [#20045](https://github.com/apache/datafusion/pull/20045) (Jefffrey) +- Wrap immutable plan parts into Arc (make creating `ExecutionPlan`s less costly) [#19893](https://github.com/apache/datafusion/pull/19893) (askalt) +- feat: Support planning subqueries with OuterReferenceColumn belongs to non-adjacent outer relations [#19930](https://github.com/apache/datafusion/pull/19930) (mkleen) +- Remove the statistics() api in execution plan [#20319](https://github.com/apache/datafusion/pull/20319) (xudong963) +- Remove recursive const check in `simplify_const_expr` [#20234](https://github.com/apache/datafusion/pull/20234) (AdamGS) +- Cache `PlanProperties`, add fast-path for `with_new_children` [#19792](https://github.com/apache/datafusion/pull/19792) (askalt) +- [branch-53] feat: parse `JsonAccess` as a binary operator, add `Operator::Colon` [#20717](https://github.com/apache/datafusion/pull/20717) (Samyak2) + +**Performance related:** + +- perf: optimize `HashTableLookupExpr::evaluate` [#19602](https://github.com/apache/datafusion/pull/19602) (UBarney) +- perf: Improve performance of `split_part` [#19570](https://github.com/apache/datafusion/pull/19570) (andygrove) +- Optimize `Nullstate` / accumulators [#19625](https://github.com/apache/datafusion/pull/19625) (Dandandan) +- perf: optimize `NthValue` when `ignore_nulls` is true [#19496](https://github.com/apache/datafusion/pull/19496) (mzabaluev) +- Optimize `concat/concat_ws` scalar path by pre-allocating memory [#19547](https://github.com/apache/datafusion/pull/19547) (lyne7-sc) +- perf: optimize left function by eliminating double chars() iteration [#19571](https://github.com/apache/datafusion/pull/19571) (viirya) +- perf: Optimize floor and ceil scalar performance [#19752](https://github.com/apache/datafusion/pull/19752) (kumarUjjawal) +- perf: improve performance of `spark hex` function [#19738](https://github.com/apache/datafusion/pull/19738) (lyne7-sc) +- perf: Optimize initcap scalar performance [#19776](https://github.com/apache/datafusion/pull/19776) (kumarUjjawal) +- Row group limit pruning for row groups that entirely match predicates [#18868](https://github.com/apache/datafusion/pull/18868) (xudong963) +- perf: Optimize trunc scalar performance [#19788](https://github.com/apache/datafusion/pull/19788) (kumarUjjawal) +- perf: optimize `spark_hex` dictionary path by avoiding dictionary expansion [#19832](https://github.com/apache/datafusion/pull/19832) (lyne7-sc) +- Add FilterExecBuilder to avoid recomputing properties multiple times [#19854](https://github.com/apache/datafusion/pull/19854) (adriangb) +- perf: Optimize round scalar performance [#19831](https://github.com/apache/datafusion/pull/19831) (kumarUjjawal) +- perf: Optimize signum scalar performance with fast path [#19871](https://github.com/apache/datafusion/pull/19871) (kumarUjjawal) +- perf: Optimize scalar performance for cot [#19888](https://github.com/apache/datafusion/pull/19888) (kumarUjjawal) +- perf: Optimize scalar fast path for iszero [#19919](https://github.com/apache/datafusion/pull/19919) (kumarUjjawal) +- Misc hash / hash aggregation performance improvements [#19910](https://github.com/apache/datafusion/pull/19910) (Dandandan) +- perf: Optimize scalar path for ascii function [#19951](https://github.com/apache/datafusion/pull/19951) (kumarUjjawal) +- perf: Optimize factorial scalar path [#19949](https://github.com/apache/datafusion/pull/19949) (kumarUjjawal) +- Speedup statistics_from_parquet_metadata [#20004](https://github.com/apache/datafusion/pull/20004) (Dandandan) +- perf: improve performance of `array_remove`, `array_remove_n` and `array_remove_all` functions [#19996](https://github.com/apache/datafusion/pull/19996) (lyne7-sc) +- perf: Optimize ArrowBytesViewMap with direct view access [#19975](https://github.com/apache/datafusion/pull/19975) (Tushar7012) +- perf: Optimize repeat function for scalar and array fast [#19976](https://github.com/apache/datafusion/pull/19976) (kumarUjjawal) +- perf: Push down join key filters for LEFT/RIGHT/ANTI joins [#19918](https://github.com/apache/datafusion/pull/19918) (nuno-faria) +- perf: Optimize scalar path for chr function [#20073](https://github.com/apache/datafusion/pull/20073) (kumarUjjawal) +- perf: improve performance of `array_repeat` function [#20049](https://github.com/apache/datafusion/pull/20049) (lyne7-sc) +- perf: optimise right for byte access and StringView [#20069](https://github.com/apache/datafusion/pull/20069) (theirix) +- Optimize `PhysicalExprSimplifier` [#20111](https://github.com/apache/datafusion/pull/20111) (AdamGS) +- Improve performance of `CASE WHEN x THEN y ELSE NULL` expressions [#20097](https://github.com/apache/datafusion/pull/20097) (pepijnve) +- perf: Optimize scalar fast path of to_hex function [#20112](https://github.com/apache/datafusion/pull/20112) (kumarUjjawal) +- perf: Optimize scalar fast path & write() encoding for sha2 [#20116](https://github.com/apache/datafusion/pull/20116) (kumarUjjawal) +- perf: improve performance of `array_union`/`array_intersect` with batched row conversion [#20243](https://github.com/apache/datafusion/pull/20243) (lyne7-sc) +- perf: various optimizations to eliminate branch misprediction in hash_utils [#20168](https://github.com/apache/datafusion/pull/20168) (notashes) +- perf: Optimize strpos() for ASCII-only inputs [#20295](https://github.com/apache/datafusion/pull/20295) (neilconway) +- perf: Optimize compare_element_to_list [#20323](https://github.com/apache/datafusion/pull/20323) (neilconway) +- perf: Optimize replace() fastpath by avoiding alloc [#20344](https://github.com/apache/datafusion/pull/20344) (neilconway) +- perf: optimize `array_distinct` with batched row conversion [#20364](https://github.com/apache/datafusion/pull/20364) (lyne7-sc) +- perf: Optimize scalar fast path of atan2 [#20336](https://github.com/apache/datafusion/pull/20336) (kumarUjjawal) +- perf: Optimize concat()/concat_ws() UDFs [#20317](https://github.com/apache/datafusion/pull/20317) (neilconway) +- perf: Optimize translate() UDF for scalar inputs [#20305](https://github.com/apache/datafusion/pull/20305) (neilconway) +- perf: Optimize `array_has()` for scalar needle [#20374](https://github.com/apache/datafusion/pull/20374) (neilconway) +- perf: Optimize lpad, rpad for ASCII strings [#20278](https://github.com/apache/datafusion/pull/20278) (neilconway) +- perf: Optimize trim UDFs for single-character trims [#20328](https://github.com/apache/datafusion/pull/20328) (neilconway) +- perf: Optimize scalar fast path for `regexp_like` and rejects g inside combined flags like ig [#20354](https://github.com/apache/datafusion/pull/20354) (kumarUjjawal) +- perf: Use zero-copy slice instead of take kernel in sort merge join [#20463](https://github.com/apache/datafusion/pull/20463) (andygrove) +- perf: Optimize `initcap()` [#20352](https://github.com/apache/datafusion/pull/20352) (neilconway) +- perf: Fix quadratic behavior of `to_array_of_size` [#20459](https://github.com/apache/datafusion/pull/20459) (neilconway) +- perf: Optimize `array_has_any()` with scalar arg [#20385](https://github.com/apache/datafusion/pull/20385) (neilconway) +- perf: Use Hashbrown for array_distinct [#20538](https://github.com/apache/datafusion/pull/20538) (neilconway) +- perf: Cache num_output_rows in sort merge join to avoid O(n) recount [#20478](https://github.com/apache/datafusion/pull/20478) (andygrove) +- perf: Optimize heap handling in TopK operator [#20556](https://github.com/apache/datafusion/pull/20556) (AdamGS) +- perf: Optimize `array_position` for scalar needle [#20532](https://github.com/apache/datafusion/pull/20532) (neilconway) +- perf: Use Arrow vectorized eq kernel for IN list with column references [#20528](https://github.com/apache/datafusion/pull/20528) (zhangxffff) +- perf: Optimize `array_agg()` using `GroupsAccumulator` [#20504](https://github.com/apache/datafusion/pull/20504) (neilconway) +- perf: Optimize `array_to_string()`, support more types [#20553](https://github.com/apache/datafusion/pull/20553) (neilconway) +- [branch-53] perf: sort replace free()->try_grow() pattern with try_resize() to reduce memory pool interactions [#20733](https://github.com/apache/datafusion/pull/20733) (mbutrovich) + +**Implemented enhancements:** + +- feat: add list_files_cache table function for `datafusion-cli` [#19388](https://github.com/apache/datafusion/pull/19388) (jizezhang) +- feat: implement metrics for AsyncFuncExec [#19626](https://github.com/apache/datafusion/pull/19626) (feniljain) +- feat: split BatchPartitioner::try_new into hash and round-robin constructors [#19668](https://github.com/apache/datafusion/pull/19668) (mohit7705) +- feat: add Time type support to date_trunc function [#19640](https://github.com/apache/datafusion/pull/19640) (kumarUjjawal) +- feat: Allow log with non-integer base on decimals [#19372](https://github.com/apache/datafusion/pull/19372) (Yuvraj-cyborg) +- feat(spark): implement array_repeat function [#19702](https://github.com/apache/datafusion/pull/19702) (cht42) +- feat(spark): Implement collect_list/collect_set aggregate functions [#19699](https://github.com/apache/datafusion/pull/19699) (cht42) +- feat: implement Spark size function for arrays and maps [#19592](https://github.com/apache/datafusion/pull/19592) (CuteChuanChuan) +- feat: support Set Comparison Subquery [#19109](https://github.com/apache/datafusion/pull/19109) (waynexia) +- feat(spark): implement array slice function [#19811](https://github.com/apache/datafusion/pull/19811) (cht42) +- feat(spark): implement substring function [#19805](https://github.com/apache/datafusion/pull/19805) (cht42) +- feat: Add support for 'isoyear' in date_part function [#19821](https://github.com/apache/datafusion/pull/19821) (cht42) +- feat: support `SELECT DISTINCT id FROM t ORDER BY id LIMIT n` query use GroupedTopKAggregateStream [#19653](https://github.com/apache/datafusion/pull/19653) (haohuaijin) +- feat(spark): add trunc, date_trunc and time_trunc functions [#19829](https://github.com/apache/datafusion/pull/19829) (cht42) +- feat(spark): implement Spark `date_diff` function [#19845](https://github.com/apache/datafusion/pull/19845) (cht42) +- feat(spark): implement add_months function [#19711](https://github.com/apache/datafusion/pull/19711) (cht42) +- feat: support pushdown alias on dynamic filter with `ProjectionExec` [#19404](https://github.com/apache/datafusion/pull/19404) (discord9) +- feat(spark): add `base64` and `unbase64` functions [#19968](https://github.com/apache/datafusion/pull/19968) (cht42) +- feat: Show the number of matched Parquet pages in `DataSourceExec` [#19977](https://github.com/apache/datafusion/pull/19977) (nuno-faria) +- feat(spark): Add `SessionStateBuilderSpark` to datafusion-spark [#19865](https://github.com/apache/datafusion/pull/19865) (cht42) +- feat(spark): implement `from/to_utc_timestamp` functions [#19880](https://github.com/apache/datafusion/pull/19880) (cht42) +- feat(spark): implement `StringView` for `SparkConcat` [#19984](https://github.com/apache/datafusion/pull/19984) (aryan-212) +- feat(spark): add unix date and timestamp functions [#19892](https://github.com/apache/datafusion/pull/19892) (cht42) +- feat: implement protobuf converter trait to allow control over serialization and deserialization processes [#19437](https://github.com/apache/datafusion/pull/19437) (timsaucer) +- feat: optimise copying in `left` for Utf8 and LargeUtf8 [#19980](https://github.com/apache/datafusion/pull/19980) (theirix) +- feat: support Spark-compatible abs math function part 2 - ANSI mode [#18828](https://github.com/apache/datafusion/pull/18828) (hsiang-c) +- feat: add AggregateMode::PartialReduce for tree-reduce aggregation [#20019](https://github.com/apache/datafusion/pull/20019) (njsmith) +- feat: add ExpressionPlacement enum for optimizer expression placement decisions [#20065](https://github.com/apache/datafusion/pull/20065) (adriangb) +- feat: support f16 in coercion logic [#18944](https://github.com/apache/datafusion/pull/18944) (Jefffrey) +- feat: unify left and right functions and benches [#20114](https://github.com/apache/datafusion/pull/20114) (theirix) +- feat(spark): Adds negative spark function [#20006](https://github.com/apache/datafusion/pull/20006) (SubhamSinghal) +- feat: support limited deletion [#20137](https://github.com/apache/datafusion/pull/20137) (askalt) +- feat: Pushdown filters through `UnionExec` nodes [#20145](https://github.com/apache/datafusion/pull/20145) (haohuaijin) +- feat: support Spark-compatible `string_to_map` function [#20120](https://github.com/apache/datafusion/pull/20120) (unknowntpo) +- feat: Add `partition_stats()` for `EmptyExec` [#20203](https://github.com/apache/datafusion/pull/20203) (jonathanc-n) +- feat: add ExtractLeafExpressions optimizer rule for get_field pushdown [#20117](https://github.com/apache/datafusion/pull/20117) (adriangb) +- feat: Push limit into hash join [#20228](https://github.com/apache/datafusion/pull/20228) (jonathanc-n) +- feat: Optimize hash util for `MapArray` [#20179](https://github.com/apache/datafusion/pull/20179) (jonathanc-n) +- feat: Implement Spark `bitmap_bit_position` function [#20275](https://github.com/apache/datafusion/pull/20275) (kazantsev-maksim) +- feat: support sqllogictest output coloring [#20368](https://github.com/apache/datafusion/pull/20368) (theirix) +- feat: support Spark-compatible `json_tuple` function [#20412](https://github.com/apache/datafusion/pull/20412) (CuteChuanChuan) +- feat: Implement Spark `bitmap_bucket_number` function [#20288](https://github.com/apache/datafusion/pull/20288) (kazantsev-maksim) +- feat: support `arrays_zip` function [#20440](https://github.com/apache/datafusion/pull/20440) (comphead) +- feat: Implement Spark `bin` function [#20479](https://github.com/apache/datafusion/pull/20479) (kazantsev-maksim) +- feat: support extension planner for `TableScan` [#20548](https://github.com/apache/datafusion/pull/20548) (linhr) + +**Fixed bugs:** + +- fix: Return Int for Date - Date instead of duration [#19563](https://github.com/apache/datafusion/pull/19563) (kumarUjjawal) +- fix: DynamicFilterPhysicalExpr violates Hash/Eq contract [#19659](https://github.com/apache/datafusion/pull/19659) (kumarUjjawal) +- fix: unnest struct field with an alias failed with internal error [#19698](https://github.com/apache/datafusion/pull/19698) (kumarUjjawal) +- fix(accumulators): preserve state in evaluate() for window frame queries [#19618](https://github.com/apache/datafusion/pull/19618) (GaneshPatil7517) +- fix: Don't treat quoted column names as placeholder variables in SQL [#19339](https://github.com/apache/datafusion/pull/19339) (pmallex) +- fix: enhance CTE resolution with identifier normalization [#19519](https://github.com/apache/datafusion/pull/19519) (kysshsy) +- feat: Add null-aware anti join support [#19635](https://github.com/apache/datafusion/pull/19635) (viirya) +- fix: expose `ListFilesEntry` [#19804](https://github.com/apache/datafusion/pull/19804) (lonless9) +- fix: trunc function with precision uses round instead of trunc semantics [#19794](https://github.com/apache/datafusion/pull/19794) (kumarUjjawal) +- fix: calculate total seconds from interval fields for `extract(epoch)` [#19807](https://github.com/apache/datafusion/pull/19807) (lemorage) +- fix: predicate cache stats calculation [#19561](https://github.com/apache/datafusion/pull/19561) (feniljain) +- fix: preserve state in DistinctMedianAccumulator::evaluate() for window frame queries [#19887](https://github.com/apache/datafusion/pull/19887) (kumarUjjawal) +- fix: null in array_agg with DISTINCT and IGNORE [#19736](https://github.com/apache/datafusion/pull/19736) (davidlghellin) +- fix: union should retrun error instead of panic when input schema's len different [#19922](https://github.com/apache/datafusion/pull/19922) (haohuaijin) +- fix: change token consumption to pick to test on EOF in parser [#19927](https://github.com/apache/datafusion/pull/19927) (askalt) +- fix: maintain inner list nullability for `array_sort` [#19948](https://github.com/apache/datafusion/pull/19948) (Jefffrey) +- fix: Make `generate_series` return an empty set with invalid ranges [#19999](https://github.com/apache/datafusion/pull/19999) (nuno-faria) +- fix: return correct length array for scalar null input to `calculate_binary_math` [#19861](https://github.com/apache/datafusion/pull/19861) (Jefffrey) +- fix: respect DataFrameWriteOptions::with_single_file_output for paths without extensions [#19931](https://github.com/apache/datafusion/pull/19931) (kumarUjjawal) +- fix: correct weight handling in approx_percentile_cont_with_weight [#19941](https://github.com/apache/datafusion/pull/19941) (sesteves) +- fix: The limit_pushdown physical optimization rule removes limits in some cases leading to incorrect results [#20048](https://github.com/apache/datafusion/pull/20048) (masonh22) +- Add duplicate name error reproducer [#20106](https://github.com/apache/datafusion/pull/20106) (gabotechs) +- fix: filter pushdown when merge filter [#20110](https://github.com/apache/datafusion/pull/20110) (haohuaijin) +- fix: Make `serialize_to_file` test cross platform [#20147](https://github.com/apache/datafusion/pull/20147) (nuno-faria) +- fix: regression of `dict_id` in physical plan proto [#20063](https://github.com/apache/datafusion/pull/20063) (kumarUjjawal) +- fix: panic in ListingTableFactory when session is not SessionState [#20139](https://github.com/apache/datafusion/pull/20139) (evangelisilva) +- fix: update comment on FilterPushdownPropagation [#20040](https://github.com/apache/datafusion/pull/20040) (niebayes) +- fix: datatype_is_logically_equal for dictionaries [#20153](https://github.com/apache/datafusion/pull/20153) (dd-annarose) +- fix: Avoid integer overflow in split_part() [#20198](https://github.com/apache/datafusion/pull/20198) (neilconway) +- fix: Fix panic in regexp_like() [#20200](https://github.com/apache/datafusion/pull/20200) (neilconway) +- fix: Handle NULL inputs correctly in find_in_set() [#20209](https://github.com/apache/datafusion/pull/20209) (neilconway) +- fix: Ensure columns are casted to the correct names with Unions [#20146](https://github.com/apache/datafusion/pull/20146) (nuno-faria) +- fix: Avoid assertion failure on divide-by-zero [#20216](https://github.com/apache/datafusion/pull/20216) (neilconway) +- fix: Throw coercion error for `LIKE` operations for nested types. [#20212](https://github.com/apache/datafusion/pull/20212) (jonathanc-n) +- fix: disable dynamic filter pushdown for non min/max aggregates [#20279](https://github.com/apache/datafusion/pull/20279) (notashes) +- fix: Avoid integer overflow in substr() [#20199](https://github.com/apache/datafusion/pull/20199) (neilconway) +- fix: Fix scalar broadcast for to_timestamp() [#20224](https://github.com/apache/datafusion/pull/20224) (neilconway) +- fix: Add integer check for bitwise coercion [#20241](https://github.com/apache/datafusion/pull/20241) (Acfboy) +- fix: percentile_cont interpolation causes NaN for f16 input [#20208](https://github.com/apache/datafusion/pull/20208) (kumarUjjawal) +- fix: validate inter-file ordering in eq_properties() [#20329](https://github.com/apache/datafusion/pull/20329) (adriangb) +- fix: update filter predicates for min/max aggregates only if bounds change [#20380](https://github.com/apache/datafusion/pull/20380) (notashes) +- fix: Handle Utf8View and LargeUtf8 separators in concat_ws [#20361](https://github.com/apache/datafusion/pull/20361) (neilconway) +- fix: HashJoin panic with dictionary-encoded columns in multi-key joins [#20441](https://github.com/apache/datafusion/pull/20441) (Tim-53) +- fix: handle out of range errors in DATE_BIN instead of panicking [#20221](https://github.com/apache/datafusion/pull/20221) (mishop-15) +- fix: prevent duplicate alias collision with user-provided \_\_datafusion_extracted names [#20432](https://github.com/apache/datafusion/pull/20432) (adriangb) +- fix: SortMergeJoin don't wait for all input before emitting [#20482](https://github.com/apache/datafusion/pull/20482) (rluvaton) +- fix: `cardinality()` of an empty array should be zero [#20533](https://github.com/apache/datafusion/pull/20533) (neilconway) +- fix: Unaccounted spill sort in row_hash [#20314](https://github.com/apache/datafusion/pull/20314) (EmilyMatt) +- fix: IS NULL panic with invalid function without input arguments [#20306](https://github.com/apache/datafusion/pull/20306) (Acfboy) +- fix: handle empty delimiter in split_part (closes #20503) [#20542](https://github.com/apache/datafusion/pull/20542) (gferrate) +- fix(substrait): Correctly parse field references in subqueries [#20439](https://github.com/apache/datafusion/pull/20439) (neilconway) +- fix: increase ROUND decimal precision to prevent overflow truncation [#19926](https://github.com/apache/datafusion/pull/19926) (kumarUjjawal) +- fix: Fix `array_to_string` with columnar third arg [#20536](https://github.com/apache/datafusion/pull/20536) (neilconway) +- fix: Fix and Refactor Spark `shuffle` function [#20484](https://github.com/apache/datafusion/pull/20484) (erenavsarogullari) + +**Documentation updates:** + +- perfect hash join [#19411](https://github.com/apache/datafusion/pull/19411) (UBarney) +- docs: Fix two small issues in introduction.md [#19712](https://github.com/apache/datafusion/pull/19712) (AdamGS) +- docs: Refine Communication documentation to highlight Discord [#19714](https://github.com/apache/datafusion/pull/19714) (alamb) +- chore(deps): bump maturin from 1.10.2 to 1.11.5 in /docs [#19740](https://github.com/apache/datafusion/pull/19740) (dependabot[bot]) +- chore: remove LZO Parquet compression [#19726](https://github.com/apache/datafusion/pull/19726) (kumarUjjawal) +- Update 52.0.0 release version number and changelog [#19767](https://github.com/apache/datafusion/pull/19767) (xudong963) +- Update the upgrading.md [#19769](https://github.com/apache/datafusion/pull/19769) (xudong963) +- chore: update copyright notice year [#19758](https://github.com/apache/datafusion/pull/19758) (Jefffrey) +- doc: Add an auto-generated dependency graph for internal crates [#19280](https://github.com/apache/datafusion/pull/19280) (2010YOUY01) +- Docs: Fix some links in docs [#19834](https://github.com/apache/datafusion/pull/19834) (alamb) +- Docs: add additional links to blog posts [#19833](https://github.com/apache/datafusion/pull/19833) (alamb) +- Ensure null inputs to array setop functions return null output [#19683](https://github.com/apache/datafusion/pull/19683) (Jefffrey) +- chore(deps): bump sphinx from 8.2.3 to 9.1.0 in /docs [#19647](https://github.com/apache/datafusion/pull/19647) (dependabot[bot]) +- Fix struct casts to align fields by name (prevent positional mis-casts) [#19674](https://github.com/apache/datafusion/pull/19674) (kosiew) +- chore(deps): bump setuptools from 80.9.0 to 80.10.1 in /docs [#19988](https://github.com/apache/datafusion/pull/19988) (dependabot[bot]) +- minor: Fix doc about `write_batch_size` [#19979](https://github.com/apache/datafusion/pull/19979) (nuno-faria) +- Fix broken links in the documentation [#19964](https://github.com/apache/datafusion/pull/19964) (alamb) +- minor: Add favicon [#20000](https://github.com/apache/datafusion/pull/20000) (nuno-faria) +- docs: Fix some broken / missing links in the DataFusion documentation [#19958](https://github.com/apache/datafusion/pull/19958) (alamb) +- chore(deps): bump setuptools from 80.10.1 to 80.10.2 in /docs [#20022](https://github.com/apache/datafusion/pull/20022) (dependabot[bot]) +- docs: Automatically update DataFusion version in docs [#20001](https://github.com/apache/datafusion/pull/20001) (nuno-faria) +- docs: update data_types.md to reflect current Arrow type mappings [#20072](https://github.com/apache/datafusion/pull/20072) (karuppuchamysuresh) +- Runs-on for `linux-build-lib` and `linux-test` (2X faster CI) [#20107](https://github.com/apache/datafusion/pull/20107) (blaginin) +- Disallow positional struct casting when field names don’t overlap [#19955](https://github.com/apache/datafusion/pull/19955) (kosiew) +- docs: fix docstring formatting [#20158](https://github.com/apache/datafusion/pull/20158) (Jefffrey) +- Break upgrade guides into separate pages [#20183](https://github.com/apache/datafusion/pull/20183) (mishop-15) +- Better document the relationship between `FileFormat::projection` / `FileFormat::filter` and `FileScanConfig::Statistics` [#20188](https://github.com/apache/datafusion/pull/20188) (alamb) +- Document the relationship between FileFormat::projection / FileFormat::filter and FileScanConfig::output_ordering [#20196](https://github.com/apache/datafusion/pull/20196) (alamb) +- More documentation on `FileSource::table_schema` and `FileSource::projection` [#20242](https://github.com/apache/datafusion/pull/20242) (alamb) +- chore(deps): bump setuptools from 80.10.2 to 82.0.0 in /docs [#20255](https://github.com/apache/datafusion/pull/20255) (dependabot[bot]) +- docs: fix typos and improve wording in README [#20301](https://github.com/apache/datafusion/pull/20301) (iampratap7997-dot) +- Reduce ExtractLeafExpressions optimizer overhead with fast pre-scan [#20341](https://github.com/apache/datafusion/pull/20341) (adriangb) +- chore(deps): bump maturin from 1.11.5 to 1.12.2 in /docs [#20400](https://github.com/apache/datafusion/pull/20400) (dependabot[bot]) +- Migrate Python usage to uv workspace [#20414](https://github.com/apache/datafusion/pull/20414) (adriangb) +- test: Extend Spark Array functions: `array_repeat `, `shuffle` and `slice` test coverage [#20420](https://github.com/apache/datafusion/pull/20420) (erenavsarogullari) +- Runs-on for more actions [#20274](https://github.com/apache/datafusion/pull/20274) (blaginin) +- docs: Document that adding new optimizer rules are expensive [#20348](https://github.com/apache/datafusion/pull/20348) (alamb) +- add redirect for old upgrading.html URL to fix broken changelog links [#20582](https://github.com/apache/datafusion/pull/20582) (mishop-15) +- Upgrade DataFusion to arrow-rs/parquet 58.0.0 / `object_store` 0.13.0 [#19728](https://github.com/apache/datafusion/pull/19728) (alamb) +- Document guidance on how to evaluate breaking API changes [#20584](https://github.com/apache/datafusion/pull/20584) (alamb) +- [branch-53] chore: prepare 53 release [#20649](https://github.com/apache/datafusion/pull/20649) (comphead) + +**Other:** + +- [branch-53] chore: Add branch protection (comphead) +- Add a protection to release candidate branch 52 [#19660](https://github.com/apache/datafusion/pull/19660) (xudong963) +- Downgrade aws-smithy-runtime, update `rust_decimal`, ignore RUSTSEC-2026-0001 to get clean CI [#19657](https://github.com/apache/datafusion/pull/19657) (alamb) +- Update dependencies [#19667](https://github.com/apache/datafusion/pull/19667) (alamb) +- Refactor PartitionedFile: add ordering field and new_from_meta constructor [#19596](https://github.com/apache/datafusion/pull/19596) (adriangb) +- Remove coalesce batches rule and deprecate CoalesceBatchesExec [#19622](https://github.com/apache/datafusion/pull/19622) (feniljain) +- Perf: Optimize `substring_index` via single-byte fast path and direct indexing [#19590](https://github.com/apache/datafusion/pull/19590) (lyne7-sc) +- refactor: Use `Signature::coercible` for isnan/iszero [#19604](https://github.com/apache/datafusion/pull/19604) (kumarUjjawal) +- Parquet: Push down supported list predicates (array_has/any/all) during decoding [#19545](https://github.com/apache/datafusion/pull/19545) (kosiew) +- Remove dependency on `rust_decimal`, remove ignore of `RUSTSEC-2026-0001` [#19666](https://github.com/apache/datafusion/pull/19666) (alamb) +- Store example data directly inside the datafusion-examples (#19141) [#19319](https://github.com/apache/datafusion/pull/19319) (cj-zhukov) +- minor: More comments to `ParquetOpener::open()` [#19677](https://github.com/apache/datafusion/pull/19677) (2010YOUY01) +- Feat: Allow pow with negative & non-integer exponent on decimals [#19369](https://github.com/apache/datafusion/pull/19369) (Yuvraj-cyborg) +- chore(deps): bump taiki-e/install-action from 2.65.13 to 2.65.15 [#19676](https://github.com/apache/datafusion/pull/19676) (dependabot[bot]) +- Refactor cache APIs to support ordering information [#19597](https://github.com/apache/datafusion/pull/19597) (adriangb) +- Record sort order when writing Parquet with WITH ORDER [#19595](https://github.com/apache/datafusion/pull/19595) (adriangb) +- implement var distinct [#19706](https://github.com/apache/datafusion/pull/19706) (thinh2) +- Fix TopK aggregation for UTF-8/Utf8View group keys and add safe fallback for unsupported string aggregates [#19285](https://github.com/apache/datafusion/pull/19285) (kosiew) +- infer parquet file order from metadata and use it to optimize scans [#19433](https://github.com/apache/datafusion/pull/19433) (adriangb) +- Add support for additional numeric types in to_timestamp functions [#19663](https://github.com/apache/datafusion/pull/19663) (gokselk) +- Fix internal error "Physical input schema should be the same as the one converted from logical input schema." [#18412](https://github.com/apache/datafusion/pull/18412) (alamb) +- fix(functions-aggregate): drain CORR state vectors for streaming aggregation [#19669](https://github.com/apache/datafusion/pull/19669) (geoffreyclaude) +- chore: bump dependabot PR limit for cargo from 5 to 15 [#19730](https://github.com/apache/datafusion/pull/19730) (Jefffrey) +- chore(deps): bump taiki-e/install-action from 2.65.15 to 2.66.1 [#19741](https://github.com/apache/datafusion/pull/19741) (dependabot[bot]) +- chore(deps): bump sqllogictest from 0.28.4 to 0.29.0 [#19744](https://github.com/apache/datafusion/pull/19744) (dependabot[bot]) +- chore(deps): bump blake3 from 1.8.2 to 1.8.3 [#19746](https://github.com/apache/datafusion/pull/19746) (dependabot[bot]) +- chore(deps): bump libc from 0.2.179 to 0.2.180 [#19748](https://github.com/apache/datafusion/pull/19748) (dependabot[bot]) +- chore(deps): bump async-compression from 0.4.36 to 0.4.37 [#19742](https://github.com/apache/datafusion/pull/19742) (dependabot[bot]) +- chore(deps): bump indexmap from 2.12.1 to 2.13.0 [#19747](https://github.com/apache/datafusion/pull/19747) (dependabot[bot]) +- Improve comment for predicate_cache_inner_records [#19762](https://github.com/apache/datafusion/pull/19762) (xudong963) +- Fix dynamic filter is_used function [#19734](https://github.com/apache/datafusion/pull/19734) (LiaCastaneda) +- slt: Add test for REE arrays in group by [#19763](https://github.com/apache/datafusion/pull/19763) (brancz) +- Fix run_tpcds data dir [#19771](https://github.com/apache/datafusion/pull/19771) (gabotechs) +- chore(deps): bump taiki-e/install-action from 2.66.1 to 2.66.2 [#19778](https://github.com/apache/datafusion/pull/19778) (dependabot[bot]) +- Include .proto files in datafusion-proto distribution [#19490](https://github.com/apache/datafusion/pull/19490) (DarkWanderer) +- Simplify `expr = L1 AND expr != L2` to `expr = L1` when `L1 != L2` [#19731](https://github.com/apache/datafusion/pull/19731) (simonvandel) +- chore(deps): bump flate2 from 1.1.5 to 1.1.8 [#19780](https://github.com/apache/datafusion/pull/19780) (dependabot[bot]) +- Upgrade DataFusion to arrow-rs/parquet 57.2.0 [#19355](https://github.com/apache/datafusion/pull/19355) (alamb) +- Expose Spilling Progress Interface in DataFusion [#19708](https://github.com/apache/datafusion/pull/19708) (xudong963) +- dev: Add a script to auto fix all lint violations [#19560](https://github.com/apache/datafusion/pull/19560) (2010YOUY01) +- refactor: Optimize `required_columns` from `BTreeSet` to `Vec` in struct `PushdownChecker` [#19678](https://github.com/apache/datafusion/pull/19678) (kumarUjjawal) +- Revert Workround for Empty FixedSizeBinary Values Buffer After arrow-rs Upgrade [#19801](https://github.com/apache/datafusion/pull/19801) (tobixdev) +- chore(deps): bump taiki-e/install-action from 2.66.2 to 2.66.3 [#19802](https://github.com/apache/datafusion/pull/19802) (dependabot[bot]) +- Add Reproducer for Issues with LEFT joins on Fixed Size Binary Columns [#19800](https://github.com/apache/datafusion/pull/19800) (tobixdev) +- Improvements to `list_files_cache` table function [#19703](https://github.com/apache/datafusion/pull/19703) (alamb) +- Issue 19781 : Internal error: Assertion failed: !self.finished: LimitedBatchCoalescer [#19785](https://github.com/apache/datafusion/pull/19785) (bert-beyondloops) +- physical plan: add `reset_plan_states `, plan re-use benchmark [#19806](https://github.com/apache/datafusion/pull/19806) (askalt) +- chore(deps): bump actions/setup-node from 6.1.0 to 6.2.0 [#19825](https://github.com/apache/datafusion/pull/19825) (dependabot[bot]) +- Use correct setting for click bench queries in sql_planner benchmark [#19835](https://github.com/apache/datafusion/pull/19835) (alamb) +- chore(deps): bump taiki-e/install-action from 2.66.3 to 2.66.5 [#19824](https://github.com/apache/datafusion/pull/19824) (dependabot[bot]) +- chore: refactor scalarvalue/encoding using available upstream arrow-rs methods [#19797](https://github.com/apache/datafusion/pull/19797) (Jefffrey) +- Refactor Spark `date_add`/`date_sub`/`bitwise_not` to remove unnecessary scalar arg check [#19473](https://github.com/apache/datafusion/pull/19473) (Jefffrey) +- Add BatchAdapter to simplify using PhysicalExprAdapter / Projector to map RecordBatch between schemas [#19716](https://github.com/apache/datafusion/pull/19716) (adriangb) +- [Minor] Reuse indices buffer in RepartitionExec [#19775](https://github.com/apache/datafusion/pull/19775) (Dandandan) +- Fix(optimizer): Make `EnsureCooperative` optimizer idempotent under multiple runs [#19757](https://github.com/apache/datafusion/pull/19757) (danielhumanmod) +- Allow dropping qualified columns [#19549](https://github.com/apache/datafusion/pull/19549) (ntjohnson1) +- Doc: Add more blog links to doc comments [#19837](https://github.com/apache/datafusion/pull/19837) (alamb) +- datafusion/common: Add support for hashing ListView arrays [#19814](https://github.com/apache/datafusion/pull/19814) (brancz) +- Project sort expressions in StreamingTable [#19719](https://github.com/apache/datafusion/pull/19719) (timsaucer) +- Fix grouping set subset satisfaction [#19853](https://github.com/apache/datafusion/pull/19853) (freakyzoidberg) +- Spark date part [#19823](https://github.com/apache/datafusion/pull/19823) (cht42) +- chore(deps): bump wasm-bindgen-test from 0.3.56 to 0.3.58 [#19898](https://github.com/apache/datafusion/pull/19898) (dependabot[bot]) +- chore(deps): bump tokio-postgres from 0.7.15 to 0.7.16 [#19899](https://github.com/apache/datafusion/pull/19899) (dependabot[bot]) +- chore(deps): bump postgres-types from 0.2.11 to 0.2.12 [#19902](https://github.com/apache/datafusion/pull/19902) (dependabot[bot]) +- chore(deps): bump insta from 1.46.0 to 1.46.1 [#19901](https://github.com/apache/datafusion/pull/19901) (dependabot[bot]) +- chore(deps): bump taiki-e/install-action from 2.66.5 to 2.66.7 [#19883](https://github.com/apache/datafusion/pull/19883) (dependabot[bot]) +- Consolidate cte_quoted_reference.slt into cte.slt [#19862](https://github.com/apache/datafusion/pull/19862) (AnjaliChoudhary99) +- Disable failing `array_union` edge-case with nested null array [#19904](https://github.com/apache/datafusion/pull/19904) (Jefffrey) +- chore(deps): bump the proto group across 1 directory with 5 updates [#19745](https://github.com/apache/datafusion/pull/19745) (dependabot[bot]) +- test(wasmtest): enable compression feature for wasm build [#19860](https://github.com/apache/datafusion/pull/19860) (ChanTsune) +- Feat : added truncate table support [#19633](https://github.com/apache/datafusion/pull/19633) (Nachiket-Roy) +- Remove UDAF manual Debug impls and simplify signatures [#19727](https://github.com/apache/datafusion/pull/19727) (Jefffrey) +- chore(deps): bump thiserror from 2.0.17 to 2.0.18 [#19900](https://github.com/apache/datafusion/pull/19900) (dependabot[bot]) +- Include license and notice files in more crates [#19913](https://github.com/apache/datafusion/pull/19913) (ankane) +- chore(deps): bump actions/setup-python from 6.1.0 to 6.2.0 [#19935](https://github.com/apache/datafusion/pull/19935) (dependabot[bot]) +- Coerce expressions to udtf [#19915](https://github.com/apache/datafusion/pull/19915) (XiangpengHao) +- Fix trailing whitespace in CROSS JOIN logical plan formatting [#19936](https://github.com/apache/datafusion/pull/19936) (mkleen) +- chore(deps): bump chrono from 0.4.42 to 0.4.43 [#19897](https://github.com/apache/datafusion/pull/19897) (dependabot[bot]) +- Improve error message when string functions receive Binary types [#19819](https://github.com/apache/datafusion/pull/19819) (lemorage) +- Refactor ListArray hashing to consider only sliced values [#19500](https://github.com/apache/datafusion/pull/19500) (Jefffrey) +- feat(datafusion-spark): implement spark compatible `unhex` function [#19909](https://github.com/apache/datafusion/pull/19909) (lyne7-sc) +- Support API for "pre-image" for pruning predicate evaluation [#19722](https://github.com/apache/datafusion/pull/19722) (sdf-jkl) +- Support LargeUtf8 as partition column [#19942](https://github.com/apache/datafusion/pull/19942) (paleolimbot) +- chore(deps): bump actions/checkout from 6.0.1 to 6.0.2 [#19953](https://github.com/apache/datafusion/pull/19953) (dependabot[bot]) +- preserve FilterExec batch size during ser/de [#19960](https://github.com/apache/datafusion/pull/19960) (askalt) +- Add struct pushdown query benchmark and projection pushdown tests [#19962](https://github.com/apache/datafusion/pull/19962) (adriangb) +- Improve error messages with nicer formatting of Date and Time types [#19954](https://github.com/apache/datafusion/pull/19954) (emilk) +- export `SessionState::register_catalog_list(...)` [#19925](https://github.com/apache/datafusion/pull/19925) (askalt) +- Change GitHub actions dependabot schedule to weekly [#19981](https://github.com/apache/datafusion/pull/19981) (Jefffrey) +- chore(deps): bump taiki-e/install-action from 2.66.7 to 2.67.9 [#19987](https://github.com/apache/datafusion/pull/19987) (dependabot[bot]) +- chore(deps): bump quote from 1.0.43 to 1.0.44 [#19992](https://github.com/apache/datafusion/pull/19992) (dependabot[bot]) +- chore(deps): bump nix from 0.30.1 to 0.31.1 [#19991](https://github.com/apache/datafusion/pull/19991) (dependabot[bot]) +- chore(deps): bump sysinfo from 0.37.2 to 0.38.0 [#19990](https://github.com/apache/datafusion/pull/19990) (dependabot[bot]) +- chore(deps): bump uuid from 1.19.0 to 1.20.0 [#19993](https://github.com/apache/datafusion/pull/19993) (dependabot[bot]) +- minor: pull `uuid` into workspace dependencies [#19997](https://github.com/apache/datafusion/pull/19997) (Jefffrey) +- Fix ClickBench EventDate handling by casting UInt16 days-since-epoch to DATE via `hits` view [#19881](https://github.com/apache/datafusion/pull/19881) (kosiew) +- refactor: extract pushdown test utilities to shared module [#20010](https://github.com/apache/datafusion/pull/20010) (adriangb) +- chore(deps): bump taiki-e/install-action from 2.67.9 to 2.67.13 [#20020](https://github.com/apache/datafusion/pull/20020) (dependabot[bot]) +- add more projection pushdown slt tests [#20015](https://github.com/apache/datafusion/pull/20015) (adriangb) +- minor: Move metric `page_index_rows_pruned` to verbose level in `EXPLAIN ANALYZE` [#20026](https://github.com/apache/datafusion/pull/20026) (2010YOUY01) +- Tweak `adapter serialization` example [#20035](https://github.com/apache/datafusion/pull/20035) (adriangb) +- Simplify wait_complete function [#19937](https://github.com/apache/datafusion/pull/19937) (LiaCastaneda) +- [main] Update version to `52.1.0` (#19878) [#20028](https://github.com/apache/datafusion/pull/20028) (alamb) +- Fix/parquet opener page index policy [#19890](https://github.com/apache/datafusion/pull/19890) (aviralgarg05) +- minor: add tests for coercible signature considering nulls/dicts/ree [#19459](https://github.com/apache/datafusion/pull/19459) (Jefffrey) +- Enforce `clippy::allow_attributes` globally across workspace [#19576](https://github.com/apache/datafusion/pull/19576) (Jefffrey) +- Fix constant value from stats [#20042](https://github.com/apache/datafusion/pull/20042) (gabotechs) +- Simplify Spark `sha2` implementation [#19475](https://github.com/apache/datafusion/pull/19475) (Jefffrey) +- Further refactoring of type coercion function code [#19603](https://github.com/apache/datafusion/pull/19603) (Jefffrey) +- replace private is_volatile_expression_tree with equivalent public is_volatile [#20056](https://github.com/apache/datafusion/pull/20056) (adriangb) +- Improve documentation for ScalarUDFImpl::preimage [#20008](https://github.com/apache/datafusion/pull/20008) (alamb) +- Use BooleanBufferBuilder rather than Vec in ArrowBytesViewMap [#20064](https://github.com/apache/datafusion/pull/20064) (etk18) +- chore: Add microbenchmark (compared to ExprOrExpr) [#20076](https://github.com/apache/datafusion/pull/20076) (CuteChuanChuan) +- Minor: update tests in limit_pushdown.rs to insta [#20066](https://github.com/apache/datafusion/pull/20066) (alamb) +- Reduce number of traversals per node in `PhysicalExprSimplifier` [#20082](https://github.com/apache/datafusion/pull/20082) (AdamGS) +- Automatically generate examples documentation adv (#19294) [#19750](https://github.com/apache/datafusion/pull/19750) (cj-zhukov) +- Implement preimage for floor function to enable predicate pushdown [#20059](https://github.com/apache/datafusion/pull/20059) (devanshu0987) +- Refactor `iszero()` and `isnan()` to accept all numeric types [#20093](https://github.com/apache/datafusion/pull/20093) (kumarUjjawal) +- Use return_field_from_args in information schema and date_trunc [#20079](https://github.com/apache/datafusion/pull/20079) (AndreaBozzo) +- Preserve PhysicalExpr graph in proto round trip using Arc pointers as unique identifiers [#20037](https://github.com/apache/datafusion/pull/20037) (adriangb) +- add ability to customize tokens in parser [#19978](https://github.com/apache/datafusion/pull/19978) (askalt) +- Adjust `case_when DivideByZeroProtection` benchmark so that "percentage of zeroes" corresponds to "number of times protection is needed" [#20105](https://github.com/apache/datafusion/pull/20105) (pepijnve) +- refactor: Rename `FileSource::try_reverse_output` to `FileSource::try_pushdown_sort` [#20043](https://github.com/apache/datafusion/pull/20043) (kumarUjjawal) +- Improve memory accounting for ArrowBytesViewMap [#20077](https://github.com/apache/datafusion/pull/20077) (vigneshsiva11) +- chore: reduce production noise by using `debug` macro [#19885](https://github.com/apache/datafusion/pull/19885) (Standing-Man) +- chore(deps): bump taiki-e/install-action from 2.67.13 to 2.67.18 [#20124](https://github.com/apache/datafusion/pull/20124) (dependabot[bot]) +- chore(deps): bump actions/setup-node from 4 to 6 [#20125](https://github.com/apache/datafusion/pull/20125) (dependabot[bot]) +- chore(deps): bump tonic from 0.14.2 to 0.14.3 [#20127](https://github.com/apache/datafusion/pull/20127) (dependabot[bot]) +- chore(deps): bump insta from 1.46.1 to 1.46.3 [#20129](https://github.com/apache/datafusion/pull/20129) (dependabot[bot]) +- chore(deps): bump flate2 from 1.1.8 to 1.1.9 [#20130](https://github.com/apache/datafusion/pull/20130) (dependabot[bot]) +- chore(deps): bump clap from 4.5.54 to 4.5.56 [#20131](https://github.com/apache/datafusion/pull/20131) (dependabot[bot]) +- Add BufferExec execution plan [#19760](https://github.com/apache/datafusion/pull/19760) (gabotechs) +- Optimize the evaluation of date_part() == when pushed down [#19733](https://github.com/apache/datafusion/pull/19733) (sdf-jkl) +- chore(deps): bump bytes from 1.11.0 to 1.11.1 [#20141](https://github.com/apache/datafusion/pull/20141) (dependabot[bot]) +- Make session state builder clonable [#20136](https://github.com/apache/datafusion/pull/20136) (askalt) +- chore: remove datatype check functions in favour of upstream versions [#20104](https://github.com/apache/datafusion/pull/20104) (Jefffrey) +- Add Decimal support for floor preimage [#20099](https://github.com/apache/datafusion/pull/20099) (devanshu0987) +- Add more struct pushdown tests and planning benchmark [#20143](https://github.com/apache/datafusion/pull/20143) (adriangb) +- Add RepartitionExec test to projection_pushdown.slt [#20156](https://github.com/apache/datafusion/pull/20156) (adriangb) +- chore: Fix typos in comments [#20157](https://github.com/apache/datafusion/pull/20157) (neilconway) +- Fix `array_repeat` handling of null count values [#20102](https://github.com/apache/datafusion/pull/20102) (lyne7-sc) +- Refactor schema rewriter: remove lifetimes, extract column/cast helpers, add mismatch coverage [#20166](https://github.com/apache/datafusion/pull/20166) (kosiew) +- chore(deps): bump time from 0.3.44 to 0.3.47 [#20172](https://github.com/apache/datafusion/pull/20172) (dependabot[bot]) +- chore(deps-dev): bump webpack from 5.94.0 to 5.105.0 in /datafusion/wasmtest/datafusion-wasm-app [#20178](https://github.com/apache/datafusion/pull/20178) (dependabot[bot]) +- Fix Arrow Spill Underrun [#20159](https://github.com/apache/datafusion/pull/20159) (cetra3) +- nom parser instead of ad-hoc in examples [#20122](https://github.com/apache/datafusion/pull/20122) (cj-zhukov) +- fix(datafusion-cli): solve row count bug adding`saturating_add` to prevent potential overflow [#20185](https://github.com/apache/datafusion/pull/20185) (dariocurr) +- Enable inlist support for preimage [#20051](https://github.com/apache/datafusion/pull/20051) (sdf-jkl) +- unify the prettier versions [#20167](https://github.com/apache/datafusion/pull/20167) (cj-zhukov) +- chore: Unbreak doctest CI [#20218](https://github.com/apache/datafusion/pull/20218) (neilconway) +- Minor: verify plan output and unique field names [#20220](https://github.com/apache/datafusion/pull/20220) (alamb) +- Add more tests to projection_pushdown.slt [#20236](https://github.com/apache/datafusion/pull/20236) (adriangb) +- Add Expr::Alias passthrough to Expr::placement() [#20237](https://github.com/apache/datafusion/pull/20237) (adriangb) +- Make PushDownFilter and CommonSubexprEliminate aware of Expr::placement [#20239](https://github.com/apache/datafusion/pull/20239) (adriangb) +- Refactor example metadata parsing utilities(#20204) [#20233](https://github.com/apache/datafusion/pull/20233) (cj-zhukov) +- add module structure and unit tests for expression pushdown logical optimizer [#20238](https://github.com/apache/datafusion/pull/20238) (adriangb) +- repro and disable dyn filter for preserve file partitions [#20175](https://github.com/apache/datafusion/pull/20175) (gene-bordegaray) +- chore(deps): bump taiki-e/install-action from 2.67.18 to 2.67.27 [#20254](https://github.com/apache/datafusion/pull/20254) (dependabot[bot]) +- chore(deps): bump sysinfo from 0.38.0 to 0.38.1 [#20261](https://github.com/apache/datafusion/pull/20261) (dependabot[bot]) +- chore(deps): bump clap from 4.5.56 to 4.5.57 [#20265](https://github.com/apache/datafusion/pull/20265) (dependabot[bot]) +- chore(deps): bump tempfile from 3.24.0 to 3.25.0 [#20262](https://github.com/apache/datafusion/pull/20262) (dependabot[bot]) +- chore(deps): bump regex from 1.12.2 to 1.12.3 [#20260](https://github.com/apache/datafusion/pull/20260) (dependabot[bot]) +- chore(deps): bump criterion from 0.8.1 to 0.8.2 [#20258](https://github.com/apache/datafusion/pull/20258) (dependabot[bot]) +- chore(deps): bump regex-syntax from 0.8.8 to 0.8.9 [#20264](https://github.com/apache/datafusion/pull/20264) (dependabot[bot]) +- chore(deps): bump aws-config from 1.8.12 to 1.8.13 [#20263](https://github.com/apache/datafusion/pull/20263) (dependabot[bot]) +- chore(deps): bump async-compression from 0.4.37 to 0.4.39 [#20259](https://github.com/apache/datafusion/pull/20259) (dependabot[bot]) +- Support JSON arrays reader/parse for datafusion [#19924](https://github.com/apache/datafusion/pull/19924) (zhuqi-lucas) +- chore: Add confirmation before tarball is released [#20207](https://github.com/apache/datafusion/pull/20207) (milenkovicm) +- FilterExec should remap indices of parent dynamic filters [#20286](https://github.com/apache/datafusion/pull/20286) (jackkleeman) +- Clean up expression placement UDF usage in tests [#20272](https://github.com/apache/datafusion/pull/20272) (adriangb) +- chore(deps): bump the arrow-parquet group with 7 updates [#20256](https://github.com/apache/datafusion/pull/20256) (dependabot[bot]) +- Cleanup example metadata parsing utilities(#20251) [#20252](https://github.com/apache/datafusion/pull/20252) (cj-zhukov) +- Add `StructArray` and `RunArray` benchmark tests to `with_hashes` [#20182](https://github.com/apache/datafusion/pull/20182) (notashes) +- Add protoc support for ArrowScanExecNode (#20280) [#20284](https://github.com/apache/datafusion/pull/20284) (JoshElkind) +- Improve ExternalSorter ResourcesExhausted Error Message [#20226](https://github.com/apache/datafusion/pull/20226) (erenavsarogullari) +- Introduce ProjectionExprs::unproject_exprs/project_exprs and improve docs [#20193](https://github.com/apache/datafusion/pull/20193) (alamb) +- chore: Remove "extern crate criterion" in benches [#20299](https://github.com/apache/datafusion/pull/20299) (neilconway) +- Support pushing down empty projections into joins [#20191](https://github.com/apache/datafusion/pull/20191) (jackkleeman) +- chore: change width_bucket buckets parameter from i32 to i64 [#20330](https://github.com/apache/datafusion/pull/20330) (comphead) +- fix null handling for `nanvl` & implement fast path [#20205](https://github.com/apache/datafusion/pull/20205) (kumarUjjawal) +- unify the prettier version adv(#20024) [#20311](https://github.com/apache/datafusion/pull/20311) (cj-zhukov) +- chore: Make memchr a workspace dependency [#20345](https://github.com/apache/datafusion/pull/20345) (neilconway) +- feat(datafusion-cli): enhance CLI helper with default hint [#20310](https://github.com/apache/datafusion/pull/20310) (dariocurr) +- Adds support for ANSI mode in negative function [#20189](https://github.com/apache/datafusion/pull/20189) (SubhamSinghal) +- Support parent dynamic filters for more join types [#20192](https://github.com/apache/datafusion/pull/20192) (jackkleeman) +- Fix incorrect `SortExec` removal before `AggregateExec` (option 2) [#20247](https://github.com/apache/datafusion/pull/20247) (alamb) +- Fix `try_shrink` not freeing back to pool [#20382](https://github.com/apache/datafusion/pull/20382) (cetra3) +- chore(deps): bump sysinfo from 0.38.1 to 0.38.2 [#20411](https://github.com/apache/datafusion/pull/20411) (dependabot[bot]) +- chore(deps): bump indicatif from 0.18.3 to 0.18.4 [#20410](https://github.com/apache/datafusion/pull/20410) (dependabot[bot]) +- chore(deps): bump liblzma from 0.4.5 to 0.4.6 [#20409](https://github.com/apache/datafusion/pull/20409) (dependabot[bot]) +- chore(deps): bump aws-config from 1.8.13 to 1.8.14 [#20407](https://github.com/apache/datafusion/pull/20407) (dependabot[bot]) +- chore(deps): bump tonic from 0.14.3 to 0.14.4 [#20406](https://github.com/apache/datafusion/pull/20406) (dependabot[bot]) +- chore(deps): bump clap from 4.5.57 to 4.5.59 [#20404](https://github.com/apache/datafusion/pull/20404) (dependabot[bot]) +- chore(deps): bump sqllogictest from 0.29.0 to 0.29.1 [#20405](https://github.com/apache/datafusion/pull/20405) (dependabot[bot]) +- chore(deps): bump env_logger from 0.11.8 to 0.11.9 [#20402](https://github.com/apache/datafusion/pull/20402) (dependabot[bot]) +- chore(deps): bump actions/stale from 10.1.1 to 10.2.0 [#20397](https://github.com/apache/datafusion/pull/20397) (dependabot[bot]) +- chore(deps): bump uuid from 1.20.0 to 1.21.0 [#20401](https://github.com/apache/datafusion/pull/20401) (dependabot[bot]) +- [Minor] Update object_store to 0.12.5 [#20378](https://github.com/apache/datafusion/pull/20378) (Dandandan) +- chore(deps): bump syn from 2.0.114 to 2.0.116 [#20399](https://github.com/apache/datafusion/pull/20399) (dependabot[bot]) +- chore(deps): bump taiki-e/install-action from 2.67.27 to 2.68.0 [#20398](https://github.com/apache/datafusion/pull/20398) (dependabot[bot]) +- chore: Cleanup returning null arrays [#20423](https://github.com/apache/datafusion/pull/20423) (neilconway) +- chore: fix labeler for `datafusion-functions-nested` [#20442](https://github.com/apache/datafusion/pull/20442) (comphead) +- build: update Rust toolchain version from 1.92.0 to 1.93.0 in `rust-toolchain.toml` [#20309](https://github.com/apache/datafusion/pull/20309) (dariocurr) +- chore: Cleanup "!is_valid(i)" -> "is_null(i)" [#20453](https://github.com/apache/datafusion/pull/20453) (neilconway) +- refactor: Extract sort-merge join filter logic into separate module [#19614](https://github.com/apache/datafusion/pull/19614) (viirya) +- Implement FFI table provider factory [#20326](https://github.com/apache/datafusion/pull/20326) (davisp) +- bench: Add criterion benchmark for sort merge join [#20464](https://github.com/apache/datafusion/pull/20464) (andygrove) +- chore: group minor dependencies into single PR [#20457](https://github.com/apache/datafusion/pull/20457) (comphead) +- chore(deps): bump taiki-e/install-action from 2.68.0 to 2.68.6 [#20467](https://github.com/apache/datafusion/pull/20467) (dependabot[bot]) +- chore(deps): bump astral-sh/setup-uv from 6.1.0 to 7.3.0 [#20468](https://github.com/apache/datafusion/pull/20468) (dependabot[bot]) +- chore(deps): bump the all-other-cargo-deps group with 6 updates [#20470](https://github.com/apache/datafusion/pull/20470) (dependabot[bot]) +- chore(deps): bump testcontainers-modules from 0.14.0 to 0.15.0 [#20471](https://github.com/apache/datafusion/pull/20471) (dependabot[bot]) +- [Minor] Use buffer_unordered [#20462](https://github.com/apache/datafusion/pull/20462) (Dandandan) +- bench: Add IN list benchmarks for non-constant list expressions [#20444](https://github.com/apache/datafusion/pull/20444) (zhangxffff) +- feat(memory-tracking): implement arrow_buffer::MemoryPool for MemoryPool [#18928](https://github.com/apache/datafusion/pull/18928) (notfilippo) +- chore: Avoid build fails on MinIO rate limits [#20472](https://github.com/apache/datafusion/pull/20472) (comphead) +- chore: Add end-to-end benchmark for array_agg, code cleanup [#20496](https://github.com/apache/datafusion/pull/20496) (neilconway) +- Upgrade to sqlparser 0.61.0 [#20177](https://github.com/apache/datafusion/pull/20177) (alamb) +- Switch to the latest Mac OS [#20510](https://github.com/apache/datafusion/pull/20510) (blaginin) +- Fix name tracker [#19856](https://github.com/apache/datafusion/pull/19856) (xanderbailey) +- Runs-on for extended CI checks [#20511](https://github.com/apache/datafusion/pull/20511) (blaginin) +- chore(deps): bump strum from 0.27.2 to 0.28.0 [#20520](https://github.com/apache/datafusion/pull/20520) (dependabot[bot]) +- chore(deps): bump taiki-e/install-action from 2.68.6 to 2.68.8 [#20518](https://github.com/apache/datafusion/pull/20518) (dependabot[bot]) +- chore(deps): bump the all-other-cargo-deps group with 2 updates [#20519](https://github.com/apache/datafusion/pull/20519) (dependabot[bot]) +- Make `custom_file_casts` example schema nullable to allow null `id` values during casting [#20486](https://github.com/apache/datafusion/pull/20486) (kosiew) +- Add support for FFI config extensions [#19469](https://github.com/apache/datafusion/pull/19469) (timsaucer) +- chore: Cleanup code to use `repeat_n` in a few places [#20527](https://github.com/apache/datafusion/pull/20527) (neilconway) +- chore(deps): bump strum_macros from 0.27.2 to 0.28.0 [#20521](https://github.com/apache/datafusion/pull/20521) (dependabot[bot]) +- chore: Replace `matches!` on fieldless enums with `==` [#20525](https://github.com/apache/datafusion/pull/20525) (neilconway) +- Update comments on OptimizerRule about function name matching [#20346](https://github.com/apache/datafusion/pull/20346) (alamb) +- Fix incorrect regex pattern in regex_replace_posix_groups [#19827](https://github.com/apache/datafusion/pull/19827) (GaneshPatil7517) +- Improve `HashJoinExecBuilder` to save state from previous fields [#20276](https://github.com/apache/datafusion/pull/20276) (askalt) +- [Minor] Fix error messages for `shrink` and `try_shrink` [#20422](https://github.com/apache/datafusion/pull/20422) (hareshkh) +- Fix physical expr adapter to resolve physical fields by name, not column index [#20485](https://github.com/apache/datafusion/pull/20485) (kosiew) +- [fix] Add type coercion from NULL to Interval to make date_bin more postgres compatible [#20499](https://github.com/apache/datafusion/pull/20499) (LiaCastaneda) +- Clamp early aggregation emit to the sort boundary when using partial group ordering [#20446](https://github.com/apache/datafusion/pull/20446) (jackkleeman) +- Split `push_down_filter.slt` into standalone sqllogictest files to reduce long-tail runtime [#20566](https://github.com/apache/datafusion/pull/20566) (kosiew) +- Add deterministic per-file timing summary to sqllogictest runner [#20569](https://github.com/apache/datafusion/pull/20569) (kosiew) +- chore: Enable workspace lint for all workspace members [#20577](https://github.com/apache/datafusion/pull/20577) (neilconway) +- Fix serde of window lead/lag defaults [#20608](https://github.com/apache/datafusion/pull/20608) (avantgardnerio) +- [branch-53] fix: make the `sql` feature truly optional (#20625) [#20680](https://github.com/apache/datafusion/pull/20680) (linhr) +- [53] fix: Fix bug in `array_has` scalar path with sliced arrays (#20677) [#20700](https://github.com/apache/datafusion/pull/20700) (neilconway) +- [branch-53] fix: Return `probe_side.len()` for RightMark/Anti count(\*) queries (#… [#20726](https://github.com/apache/datafusion/pull/20726) (jonathanc-n) +- [branch-53] FFI_TableOptions are using default values only [#20722](https://github.com/apache/datafusion/pull/20722) (timsaucer) +- chore(deps): pin substrait to `0.62.2` [#20827](https://github.com/apache/datafusion/pull/20827) (milenkovicm) +- chore(deps): pin substrait version [#20848](https://github.com/apache/datafusion/pull/20848) (milenkovicm) +- [branch-53] Fix repartition from dropping data when spilling (#20672) [#20792](https://github.com/apache/datafusion/pull/20792) (xanderbailey) +- [branch-53] fix: `HashJoin` panic with String dictionary keys (don't flatten keys) (#20505) [#20791](https://github.com/apache/datafusion/pull/20791) (alamb) +- [branch-53] cli: Fix datafusion-cli hint edge cases (#20609) [#20887](https://github.com/apache/datafusion/pull/20887) (comphead) +- [branch-53] perf: Optimize `to_char` to allocate less, fix NULL handling (#20635) [#20885](https://github.com/apache/datafusion/pull/20885) (neilconway) +- [branch-53] fix: interval analysis error when have two filterexec that inner filter proves zero selectivity (#20743) [#20882](https://github.com/apache/datafusion/pull/20882) (haohuaijin) +- [branch-53] correct parquet leaf index mapping when schema contains struct cols (#20698) [#20884](https://github.com/apache/datafusion/pull/20884) (friendlymatthew) +- [branch-53] ser/de fetch in FilterExec (#20738) [#20883](https://github.com/apache/datafusion/pull/20883) (haohuaijin) +- [branch-53] fix: use try_shrink instead of shrink in try_resize (#20424) [#20890](https://github.com/apache/datafusion/pull/20890) (ariel-miculas) +- [branch-53] Reattach parquet metadata cache after deserializing in datafusion-proto (#20574) [#20891](https://github.com/apache/datafusion/pull/20891) (nathanb9) +- [branch-53] fix: do not recompute hash join exec properties if not required (#20900) [#20903](https://github.com/apache/datafusion/pull/20903) (askalt) +- [branch-53] fix(spark): handle divide-by-zero in Spark `mod`/`pmod` with ANSI mode support (#20461) [#20896](https://github.com/apache/datafusion/pull/20896) (davidlghellin) +- [branch-53] fix: Provide more generic API for the capacity limit parsing (#20372) [#20893](https://github.com/apache/datafusion/pull/20893) (erenavsarogullari) +- [branch-53] fix: sqllogictest cannot convert to Substrait (#19739) [#20897](https://github.com/apache/datafusion/pull/20897) (kumarUjjawal) +- [branch-53] Fix DELETE/UPDATE filter extraction when predicates are pushed down into TableScan (#19884) [#20898](https://github.com/apache/datafusion/pull/20898) (kosiew) +- [branch-53] fix: preserve None projection semantics across FFI boundary in ForeignTableProvider::scan (#20393) [#20895](https://github.com/apache/datafusion/pull/20895) (Kontinuation) +- [branch-53] Fix FilterExec converting Absent column stats to Exact(NULL) (#20391) [#20892](https://github.com/apache/datafusion/pull/20892) (fwojciec) +- [branch-53] backport: Support Spark `array_contains` builtin function (#20685) [#20914](https://github.com/apache/datafusion/pull/20914) (comphead) +- [branch-53] Fix duplicate group keys after hash aggregation spill (#20724) (#20858) [#20918](https://github.com/apache/datafusion/pull/20918) (gboucher90) +- [branch-53] fix: SanityCheckPlan error with window functions and NVL filter (#20231) [#20932](https://github.com/apache/datafusion/pull/20932) (EeshanBembi) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 73 dependabot[bot] + 37 Neil Conway + 32 Kumar Ujjawal + 28 Andrew Lamb + 26 Adrian Garcia Badaracco + 21 Jeffrey Vo + 13 cht42 + 11 Albert Skalt + 11 kosiew + 10 lyne + 8 Nuno Faria + 8 Oleks V + 7 Sergey Zhukov + 7 xudong.w + 6 Daniël Heres + 6 Huaijin + 5 Adam Gutglick + 5 Gabriel + 5 Jonathan Chen + 4 Andy Grove + 4 Dmitrii Blaginin + 4 Eren Avsarogullari + 4 Jack Kleeman + 4 notashes + 4 theirix + 4 Tim Saucer + 4 Yongting You + 3 dario curreri + 3 feniljain + 3 Kazantsev Maksim + 3 Kosta Tarasov + 3 Liang-Chi Hsieh + 3 Lía Adriana + 3 Marko Milenković + 3 mishop-15 + 3 Yu-Chuan Hung + 2 Acfboy + 2 Alan Tang + 2 David López + 2 Devanshu + 2 Frederic Branczyk + 2 Ganesh Patil + 2 Heran Lin + 2 jizezhang + 2 Miao + 2 Michael Kleen + 2 niebayes + 2 Pepijn Van Eeckhoudt + 2 Peter L + 2 Subham Singhal + 2 Tobias Schwarzinger + 2 UBarney + 2 Xander + 2 Yuvraj + 2 Zhang Xiaofeng + 1 Andrea Bozzo + 1 Andrew Kane + 1 Anjali Choudhary + 1 Anna-Rose Lescure + 1 Ariel Miculas-Trif + 1 Aryan Anand + 1 Aviral Garg + 1 Bert Vermeiren + 1 Brent Gardner + 1 ChanTsune + 1 comphead + 1 danielhumanmod + 1 Dewey Dunnington + 1 discord9 + 1 Divyansh Pratap Singh + 1 Eesh Sagar Singh + 1 EeshanBembi + 1 Emil Ernerfeldt + 1 Emily Matheys + 1 Eric Chang + 1 Evangeli Silva + 1 Filip Wojciechowski + 1 Filippo + 1 Gabriel Ferraté + 1 Gene Bordegaray + 1 Geoffrey Claude + 1 Goksel Kabadayi + 1 Guillaume Boucher + 1 Haresh Khanna + 1 hsiang-c + 1 iamthinh + 1 Josh Elkind + 1 karuppuchamysuresh + 1 Kristin Cowalcijk + 1 Mason + 1 Matt Butrovich + 1 Matthew Kim + 1 Mikhail Zabaluev + 1 Mohit rao + 1 nathan + 1 Nathaniel J. Smith + 1 Nick + 1 Oleg V. Kozlyuk + 1 Paul J. Davis + 1 Pierre Lacave + 1 pmallex + 1 Qi Zhu + 1 Raz Luvaton + 1 Rosai + 1 Ruihang Xia + 1 Samyak Sarnayak + 1 Sergio Esteves + 1 Simon Vandel Sillesen + 1 Siyuan Huang + 1 Tim-53 + 1 Tushar Das + 1 Vignesh + 1 Xiangpeng Hao + 1 XL Liang +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. diff --git a/dev/changelog/53.1.0.md b/dev/changelog/53.1.0.md new file mode 100644 index 000000000000..5e39e0041f4a --- /dev/null +++ b/dev/changelog/53.1.0.md @@ -0,0 +1,51 @@ + + +# Apache DataFusion 53.1.0 Changelog + +This release consists of 10 commits from 4 contributors. See credits at the end of this changelog for more information. + +See the [upgrade guide](https://datafusion.apache.org/library-user-guide/upgrading.html) for information on how to upgrade from previous versions. + +**Other:** + +- [branch-53] fix: InList Dictionary filter pushdown type mismatch (#20962) [#20996](https://github.com/apache/datafusion/pull/20996) (alamb) +- [branch-53] Planning speed improve (port of #21084) [#21137](https://github.com/apache/datafusion/pull/21137) (blaginin) +- [branch-53] Fix push_down_filter for children with non-empty fetch fields (#21057) [#21142](https://github.com/apache/datafusion/pull/21142) (hareshkh) +- [branch-53] Substrait join consumer should not merge nullability of join keys (#21121) [#21162](https://github.com/apache/datafusion/pull/21162) (hareshkh) +- [branch-53] chore: Optimize schema rewriter usages (#21158) [#21183](https://github.com/apache/datafusion/pull/21183) (comphead) +- [branch-53] fix: use spill writer's schema instead of the first batch schema for … [#21451](https://github.com/apache/datafusion/pull/21451) (comphead) +- [branch-53] fix: use datafusion_expr instead of datafusion crate in spark bitmap/… [#21452](https://github.com/apache/datafusion/pull/21452) (comphead) +- [branch-53] fix: FilterExec should drop projection when apply projection pushdown [#21492](https://github.com/apache/datafusion/pull/21492) (comphead) +- [branch-53] fix: foreign inner ffi types (#21439) [#21524](https://github.com/apache/datafusion/pull/21524) (alamb) +- [branch-53] Restore Sort unparser guard for correct ORDER BY placement (#20658) [#21523](https://github.com/apache/datafusion/pull/21523) (alamb) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 4 Oleks V + 3 Andrew Lamb + 2 Haresh Khanna + 1 Dmitrii Blaginin + +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. diff --git a/dev/pyproject.toml b/dev/pyproject.toml new file mode 100644 index 000000000000..a2f5653d9d87 --- /dev/null +++ b/dev/pyproject.toml @@ -0,0 +1,5 @@ +[project] +name = "datafusion-dev" +version = "0.1.0" +requires-python = ">=3.11" +dependencies = ["tomlkit", "PyGithub", "requests"] diff --git a/dev/release/README.md b/dev/release/README.md index 898bceb6f7f4..0ca13c175f23 100644 --- a/dev/release/README.md +++ b/dev/release/README.md @@ -178,10 +178,10 @@ We maintain a [changelog] so our users know what has been changed between releas The changelog is generated using a Python script. -To run the script, you will need a GitHub Personal Access Token (described in the prerequisites section) and the `PyGitHub` library. First install the `PyGitHub` dependency via `pip`: +To run the script, you will need a GitHub Personal Access Token (described in the prerequisites section) and the `PyGitHub` library. First install the dev dependencies via `uv`: ```shell -pip3 install PyGitHub +uv sync ``` To generate the changelog, set the `GITHUB_TOKEN` environment variable and then run `./dev/release/generate-changelog.py` @@ -199,7 +199,7 @@ to generate a change log of all changes between the `50.3.0` tag and `branch-51` ```shell export GITHUB_TOKEN= -./dev/release/generate-changelog.py 50.3.0 branch-51 51.0.0 > dev/changelog/51.0.0.md +uv run ./dev/release/generate-changelog.py 50.3.0 branch-51 51.0.0 > dev/changelog/51.0.0.md ``` This script creates a changelog from GitHub PRs based on the labels associated with them as well as looking for diff --git a/dev/release/release-tarball.sh b/dev/release/release-tarball.sh index bd858d23a767..a284b6c4351f 100755 --- a/dev/release/release-tarball.sh +++ b/dev/release/release-tarball.sh @@ -43,6 +43,13 @@ fi version=$1 rc=$2 +read -r -p "Proceed to release tarball for ${version}-rc${rc}? [y/N]: " answer +answer=${answer:-no} +if [ "${answer}" != "y" ]; then + echo "Cancelled tarball release!" + exit 1 +fi + tmp_dir=tmp-apache-datafusion-dist echo "Recreate temporary directory: ${tmp_dir}" diff --git a/dev/release/verify-release-candidate.sh b/dev/release/verify-release-candidate.sh index 9ecbe1bc1713..9ddd1d3ba855 100755 --- a/dev/release/verify-release-candidate.sh +++ b/dev/release/verify-release-candidate.sh @@ -22,7 +22,7 @@ check_dependencies() { local missing_deps=0 local required_deps=("curl" "git" "gpg" "cc" "protoc") - + # Either shasum or sha256sum/sha512sum are required local has_sha_tools=0 @@ -32,7 +32,7 @@ check_dependencies() { missing_deps=1 fi done - + # Check for either shasum or sha256sum/sha512sum if command -v shasum &> /dev/null; then has_sha_tools=1 @@ -42,7 +42,7 @@ check_dependencies() { echo "Error: Neither shasum nor sha256sum/sha512sum are installed or in PATH" missing_deps=1 fi - + if [ $missing_deps -ne 0 ]; then echo "Please install missing dependencies and try again" exit 1 @@ -163,7 +163,7 @@ test_source_distribution() { git clone https://github.com/apache/parquet-testing.git parquet-testing cargo build - cargo test --all --features=avro + cargo test --profile=ci --all --features=avro if ( find -iname 'Cargo.toml' | xargs grep SNAPSHOT ); then echo "Cargo.toml version should not contain SNAPSHOT for releases" diff --git a/dev/requirements.txt b/dev/requirements.txt deleted file mode 100644 index 7fcba0493129..000000000000 --- a/dev/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -tomlkit -PyGitHub \ No newline at end of file diff --git a/dev/rust_lint.sh b/dev/rust_lint.sh index 21d461184641..43d29bd88166 100755 --- a/dev/rust_lint.sh +++ b/dev/rust_lint.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -23,30 +23,103 @@ # Note: The installed checking tools (e.g., taplo) are not guaranteed to match # the CI versions for simplicity, there might be some minor differences. Check # `.github/workflows` for the CI versions. +# +# +# +# For each lint scripts: +# +# By default, they run in check mode: +# ./ci/scripts/rust_fmt.sh +# +# With `--write`, scripts perform best-effort auto fixes: +# ./ci/scripts/rust_fmt.sh --write +# +# The `--write` flag assumes a clean git repository (no uncommitted changes); to force +# auto fixes even if there are unstaged changes, use `--allow-dirty`: +# ./ci/scripts/rust_fmt.sh --write --allow-dirty +# +# New scripts can use `rust_fmt.sh` as a reference. + +set -euo pipefail + +usage() { + cat >&2 < /dev/null; then + echo "Installing $cmd using: $install_cmd" + eval "$install_cmd" + fi +} + +MODE="check" +ALLOW_DIRTY=0 + +while [[ $# -gt 0 ]]; do + case "$1" in + --write) + MODE="write" + ;; + --allow-dirty) + ALLOW_DIRTY=1 + ;; + -h|--help) + usage + ;; + *) + usage + ;; + esac + shift +done + +SCRIPT_NAME="$(basename "${BASH_SOURCE[0]}")" + +ensure_tool "taplo" "cargo install taplo-cli --locked" +ensure_tool "hawkeye" "cargo install hawkeye --locked" +ensure_tool "typos" "cargo install typos-cli --locked" + +run_step() { + local name="$1" + shift + echo "[${SCRIPT_NAME}] Running ${name}" + "$@" +} + +declare -a WRITE_STEPS=( + "ci/scripts/rust_fmt.sh|true" + "ci/scripts/rust_clippy.sh|true" + "ci/scripts/rust_toml_fmt.sh|true" + "ci/scripts/license_header.sh|true" + "ci/scripts/typos_check.sh|true" + "ci/scripts/doc_prettier_check.sh|true" +) + +declare -a READONLY_STEPS=( + "ci/scripts/rust_docs.sh|false" +) -# For `.toml` format checking -set -e -if ! command -v taplo &> /dev/null; then - echo "Installing taplo using cargo" - cargo install taplo-cli -fi - -# For Apache licence header checking -if ! command -v hawkeye &> /dev/null; then - echo "Installing hawkeye using cargo" - cargo install hawkeye --locked -fi - -# For spelling checks -if ! command -v typos &> /dev/null; then - echo "Installing typos using cargo" - cargo install typos-cli --locked -fi - -ci/scripts/rust_fmt.sh -ci/scripts/rust_clippy.sh -ci/scripts/rust_toml_fmt.sh -ci/scripts/rust_docs.sh -ci/scripts/license_header.sh -ci/scripts/typos_check.sh -ci/scripts/doc_prettier_check.sh +for entry in "${WRITE_STEPS[@]}" "${READONLY_STEPS[@]}"; do + IFS='|' read -r script_path supports_write <<<"$entry" + script_name="$(basename "$script_path")" + args=() + if [[ "$supports_write" == "true" && "$MODE" == "write" ]]; then + args+=(--write) + [[ $ALLOW_DIRTY -eq 1 ]] && args+=(--allow-dirty) + fi + if [[ ${#args[@]} -gt 0 ]]; then + run_step "$script_name" "$script_path" "${args[@]}" + else + run_step "$script_name" "$script_path" + fi +done diff --git a/dev/update_arrow_deps.py b/dev/update_arrow_deps.py index 6bd5d47ff059..bdfdfe22eaeb 100755 --- a/dev/update_arrow_deps.py +++ b/dev/update_arrow_deps.py @@ -19,7 +19,7 @@ # Script that updates the arrow dependencies in datafusion locally # # installation: -# pip install tomlkit requests +# uv sync # # pin all arrow crates deps to a specific version: # diff --git a/dev/update_config_docs.sh b/dev/update_config_docs.sh index 90bbc5d3bad0..f39bdda3aee8 100755 --- a/dev/update_config_docs.sh +++ b/dev/update_config_docs.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -20,14 +20,16 @@ set -e -SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -cd "${SOURCE_DIR}/../" && pwd +ROOT_DIR="$(git rev-parse --show-toplevel)" +cd "${ROOT_DIR}" + +# Load centralized tool versions +source "${ROOT_DIR}/ci/scripts/utils/tool_versions.sh" TARGET_FILE="docs/source/user-guide/configs.md" PRINT_CONFIG_DOCS_COMMAND="cargo run --manifest-path datafusion/core/Cargo.toml --bin print_config_docs" PRINT_RUNTIME_CONFIG_DOCS_COMMAND="cargo run --manifest-path datafusion/core/Cargo.toml --bin print_runtime_config_docs" - echo "Inserting header" cat <<'EOF' > "$TARGET_FILE" + +# Workspace Dependency Graph + +This page shows the dependency relationships between DataFusion's workspace +crates. This only includes internal dependencies, external crates like `Arrow` are not included + +The dependency graph is auto-generated by `docs/scripts/generate_dependency_graph.sh` to ensure it stays up-to-date, and the script now runs automatically as part of `docs/build.sh`. + +## Dependency Graph for Workspace Crates + + + +```{raw} html + + +``` + +### Legend + +- black lines: normal dependency +- blue lines: dev-dependency +- green lines: build-dependency +- dotted lines: optional dependency (could be removed by disabling a cargo feature) + +Transitive dependencies are intentionally ignored to keep the graph readable. + +The dependency graph is generated through `cargo depgraph` by `docs/scripts/generate_dependency_graph.sh`. diff --git a/docs/source/contributor-guide/communication.md b/docs/source/contributor-guide/communication.md index 5d4561a3512c..ad80ea498f50 100644 --- a/docs/source/contributor-guide/communication.md +++ b/docs/source/contributor-guide/communication.md @@ -17,56 +17,56 @@ under the License. --> -# Communication +# Community Communication We welcome participation from everyone and encourage you to join us, ask questions, and get involved. - All participation in the Apache DataFusion project is governed by the Apache Software Foundation's [code of conduct](https://www.apache.org/foundation/policies/conduct.html). ## GitHub -The vast majority of communication occurs in the open on our -[github repository](https://github.com/apache/datafusion) in the form of tickets, issues, discussions, and Pull Requests. +The primary means of communication is the +[GitHub repository](https://github.com/apache/datafusion) in the form of issues, discussions, and Pull Requests. +Our repository is open to everyone. We encourage you to +participate by reporting issues, asking questions, and contributing code. -## Slack and Discord +## Chat -We use the Slack and Discord platforms for informal discussions and coordination. These are great places to -meet other contributors and get guidance on where to contribute. It is important to note that any technical designs and -decisions are made fully in the open, on GitHub. +We also use the Discord and Slack platforms for lower latency, informal discussions and coordination. +These are great places to +meet other members of the community, ask questions, and brainstorm ideas. +However, to ensure technical discussions are archived and accessible to everyone, +all technical designs are recorded and formalized in GitHub issues. -Most of us use the [ASF Slack -workspace](https://s.apache.org/slack-invite) and the [Arrow Rust Discord -server][discord-link] for discussions. +### Discord -There are specific channels for Arrow, DataFusion, and the DataFusion subprojects (Ballista, Comet, Python, etc). +Historically, the most active discussion forum has been the [Arrow Rust Discord +server][discord-link] which has specific channels for Arrow, DataFusion, and +DataFusion subprojects such as Ballista, Comet, Python, etc. +DataFusion specific channels are prefixed with the `#datafusion-` tag. +We recommend new users join this server for real-time discussions with the community. -In Slack we use these channels: +### Slack -- #arrow -- #arrow-rust -- #datafusion -- #datafusion-ballista -- #datafusion-comet -- #datafusion-python +Some of the community also uses the [ASF Slack workspace] for discussions. This +has historically been much less active than the Discord server. +Unfortunately, due to spammers, the ASF Slack workspace [requires an invitation] +to join. We are happy to invite any community member -- please ask for an +invitation in the Discord server. -In Discord we use these channels: +[asf slack workspace]: https://the-asf.slack.com/ +[requires an invitation]: https://s.apache.org/slack-invite -- #ballista -- #comet -- #contrib-federation -- #datafusion -- #datafusion-python -- #dolomite-optimizer -- #general -- #hiring -- #incremental-materialized-views +In Slack, we use these channels: -Unfortunately, due to spammers, the ASF Slack workspace requires an invitation -to join. We are happy to invite you -- please ask for an invitation in the -Discord server. +- `#arrow` +- `#arrow-rust` +- `#datafusion` +- `#datafusion-ballista` +- `#datafusion-comet` +- `#datafusion-python` ### Job Board @@ -77,8 +77,8 @@ Please feel free to post links to DataFusion related jobs there. ## Mailing Lists Like other Apache projects, we use [mailing lists] for certain purposes, most -importantly release coordination. Other than the release process, most -DataFusion mailing list traffic will simply link to a GitHub issue or PR where +importantly release coordination and announcing new committers and PMC members. +Other than these processes, most DataFusion mailing list traffic will link to a GitHub issue or PR where the actual discussion occurs. The project mailing lists are: - [`dev@datafusion.apache.org`](mailto:dev@datafusion.apache.org): the main diff --git a/docs/source/contributor-guide/howtos.md b/docs/source/contributor-guide/howtos.md index 1b38e95bf35d..18d9391d24bb 100644 --- a/docs/source/contributor-guide/howtos.md +++ b/docs/source/contributor-guide/howtos.md @@ -187,4 +187,4 @@ valid installation of [protoc] (see [installation instructions] for details). ``` [protoc]: https://github.com/protocolbuffers/protobuf#protocol-compiler-installation -[installation instructions]: https://datafusion.apache.org/contributor-guide/getting_started.html#protoc-installation +[installation instructions]: https://datafusion.apache.org/contributor-guide/development_environment.html#protoc-installation diff --git a/docs/source/contributor-guide/index.md b/docs/source/contributor-guide/index.md index ea42329f2c00..2ee8a2aaac6c 100644 --- a/docs/source/contributor-guide/index.md +++ b/docs/source/contributor-guide/index.md @@ -199,3 +199,35 @@ Please understand the reviewing capacity is **very limited** for the project, so ### Better ways to contribute than an “AI dump” It's recommended to write a high-quality issue with a clear problem statement and a minimal, reproducible example. This can make it easier for others to contribute. + +### CI Runners + +#### Runs-On + +We use [Runs-On](https://runs-on.com/) for some actions in the main repository, which run in the ASF AWS account to speed up CI. In forks, these actions run on the default GitHub runners since forks do not have access to ASF infrastructure. + +To configure them, we use the following format: + +`runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }}` + +This is a conditional expression that uses Runs-On custom runners for the main repository and falls back to the standard GitHub runners for forks. Runs-On configuration follows the [Runs-On pattern](https://runs-on.com/configuration/job-labels/). + +For those actions we also use the [Runs-On action](https://runs-on.com/caching/magic-cache/#how-to-use), which adds support for external caching and reports job metrics: + +`- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e` + +For the standard GitHub runners, this action will do nothing. + +##### Spot Instances + +By default, Runs-On actions run as [spot instances](https://runs-on.com/configuration/spot-instances/), which means they might occasionally be interrupted. In the CI you would see: + +``` +Error: The operation was canceled. +``` + +According to Runs-On, spot instance termination is extremely rare for instances running for less than 1h. Those actions will be restarted automatically. + +#### GitHub Runners + +We also use standard GitHub runners for some actions in the main repository; these are also runnable in forks. diff --git a/docs/source/contributor-guide/testing.md b/docs/source/contributor-guide/testing.md index 81ceabb646bf..43b727211de7 100644 --- a/docs/source/contributor-guide/testing.md +++ b/docs/source/contributor-guide/testing.md @@ -70,7 +70,9 @@ DataFusion's SQL implementation is tested using [sqllogictest](https://github.co cargo test --profile=ci --test sqllogictests # Run a specific test file cargo test --profile=ci --test sqllogictests -- aggregate.slt -# Run and update expected outputs +# Run a specific test file and update expected outputs +cargo test --profile=ci --test sqllogictests -- aggregate.slt --complete +# Run and update expected outputs for all test files cargo test --profile=ci --test sqllogictests -- --complete ``` @@ -104,6 +106,7 @@ locally by following the [instructions in the documentation]. [sqlite test suite]: https://www.sqlite.org/sqllogictest/dir?ci=tip [instructions in the documentation]: https://github.com/apache/datafusion/tree/main/datafusion/sqllogictest#running-tests-sqlite +[extended.yml]: https://github.com/apache/datafusion/blob/main/.github/workflows/extended.yml ## Rust Integration Tests diff --git a/docs/source/download.md b/docs/source/download.md index 7a62e398c02b..ed8fc06440f0 100644 --- a/docs/source/download.md +++ b/docs/source/download.md @@ -26,7 +26,7 @@ For example: ```toml [dependencies] -datafusion = "41.0.0" +datafusion = "53.0.0" ``` While DataFusion is distributed via [crates.io] as a convenience, the diff --git a/docs/source/index.rst b/docs/source/index.rst index 9764e6c99526..4d57faa0cbf7 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -61,7 +61,7 @@ The following related subprojects target end users and have separate documentati "Out of the box," DataFusion offers `SQL `_ and `Dataframe `_ APIs, excellent `performance `_, built-in support for CSV, Parquet, JSON, and Avro, -extensive customization, and a great community. +extensive customization, and a great `community`_. `Python Bindings `_ are also available. `Ballista `_ is Apache DataFusion extension enabling the parallelized execution of workloads across multiple nodes in a distributed environment. @@ -81,6 +81,7 @@ To get started, see .. _datafusion-examples: https://github.com/apache/datafusion/tree/main/datafusion-examples .. _developer’s guide: contributor-guide/index.html#developer-s-guide .. _library user guide: library-user-guide/index.html +.. _community: contributor-guide/communication.html .. _communication: contributor-guide/communication.html .. _toc.asf-links: @@ -133,7 +134,7 @@ To get started, see :caption: Library User Guide library-user-guide/index - library-user-guide/upgrading + library-user-guide/upgrading/index library-user-guide/extensions library-user-guide/using-the-sql-api library-user-guide/extending-sql @@ -158,6 +159,7 @@ To get started, see contributor-guide/communication contributor-guide/development_environment contributor-guide/architecture + contributor-guide/architecture/dependency-graph contributor-guide/testing contributor-guide/api-health contributor-guide/howtos diff --git a/docs/source/library-user-guide/custom-table-providers.md b/docs/source/library-user-guide/custom-table-providers.md index 8e1dee9e843a..50005a7527da 100644 --- a/docs/source/library-user-guide/custom-table-providers.md +++ b/docs/source/library-user-guide/custom-table-providers.md @@ -108,7 +108,7 @@ impl ExecutionPlan for CustomExec { } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { unreachable!() } @@ -232,7 +232,7 @@ The `scan` method of the `TableProvider` returns a `Result &PlanProperties { +# fn properties(&self) -> &Arc { # unreachable!() # } # @@ -424,7 +424,7 @@ This will allow you to use the custom table provider in DataFusion. For example, # } # # -# fn properties(&self) -> &PlanProperties { +# fn properties(&self) -> &Arc { # unreachable!() # } # diff --git a/docs/source/library-user-guide/extending-sql.md b/docs/source/library-user-guide/extending-sql.md index 409a0fb89a32..687d884895c8 100644 --- a/docs/source/library-user-guide/extending-sql.md +++ b/docs/source/library-user-guide/extending-sql.md @@ -27,6 +27,11 @@ need to: - Add custom data types not natively supported - Implement SQL constructs like `TABLESAMPLE`, `PIVOT`/`UNPIVOT`, or `MATCH_RECOGNIZE` +You can read more about this topic in the [Extending SQL in DataFusion: from ->> +to TABLESAMPLE] blog. + +[extending sql in datafusion: from ->> to tablesample]: https://datafusion.apache.org/blog/2026/01/12/extending-sql + ## Architecture Overview When DataFusion processes a SQL query, it goes through these stages: @@ -329,7 +334,7 @@ SELECT * FROM sales [`executionplan`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html [`sessioncontext`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html [`sessionstatebuilder`]: https://docs.rs/datafusion/latest/datafusion/execution/session_state/struct.SessionStateBuilder.html -[`relationplannercontext`]: https://docs.rs/datafusion/latest/datafusion/sql/planner/trait.RelationPlannerContext.html +[`relationplannercontext`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/planner/trait.RelationPlannerContext.html [exprplanner api documentation]: https://docs.rs/datafusion/latest/datafusion/logical_expr/planner/trait.ExprPlanner.html [typeplanner api documentation]: https://docs.rs/datafusion/latest/datafusion/logical_expr/planner/trait.TypePlanner.html [relationplanner api documentation]: https://docs.rs/datafusion/latest/datafusion/logical_expr/planner/trait.RelationPlanner.html diff --git a/docs/source/library-user-guide/functions/adding-udfs.md b/docs/source/library-user-guide/functions/adding-udfs.md index 5d033ae3f9e9..48162d6abcdf 100644 --- a/docs/source/library-user-guide/functions/adding-udfs.md +++ b/docs/source/library-user-guide/functions/adding-udfs.md @@ -583,7 +583,6 @@ For async UDF implementation details, see [`async_udf.rs`](https://github.com/ap [`scalarudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarUDF.html [`create_udf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udf.html -[`process_scalar_func_inputs`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.process_scalar_func_inputs.html [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/advanced_udf.rs ## Named Arguments @@ -684,6 +683,10 @@ No function matches the given name and argument types substr(Utf8). Scalar UDFs are functions that take a row of data and return a single value. Window UDFs are similar, but they also have access to the rows around them. Access to the proximal rows is helpful, but adds some complexity to the implementation. +For background and other considerations, see the [User defined Window Functions in DataFusion] blog. + +[user defined window functions in datafusion]: https://datafusion.apache.org/blog/2025/04/19/user-defined-window-functions + For example, we will declare a user defined window function that computes a moving average. ```rust diff --git a/docs/source/library-user-guide/query-optimizer.md b/docs/source/library-user-guide/query-optimizer.md index 8ed6593d5620..2254776bf6e3 100644 --- a/docs/source/library-user-guide/query-optimizer.md +++ b/docs/source/library-user-guide/query-optimizer.md @@ -25,11 +25,21 @@ format. DataFusion has modular design, allowing individual crates to be re-used in other projects. This crate is a submodule of DataFusion that provides a query optimizer for logical plans, and -contains an extensive set of [`OptimizerRule`]s and [`PhysicalOptimizerRules`] that may rewrite the plan and/or its expressions so +contains an extensive set of [`OptimizerRule`]s and [`PhysicalOptimizerRule`]s that may rewrite the plan and/or its expressions so they execute more quickly while still computing the same result. +For a deeper background on optimizer architecture and rule types and predicates, see +[Optimizing SQL (and DataFrames) in DataFusion, Part 1], [Part 2], +[Using Ordering for Better Plans in Apache DataFusion], and +[Dynamic Filters: Passing Information Between Operators During Execution for 25x Faster Queries]. + [`optimizerrule`]: https://docs.rs/datafusion/latest/datafusion/optimizer/trait.OptimizerRule.html -[`physicaloptimizerrules`]: https://docs.rs/datafusion/latest/datafusion/physical_optimizer/trait.PhysicalOptimizerRule.html +[`physicaloptimizerrule`]: https://docs.rs/datafusion/latest/datafusion/physical_optimizer/trait.PhysicalOptimizerRule.html +[optimizing sql (and dataframes) in datafusion, part 1]: https://datafusion.apache.org/blog/2025/06/15/optimizing-sql-dataframes-part-one +[part 2]: https://datafusion.apache.org/blog/2025/06/15/optimizing-sql-dataframes-part-two +[using ordering for better plans in apache datafusion]: https://datafusion.apache.org/blog/2025/03/11/ordering-analysis +[dynamic filters: passing information between operators during execution for 25x faster queries]: https://datafusion.apache.org/blog/2025/09/10/dynamic-filters +[`logicalplan`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.LogicalPlan.html ## Running the Optimizer @@ -75,7 +85,7 @@ Please refer to the example to learn more about the general approach to writing optimizer rules and then move onto studying the existing rules. -`OptimizerRule` transforms one ['LogicalPlan'] into another which +`OptimizerRule` transforms one [`LogicalPlan`] into another which computes the same results, but in a potentially more efficient way. If there are no suitable transformations for the input plan, the optimizer can simply return it as is. @@ -504,3 +514,5 @@ fn analyze_filter_example() -> Result<()> { Ok(()) } ``` + +[treenode api]: https://docs.rs/datafusion/latest/datafusion/common/tree_node/trait.TreeNode.html diff --git a/docs/source/library-user-guide/table-constraints.md b/docs/source/library-user-guide/table-constraints.md index dea746463d23..252817822d99 100644 --- a/docs/source/library-user-guide/table-constraints.md +++ b/docs/source/library-user-guide/table-constraints.md @@ -37,6 +37,6 @@ They are provided for informational purposes and can be used by custom - **Foreign keys and check constraints**: These constraints are parsed but are not validated or used during query planning. -[`tableconstraint`]: https://docs.rs/datafusion/latest/datafusion/sql/planner/enum.TableConstraint.html -[`constraints`]: https://docs.rs/datafusion/latest/datafusion/common/functional_dependencies/struct.Constraints.html -[`field`]: https://docs.rs/arrow/latest/arrow/datatype/struct.Field.html +[`tableconstraint`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/sqlparser/ast/enum.TableConstraint.html +[`constraints`]: https://docs.rs/datafusion/latest/datafusion/common/struct.Constraints.html +[`field`]: https://docs.rs/arrow/latest/arrow/datatypes/struct.Field.html diff --git a/docs/source/library-user-guide/upgrading.md b/docs/source/library-user-guide/upgrading.md deleted file mode 100644 index 61246f00dfe7..000000000000 --- a/docs/source/library-user-guide/upgrading.md +++ /dev/null @@ -1,2085 +0,0 @@ - - -# Upgrade Guides - -## DataFusion `52.0.0` - -**Note:** DataFusion `52.0.0` has not been released yet. The information provided in this section pertains to features and changes that have already been merged to the main branch and are awaiting release in this version. - -You can see the current [status of the `52.0.0`release here](https://github.com/apache/datafusion/issues/18566) - -### Changes to DFSchema API - -To permit more efficient planning, several methods on `DFSchema` have been -changed to return references to the underlying [`&FieldRef`] rather than -[`&Field`]. This allows planners to more cheaply copy the references via -`Arc::clone` rather than cloning the entire `Field` structure. - -You may need to change code to use `Arc::clone` instead of `.as_ref().clone()` -directly on the `Field`. For example: - -```diff -- let field = df_schema.field("my_column").as_ref().clone(); -+ let field = Arc::clone(df_schema.field("my_column")); -``` - -### ListingTableProvider now caches `LIST` commands - -In prior versions, `ListingTableProvider` would issue `LIST` commands to -the underlying object store each time it needed to list files for a query. -To improve performance, `ListingTableProvider` now caches the results of -`LIST` commands for the lifetime of the `ListingTableProvider` instance or -until a cache entry expires. - -Note that by default the cache has no expiration time, so if files are added or removed -from the underlying object store, the `ListingTableProvider` will not see -those changes until the `ListingTableProvider` instance is dropped and recreated. - -You can configure the maximum cache size and cache entry expiration time via configuration options: - -- `datafusion.runtime.list_files_cache_limit` - Limits the size of the cache in bytes -- `datafusion.runtime.list_files_cache_ttl` - Limits the TTL (time-to-live) of an entry in seconds - -Detailed configuration information can be found in the [DataFusion Runtime -Configuration](https://datafusion.apache.org/user-guide/configs.html#runtime-configuration-settings) user's guide. - -Caching can be disabled by setting the limit to 0: - -```sql -SET datafusion.runtime.list_files_cache_limit TO "0K"; -``` - -Note that the internal API has changed to use a trait `ListFilesCache` instead of a type alias. - -### `newlines_in_values` moved from `FileScanConfig` to `CsvOptions` - -The CSV-specific `newlines_in_values` configuration option has been moved from `FileScanConfig` to `CsvOptions`, as it only applies to CSV file parsing. - -**Who is affected:** - -- Users who set `newlines_in_values` via `FileScanConfigBuilder::with_newlines_in_values()` - -**Migration guide:** - -Set `newlines_in_values` in `CsvOptions` instead of on `FileScanConfigBuilder`: - -**Before:** - -```rust,ignore -let source = Arc::new(CsvSource::new(file_schema.clone())); -let config = FileScanConfigBuilder::new(object_store_url, source) - .with_newlines_in_values(true) - .build(); -``` - -**After:** - -```rust,ignore -let options = CsvOptions { - newlines_in_values: Some(true), - ..Default::default() -}; -let source = Arc::new(CsvSource::new(file_schema.clone()) - .with_csv_options(options)); -let config = FileScanConfigBuilder::new(object_store_url, source) - .build(); -``` - -### Removal of `pyarrow` feature - -The `pyarrow` feature flag has been removed. This feature has been migrated to -the `datafusion-python` repository since version `44.0.0`. - -### Refactoring of `FileSource` constructors and `FileScanConfigBuilder` to accept schemas upfront - -The way schemas are passed to file sources and scan configurations has been significantly refactored. File sources now require the schema (including partition columns) to be provided at construction time, and `FileScanConfigBuilder` no longer takes a separate schema parameter. - -**Who is affected:** - -- Users who create `FileScanConfig` or file sources (`ParquetSource`, `CsvSource`, `JsonSource`, `AvroSource`) directly -- Users who implement custom `FileFormat` implementations - -**Key changes:** - -1. **FileSource constructors now require TableSchema**: All built-in file sources now take the schema in their constructor: - - ```diff - - let source = ParquetSource::default(); - + let source = ParquetSource::new(table_schema); - ``` - -2. **FileScanConfigBuilder no longer takes schema as a parameter**: The schema is now passed via the FileSource: - - ```diff - - FileScanConfigBuilder::new(url, schema, source) - + FileScanConfigBuilder::new(url, source) - ``` - -3. **Partition columns are now part of TableSchema**: The `with_table_partition_cols()` method has been removed from `FileScanConfigBuilder`. Partition columns are now passed as part of the `TableSchema` to the FileSource constructor: - - ```diff - + let table_schema = TableSchema::new( - + file_schema, - + vec![Arc::new(Field::new("date", DataType::Utf8, false))], - + ); - + let source = ParquetSource::new(table_schema); - let config = FileScanConfigBuilder::new(url, source) - - .with_table_partition_cols(vec![Field::new("date", DataType::Utf8, false)]) - .with_file(partitioned_file) - .build(); - ``` - -4. **FileFormat::file_source() now takes TableSchema parameter**: Custom `FileFormat` implementations must be updated: - ```diff - impl FileFormat for MyFileFormat { - - fn file_source(&self) -> Arc { - + fn file_source(&self, table_schema: TableSchema) -> Arc { - - Arc::new(MyFileSource::default()) - + Arc::new(MyFileSource::new(table_schema)) - } - } - ``` - -**Migration examples:** - -For Parquet files: - -```diff -- let source = Arc::new(ParquetSource::default()); -- let config = FileScanConfigBuilder::new(url, schema, source) -+ let table_schema = TableSchema::new(schema, vec![]); -+ let source = Arc::new(ParquetSource::new(table_schema)); -+ let config = FileScanConfigBuilder::new(url, source) - .with_file(partitioned_file) - .build(); -``` - -For CSV files with partition columns: - -```diff -- let source = Arc::new(CsvSource::new(true, b',', b'"')); -- let config = FileScanConfigBuilder::new(url, file_schema, source) -- .with_table_partition_cols(vec![Field::new("year", DataType::Int32, false)]) -+ let options = CsvOptions { -+ has_header: Some(true), -+ delimiter: b',', -+ quote: b'"', -+ ..Default::default() -+ }; -+ let table_schema = TableSchema::new( -+ file_schema, -+ vec![Arc::new(Field::new("year", DataType::Int32, false))], -+ ); -+ let source = Arc::new(CsvSource::new(table_schema).with_csv_options(options)); -+ let config = FileScanConfigBuilder::new(url, source) - .build(); -``` - -### Adaptive filter representation in Parquet filter pushdown - -As of Arrow 57.1.0, DataFusion uses a new adaptive filter strategy when -evaluating pushed down filters for Parquet files. This new strategy improves -performance for certain types of queries where the results of filtering are -more efficiently represented with a bitmask rather than a selection. -See [arrow-rs #5523] for more details. - -This change only applies to the built-in Parquet data source with filter-pushdown enabled ( -which is [not yet the default behavior]). - -You can disable the new behavior by setting the -`datafusion.execution.parquet.force_filter_selections` [configuration setting] to true. - -```sql -> set datafusion.execution.parquet.force_filter_selections = true; -``` - -[arrow-rs #5523]: https://github.com/apache/arrow-rs/issues/5523 -[configuration setting]: https://datafusion.apache.org/user-guide/configs.html -[not yet the default behavior]: https://github.com/apache/datafusion/issues/3463 - -### Statistics handling moved from `FileSource` to `FileScanConfig` - -Statistics are now managed directly by `FileScanConfig` instead of being delegated to `FileSource` implementations. This simplifies the `FileSource` trait and provides more consistent statistics handling across all file formats. - -**Who is affected:** - -- Users who have implemented custom `FileSource` implementations - -**Breaking changes:** - -Two methods have been removed from the `FileSource` trait: - -- `with_statistics(&self, statistics: Statistics) -> Arc` -- `statistics(&self) -> Result` - -**Migration guide:** - -If you have a custom `FileSource` implementation, you need to: - -1. Remove the `with_statistics` method implementation -2. Remove the `statistics` method implementation -3. Remove any internal state that was storing statistics - -**Before:** - -```rust,ignore -#[derive(Clone)] -struct MyCustomSource { - table_schema: TableSchema, - projected_statistics: Option, - // other fields... -} - -impl FileSource for MyCustomSource { - fn with_statistics(&self, statistics: Statistics) -> Arc { - Arc::new(Self { - table_schema: self.table_schema.clone(), - projected_statistics: Some(statistics), - // other fields... - }) - } - - fn statistics(&self) -> Result { - Ok(self.projected_statistics.clone().unwrap_or_else(|| - Statistics::new_unknown(self.table_schema.file_schema()) - )) - } - - // other methods... -} -``` - -**After:** - -```rust,ignore -#[derive(Clone)] -struct MyCustomSource { - table_schema: TableSchema, - // projected_statistics field removed - // other fields... -} - -impl FileSource for MyCustomSource { - // with_statistics method removed - // statistics method removed - - // other methods... -} -``` - -**Accessing statistics:** - -Statistics are now accessed through `FileScanConfig` instead of `FileSource`: - -```diff -- let stats = config.file_source.statistics()?; -+ let stats = config.statistics(); -``` - -Note that `FileScanConfig::statistics()` automatically marks statistics as inexact when filters are present, ensuring correctness when filters are pushed down. - -### Partition column handling moved out of `PhysicalExprAdapter` - -Partition column replacement is now a separate preprocessing step performed before expression rewriting via `PhysicalExprAdapter`. This change provides better separation of concerns and makes the adapter more focused on schema differences rather than partition value substitution. - -**Who is affected:** - -- Users who have custom implementations of `PhysicalExprAdapterFactory` that handle partition columns -- Users who directly use the `FilePruner` API - -**Breaking changes:** - -1. `FilePruner::try_new()` signature changed: the `partition_fields` parameter has been removed since partition column handling is now done separately -2. Partition column replacement must now be done via `replace_columns_with_literals()` before expressions are passed to the adapter - -**Migration guide:** - -If you have code that creates a `FilePruner` with partition fields: - -**Before:** - -```rust,ignore -use datafusion_pruning::FilePruner; - -let pruner = FilePruner::try_new( - predicate, - file_schema, - partition_fields, // This parameter is removed - file_stats, -)?; -``` - -**After:** - -```rust,ignore -use datafusion_pruning::FilePruner; - -// Partition fields are no longer needed -let pruner = FilePruner::try_new( - predicate, - file_schema, - file_stats, -)?; -``` - -If you have custom code that relies on `PhysicalExprAdapter` to handle partition columns, you must now call `replace_columns_with_literals()` separately: - -**Before:** - -```rust,ignore -// Adapter handled partition column replacement internally -let adapted_expr = adapter.rewrite(expr)?; -``` - -**After:** - -```rust,ignore -use datafusion_physical_expr_adapter::replace_columns_with_literals; - -// Replace partition columns first -let expr_with_literals = replace_columns_with_literals(expr, &partition_values)?; -// Then apply the adapter -let adapted_expr = adapter.rewrite(expr_with_literals)?; -``` - -### `build_row_filter` signature simplified - -The `build_row_filter` function in `datafusion-datasource-parquet` has been simplified to take a single schema parameter instead of two. -The expectation is now that the filter has been adapted to the physical file schema (the arrow representation of the parquet file's schema) before being passed to this function -using a `PhysicalExprAdapter` for example. - -**Who is affected:** - -- Users who call `build_row_filter` directly - -**Breaking changes:** - -The function signature changed from: - -```rust,ignore -pub fn build_row_filter( - expr: &Arc, - physical_file_schema: &SchemaRef, - predicate_file_schema: &SchemaRef, // removed - metadata: &ParquetMetaData, - reorder_predicates: bool, - file_metrics: &ParquetFileMetrics, -) -> Result> -``` - -To: - -```rust,ignore -pub fn build_row_filter( - expr: &Arc, - file_schema: &SchemaRef, - metadata: &ParquetMetaData, - reorder_predicates: bool, - file_metrics: &ParquetFileMetrics, -) -> Result> -``` - -**Migration guide:** - -Remove the duplicate schema parameter from your call: - -```diff -- build_row_filter(&predicate, &file_schema, &file_schema, metadata, reorder, metrics) -+ build_row_filter(&predicate, &file_schema, metadata, reorder, metrics) -``` - -### Planner now requires explicit opt-in for WITHIN GROUP syntax - -The SQL planner now enforces the aggregate UDF contract more strictly: the -`WITHIN GROUP (ORDER BY ...)` syntax is accepted only if the aggregate UDAF -explicitly advertises support by returning `true` from -`AggregateUDFImpl::supports_within_group_clause()`. - -Previously the planner forwarded a `WITHIN GROUP` clause to order-sensitive -aggregates even when they did not implement ordered-set semantics, which could -cause queries such as `SUM(x) WITHIN GROUP (ORDER BY x)` to plan successfully. -This behavior was too permissive and has been changed to match PostgreSQL and -the documented semantics. - -Migration: If your UDAF intentionally implements ordered-set semantics and -wants to accept the `WITHIN GROUP` SQL syntax, update your implementation to -return `true` from `supports_within_group_clause()` and handle the ordering -semantics in your accumulator implementation. If your UDAF is merely -order-sensitive (but not an ordered-set aggregate), do not advertise -`supports_within_group_clause()` and clients should use alternative function -signatures (for example, explicit ordering as a function argument) instead. - -### `AggregateUDFImpl::supports_null_handling_clause` now defaults to `false` - -This method specifies whether an aggregate function allows `IGNORE NULLS`/`RESPECT NULLS` -during SQL parsing, with the implication it respects these configs during computation. - -Most DataFusion aggregate functions silently ignored this syntax in prior versions -as they did not make use of it and it was permitted by default. We change this so -only the few functions which do respect this clause (e.g. `array_agg`, `first_value`, -`last_value`) need to implement it. - -Custom user defined aggregate functions will also error if this syntax is used, -unless they explicitly declare support by overriding the method. - -For example, SQL parsing will now fail for queries such as this: - -```sql -SELECT median(c1) IGNORE NULLS FROM table -``` - -Instead of silently succeeding. - -### API change for `CacheAccessor` trait - -The remove API no longer requires a mutable instance - -### FFI crate updates - -Many of the structs in the `datafusion-ffi` crate have been updated to allow easier -conversion to the underlying trait types they represent. This simplifies some code -paths, but also provides an additional improvement in cases where library code goes -through a round trip via the foreign function interface. - -To update your code, suppose you have a `FFI_SchemaProvider` called `ffi_provider` -and you wish to use this as a `SchemaProvider`. In the old approach you would do -something like: - -```rust,ignore - let foreign_provider: ForeignSchemaProvider = ffi_provider.into(); - let foreign_provider = Arc::new(foreign_provider) as Arc; -``` - -This code should now be written as: - -```rust,ignore - let foreign_provider: Arc = ffi_provider.into(); - let foreign_provider = foreign_provider as Arc; -``` - -For the case of user defined functions, the updates are similar but you -may need to change the way you call the creation of the `ScalarUDF`. -Aggregate and window functions follow the same pattern. - -Previously you may write: - -```rust,ignore - let foreign_udf: ForeignScalarUDF = ffi_udf.try_into()?; - let foreign_udf: ScalarUDF = foreign_udf.into(); -``` - -Instead this should now be: - -```rust,ignore - let foreign_udf: Arc = ffi_udf.into(); - let foreign_udf = ScalarUDF::new_from_shared_impl(foreign_udf); -``` - -When creating any of the following structs, we now require the user to -provide a `TaskContextProvider` and optionally a `LogicalExtensionCodec`: - -- `FFI_CatalogListProvider` -- `FFI_CatalogProvider` -- `FFI_SchemaProvider` -- `FFI_TableProvider` -- `FFI_TableFunction` - -Each of these structs has a `new()` and a `new_with_ffi_codec()` method for -instantiation. For example, when you previously would write - -```rust,ignore - let table = Arc::new(MyTableProvider::new()); - let ffi_table = FFI_TableProvider::new(table, None); -``` - -Now you will need to provide a `TaskContextProvider`. The most common -implementation of this trait is `SessionContext`. - -```rust,ignore - let ctx = Arc::new(SessionContext::default()); - let table = Arc::new(MyTableProvider::new()); - let ffi_table = FFI_TableProvider::new(table, None, ctx, None); -``` - -The alternative function to create these structures may be more convenient -if you are doing many of these operations. A `FFI_LogicalExtensionCodec` will -store the `TaskContextProvider` as well. - -```rust,ignore - let codec = Arc::new(DefaultLogicalExtensionCodec {}); - let ctx = Arc::new(SessionContext::default()); - let ffi_codec = FFI_LogicalExtensionCodec::new(codec, None, ctx); - let table = Arc::new(MyTableProvider::new()); - let ffi_table = FFI_TableProvider::new_with_ffi_codec(table, None, ffi_codec); -``` - -Additional information about the usage of the `TaskContextProvider` can be -found in the crate README. - -Additionally, the FFI structure for Scalar UDF's no longer contains a -`return_type` call. This code was not used since the `ForeignScalarUDF` -struct implements the `return_field_from_args` instead. - -### Projection handling moved from FileScanConfig to FileSource - -Projection handling has been moved from `FileScanConfig` into `FileSource` implementations. This enables format-specific projection pushdown (e.g., Parquet can push down struct field access, Vortex can push down computed expressions into un-decoded data). - -**Who is affected:** - -- Users who have implemented custom `FileSource` implementations -- Users who use `FileScanConfigBuilder::with_projection_indices` directly - -**Breaking changes:** - -1. **`FileSource::with_projection` replaced with `try_pushdown_projection`:** - - The `with_projection(&self, config: &FileScanConfig) -> Arc` method has been removed and replaced with `try_pushdown_projection(&self, projection: &ProjectionExprs) -> Result>>`. - -2. **`FileScanConfig.projection_exprs` field removed:** - - Projections are now stored in the `FileSource` directly, not in `FileScanConfig`. - Various public helper methods that access projection information have been removed from `FileScanConfig`. - -3. **`FileScanConfigBuilder::with_projection_indices` now returns `Result`:** - - This method can now fail if the projection pushdown fails. - -4. **`FileSource::create_file_opener` now returns `Result>`:** - - Previously returned `Arc` directly. - Any `FileSource` implementation that may fail to create a `FileOpener` should now return an appropriate error. - -5. **`DataSource::try_swapping_with_projection` signature changed:** - - Parameter changed from `&[ProjectionExpr]` to `&ProjectionExprs`. - -**Migration guide:** - -If you have a custom `FileSource` implementation: - -**Before:** - -```rust,ignore -impl FileSource for MyCustomSource { - fn with_projection(&self, config: &FileScanConfig) -> Arc { - // Apply projection from config - Arc::new(Self { /* ... */ }) - } - - fn create_file_opener( - &self, - object_store: Arc, - base_config: &FileScanConfig, - partition: usize, - ) -> Arc { - Arc::new(MyOpener { /* ... */ }) - } -} -``` - -**After:** - -```rust,ignore -impl FileSource for MyCustomSource { - fn try_pushdown_projection( - &self, - projection: &ProjectionExprs, - ) -> Result>> { - // Return None if projection cannot be pushed down - // Return Some(new_source) with projection applied if it can - Ok(Some(Arc::new(Self { - projection: Some(projection.clone()), - /* ... */ - }))) - } - - fn projection(&self) -> Option<&ProjectionExprs> { - self.projection.as_ref() - } - - fn create_file_opener( - &self, - object_store: Arc, - base_config: &FileScanConfig, - partition: usize, - ) -> Result> { - Ok(Arc::new(MyOpener { /* ... */ })) - } -} -``` - -We recommend you look at [#18627](https://github.com/apache/datafusion/pull/18627) -that introduced these changes for more examples for how this was handled for the various built in file sources. - -We have added [`SplitProjection`](https://docs.rs/datafusion-datasource/latest/datafusion_datasource/projection/struct.SplitProjection.html) and [`ProjectionOpener`](https://docs.rs/datafusion-datasource/latest/datafusion_datasource/projection/struct.ProjectionOpener.html) helpers to make it easier to handle projections in your `FileSource` implementations. - -For file sources that can only handle simple column selections (not computed expressions), use the `SplitProjection` and `ProjectionOpener` helpers to split the projection into pushdownable and non-pushdownable parts: - -```rust,ignore -use datafusion_datasource::projection::{SplitProjection, ProjectionOpener}; - -// In try_pushdown_projection: -let split = SplitProjection::new(projection, self.table_schema())?; -// Use split.file_projection() for what to push down to the file format -// The ProjectionOpener wrapper will handle the rest -``` - -**For `FileScanConfigBuilder` users:** - -```diff -let config = FileScanConfigBuilder::new(url, source) -- .with_projection_indices(Some(vec![0, 2, 3])) -+ .with_projection_indices(Some(vec![0, 2, 3]))? - .build(); -``` - -### `SchemaAdapter` and `SchemaAdapterFactory` completely removed - -Following the deprecation announced in [DataFusion 49.0.0](#deprecating-schemaadapterfactory-and-schemaadapter), `SchemaAdapterFactory` has been fully removed from Parquet scanning. This applies to both: - -The following symbols have been deprecated and will be removed in the next release: - -- `SchemaAdapter` trait -- `SchemaAdapterFactory` trait -- `SchemaMapper` trait -- `SchemaMapping` struct -- `DefaultSchemaAdapterFactory` struct - -These types were previously used to adapt record batch schemas during file reading. -This functionality has been replaced by `PhysicalExprAdapterFactory`, which rewrites expressions at planning time rather than transforming batches at runtime. -If you were using a custom `SchemaAdapterFactory` for schema adaptation (e.g., default column values, type coercion), you should now implement `PhysicalExprAdapterFactory` instead. -See the [default column values example](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/custom_data_source/default_column_values.rs) for how to implement a custom `PhysicalExprAdapterFactory`. - -**Migration guide:** - -If you implemented a custom `SchemaAdapterFactory`, migrate to `PhysicalExprAdapterFactory`. -See the [default column values example](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/custom_data_source/default_column_values.rs) for a complete implementation. - -## DataFusion `51.0.0` - -### `arrow` / `parquet` updated to 57.0.0 - -### Upgrade to arrow `57.0.0` and parquet `57.0.0` - -This version of DataFusion upgrades the underlying Apache Arrow implementation -to version `57.0.0`, including several dependent crates such as `prost`, -`tonic`, `pyo3`, and `substrait`. . See the [release -notes](https://github.com/apache/arrow-rs/releases/tag/57.0.0) for more details. - -### `MSRV` updated to 1.88.0 - -The Minimum Supported Rust Version (MSRV) has been updated to [`1.88.0`]. - -[`1.88.0`]: https://releases.rs/docs/1.88.0/ - -### `FunctionRegistry` exposes two additional methods - -`FunctionRegistry` exposes two additional methods `udafs` and `udwfs` which expose set of registered user defined aggregation and window function names. To upgrade implement methods returning set of registered function names: - -```diff -impl FunctionRegistry for FunctionRegistryImpl { - fn udfs(&self) -> HashSet { - self.scalar_functions.keys().cloned().collect() - } -+ fn udafs(&self) -> HashSet { -+ self.aggregate_functions.keys().cloned().collect() -+ } -+ -+ fn udwfs(&self) -> HashSet { -+ self.window_functions.keys().cloned().collect() -+ } -} -``` - -### `datafusion-proto` use `TaskContext` rather than `SessionContext` in physical plan serde methods - -There have been changes in the public API methods of `datafusion-proto` which handle physical plan serde. - -Methods like `physical_plan_from_bytes`, `parse_physical_expr` and similar, expect `TaskContext` instead of `SessionContext` - -```diff -- let plan2 = physical_plan_from_bytes(&bytes, &ctx)?; -+ let plan2 = physical_plan_from_bytes(&bytes, &ctx.task_ctx())?; -``` - -as `TaskContext` contains `RuntimeEnv` methods such as `try_into_physical_plan` will not have explicit `RuntimeEnv` parameter. - -```diff -let result_exec_plan: Arc = proto -- .try_into_physical_plan(&ctx, runtime.deref(), &composed_codec) -+. .try_into_physical_plan(&ctx.task_ctx(), &composed_codec) -``` - -`PhysicalExtensionCodec::try_decode()` expects `TaskContext` instead of `FunctionRegistry`: - -```diff -pub trait PhysicalExtensionCodec { - fn try_decode( - &self, - buf: &[u8], - inputs: &[Arc], -- registry: &dyn FunctionRegistry, -+ ctx: &TaskContext, - ) -> Result>; -``` - -See [issue #17601] for more details. - -[issue #17601]: https://github.com/apache/datafusion/issues/17601 - -### `SessionState`'s `sql_to_statement` method takes `Dialect` rather than a `str` - -The `dialect` parameter of `sql_to_statement` method defined in `datafusion::execution::session_state::SessionState` -has changed from `&str` to `&Dialect`. -`Dialect` is an enum defined in the `datafusion-common` -crate under the `config` module that provides type safety -and better validation for SQL dialect selection - -### Reorganization of `ListingTable` into `datafusion-catalog-listing` crate - -There has been a long standing request to remove features such as `ListingTable` -from the `datafusion` crate to support faster build times. The structs -`ListingOptions`, `ListingTable`, and `ListingTableConfig` are now available -within the `datafusion-catalog-listing` crate. These are re-exported in -the `datafusion` crate, so this should be a minimal impact to existing users. - -See [issue #14462] and [issue #17713] for more details. - -[issue #14462]: https://github.com/apache/datafusion/issues/14462 -[issue #17713]: https://github.com/apache/datafusion/issues/17713 - -### Reorganization of `ArrowSource` into `datafusion-datasource-arrow` crate - -To support [issue #17713] the `ArrowSource` code has been removed from -the `datafusion` core crate into it's own crate, `datafusion-datasource-arrow`. -This follows the pattern for the AVRO, CSV, JSON, and Parquet data sources. -Users may need to update their paths to account for these changes. - -See [issue #17713] for more details. - -### `FileScanConfig::projection` renamed to `FileScanConfig::projection_exprs` - -The `projection` field in `FileScanConfig` has been renamed to `projection_exprs` and its type has changed from `Option>` to `Option`. This change enables more powerful projection pushdown capabilities by supporting arbitrary physical expressions rather than just column indices. - -**Impact on direct field access:** - -If you directly access the `projection` field: - -```rust,ignore -let config: FileScanConfig = ...; -let projection = config.projection; -``` - -You should update to: - -```rust,ignore -let config: FileScanConfig = ...; -let projection_exprs = config.projection_exprs; -``` - -**Impact on builders:** - -The `FileScanConfigBuilder::with_projection()` method has been deprecated in favor of `with_projection_indices()`: - -```diff -let config = FileScanConfigBuilder::new(url, file_source) -- .with_projection(Some(vec![0, 2, 3])) -+ .with_projection_indices(Some(vec![0, 2, 3])) - .build(); -``` - -Note: `with_projection()` still works but is deprecated and will be removed in a future release. - -**What is `ProjectionExprs`?** - -`ProjectionExprs` is a new type that represents a list of physical expressions for projection. While it can be constructed from column indices (which is what `with_projection_indices` does internally), it also supports arbitrary physical expressions, enabling advanced features like expression evaluation during scanning. - -You can access column indices from `ProjectionExprs` using its methods if needed: - -```rust,ignore -let projection_exprs: ProjectionExprs = ...; -// Get the column indices if the projection only contains simple column references -let indices = projection_exprs.column_indices(); -``` - -### `DESCRIBE query` support - -`DESCRIBE query` was previously an alias for `EXPLAIN query`, which outputs the -_execution plan_ of the query. With this release, `DESCRIBE query` now outputs -the computed _schema_ of the query, consistent with the behavior of `DESCRIBE table_name`. - -### `datafusion.execution.time_zone` default configuration changed - -The default value for `datafusion.execution.time_zone` previously was a string value of `+00:00` (GMT/Zulu time). -This was changed to be an `Option` with a default of `None`. If you want to change the timezone back -to the previous value you can execute the sql: - -```sql -SET -TIMEZONE = '+00:00'; -``` - -This change was made to better support using the default timezone in scalar UDF functions such as -`now`, `current_date`, `current_time`, and `to_timestamp` among others. - -### Introduction of `TableSchema` and changes to `FileSource::with_schema()` method - -A new `TableSchema` struct has been introduced in the `datafusion-datasource` crate to better manage table schemas with partition columns. This struct helps distinguish between: - -- **File schema**: The schema of actual data files on disk -- **Partition columns**: Columns derived from directory structure (e.g., Hive-style partitioning) -- **Table schema**: The complete schema combining both file and partition columns - -As part of this change, the `FileSource::with_schema()` method signature has changed from accepting a `SchemaRef` to accepting a `TableSchema`. - -**Who is affected:** - -- Users who have implemented custom `FileSource` implementations will need to update their code -- Users who only use built-in file sources (Parquet, CSV, JSON, AVRO, Arrow) are not affected - -**Migration guide for custom `FileSource` implementations:** - -```diff - use datafusion_datasource::file::FileSource; --use arrow::datatypes::SchemaRef; -+use datafusion_datasource::TableSchema; - - impl FileSource for MyCustomSource { -- fn with_schema(&self, schema: SchemaRef) -> Arc { -+ fn with_schema(&self, schema: TableSchema) -> Arc { - Arc::new(Self { -- schema: Some(schema), -+ // Use schema.file_schema() to get the file schema without partition columns -+ schema: Some(Arc::clone(schema.file_schema())), - ..self.clone() - }) - } - } -``` - -For implementations that need access to partition columns: - -```rust,ignore -fn with_schema(&self, schema: TableSchema) -> Arc { - Arc::new(Self { - file_schema: Arc::clone(schema.file_schema()), - partition_cols: schema.table_partition_cols().clone(), - table_schema: Arc::clone(schema.table_schema()), - ..self.clone() - }) -} -``` - -**Note**: Most `FileSource` implementations only need to store the file schema (without partition columns), as shown in the first example. The second pattern of storing all three schema components is typically only needed for advanced use cases where you need access to different schema representations for different operations (e.g., ParquetSource uses the file schema for building pruning predicates but needs the table schema for filter pushdown logic). - -**Using `TableSchema` directly:** - -If you're constructing a `FileScanConfig` or working with table schemas and partition columns, you can now use `TableSchema`: - -```rust -use datafusion_datasource::TableSchema; -use arrow::datatypes::{Schema, Field, DataType}; -use std::sync::Arc; - -// Create a TableSchema with partition columns -let file_schema = Arc::new(Schema::new(vec![ - Field::new("user_id", DataType::Int64, false), - Field::new("amount", DataType::Float64, false), -])); - -let partition_cols = vec![ - Arc::new(Field::new("date", DataType::Utf8, false)), - Arc::new(Field::new("region", DataType::Utf8, false)), -]; - -let table_schema = TableSchema::new(file_schema, partition_cols); - -// Access different schema representations -let file_schema_ref = table_schema.file_schema(); // Schema without partition columns -let full_schema = table_schema.table_schema(); // Complete schema with partition columns -let partition_cols_ref = table_schema.table_partition_cols(); // Just the partition columns -``` - -### `AggregateUDFImpl::is_ordered_set_aggregate` has been renamed to `AggregateUDFImpl::supports_within_group_clause` - -This method has been renamed to better reflect the actual impact it has for aggregate UDF implementations. -The accompanying `AggregateUDF::is_ordered_set_aggregate` has also been renamed to `AggregateUDF::supports_within_group_clause`. -No functionality has been changed with regards to this method; it still refers only to permitting use of `WITHIN GROUP` -SQL syntax for the aggregate function. - -## DataFusion `50.0.0` - -### ListingTable automatically detects Hive Partitioned tables - -DataFusion 50.0.0 automatically infers Hive partitions when using the `ListingTableFactory` and `CREATE EXTERNAL TABLE`. Previously, -when creating a `ListingTable`, datasets that use Hive partitioning (e.g. -`/table_root/column1=value1/column2=value2/data.parquet`) would not have the Hive columns reflected in -the table's schema or data. The previous behavior can be -restored by setting the `datafusion.execution.listing_table_factory_infer_partitions` configuration option to `false`. -See [issue #17049] for more details. - -[issue #17049]: https://github.com/apache/datafusion/issues/17049 - -### `MSRV` updated to 1.86.0 - -The Minimum Supported Rust Version (MSRV) has been updated to [`1.86.0`]. -See [#17230] for details. - -[`1.86.0`]: https://releases.rs/docs/1.86.0/ -[#17230]: https://github.com/apache/datafusion/pull/17230 - -### `ScalarUDFImpl`, `AggregateUDFImpl` and `WindowUDFImpl` traits now require `PartialEq`, `Eq`, and `Hash` traits - -To address error-proneness of `ScalarUDFImpl::equals`, `AggregateUDFImpl::equals`and -`WindowUDFImpl::equals` methods and to make it easy to implement function equality correctly, -the `equals` and `hash_value` methods have been removed from `ScalarUDFImpl`, `AggregateUDFImpl` -and `WindowUDFImpl` traits. They are replaced the requirement to implement the `PartialEq`, `Eq`, -and `Hash` traits on any type implementing `ScalarUDFImpl`, `AggregateUDFImpl` or `WindowUDFImpl`. -Please see [issue #16677] for more details. - -Most of the scalar functions are stateless and have a `signature` field. These can be migrated -using regular expressions - -- search for `\#\[derive\(Debug\)\](\n *(pub )?struct \w+ \{\n *signature\: Signature\,\n *\})`, -- replace with `#[derive(Debug, PartialEq, Eq, Hash)]$1`, -- review all the changes and make sure only function structs were changed. - -[issue #16677]: https://github.com/apache/datafusion/issues/16677 - -### `AsyncScalarUDFImpl::invoke_async_with_args` returns `ColumnarValue` - -In order to enable single value optimizations and be consistent with other -user defined function APIs, the `AsyncScalarUDFImpl::invoke_async_with_args` method now -returns a `ColumnarValue` instead of a `ArrayRef`. - -To upgrade, change the return type of your implementation - -```rust -# /* comment to avoid running -impl AsyncScalarUDFImpl for AskLLM { - async fn invoke_async_with_args( - &self, - args: ScalarFunctionArgs, - _option: &ConfigOptions, - ) -> Result { - .. - return array_ref; // old code - } -} -# */ -``` - -To return a `ColumnarValue` - -```rust -# /* comment to avoid running -impl AsyncScalarUDFImpl for AskLLM { - async fn invoke_async_with_args( - &self, - args: ScalarFunctionArgs, - _option: &ConfigOptions, - ) -> Result { - .. - return ColumnarValue::from(array_ref); // new code - } -} -# */ -``` - -See [#16896](https://github.com/apache/datafusion/issues/16896) for more details. - -### `ProjectionExpr` changed from type alias to struct - -`ProjectionExpr` has been changed from a type alias to a struct with named fields to improve code clarity and maintainability. - -**Before:** - -```rust,ignore -pub type ProjectionExpr = (Arc, String); -``` - -**After:** - -```rust,ignore -#[derive(Debug, Clone)] -pub struct ProjectionExpr { - pub expr: Arc, - pub alias: String, -} -``` - -To upgrade your code: - -- Replace tuple construction `(expr, alias)` with `ProjectionExpr::new(expr, alias)` or `ProjectionExpr { expr, alias }` -- Replace tuple field access `.0` and `.1` with `.expr` and `.alias` -- Update pattern matching from `(expr, alias)` to `ProjectionExpr { expr, alias }` - -This mainly impacts use of `ProjectionExec`. - -This change was done in [#17398] - -[#17398]: https://github.com/apache/datafusion/pull/17398 - -### `SessionState`, `SessionConfig`, and `OptimizerConfig` returns `&Arc` instead of `&ConfigOptions` - -To provide broader access to `ConfigOptions` and reduce required clones, some -APIs have been changed to return a `&Arc` instead of a -`&ConfigOptions`. This allows sharing the same `ConfigOptions` across multiple -threads without needing to clone the entire `ConfigOptions` structure unless it -is modified. - -Most users will not be impacted by this change since the Rust compiler typically -automatically dereference the `Arc` when needed. However, in some cases you may -have to change your code to explicitly call `as_ref()` for example, from - -```rust -# /* comment to avoid running -let optimizer_config: &ConfigOptions = state.options(); -# */ -``` - -To - -```rust -# /* comment to avoid running -let optimizer_config: &ConfigOptions = state.options().as_ref(); -# */ -``` - -See PR [#16970](https://github.com/apache/datafusion/pull/16970) - -### API Change to `AsyncScalarUDFImpl::invoke_async_with_args` - -The `invoke_async_with_args` method of the `AsyncScalarUDFImpl` trait has been -updated to remove the `_option: &ConfigOptions` parameter to simplify the API -now that the `ConfigOptions` can be accessed through the `ScalarFunctionArgs` -parameter. - -You can change your code like this - -```rust -# /* comment to avoid running -impl AsyncScalarUDFImpl for AskLLM { - async fn invoke_async_with_args( - &self, - args: ScalarFunctionArgs, - _option: &ConfigOptions, - ) -> Result { - .. - } - ... -} -# */ -``` - -To this: - -```rust -# /* comment to avoid running - -impl AsyncScalarUDFImpl for AskLLM { - async fn invoke_async_with_args( - &self, - args: ScalarFunctionArgs, - ) -> Result { - let options = &args.config_options; - .. - } - ... -} -# */ -``` - -### Schema Rewriter Module Moved to New Crate - -The `schema_rewriter` module and its associated symbols have been moved from `datafusion_physical_expr` to a new crate `datafusion_physical_expr_adapter`. This affects the following symbols: - -- `DefaultPhysicalExprAdapter` -- `DefaultPhysicalExprAdapterFactory` -- `PhysicalExprAdapter` -- `PhysicalExprAdapterFactory` - -To upgrade, change your imports to: - -```rust -use datafusion_physical_expr_adapter::{ - DefaultPhysicalExprAdapter, DefaultPhysicalExprAdapterFactory, - PhysicalExprAdapter, PhysicalExprAdapterFactory -}; -``` - -### Upgrade to arrow `56.0.0` and parquet `56.0.0` - -This version of DataFusion upgrades the underlying Apache Arrow implementation -to version `56.0.0`. See the [release notes](https://github.com/apache/arrow-rs/releases/tag/56.0.0) -for more details. - -### Added `ExecutionPlan::reset_state` - -In order to fix a bug in DataFusion `49.0.0` where dynamic filters (currently only generated in the presence of a query such as `ORDER BY ... LIMIT ...`) -produced incorrect results in recursive queries, a new method `reset_state` has been added to the `ExecutionPlan` trait. - -Any `ExecutionPlan` that needs to maintain internal state or references to other nodes in the execution plan tree should implement this method to reset that state. -See [#17028] for more details and an example implementation for `SortExec`. - -[#17028]: https://github.com/apache/datafusion/pull/17028 - -### Nested Loop Join input sort order cannot be preserved - -The Nested Loop Join operator has been rewritten from scratch to improve performance and memory efficiency. From the micro-benchmarks: this change introduces up to 5X speed-up and uses only 1% memory in extreme cases compared to the previous implementation. - -However, the new implementation cannot preserve input sort order like the old version could. This is a fundamental design trade-off that prioritizes performance and memory efficiency over sort order preservation. - -See [#16996] for details. - -[#16996]: https://github.com/apache/datafusion/pull/16996 - -### Add `as_any()` method to `LazyBatchGenerator` - -To help with protobuf serialization, the `as_any()` method has been added to the `LazyBatchGenerator` trait. This means you will need to add `as_any()` to your implementation of `LazyBatchGenerator`: - -```rust -# /* comment to avoid running - -impl LazyBatchGenerator for MyBatchGenerator { - fn as_any(&self) -> &dyn Any { - self - } - - ... -} - -# */ -``` - -See [#17200](https://github.com/apache/datafusion/pull/17200) for details. - -### Refactored `DataSource::try_swapping_with_projection` - -We refactored `DataSource::try_swapping_with_projection` to simplify the method and minimize leakage across the ExecutionPlan <-> DataSource abstraction layer. -Reimplementation for any custom `DataSource` should be relatively straightforward, see [#17395] for more details. - -[#17395]: https://github.com/apache/datafusion/pull/17395/ - -### `FileOpenFuture` now uses `DataFusionError` instead of `ArrowError` - -The `FileOpenFuture` type alias has been updated to use `DataFusionError` instead of `ArrowError` for its error type. This change affects the `FileOpener` trait and any implementations that work with file streaming operations. - -**Before:** - -```rust,ignore -pub type FileOpenFuture = BoxFuture<'static, Result>>>; -``` - -**After:** - -```rust,ignore -pub type FileOpenFuture = BoxFuture<'static, Result>>>; -``` - -If you have custom implementations of `FileOpener` or work directly with `FileOpenFuture`, you'll need to update your error handling to use `DataFusionError` instead of `ArrowError`. The `FileStreamState` enum's `Open` variant has also been updated accordingly. See [#17397] for more details. - -[#17397]: https://github.com/apache/datafusion/pull/17397 - -### FFI user defined aggregate function signature change - -The Foreign Function Interface (FFI) signature for user defined aggregate functions -has been updated to call `return_field` instead of `return_type` on the underlying -aggregate function. This is to support metadata handling with these aggregate functions. -This change should be transparent to most users. If you have written unit tests to call -`return_type` directly, you may need to change them to calling `return_field` instead. - -This update is a breaking change to the FFI API. The current best practice when using the -FFI crate is to ensure that all libraries that are interacting are using the same -underlying Rust version. Issue [#17374] has been opened to discuss stabilization of -this interface so that these libraries can be used across different DataFusion versions. - -See [#17407] for details. - -[#17407]: https://github.com/apache/datafusion/pull/17407 -[#17374]: https://github.com/apache/datafusion/issues/17374 - -### Added `PhysicalExpr::is_volatile_node` - -We added a method to `PhysicalExpr` to mark a `PhysicalExpr` as volatile: - -```rust,ignore -impl PhysicalExpr for MyRandomExpr { - fn is_volatile_node(&self) -> bool { - true - } -} -``` - -We've shipped this with a default value of `false` to minimize breakage but we highly recommend that implementers of `PhysicalExpr` opt into a behavior, even if it is returning `false`. - -You can see more discussion and example implementations in [#17351]. - -[#17351]: https://github.com/apache/datafusion/pull/17351 - -## DataFusion `49.0.0` - -### `MSRV` updated to 1.85.1 - -The Minimum Supported Rust Version (MSRV) has been updated to [`1.85.1`]. See -[#16728] for details. - -[`1.85.1`]: https://releases.rs/docs/1.85.1/ -[#16728]: https://github.com/apache/datafusion/pull/16728 - -### `DataFusionError` variants are now `Box`ed - -To reduce the size of `DataFusionError`, several variants that were previously stored inline are now `Box`ed. This reduces the size of `Result` and thus stack usage and async state machine size. Please see [#16652] for more details. - -The following variants of `DataFusionError` are now boxed: - -- `ArrowError` -- `SQL` -- `SchemaError` - -This is a breaking change. Code that constructs or matches on these variants will need to be updated. - -For example, to create a `SchemaError`, instead of: - -```rust -# /* comment to avoid running -use datafusion_common::{DataFusionError, SchemaError}; -DataFusionError::SchemaError( - SchemaError::DuplicateUnqualifiedField { name: "foo".to_string() }, - Box::new(None) -) -# */ -``` - -You now need to `Box` the inner error: - -```rust -# /* comment to avoid running -use datafusion_common::{DataFusionError, SchemaError}; -DataFusionError::SchemaError( - Box::new(SchemaError::DuplicateUnqualifiedField { name: "foo".to_string() }), - Box::new(None) -) -# */ -``` - -[#16652]: https://github.com/apache/datafusion/issues/16652 - -### Metadata on Arrow Types is now represented by `FieldMetadata` - -Metadata from the Arrow `Field` is now stored using the `FieldMetadata` -structure. In prior versions it was stored as both a `HashMap` -and a `BTreeMap`. `FieldMetadata` is a easier to work with and -is more efficient. - -To create `FieldMetadata` from a `Field`: - -```rust -# /* comment to avoid running - let metadata = FieldMetadata::from(&field); -# */ -``` - -To add metadata to a `Field`, use the `add_to_field` method: - -```rust -# /* comment to avoid running -let updated_field = metadata.add_to_field(field); -# */ -``` - -See [#16317] for details. - -[#16317]: https://github.com/apache/datafusion/pull/16317 - -### New `datafusion.execution.spill_compression` configuration option - -DataFusion 49.0.0 adds support for compressing spill files when data is written to disk during spilling query execution. A new configuration option `datafusion.execution.spill_compression` controls the compression codec used. - -**Configuration:** - -- **Key**: `datafusion.execution.spill_compression` -- **Default**: `uncompressed` -- **Valid values**: `uncompressed`, `lz4_frame`, `zstd` - -**Usage:** - -```rust -# /* comment to avoid running -use datafusion::prelude::*; -use datafusion_common::config::SpillCompression; - -let config = SessionConfig::default() - .with_spill_compression(SpillCompression::Zstd); -let ctx = SessionContext::new_with_config(config); -# */ -``` - -Or via SQL: - -```sql -SET datafusion.execution.spill_compression = 'zstd'; -``` - -For more details about this configuration option, including performance trade-offs between different compression codecs, see the [Configuration Settings](../user-guide/configs.md) documentation. - -### Deprecated `map_varchar_to_utf8view` configuration option - -See [issue #16290](https://github.com/apache/datafusion/pull/16290) for more information -The old configuration - -```text -datafusion.sql_parser.map_varchar_to_utf8view -``` - -is now **deprecated** in favor of the unified option below.\ -If you previously used this to control only `VARCHAR`→`Utf8View` mapping, please migrate to `map_string_types_to_utf8view`. - ---- - -### New `map_string_types_to_utf8view` configuration option - -To unify **all** SQL string types (`CHAR`, `VARCHAR`, `TEXT`, `STRING`) to Arrow’s zero‑copy `Utf8View`, DataFusion 49.0.0 introduces: - -- **Key**: `datafusion.sql_parser.map_string_types_to_utf8view` -- **Default**: `true` - -**Description:** - -- When **true** (default), **all** SQL string types are mapped to `Utf8View`, avoiding full‑copy UTF‑8 allocations and improving performance. -- When **false**, DataFusion falls back to the legacy `Utf8` mapping for **all** string types. - -#### Examples - -```rust -# /* comment to avoid running -// Disable Utf8View mapping for all SQL string types -let opts = datafusion::sql::planner::ParserOptions::new() - .with_map_string_types_to_utf8view(false); - -// Verify the setting is applied -assert!(!opts.map_string_types_to_utf8view); -# */ -``` - ---- - -```sql --- Disable Utf8View mapping globally -SET datafusion.sql_parser.map_string_types_to_utf8view = false; - --- Now VARCHAR, CHAR, TEXT, STRING all use Utf8 rather than Utf8View -CREATE TABLE my_table (a VARCHAR, b TEXT, c STRING); -DESCRIBE my_table; -``` - -### Deprecating `SchemaAdapterFactory` and `SchemaAdapter` - -We are moving away from converting data (using `SchemaAdapter`) to converting the expressions themselves (which is more efficient and flexible). - -See [issue #16800](https://github.com/apache/datafusion/issues/16800) for more information -The first place this change has taken place is in predicate pushdown for Parquet. -By default if you do not use a custom `SchemaAdapterFactory` we will use expression conversion instead. -If you do set a custom `SchemaAdapterFactory` we will continue to use it but emit a warning about that code path being deprecated. - -To resolve this you need to implement a custom `PhysicalExprAdapterFactory` and use that instead of a `SchemaAdapterFactory`. -See the [default values](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/custom_data_source/default_column_values.rs) for an example of how to do this. -Opting into the new APIs will set you up for future changes since we plan to expand use of `PhysicalExprAdapterFactory` to other areas of DataFusion. - -See [#16800] for details. - -[#16800]: https://github.com/apache/datafusion/issues/16800 - -### `TableParquetOptions` Updated - -The `TableParquetOptions` struct has a new `crypto` field to specify encryption -options for Parquet files. The `ParquetEncryptionOptions` implements `Default` -so you can upgrade your existing code like this: - -```rust -# /* comment to avoid running -TableParquetOptions { - global, - column_specific_options, - key_value_metadata, -} -# */ -``` - -To this: - -```rust -# /* comment to avoid running -TableParquetOptions { - global, - column_specific_options, - key_value_metadata, - crypto: Default::default(), // New crypto field -} -# */ -``` - -## DataFusion `48.0.1` - -### `datafusion.execution.collect_statistics` now defaults to `true` - -The default value of the `datafusion.execution.collect_statistics` configuration -setting is now true. This change impacts users that use that value directly and relied -on its default value being `false`. - -This change also restores the default behavior of `ListingTable` to its previous. If you use it directly -you can maintain the current behavior by overriding the default value in your code. - -```rust -# /* comment to avoid running -ListingOptions::new(Arc::new(ParquetFormat::default())) - .with_collect_stat(false) - // other options -# */ -``` - -## DataFusion `48.0.0` - -### `Expr::Literal` has optional metadata - -The [`Expr::Literal`] variant now includes optional metadata, which allows for -carrying through Arrow field metadata to support extension types and other uses. - -This means code such as - -```rust -# /* comment to avoid running -match expr { -... - Expr::Literal(scalar) => ... -... -} -# */ -``` - -Should be updated to: - -```rust -# /* comment to avoid running -match expr { -... - Expr::Literal(scalar, _metadata) => ... -... -} -# */ -``` - -Likewise constructing `Expr::Literal` requires metadata as well. The [`lit`] function -has not changed and returns an `Expr::Literal` with no metadata. - -[`expr::literal`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html#variant.Literal -[`lit`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.lit.html - -### `Expr::WindowFunction` is now `Box`ed - -`Expr::WindowFunction` is now a `Box` instead of a `WindowFunction` directly. -This change was made to reduce the size of `Expr` and improve performance when -planning queries (see [details on #16207]). - -This is a breaking change, so you will need to update your code if you match -on `Expr::WindowFunction` directly. For example, if you have code like this: - -```rust -# /* comment to avoid running -match expr { - Expr::WindowFunction(WindowFunction { - params: - WindowFunctionParams { - partition_by, - order_by, - .. - } - }) => { - // Use partition_by and order_by as needed - } - _ => { - // other expr - } -} -# */ -``` - -You will need to change it to: - -```rust -# /* comment to avoid running -match expr { - Expr::WindowFunction(window_fun) => { - let WindowFunction { - fun, - params: WindowFunctionParams { - args, - partition_by, - .. - }, - } = window_fun.as_ref(); - // Use partition_by and order_by as needed - } - _ => { - // other expr - } -} -# */ -``` - -[details on #16207]: https://github.com/apache/datafusion/pull/16207#issuecomment-2922659103 - -### The `VARCHAR` SQL type is now represented as `Utf8View` in Arrow - -The mapping of the SQL `VARCHAR` type has been changed from `Utf8` to `Utf8View` -which improves performance for many string operations. You can read more about -`Utf8View` in the [DataFusion blog post on German-style strings] - -[datafusion blog post on german-style strings]: https://datafusion.apache.org/blog/2024/09/13/string-view-german-style-strings-part-1/ - -This means that when you create a table with a `VARCHAR` column, it will now use -`Utf8View` as the underlying data type. For example: - -```sql -> CREATE TABLE my_table (my_column VARCHAR); -0 row(s) fetched. -Elapsed 0.001 seconds. - -> DESCRIBE my_table; -+-------------+-----------+-------------+ -| column_name | data_type | is_nullable | -+-------------+-----------+-------------+ -| my_column | Utf8View | YES | -+-------------+-----------+-------------+ -1 row(s) fetched. -Elapsed 0.000 seconds. -``` - -You can restore the old behavior of using `Utf8` by changing the -`datafusion.sql_parser.map_varchar_to_utf8view` configuration setting. For -example - -```sql -> set datafusion.sql_parser.map_varchar_to_utf8view = false; -0 row(s) fetched. -Elapsed 0.001 seconds. - -> CREATE TABLE my_table (my_column VARCHAR); -0 row(s) fetched. -Elapsed 0.014 seconds. - -> DESCRIBE my_table; -+-------------+-----------+-------------+ -| column_name | data_type | is_nullable | -+-------------+-----------+-------------+ -| my_column | Utf8 | YES | -+-------------+-----------+-------------+ -1 row(s) fetched. -Elapsed 0.004 seconds. -``` - -### `ListingOptions` default for `collect_stat` changed from `true` to `false` - -This makes it agree with the default for `SessionConfig`. -Most users won't be impacted by this change but if you were using `ListingOptions` directly -and relied on the default value of `collect_stat` being `true`, you will need to -explicitly set it to `true` in your code. - -```rust -# /* comment to avoid running -ListingOptions::new(Arc::new(ParquetFormat::default())) - .with_collect_stat(true) - // other options -# */ -``` - -### Processing `FieldRef` instead of `DataType` for user defined functions - -In order to support metadata handling and extension types, user defined functions are -now switching to traits which use `FieldRef` rather than a `DataType` and nullability. -This gives a single interface to both of these parameters and additionally allows -access to metadata fields, which can be used for extension types. - -To upgrade structs which implement `ScalarUDFImpl`, if you have implemented -`return_type_from_args` you need instead to implement `return_field_from_args`. -If your functions do not need to handle metadata, this should be straightforward -repackaging of the output data into a `FieldRef`. The name you specify on the -field is not important. It will be overwritten during planning. `ReturnInfo` -has been removed, so you will need to remove all references to it. - -`ScalarFunctionArgs` now contains a field called `arg_fields`. You can use this -to access the metadata associated with the columnar values during invocation. - -To upgrade user defined aggregate functions, there is now a function -`return_field` that will allow you to specify both metadata and nullability of -your function. You are not required to implement this if you do not need to -handle metadata. - -The largest change to aggregate functions happens in the accumulator arguments. -Both the `AccumulatorArgs` and `StateFieldsArgs` now contain `FieldRef` rather -than `DataType`. - -To upgrade window functions, `ExpressionArgs` now contains input fields instead -of input data types. When setting these fields, the name of the field is -not important since this gets overwritten during the planning stage. All you -should need to do is wrap your existing data types in fields with nullability -set depending on your use case. - -### Physical Expression return `Field` - -To support the changes to user defined functions processing metadata, the -`PhysicalExpr` trait, which now must specify a return `Field` based on the input -schema. To upgrade structs which implement `PhysicalExpr` you need to implement -the `return_field` function. There are numerous examples in the `physical-expr` -crate. - -### `FileFormat::supports_filters_pushdown` replaced with `FileSource::try_pushdown_filters` - -To support more general filter pushdown, the `FileFormat::supports_filters_pushdown` was replaced with -`FileSource::try_pushdown_filters`. -If you implemented a custom `FileFormat` that uses a custom `FileSource` you will need to implement -`FileSource::try_pushdown_filters`. -See `ParquetSource::try_pushdown_filters` for an example of how to implement this. - -`FileFormat::supports_filters_pushdown` has been removed. - -### `ParquetExec`, `AvroExec`, `CsvExec`, `JsonExec` Removed - -`ParquetExec`, `AvroExec`, `CsvExec`, and `JsonExec` were deprecated in -DataFusion 46 and are removed in DataFusion 48. This is sooner than the normal -process described in the [API Deprecation Guidelines] because all the tests -cover the new `DataSourceExec` rather than the older structures. As we evolve -`DataSource`, the old structures began to show signs of "bit rotting" (not -working but no one knows due to lack of test coverage). - -[api deprecation guidelines]: https://datafusion.apache.org/contributor-guide/api-health.html#deprecation-guidelines - -### `PartitionedFile` added as an argument to the `FileOpener` trait - -This is necessary to properly fix filter pushdown for filters that combine partition -columns and file columns (e.g. `day = username['dob']`). - -If you implemented a custom `FileOpener` you will need to add the `PartitionedFile` argument -but are not required to use it in any way. - -## DataFusion `47.0.0` - -This section calls out some of the major changes in the `47.0.0` release of DataFusion. - -Here are some example upgrade PRs that demonstrate changes required when upgrading from DataFusion 46.0.0: - -- [delta-rs Upgrade to `47.0.0`](https://github.com/delta-io/delta-rs/pull/3378) -- [DataFusion Comet Upgrade to `47.0.0`](https://github.com/apache/datafusion-comet/pull/1563) -- [Sail Upgrade to `47.0.0`](https://github.com/lakehq/sail/pull/434) - -### Upgrades to `arrow-rs` and `arrow-parquet` 55.0.0 and `object_store` 0.12.0 - -Several APIs are changed in the underlying arrow and parquet libraries to use a -`u64` instead of `usize` to better support WASM (See [#7371] and [#6961]) - -Additionally `ObjectStore::list` and `ObjectStore::list_with_offset` have been changed to return `static` lifetimes (See [#6619]) - -[#6619]: https://github.com/apache/arrow-rs/pull/6619 -[#7371]: https://github.com/apache/arrow-rs/pull/7371 - -This requires converting from `usize` to `u64` occasionally as well as changes to `ObjectStore` implementations such as - -```rust -# /* comment to avoid running -impl Objectstore { - ... - // The range is now a u64 instead of usize - async fn get_range(&self, location: &Path, range: Range) -> ObjectStoreResult { - self.inner.get_range(location, range).await - } - ... - // the lifetime is now 'static instead of `_ (meaning the captured closure can't contain references) - // (this also applies to list_with_offset) - fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, ObjectStoreResult> { - self.inner.list(prefix) - } -} -# */ -``` - -The `ParquetObjectReader` has been updated to no longer require the object size -(it can be fetched using a single suffix request). See [#7334] for details - -[#7334]: https://github.com/apache/arrow-rs/pull/7334 - -Pattern in DataFusion `46.0.0`: - -```rust -# /* comment to avoid running -let meta: ObjectMeta = ...; -let reader = ParquetObjectReader::new(store, meta); -# */ -``` - -Pattern in DataFusion `47.0.0`: - -```rust -# /* comment to avoid running -let meta: ObjectMeta = ...; -let reader = ParquetObjectReader::new(store, location) - .with_file_size(meta.size); -# */ -``` - -### `DisplayFormatType::TreeRender` - -DataFusion now supports [`tree` style explain plans]. Implementations of -`Executionplan` must also provide a description in the -`DisplayFormatType::TreeRender` format. This can be the same as the existing -`DisplayFormatType::Default`. - -[`tree` style explain plans]: https://datafusion.apache.org/user-guide/sql/explain.html#tree-format-default - -### Removed Deprecated APIs - -Several APIs have been removed in this release. These were either deprecated -previously or were hard to use correctly such as the multiple different -`ScalarUDFImpl::invoke*` APIs. See [#15130], [#15123], and [#15027] for more -details. - -[#15130]: https://github.com/apache/datafusion/pull/15130 -[#15123]: https://github.com/apache/datafusion/pull/15123 -[#15027]: https://github.com/apache/datafusion/pull/15027 - -### `FileScanConfig` --> `FileScanConfigBuilder` - -Previously, `FileScanConfig::build()` directly created ExecutionPlans. In -DataFusion 47.0.0 this has been changed to use `FileScanConfigBuilder`. See -[#15352] for details. - -[#15352]: https://github.com/apache/datafusion/pull/15352 - -Pattern in DataFusion `46.0.0`: - -```rust -# /* comment to avoid running -let plan = FileScanConfig::new(url, schema, Arc::new(file_source)) - .with_statistics(stats) - ... - .build() -# */ -``` - -Pattern in DataFusion `47.0.0`: - -```rust -# /* comment to avoid running -let config = FileScanConfigBuilder::new(url, Arc::new(file_source)) - .with_statistics(stats) - ... - .build(); -let scan = DataSourceExec::from_data_source(config); -# */ -``` - -## DataFusion `46.0.0` - -### Use `invoke_with_args` instead of `invoke()` and `invoke_batch()` - -DataFusion is moving to a consistent API for invoking ScalarUDFs, -[`ScalarUDFImpl::invoke_with_args()`], and deprecating -[`ScalarUDFImpl::invoke()`], [`ScalarUDFImpl::invoke_batch()`], and [`ScalarUDFImpl::invoke_no_args()`] - -If you see errors such as the following it means the older APIs are being used: - -```text -This feature is not implemented: Function concat does not implement invoke but called -``` - -To fix this error, use [`ScalarUDFImpl::invoke_with_args()`] instead, as shown -below. See [PR 14876] for an example. - -Given existing code like this: - -```rust -# /* comment to avoid running -impl ScalarUDFImpl for SparkConcat { -... - fn invoke_batch(&self, args: &[ColumnarValue], number_rows: usize) -> Result { - if args - .iter() - .any(|arg| matches!(arg.data_type(), DataType::List(_))) - { - ArrayConcat::new().invoke_batch(args, number_rows) - } else { - ConcatFunc::new().invoke_batch(args, number_rows) - } - } -} -# */ -``` - -To - -```rust -# /* comment to avoid running -impl ScalarUDFImpl for SparkConcat { - ... - fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - if args - .args - .iter() - .any(|arg| matches!(arg.data_type(), DataType::List(_))) - { - ArrayConcat::new().invoke_with_args(args) - } else { - ConcatFunc::new().invoke_with_args(args) - } - } -} - # */ -``` - -[`scalarudfimpl::invoke()`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/trait.ScalarUDFImpl.html#method.invoke -[`scalarudfimpl::invoke_batch()`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/trait.ScalarUDFImpl.html#method.invoke_batch -[`scalarudfimpl::invoke_no_args()`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/trait.ScalarUDFImpl.html#method.invoke_no_args -[`scalarudfimpl::invoke_with_args()`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/trait.ScalarUDFImpl.html#method.invoke_with_args -[pr 14876]: https://github.com/apache/datafusion/pull/14876 - -### `ParquetExec`, `AvroExec`, `CsvExec`, `JsonExec` deprecated - -DataFusion 46 has a major change to how the built in DataSources are organized. -Instead of individual `ExecutionPlan`s for the different file formats they now -all use `DataSourceExec` and the format specific information is embodied in new -traits `DataSource` and `FileSource`. - -Here is more information about - -- [Design Ticket] -- Change PR [PR #14224] -- Example of an Upgrade [PR in delta-rs] - -[design ticket]: https://github.com/apache/datafusion/issues/13838 -[pr #14224]: https://github.com/apache/datafusion/pull/14224 -[pr in delta-rs]: https://github.com/delta-io/delta-rs/pull/3261 - -### Cookbook: Changes to `ParquetExecBuilder` - -Code that looks for `ParquetExec` like this will no longer work: - -```rust -# /* comment to avoid running - if let Some(parquet_exec) = plan.as_any().downcast_ref::() { - // Do something with ParquetExec here - } -# */ -``` - -Instead, with `DataSourceExec`, the same information is now on `FileScanConfig` and -`ParquetSource`. The equivalent code is - -```rust -# /* comment to avoid running -if let Some(datasource_exec) = plan.as_any().downcast_ref::() { - if let Some(scan_config) = datasource_exec.data_source().as_any().downcast_ref::() { - // FileGroups, and other information is on the FileScanConfig - // parquet - if let Some(parquet_source) = scan_config.file_source.as_any().downcast_ref::() - { - // Information on PruningPredicates and parquet options are here - } -} -# */ -``` - -### Cookbook: Changes to `ParquetExecBuilder` - -Likewise code that builds `ParquetExec` using the `ParquetExecBuilder` such as -the following must be changed: - -```rust -# /* comment to avoid running -let mut exec_plan_builder = ParquetExecBuilder::new( - FileScanConfig::new(self.log_store.object_store_url(), file_schema) - .with_projection(self.projection.cloned()) - .with_limit(self.limit) - .with_table_partition_cols(table_partition_cols), -) -.with_schema_adapter_factory(Arc::new(DeltaSchemaAdapterFactory {})) -.with_table_parquet_options(parquet_options); - -// Add filter -if let Some(predicate) = logical_filter { - if config.enable_parquet_pushdown { - exec_plan_builder = exec_plan_builder.with_predicate(predicate); - } -}; -# */ -``` - -New code should use `FileScanConfig` to build the appropriate `DataSourceExec`: - -```rust -# /* comment to avoid running -let mut file_source = ParquetSource::new(parquet_options) - .with_schema_adapter_factory(Arc::new(DeltaSchemaAdapterFactory {})); - -// Add filter -if let Some(predicate) = logical_filter { - if config.enable_parquet_pushdown { - file_source = file_source.with_predicate(predicate); - } -}; - -let file_scan_config = FileScanConfig::new( - self.log_store.object_store_url(), - file_schema, - Arc::new(file_source), -) -.with_statistics(stats) -.with_projection(self.projection.cloned()) -.with_limit(self.limit) -.with_table_partition_cols(table_partition_cols); - -// Build the actual scan like this -parquet_scan: file_scan_config.build(), -# */ -``` - -### `datafusion-cli` no longer automatically unescapes strings - -`datafusion-cli` previously would incorrectly unescape string literals (see [ticket] for more details). - -To escape `'` in SQL literals, use `''`: - -```sql -> select 'it''s escaped'; -+----------------------+ -| Utf8("it's escaped") | -+----------------------+ -| it's escaped | -+----------------------+ -1 row(s) fetched. -``` - -To include special characters (such as newlines via `\n`) you can use an `E` literal string. For example - -```sql -> select 'foo\nbar'; -+------------------+ -| Utf8("foo\nbar") | -+------------------+ -| foo\nbar | -+------------------+ -1 row(s) fetched. -Elapsed 0.005 seconds. -``` - -### Changes to array scalar function signatures - -DataFusion 46 has changed the way scalar array function signatures are -declared. Previously, functions needed to select from a list of predefined -signatures within the `ArrayFunctionSignature` enum. Now the signatures -can be defined via a `Vec` of pseudo-types, which each correspond to a -single argument. Those pseudo-types are the variants of the -`ArrayFunctionArgument` enum and are as follows: - -- `Array`: An argument of type List/LargeList/FixedSizeList. All Array - arguments must be coercible to the same type. -- `Element`: An argument that is coercible to the inner type of the `Array` - arguments. -- `Index`: An `Int64` argument. - -Each of the old variants can be converted to the new format as follows: - -`TypeSignature::ArraySignature(ArrayFunctionSignature::ArrayAndElement)`: - -```rust -# use datafusion::common::utils::ListCoercion; -# use datafusion_expr_common::signature::{ArrayFunctionArgument, ArrayFunctionSignature, TypeSignature}; - -TypeSignature::ArraySignature(ArrayFunctionSignature::Array { - arguments: vec![ArrayFunctionArgument::Array, ArrayFunctionArgument::Element], - array_coercion: Some(ListCoercion::FixedSizedListToList), -}); -``` - -`TypeSignature::ArraySignature(ArrayFunctionSignature::ElementAndArray)`: - -```rust -# use datafusion::common::utils::ListCoercion; -# use datafusion_expr_common::signature::{ArrayFunctionArgument, ArrayFunctionSignature, TypeSignature}; - -TypeSignature::ArraySignature(ArrayFunctionSignature::Array { - arguments: vec![ArrayFunctionArgument::Element, ArrayFunctionArgument::Array], - array_coercion: Some(ListCoercion::FixedSizedListToList), -}); -``` - -`TypeSignature::ArraySignature(ArrayFunctionSignature::ArrayAndIndex)`: - -```rust -# use datafusion::common::utils::ListCoercion; -# use datafusion_expr_common::signature::{ArrayFunctionArgument, ArrayFunctionSignature, TypeSignature}; - -TypeSignature::ArraySignature(ArrayFunctionSignature::Array { - arguments: vec![ArrayFunctionArgument::Array, ArrayFunctionArgument::Index], - array_coercion: None, -}); -``` - -`TypeSignature::ArraySignature(ArrayFunctionSignature::ArrayAndElementAndOptionalIndex)`: - -```rust -# use datafusion::common::utils::ListCoercion; -# use datafusion_expr_common::signature::{ArrayFunctionArgument, ArrayFunctionSignature, TypeSignature}; - -TypeSignature::OneOf(vec![ - TypeSignature::ArraySignature(ArrayFunctionSignature::Array { - arguments: vec![ArrayFunctionArgument::Array, ArrayFunctionArgument::Element], - array_coercion: None, - }), - TypeSignature::ArraySignature(ArrayFunctionSignature::Array { - arguments: vec![ - ArrayFunctionArgument::Array, - ArrayFunctionArgument::Element, - ArrayFunctionArgument::Index, - ], - array_coercion: None, - }), -]); -``` - -`TypeSignature::ArraySignature(ArrayFunctionSignature::Array)`: - -```rust -# use datafusion::common::utils::ListCoercion; -# use datafusion_expr_common::signature::{ArrayFunctionArgument, ArrayFunctionSignature, TypeSignature}; - -TypeSignature::ArraySignature(ArrayFunctionSignature::Array { - arguments: vec![ArrayFunctionArgument::Array], - array_coercion: None, -}); -``` - -Alternatively, you can switch to using one of the following functions which -take care of constructing the `TypeSignature` for you: - -- `Signature::array_and_element` -- `Signature::array_and_element_and_optional_index` -- `Signature::array_and_index` -- `Signature::array` - -[ticket]: https://github.com/apache/datafusion/issues/13286 diff --git a/docs/source/library-user-guide/upgrading/46.0.0.md b/docs/source/library-user-guide/upgrading/46.0.0.md new file mode 100644 index 000000000000..e38d18c3d660 --- /dev/null +++ b/docs/source/library-user-guide/upgrading/46.0.0.md @@ -0,0 +1,310 @@ + + +# Upgrade Guides + +## DataFusion 46.0.0 + +### Use `invoke_with_args` instead of `invoke()` and `invoke_batch()` + +DataFusion is moving to a consistent API for invoking ScalarUDFs, +[`ScalarUDFImpl::invoke_with_args()`], and deprecating +[`ScalarUDFImpl::invoke()`], [`ScalarUDFImpl::invoke_batch()`], and [`ScalarUDFImpl::invoke_no_args()`] + +If you see errors such as the following it means the older APIs are being used: + +```text +This feature is not implemented: Function concat does not implement invoke but called +``` + +To fix this error, use [`ScalarUDFImpl::invoke_with_args()`] instead, as shown +below. See [PR 14876] for an example. + +Given existing code like this: + +```rust +# /* comment to avoid running +impl ScalarUDFImpl for SparkConcat { +... + fn invoke_batch(&self, args: &[ColumnarValue], number_rows: usize) -> Result { + if args + .iter() + .any(|arg| matches!(arg.data_type(), DataType::List(_))) + { + ArrayConcat::new().invoke_batch(args, number_rows) + } else { + ConcatFunc::new().invoke_batch(args, number_rows) + } + } +} +# */ +``` + +To + +```rust +# /* comment to avoid running +impl ScalarUDFImpl for SparkConcat { + ... + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + if args + .args + .iter() + .any(|arg| matches!(arg.data_type(), DataType::List(_))) + { + ArrayConcat::new().invoke_with_args(args) + } else { + ConcatFunc::new().invoke_with_args(args) + } + } +} + # */ +``` + +[`scalarudfimpl::invoke()`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/trait.ScalarUDFImpl.html#method.invoke +[`scalarudfimpl::invoke_batch()`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/trait.ScalarUDFImpl.html#method.invoke_batch +[`scalarudfimpl::invoke_no_args()`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/trait.ScalarUDFImpl.html#method.invoke_no_args +[`scalarudfimpl::invoke_with_args()`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/trait.ScalarUDFImpl.html#method.invoke_with_args +[pr 14876]: https://github.com/apache/datafusion/pull/14876 + +### `ParquetExec`, `AvroExec`, `CsvExec`, `JsonExec` deprecated + +DataFusion 46 has a major change to how the built in DataSources are organized. +Instead of individual `ExecutionPlan`s for the different file formats they now +all use `DataSourceExec` and the format specific information is embodied in new +traits `DataSource` and `FileSource`. + +Here is more information about + +- [Design Ticket] +- Change PR [PR #14224] +- Example of an Upgrade [PR in delta-rs] + +[design ticket]: https://github.com/apache/datafusion/issues/13838 +[pr #14224]: https://github.com/apache/datafusion/pull/14224 +[pr in delta-rs]: https://github.com/delta-io/delta-rs/pull/3261 + +### Cookbook: Changes to `ParquetExecBuilder` + +Code that looks for `ParquetExec` like this will no longer work: + +```rust +# /* comment to avoid running + if let Some(parquet_exec) = plan.as_any().downcast_ref::() { + // Do something with ParquetExec here + } +# */ +``` + +Instead, with `DataSourceExec`, the same information is now on `FileScanConfig` and +`ParquetSource`. The equivalent code is + +```rust +# /* comment to avoid running +if let Some(datasource_exec) = plan.as_any().downcast_ref::() { + if let Some(scan_config) = datasource_exec.data_source().as_any().downcast_ref::() { + // FileGroups, and other information is on the FileScanConfig + // parquet + if let Some(parquet_source) = scan_config.file_source.as_any().downcast_ref::() + { + // Information on PruningPredicates and parquet options are here + } +} +# */ +``` + +### Cookbook: Changes to `ParquetExecBuilder` + +Likewise code that builds `ParquetExec` using the `ParquetExecBuilder` such as +the following must be changed: + +```rust +# /* comment to avoid running +let mut exec_plan_builder = ParquetExecBuilder::new( + FileScanConfig::new(self.log_store.object_store_url(), file_schema) + .with_projection(self.projection.cloned()) + .with_limit(self.limit) + .with_table_partition_cols(table_partition_cols), +) +.with_schema_adapter_factory(Arc::new(DeltaSchemaAdapterFactory {})) +.with_table_parquet_options(parquet_options); + +// Add filter +if let Some(predicate) = logical_filter { + if config.enable_parquet_pushdown { + exec_plan_builder = exec_plan_builder.with_predicate(predicate); + } +}; +# */ +``` + +New code should use `FileScanConfig` to build the appropriate `DataSourceExec`: + +```rust +# /* comment to avoid running +let mut file_source = ParquetSource::new(parquet_options) + .with_schema_adapter_factory(Arc::new(DeltaSchemaAdapterFactory {})); + +// Add filter +if let Some(predicate) = logical_filter { + if config.enable_parquet_pushdown { + file_source = file_source.with_predicate(predicate); + } +}; + +let file_scan_config = FileScanConfig::new( + self.log_store.object_store_url(), + file_schema, + Arc::new(file_source), +) +.with_statistics(stats) +.with_projection(self.projection.cloned()) +.with_limit(self.limit) +.with_table_partition_cols(table_partition_cols); + +// Build the actual scan like this +parquet_scan: file_scan_config.build(), +# */ +``` + +### `datafusion-cli` no longer automatically unescapes strings + +`datafusion-cli` previously would incorrectly unescape string literals (see [ticket] for more details). + +To escape `'` in SQL literals, use `''`: + +```sql +> select 'it''s escaped'; ++----------------------+ +| Utf8("it's escaped") | ++----------------------+ +| it's escaped | ++----------------------+ +1 row(s) fetched. +``` + +To include special characters (such as newlines via `\n`) you can use an `E` literal string. For example + +```sql +> select 'foo\nbar'; ++------------------+ +| Utf8("foo\nbar") | ++------------------+ +| foo\nbar | ++------------------+ +1 row(s) fetched. +Elapsed 0.005 seconds. +``` + +### Changes to array scalar function signatures + +DataFusion 46 has changed the way scalar array function signatures are +declared. Previously, functions needed to select from a list of predefined +signatures within the `ArrayFunctionSignature` enum. Now the signatures +can be defined via a `Vec` of pseudo-types, which each correspond to a +single argument. Those pseudo-types are the variants of the +`ArrayFunctionArgument` enum and are as follows: + +- `Array`: An argument of type List/LargeList/FixedSizeList. All Array + arguments must be coercible to the same type. +- `Element`: An argument that is coercible to the inner type of the `Array` + arguments. +- `Index`: An `Int64` argument. + +Each of the old variants can be converted to the new format as follows: + +`TypeSignature::ArraySignature(ArrayFunctionSignature::ArrayAndElement)`: + +```rust +# use datafusion::common::utils::ListCoercion; +# use datafusion_expr_common::signature::{ArrayFunctionArgument, ArrayFunctionSignature, TypeSignature}; + +TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array, ArrayFunctionArgument::Element], + array_coercion: Some(ListCoercion::FixedSizedListToList), +}); +``` + +`TypeSignature::ArraySignature(ArrayFunctionSignature::ElementAndArray)`: + +```rust +# use datafusion::common::utils::ListCoercion; +# use datafusion_expr_common::signature::{ArrayFunctionArgument, ArrayFunctionSignature, TypeSignature}; + +TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Element, ArrayFunctionArgument::Array], + array_coercion: Some(ListCoercion::FixedSizedListToList), +}); +``` + +`TypeSignature::ArraySignature(ArrayFunctionSignature::ArrayAndIndex)`: + +```rust +# use datafusion::common::utils::ListCoercion; +# use datafusion_expr_common::signature::{ArrayFunctionArgument, ArrayFunctionSignature, TypeSignature}; + +TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array, ArrayFunctionArgument::Index], + array_coercion: None, +}); +``` + +`TypeSignature::ArraySignature(ArrayFunctionSignature::ArrayAndElementAndOptionalIndex)`: + +```rust +# use datafusion::common::utils::ListCoercion; +# use datafusion_expr_common::signature::{ArrayFunctionArgument, ArrayFunctionSignature, TypeSignature}; + +TypeSignature::OneOf(vec![ + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array, ArrayFunctionArgument::Element], + array_coercion: None, + }), + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Element, + ArrayFunctionArgument::Index, + ], + array_coercion: None, + }), +]); +``` + +`TypeSignature::ArraySignature(ArrayFunctionSignature::Array)`: + +```rust +# use datafusion::common::utils::ListCoercion; +# use datafusion_expr_common::signature::{ArrayFunctionArgument, ArrayFunctionSignature, TypeSignature}; + +TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array], + array_coercion: None, +}); +``` + +Alternatively, you can switch to using one of the following functions which +take care of constructing the `TypeSignature` for you: + +- `Signature::array_and_element` +- `Signature::array_and_element_and_optional_index` +- `Signature::array_and_index` +- `Signature::array` + +[ticket]: https://github.com/apache/datafusion/issues/13286 diff --git a/docs/source/library-user-guide/upgrading/47.0.0.md b/docs/source/library-user-guide/upgrading/47.0.0.md new file mode 100644 index 000000000000..354b6740df02 --- /dev/null +++ b/docs/source/library-user-guide/upgrading/47.0.0.md @@ -0,0 +1,135 @@ + + +# Upgrade Guides + +## DataFusion 47.0.0 + +This section calls out some of the major changes in the `47.0.0` release of DataFusion. + +Here are some example upgrade PRs that demonstrate changes required when upgrading from DataFusion 46.0.0: + +- [delta-rs Upgrade to `47.0.0`](https://github.com/delta-io/delta-rs/pull/3378) +- [DataFusion Comet Upgrade to `47.0.0`](https://github.com/apache/datafusion-comet/pull/1563) +- [Sail Upgrade to `47.0.0`](https://github.com/lakehq/sail/pull/434) + +### Upgrades to `arrow-rs` and `arrow-parquet` 55.0.0 and `object_store` 0.12.0 + +Several APIs are changed in the underlying arrow and parquet libraries to use a +`u64` instead of `usize` to better support WASM (See [#7371] and [#6961]) + +Additionally `ObjectStore::list` and `ObjectStore::list_with_offset` have been changed to return `static` lifetimes (See [#6619]) + +[#6619]: https://github.com/apache/arrow-rs/pull/6619 +[#7371]: https://github.com/apache/arrow-rs/pull/7371 + +This requires converting from `usize` to `u64` occasionally as well as changes to `ObjectStore` implementations such as + +```rust +# /* comment to avoid running +impl Objectstore { + ... + // The range is now a u64 instead of usize + async fn get_range(&self, location: &Path, range: Range) -> ObjectStoreResult { + self.inner.get_range(location, range).await + } + ... + // the lifetime is now 'static instead of `_ (meaning the captured closure can't contain references) + // (this also applies to list_with_offset) + fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, ObjectStoreResult> { + self.inner.list(prefix) + } +} +# */ +``` + +The `ParquetObjectReader` has been updated to no longer require the object size +(it can be fetched using a single suffix request). See [#7334] for details + +[#7334]: https://github.com/apache/arrow-rs/pull/7334 + +Pattern in DataFusion `46.0.0`: + +```rust +# /* comment to avoid running +let meta: ObjectMeta = ...; +let reader = ParquetObjectReader::new(store, meta); +# */ +``` + +Pattern in DataFusion `47.0.0`: + +```rust +# /* comment to avoid running +let meta: ObjectMeta = ...; +let reader = ParquetObjectReader::new(store, location) + .with_file_size(meta.size); +# */ +``` + +### `DisplayFormatType::TreeRender` + +DataFusion now supports [`tree` style explain plans]. Implementations of +`Executionplan` must also provide a description in the +`DisplayFormatType::TreeRender` format. This can be the same as the existing +`DisplayFormatType::Default`. + +[`tree` style explain plans]: https://datafusion.apache.org/user-guide/sql/explain.html#tree-format-default + +### Removed Deprecated APIs + +Several APIs have been removed in this release. These were either deprecated +previously or were hard to use correctly such as the multiple different +`ScalarUDFImpl::invoke*` APIs. See [#15130], [#15123], and [#15027] for more +details. + +[#15130]: https://github.com/apache/datafusion/pull/15130 +[#15123]: https://github.com/apache/datafusion/pull/15123 +[#15027]: https://github.com/apache/datafusion/pull/15027 + +### `FileScanConfig` --> `FileScanConfigBuilder` + +Previously, `FileScanConfig::build()` directly created ExecutionPlans. In +DataFusion 47.0.0 this has been changed to use `FileScanConfigBuilder`. See +[#15352] for details. + +[#15352]: https://github.com/apache/datafusion/pull/15352 + +Pattern in DataFusion `46.0.0`: + +```rust +# /* comment to avoid running +let plan = FileScanConfig::new(url, schema, Arc::new(file_source)) + .with_statistics(stats) + ... + .build() +# */ +``` + +Pattern in DataFusion `47.0.0`: + +```rust +# /* comment to avoid running +let config = FileScanConfigBuilder::new(url, Arc::new(file_source)) + .with_statistics(stats) + ... + .build(); +let scan = DataSourceExec::from_data_source(config); +# */ +``` diff --git a/docs/source/library-user-guide/upgrading/48.0.0.md b/docs/source/library-user-guide/upgrading/48.0.0.md new file mode 100644 index 000000000000..7872a6f54f24 --- /dev/null +++ b/docs/source/library-user-guide/upgrading/48.0.0.md @@ -0,0 +1,244 @@ + + +# Upgrade Guides + +## DataFusion 48.0.0 + +### `Expr::Literal` has optional metadata + +The [`Expr::Literal`] variant now includes optional metadata, which allows for +carrying through Arrow field metadata to support extension types and other uses. + +This means code such as + +```rust +# /* comment to avoid running +match expr { +... + Expr::Literal(scalar) => ... +... +} +# */ +``` + +Should be updated to: + +```rust +# /* comment to avoid running +match expr { +... + Expr::Literal(scalar, _metadata) => ... +... +} +# */ +``` + +Likewise constructing `Expr::Literal` requires metadata as well. The [`lit`] function +has not changed and returns an `Expr::Literal` with no metadata. + +[`expr::literal`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html#variant.Literal +[`lit`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.lit.html + +### `Expr::WindowFunction` is now `Box`ed + +`Expr::WindowFunction` is now a `Box` instead of a `WindowFunction` directly. +This change was made to reduce the size of `Expr` and improve performance when +planning queries (see [details on #16207]). + +This is a breaking change, so you will need to update your code if you match +on `Expr::WindowFunction` directly. For example, if you have code like this: + +```rust +# /* comment to avoid running +match expr { + Expr::WindowFunction(WindowFunction { + params: + WindowFunctionParams { + partition_by, + order_by, + .. + } + }) => { + // Use partition_by and order_by as needed + } + _ => { + // other expr + } +} +# */ +``` + +You will need to change it to: + +```rust +# /* comment to avoid running +match expr { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: WindowFunctionParams { + args, + partition_by, + .. + }, + } = window_fun.as_ref(); + // Use partition_by and order_by as needed + } + _ => { + // other expr + } +} +# */ +``` + +[details on #16207]: https://github.com/apache/datafusion/pull/16207#issuecomment-2922659103 + +### The `VARCHAR` SQL type is now represented as `Utf8View` in Arrow + +The mapping of the SQL `VARCHAR` type has been changed from `Utf8` to `Utf8View` +which improves performance for many string operations. You can read more about +`Utf8View` in the [DataFusion blog post on German-style strings] + +[datafusion blog post on german-style strings]: https://datafusion.apache.org/blog/2024/09/13/string-view-german-style-strings-part-1/ + +This means that when you create a table with a `VARCHAR` column, it will now use +`Utf8View` as the underlying data type. For example: + +```sql +> CREATE TABLE my_table (my_column VARCHAR); +0 row(s) fetched. +Elapsed 0.001 seconds. + +> DESCRIBE my_table; ++-------------+-----------+-------------+ +| column_name | data_type | is_nullable | ++-------------+-----------+-------------+ +| my_column | Utf8View | YES | ++-------------+-----------+-------------+ +1 row(s) fetched. +Elapsed 0.000 seconds. +``` + +You can restore the old behavior of using `Utf8` by changing the +`datafusion.sql_parser.map_varchar_to_utf8view` configuration setting. For +example + +```sql +> set datafusion.sql_parser.map_varchar_to_utf8view = false; +0 row(s) fetched. +Elapsed 0.001 seconds. + +> CREATE TABLE my_table (my_column VARCHAR); +0 row(s) fetched. +Elapsed 0.014 seconds. + +> DESCRIBE my_table; ++-------------+-----------+-------------+ +| column_name | data_type | is_nullable | ++-------------+-----------+-------------+ +| my_column | Utf8 | YES | ++-------------+-----------+-------------+ +1 row(s) fetched. +Elapsed 0.004 seconds. +``` + +### `ListingOptions` default for `collect_stat` changed from `true` to `false` + +This makes it agree with the default for `SessionConfig`. +Most users won't be impacted by this change but if you were using `ListingOptions` directly +and relied on the default value of `collect_stat` being `true`, you will need to +explicitly set it to `true` in your code. + +```rust +# /* comment to avoid running +ListingOptions::new(Arc::new(ParquetFormat::default())) + .with_collect_stat(true) + // other options +# */ +``` + +### Processing `FieldRef` instead of `DataType` for user defined functions + +In order to support metadata handling and extension types, user defined functions are +now switching to traits which use `FieldRef` rather than a `DataType` and nullability. +This gives a single interface to both of these parameters and additionally allows +access to metadata fields, which can be used for extension types. + +To upgrade structs which implement `ScalarUDFImpl`, if you have implemented +`return_type_from_args` you need instead to implement `return_field_from_args`. +If your functions do not need to handle metadata, this should be straightforward +repackaging of the output data into a `FieldRef`. The name you specify on the +field is not important. It will be overwritten during planning. `ReturnInfo` +has been removed, so you will need to remove all references to it. + +`ScalarFunctionArgs` now contains a field called `arg_fields`. You can use this +to access the metadata associated with the columnar values during invocation. + +To upgrade user defined aggregate functions, there is now a function +`return_field` that will allow you to specify both metadata and nullability of +your function. You are not required to implement this if you do not need to +handle metadata. + +The largest change to aggregate functions happens in the accumulator arguments. +Both the `AccumulatorArgs` and `StateFieldsArgs` now contain `FieldRef` rather +than `DataType`. + +To upgrade window functions, `ExpressionArgs` now contains input fields instead +of input data types. When setting these fields, the name of the field is +not important since this gets overwritten during the planning stage. All you +should need to do is wrap your existing data types in fields with nullability +set depending on your use case. + +### Physical Expression return `Field` + +To support the changes to user defined functions processing metadata, the +`PhysicalExpr` trait, which now must specify a return `Field` based on the input +schema. To upgrade structs which implement `PhysicalExpr` you need to implement +the `return_field` function. There are numerous examples in the `physical-expr` +crate. + +### `FileFormat::supports_filters_pushdown` replaced with `FileSource::try_pushdown_filters` + +To support more general filter pushdown, the `FileFormat::supports_filters_pushdown` was replaced with +`FileSource::try_pushdown_filters`. +If you implemented a custom `FileFormat` that uses a custom `FileSource` you will need to implement +`FileSource::try_pushdown_filters`. +See `ParquetSource::try_pushdown_filters` for an example of how to implement this. + +`FileFormat::supports_filters_pushdown` has been removed. + +### `ParquetExec`, `AvroExec`, `CsvExec`, `JsonExec` Removed + +`ParquetExec`, `AvroExec`, `CsvExec`, and `JsonExec` were deprecated in +DataFusion 46 and are removed in DataFusion 48. This is sooner than the normal +process described in the [API Deprecation Guidelines] because all the tests +cover the new `DataSourceExec` rather than the older structures. As we evolve +`DataSource`, the old structures began to show signs of "bit rotting" (not +working but no one knows due to lack of test coverage). + +[api deprecation guidelines]: https://datafusion.apache.org/contributor-guide/api-health.html#deprecation-guidelines + +### `PartitionedFile` added as an argument to the `FileOpener` trait + +This is necessary to properly fix filter pushdown for filters that combine partition +columns and file columns (e.g. `day = username['dob']`). + +If you implemented a custom `FileOpener` you will need to add the `PartitionedFile` argument +but are not required to use it in any way. diff --git a/docs/source/library-user-guide/upgrading/48.0.1.md b/docs/source/library-user-guide/upgrading/48.0.1.md new file mode 100644 index 000000000000..5dfb9e1e3d0b --- /dev/null +++ b/docs/source/library-user-guide/upgrading/48.0.1.md @@ -0,0 +1,39 @@ + + +# Upgrade Guides + +## DataFusion 48.0.1 + +### `datafusion.execution.collect_statistics` now defaults to `true` + +The default value of the `datafusion.execution.collect_statistics` configuration +setting is now true. This change impacts users that use that value directly and relied +on its default value being `false`. + +This change also restores the default behavior of `ListingTable` to its previous. If you use it directly +you can maintain the current behavior by overriding the default value in your code. + +```rust +# /* comment to avoid running +ListingOptions::new(Arc::new(ParquetFormat::default())) + .with_collect_stat(false) + // other options +# */ +``` diff --git a/docs/source/library-user-guide/upgrading/49.0.0.md b/docs/source/library-user-guide/upgrading/49.0.0.md new file mode 100644 index 000000000000..92dee8135590 --- /dev/null +++ b/docs/source/library-user-guide/upgrading/49.0.0.md @@ -0,0 +1,222 @@ + + +# Upgrade Guides + +## DataFusion 49.0.0 + +### `MSRV` updated to 1.85.1 + +The Minimum Supported Rust Version (MSRV) has been updated to [`1.85.1`]. See +[#16728] for details. + +[`1.85.1`]: https://releases.rs/docs/1.85.1/ +[#16728]: https://github.com/apache/datafusion/pull/16728 + +### `DataFusionError` variants are now `Box`ed + +To reduce the size of `DataFusionError`, several variants that were previously stored inline are now `Box`ed. This reduces the size of `Result` and thus stack usage and async state machine size. Please see [#16652] for more details. + +The following variants of `DataFusionError` are now boxed: + +- `ArrowError` +- `SQL` +- `SchemaError` + +This is a breaking change. Code that constructs or matches on these variants will need to be updated. + +For example, to create a `SchemaError`, instead of: + +```rust +# /* comment to avoid running +use datafusion_common::{DataFusionError, SchemaError}; +DataFusionError::SchemaError( + SchemaError::DuplicateUnqualifiedField { name: "foo".to_string() }, + Box::new(None) +) +# */ +``` + +You now need to `Box` the inner error: + +```rust +# /* comment to avoid running +use datafusion_common::{DataFusionError, SchemaError}; +DataFusionError::SchemaError( + Box::new(SchemaError::DuplicateUnqualifiedField { name: "foo".to_string() }), + Box::new(None) +) +# */ +``` + +[#16652]: https://github.com/apache/datafusion/issues/16652 + +### Metadata on Arrow Types is now represented by `FieldMetadata` + +Metadata from the Arrow `Field` is now stored using the `FieldMetadata` +structure. In prior versions it was stored as both a `HashMap` +and a `BTreeMap`. `FieldMetadata` is a easier to work with and +is more efficient. + +To create `FieldMetadata` from a `Field`: + +```rust +# /* comment to avoid running + let metadata = FieldMetadata::from(&field); +# */ +``` + +To add metadata to a `Field`, use the `add_to_field` method: + +```rust +# /* comment to avoid running +let updated_field = metadata.add_to_field(field); +# */ +``` + +See [#16317] for details. + +[#16317]: https://github.com/apache/datafusion/pull/16317 + +### New `datafusion.execution.spill_compression` configuration option + +DataFusion 49.0.0 adds support for compressing spill files when data is written to disk during spilling query execution. A new configuration option `datafusion.execution.spill_compression` controls the compression codec used. + +**Configuration:** + +- **Key**: `datafusion.execution.spill_compression` +- **Default**: `uncompressed` +- **Valid values**: `uncompressed`, `lz4_frame`, `zstd` + +**Usage:** + +```rust +# /* comment to avoid running +use datafusion::prelude::*; +use datafusion_common::config::SpillCompression; + +let config = SessionConfig::default() + .with_spill_compression(SpillCompression::Zstd); +let ctx = SessionContext::new_with_config(config); +# */ +``` + +Or via SQL: + +```sql +SET datafusion.execution.spill_compression = 'zstd'; +``` + +For more details about this configuration option, including performance trade-offs between different compression codecs, see the [Configuration Settings](../../user-guide/configs) documentation. + +### Deprecated `map_varchar_to_utf8view` configuration option + +See [issue #16290](https://github.com/apache/datafusion/pull/16290) for more information +The old configuration + +```text +datafusion.sql_parser.map_varchar_to_utf8view +``` + +is now **deprecated** in favor of the unified option below.\ +If you previously used this to control only `VARCHAR`→`Utf8View` mapping, please migrate to `map_string_types_to_utf8view`. + +--- + +### New `map_string_types_to_utf8view` configuration option + +To unify **all** SQL string types (`CHAR`, `VARCHAR`, `TEXT`, `STRING`) to Arrow’s zero‑copy `Utf8View`, DataFusion 49.0.0 introduces: + +- **Key**: `datafusion.sql_parser.map_string_types_to_utf8view` +- **Default**: `true` + +**Description:** + +- When **true** (default), **all** SQL string types are mapped to `Utf8View`, avoiding full‑copy UTF‑8 allocations and improving performance. +- When **false**, DataFusion falls back to the legacy `Utf8` mapping for **all** string types. + +#### Examples + +```rust +# /* comment to avoid running +// Disable Utf8View mapping for all SQL string types +let opts = datafusion::sql::planner::ParserOptions::new() + .with_map_string_types_to_utf8view(false); + +// Verify the setting is applied +assert!(!opts.map_string_types_to_utf8view); +# */ +``` + +--- + +```sql +-- Disable Utf8View mapping globally +SET datafusion.sql_parser.map_string_types_to_utf8view = false; + +-- Now VARCHAR, CHAR, TEXT, STRING all use Utf8 rather than Utf8View +CREATE TABLE my_table (a VARCHAR, b TEXT, c STRING); +DESCRIBE my_table; +``` + +### Deprecating `SchemaAdapterFactory` and `SchemaAdapter` + +We are moving away from converting data (using `SchemaAdapter`) to converting the expressions themselves (which is more efficient and flexible). + +See [issue #16800](https://github.com/apache/datafusion/issues/16800) for more information +The first place this change has taken place is in predicate pushdown for Parquet. +By default if you do not use a custom `SchemaAdapterFactory` we will use expression conversion instead. +If you do set a custom `SchemaAdapterFactory` we will continue to use it but emit a warning about that code path being deprecated. + +To resolve this you need to implement a custom `PhysicalExprAdapterFactory` and use that instead of a `SchemaAdapterFactory`. +See the [default values](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/custom_data_source/default_column_values.rs) for an example of how to do this. +Opting into the new APIs will set you up for future changes since we plan to expand use of `PhysicalExprAdapterFactory` to other areas of DataFusion. + +See [#16800] for details. + +[#16800]: https://github.com/apache/datafusion/issues/16800 + +### `TableParquetOptions` Updated + +The `TableParquetOptions` struct has a new `crypto` field to specify encryption +options for Parquet files. The `ParquetEncryptionOptions` implements `Default` +so you can upgrade your existing code like this: + +```rust +# /* comment to avoid running +TableParquetOptions { + global, + column_specific_options, + key_value_metadata, +} +# */ +``` + +To this: + +```rust +# /* comment to avoid running +TableParquetOptions { + global, + column_specific_options, + key_value_metadata, + crypto: Default::default(), // New crypto field +} +# */ +``` diff --git a/docs/source/library-user-guide/upgrading/50.0.0.md b/docs/source/library-user-guide/upgrading/50.0.0.md new file mode 100644 index 000000000000..d8155dab5896 --- /dev/null +++ b/docs/source/library-user-guide/upgrading/50.0.0.md @@ -0,0 +1,330 @@ + + +# Upgrade Guides + +## DataFusion 50.0.0 + +### ListingTable automatically detects Hive Partitioned tables + +DataFusion 50.0.0 automatically infers Hive partitions when using the `ListingTableFactory` and `CREATE EXTERNAL TABLE`. Previously, +when creating a `ListingTable`, datasets that use Hive partitioning (e.g. +`/table_root/column1=value1/column2=value2/data.parquet`) would not have the Hive columns reflected in +the table's schema or data. The previous behavior can be +restored by setting the `datafusion.execution.listing_table_factory_infer_partitions` configuration option to `false`. +See [issue #17049] for more details. + +[issue #17049]: https://github.com/apache/datafusion/issues/17049 + +### `MSRV` updated to 1.86.0 + +The Minimum Supported Rust Version (MSRV) has been updated to [`1.86.0`]. +See [#17230] for details. + +[`1.86.0`]: https://releases.rs/docs/1.86.0/ +[#17230]: https://github.com/apache/datafusion/pull/17230 + +### `ScalarUDFImpl`, `AggregateUDFImpl` and `WindowUDFImpl` traits now require `PartialEq`, `Eq`, and `Hash` traits + +To address error-proneness of `ScalarUDFImpl::equals`, `AggregateUDFImpl::equals`and +`WindowUDFImpl::equals` methods and to make it easy to implement function equality correctly, +the `equals` and `hash_value` methods have been removed from `ScalarUDFImpl`, `AggregateUDFImpl` +and `WindowUDFImpl` traits. They are replaced the requirement to implement the `PartialEq`, `Eq`, +and `Hash` traits on any type implementing `ScalarUDFImpl`, `AggregateUDFImpl` or `WindowUDFImpl`. +Please see [issue #16677] for more details. + +Most of the scalar functions are stateless and have a `signature` field. These can be migrated +using regular expressions + +- search for `\#\[derive\(Debug\)\](\n *(pub )?struct \w+ \{\n *signature\: Signature\,\n *\})`, +- replace with `#[derive(Debug, PartialEq, Eq, Hash)]$1`, +- review all the changes and make sure only function structs were changed. + +[issue #16677]: https://github.com/apache/datafusion/issues/16677 + +### `AsyncScalarUDFImpl::invoke_async_with_args` returns `ColumnarValue` + +In order to enable single value optimizations and be consistent with other +user defined function APIs, the `AsyncScalarUDFImpl::invoke_async_with_args` method now +returns a `ColumnarValue` instead of a `ArrayRef`. + +To upgrade, change the return type of your implementation + +```rust +# /* comment to avoid running +impl AsyncScalarUDFImpl for AskLLM { + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + _option: &ConfigOptions, + ) -> Result { + .. + return array_ref; // old code + } +} +# */ +``` + +To return a `ColumnarValue` + +```rust +# /* comment to avoid running +impl AsyncScalarUDFImpl for AskLLM { + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + _option: &ConfigOptions, + ) -> Result { + .. + return ColumnarValue::from(array_ref); // new code + } +} +# */ +``` + +See [#16896](https://github.com/apache/datafusion/issues/16896) for more details. + +### `ProjectionExpr` changed from type alias to struct + +`ProjectionExpr` has been changed from a type alias to a struct with named fields to improve code clarity and maintainability. + +**Before:** + +```rust,ignore +pub type ProjectionExpr = (Arc, String); +``` + +**After:** + +```rust,ignore +#[derive(Debug, Clone)] +pub struct ProjectionExpr { + pub expr: Arc, + pub alias: String, +} +``` + +To upgrade your code: + +- Replace tuple construction `(expr, alias)` with `ProjectionExpr::new(expr, alias)` or `ProjectionExpr { expr, alias }` +- Replace tuple field access `.0` and `.1` with `.expr` and `.alias` +- Update pattern matching from `(expr, alias)` to `ProjectionExpr { expr, alias }` + +This mainly impacts use of `ProjectionExec`. + +This change was done in [#17398] + +[#17398]: https://github.com/apache/datafusion/pull/17398 + +### `SessionState`, `SessionConfig`, and `OptimizerConfig` returns `&Arc` instead of `&ConfigOptions` + +To provide broader access to `ConfigOptions` and reduce required clones, some +APIs have been changed to return a `&Arc` instead of a +`&ConfigOptions`. This allows sharing the same `ConfigOptions` across multiple +threads without needing to clone the entire `ConfigOptions` structure unless it +is modified. + +Most users will not be impacted by this change since the Rust compiler typically +automatically dereference the `Arc` when needed. However, in some cases you may +have to change your code to explicitly call `as_ref()` for example, from + +```rust +# /* comment to avoid running +let optimizer_config: &ConfigOptions = state.options(); +# */ +``` + +To + +```rust +# /* comment to avoid running +let optimizer_config: &ConfigOptions = state.options().as_ref(); +# */ +``` + +See PR [#16970](https://github.com/apache/datafusion/pull/16970) + +### API Change to `AsyncScalarUDFImpl::invoke_async_with_args` + +The `invoke_async_with_args` method of the `AsyncScalarUDFImpl` trait has been +updated to remove the `_option: &ConfigOptions` parameter to simplify the API +now that the `ConfigOptions` can be accessed through the `ScalarFunctionArgs` +parameter. + +You can change your code like this + +```rust +# /* comment to avoid running +impl AsyncScalarUDFImpl for AskLLM { + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + _option: &ConfigOptions, + ) -> Result { + .. + } + ... +} +# */ +``` + +To this: + +```rust +# /* comment to avoid running + +impl AsyncScalarUDFImpl for AskLLM { + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + ) -> Result { + let options = &args.config_options; + .. + } + ... +} +# */ +``` + +### Schema Rewriter Module Moved to New Crate + +The `schema_rewriter` module and its associated symbols have been moved from `datafusion_physical_expr` to a new crate `datafusion_physical_expr_adapter`. This affects the following symbols: + +- `DefaultPhysicalExprAdapter` +- `DefaultPhysicalExprAdapterFactory` +- `PhysicalExprAdapter` +- `PhysicalExprAdapterFactory` + +To upgrade, change your imports to: + +```rust +use datafusion_physical_expr_adapter::{ + DefaultPhysicalExprAdapter, DefaultPhysicalExprAdapterFactory, + PhysicalExprAdapter, PhysicalExprAdapterFactory +}; +``` + +### Upgrade to arrow `56.0.0` and parquet `56.0.0` + +This version of DataFusion upgrades the underlying Apache Arrow implementation +to version `56.0.0`. See the [release notes](https://github.com/apache/arrow-rs/releases/tag/56.0.0) +for more details. + +### Added `ExecutionPlan::reset_state` + +In order to fix a bug in DataFusion `49.0.0` where dynamic filters (currently only generated in the presence of a query such as `ORDER BY ... LIMIT ...`) +produced incorrect results in recursive queries, a new method `reset_state` has been added to the `ExecutionPlan` trait. + +Any `ExecutionPlan` that needs to maintain internal state or references to other nodes in the execution plan tree should implement this method to reset that state. +See [#17028] for more details and an example implementation for `SortExec`. + +[#17028]: https://github.com/apache/datafusion/pull/17028 + +### Nested Loop Join input sort order cannot be preserved + +The Nested Loop Join operator has been rewritten from scratch to improve performance and memory efficiency. From the micro-benchmarks: this change introduces up to 5X speed-up and uses only 1% memory in extreme cases compared to the previous implementation. + +However, the new implementation cannot preserve input sort order like the old version could. This is a fundamental design trade-off that prioritizes performance and memory efficiency over sort order preservation. + +See [#16996] for details. + +[#16996]: https://github.com/apache/datafusion/pull/16996 + +### Add `as_any()` method to `LazyBatchGenerator` + +To help with protobuf serialization, the `as_any()` method has been added to the `LazyBatchGenerator` trait. This means you will need to add `as_any()` to your implementation of `LazyBatchGenerator`: + +```rust +# /* comment to avoid running + +impl LazyBatchGenerator for MyBatchGenerator { + fn as_any(&self) -> &dyn Any { + self + } + + ... +} + +# */ +``` + +See [#17200](https://github.com/apache/datafusion/pull/17200) for details. + +### Refactored `DataSource::try_swapping_with_projection` + +We refactored `DataSource::try_swapping_with_projection` to simplify the method and minimize leakage across the ExecutionPlan <-> DataSource abstraction layer. +Reimplementation for any custom `DataSource` should be relatively straightforward, see [#17395] for more details. + +[#17395]: https://github.com/apache/datafusion/pull/17395/ + +### `FileOpenFuture` now uses `DataFusionError` instead of `ArrowError` + +The `FileOpenFuture` type alias has been updated to use `DataFusionError` instead of `ArrowError` for its error type. This change affects the `FileOpener` trait and any implementations that work with file streaming operations. + +**Before:** + +```rust,ignore +pub type FileOpenFuture = BoxFuture<'static, Result>>>; +``` + +**After:** + +```rust,ignore +pub type FileOpenFuture = BoxFuture<'static, Result>>>; +``` + +If you have custom implementations of `FileOpener` or work directly with `FileOpenFuture`, you'll need to update your error handling to use `DataFusionError` instead of `ArrowError`. The `FileStreamState` enum's `Open` variant has also been updated accordingly. See [#17397] for more details. + +[#17397]: https://github.com/apache/datafusion/pull/17397 + +### FFI user defined aggregate function signature change + +The Foreign Function Interface (FFI) signature for user defined aggregate functions +has been updated to call `return_field` instead of `return_type` on the underlying +aggregate function. This is to support metadata handling with these aggregate functions. +This change should be transparent to most users. If you have written unit tests to call +`return_type` directly, you may need to change them to calling `return_field` instead. + +This update is a breaking change to the FFI API. The current best practice when using the +FFI crate is to ensure that all libraries that are interacting are using the same +underlying Rust version. Issue [#17374] has been opened to discuss stabilization of +this interface so that these libraries can be used across different DataFusion versions. + +See [#17407] for details. + +[#17407]: https://github.com/apache/datafusion/pull/17407 +[#17374]: https://github.com/apache/datafusion/issues/17374 + +### Added `PhysicalExpr::is_volatile_node` + +We added a method to `PhysicalExpr` to mark a `PhysicalExpr` as volatile: + +```rust,ignore +impl PhysicalExpr for MyRandomExpr { + fn is_volatile_node(&self) -> bool { + true + } +} +``` + +We've shipped this with a default value of `false` to minimize breakage but we highly recommend that implementers of `PhysicalExpr` opt into a behavior, even if it is returning `false`. + +You can see more discussion and example implementations in [#17351]. + +[#17351]: https://github.com/apache/datafusion/pull/17351 diff --git a/docs/source/library-user-guide/upgrading/51.0.0.md b/docs/source/library-user-guide/upgrading/51.0.0.md new file mode 100644 index 000000000000..c3acfe15c493 --- /dev/null +++ b/docs/source/library-user-guide/upgrading/51.0.0.md @@ -0,0 +1,272 @@ + + +# Upgrade Guides + +## DataFusion 51.0.0 + +### `arrow` / `parquet` updated to 57.0.0 + +### Upgrade to arrow `57.0.0` and parquet `57.0.0` + +This version of DataFusion upgrades the underlying Apache Arrow implementation +to version `57.0.0`, including several dependent crates such as `prost`, +`tonic`, `pyo3`, and `substrait`. . See the [release +notes](https://github.com/apache/arrow-rs/releases/tag/57.0.0) for more details. + +### `MSRV` updated to 1.88.0 + +The Minimum Supported Rust Version (MSRV) has been updated to [`1.88.0`]. + +[`1.88.0`]: https://releases.rs/docs/1.88.0/ + +### `FunctionRegistry` exposes two additional methods + +`FunctionRegistry` exposes two additional methods `udafs` and `udwfs` which expose set of registered user defined aggregation and window function names. To upgrade implement methods returning set of registered function names: + +```diff +impl FunctionRegistry for FunctionRegistryImpl { + fn udfs(&self) -> HashSet { + self.scalar_functions.keys().cloned().collect() + } ++ fn udafs(&self) -> HashSet { ++ self.aggregate_functions.keys().cloned().collect() ++ } ++ ++ fn udwfs(&self) -> HashSet { ++ self.window_functions.keys().cloned().collect() ++ } +} +``` + +### `datafusion-proto` use `TaskContext` rather than `SessionContext` in physical plan serde methods + +There have been changes in the public API methods of `datafusion-proto` which handle physical plan serde. + +Methods like `physical_plan_from_bytes`, `parse_physical_expr` and similar, expect `TaskContext` instead of `SessionContext` + +```diff +- let plan2 = physical_plan_from_bytes(&bytes, &ctx)?; ++ let plan2 = physical_plan_from_bytes(&bytes, &ctx.task_ctx())?; +``` + +as `TaskContext` contains `RuntimeEnv` methods such as `try_into_physical_plan` will not have explicit `RuntimeEnv` parameter. + +```diff +let result_exec_plan: Arc = proto +- .try_into_physical_plan(&ctx, runtime.deref(), &composed_codec) ++. .try_into_physical_plan(&ctx.task_ctx(), &composed_codec) +``` + +`PhysicalExtensionCodec::try_decode()` expects `TaskContext` instead of `FunctionRegistry`: + +```diff +pub trait PhysicalExtensionCodec { + fn try_decode( + &self, + buf: &[u8], + inputs: &[Arc], +- registry: &dyn FunctionRegistry, ++ ctx: &TaskContext, + ) -> Result>; +``` + +See [issue #17601] for more details. + +[issue #17601]: https://github.com/apache/datafusion/issues/17601 + +### `SessionState`'s `sql_to_statement` method takes `Dialect` rather than a `str` + +The `dialect` parameter of `sql_to_statement` method defined in `datafusion::execution::session_state::SessionState` +has changed from `&str` to `&Dialect`. +`Dialect` is an enum defined in the `datafusion-common` +crate under the `config` module that provides type safety +and better validation for SQL dialect selection + +### Reorganization of `ListingTable` into `datafusion-catalog-listing` crate + +There has been a long standing request to remove features such as `ListingTable` +from the `datafusion` crate to support faster build times. The structs +`ListingOptions`, `ListingTable`, and `ListingTableConfig` are now available +within the `datafusion-catalog-listing` crate. These are re-exported in +the `datafusion` crate, so this should be a minimal impact to existing users. + +See [issue #14462] and [issue #17713] for more details. + +[issue #14462]: https://github.com/apache/datafusion/issues/14462 +[issue #17713]: https://github.com/apache/datafusion/issues/17713 + +### Reorganization of `ArrowSource` into `datafusion-datasource-arrow` crate + +To support [issue #17713] the `ArrowSource` code has been removed from +the `datafusion` core crate into it's own crate, `datafusion-datasource-arrow`. +This follows the pattern for the AVRO, CSV, JSON, and Parquet data sources. +Users may need to update their paths to account for these changes. + +See [issue #17713] for more details. + +### `FileScanConfig::projection` renamed to `FileScanConfig::projection_exprs` + +The `projection` field in `FileScanConfig` has been renamed to `projection_exprs` and its type has changed from `Option>` to `Option`. This change enables more powerful projection pushdown capabilities by supporting arbitrary physical expressions rather than just column indices. + +**Impact on direct field access:** + +If you directly access the `projection` field: + +```rust,ignore +let config: FileScanConfig = ...; +let projection = config.projection; +``` + +You should update to: + +```rust,ignore +let config: FileScanConfig = ...; +let projection_exprs = config.projection_exprs; +``` + +**Impact on builders:** + +The `FileScanConfigBuilder::with_projection()` method has been deprecated in favor of `with_projection_indices()`: + +```diff +let config = FileScanConfigBuilder::new(url, file_source) +- .with_projection(Some(vec![0, 2, 3])) ++ .with_projection_indices(Some(vec![0, 2, 3])) + .build(); +``` + +Note: `with_projection()` still works but is deprecated and will be removed in a future release. + +**What is `ProjectionExprs`?** + +`ProjectionExprs` is a new type that represents a list of physical expressions for projection. While it can be constructed from column indices (which is what `with_projection_indices` does internally), it also supports arbitrary physical expressions, enabling advanced features like expression evaluation during scanning. + +You can access column indices from `ProjectionExprs` using its methods if needed: + +```rust,ignore +let projection_exprs: ProjectionExprs = ...; +// Get the column indices if the projection only contains simple column references +let indices = projection_exprs.column_indices(); +``` + +### `DESCRIBE query` support + +`DESCRIBE query` was previously an alias for `EXPLAIN query`, which outputs the +_execution plan_ of the query. With this release, `DESCRIBE query` now outputs +the computed _schema_ of the query, consistent with the behavior of `DESCRIBE table_name`. + +### `datafusion.execution.time_zone` default configuration changed + +The default value for `datafusion.execution.time_zone` previously was a string value of `+00:00` (GMT/Zulu time). +This was changed to be an `Option` with a default of `None`. If you want to change the timezone back +to the previous value you can execute the sql: + +```sql +SET +TIMEZONE = '+00:00'; +``` + +This change was made to better support using the default timezone in scalar UDF functions such as +`now`, `current_date`, `current_time`, and `to_timestamp` among others. + +### Introduction of `TableSchema` and changes to `FileSource::with_schema()` method + +A new `TableSchema` struct has been introduced in the `datafusion-datasource` crate to better manage table schemas with partition columns. This struct helps distinguish between: + +- **File schema**: The schema of actual data files on disk +- **Partition columns**: Columns derived from directory structure (e.g., Hive-style partitioning) +- **Table schema**: The complete schema combining both file and partition columns + +As part of this change, the `FileSource::with_schema()` method signature has changed from accepting a `SchemaRef` to accepting a `TableSchema`. + +**Who is affected:** + +- Users who have implemented custom `FileSource` implementations will need to update their code +- Users who only use built-in file sources (Parquet, CSV, JSON, AVRO, Arrow) are not affected + +**Migration guide for custom `FileSource` implementations:** + +```diff + use datafusion_datasource::file::FileSource; +-use arrow::datatypes::SchemaRef; ++use datafusion_datasource::TableSchema; + + impl FileSource for MyCustomSource { +- fn with_schema(&self, schema: SchemaRef) -> Arc { ++ fn with_schema(&self, schema: TableSchema) -> Arc { + Arc::new(Self { +- schema: Some(schema), ++ // Use schema.file_schema() to get the file schema without partition columns ++ schema: Some(Arc::clone(schema.file_schema())), + ..self.clone() + }) + } + } +``` + +For implementations that need access to partition columns: + +```rust,ignore +fn with_schema(&self, schema: TableSchema) -> Arc { + Arc::new(Self { + file_schema: Arc::clone(schema.file_schema()), + partition_cols: schema.table_partition_cols().clone(), + table_schema: Arc::clone(schema.table_schema()), + ..self.clone() + }) +} +``` + +**Note**: Most `FileSource` implementations only need to store the file schema (without partition columns), as shown in the first example. The second pattern of storing all three schema components is typically only needed for advanced use cases where you need access to different schema representations for different operations (e.g., ParquetSource uses the file schema for building pruning predicates but needs the table schema for filter pushdown logic). + +**Using `TableSchema` directly:** + +If you're constructing a `FileScanConfig` or working with table schemas and partition columns, you can now use `TableSchema`: + +```rust +use datafusion_datasource::TableSchema; +use arrow::datatypes::{Schema, Field, DataType}; +use std::sync::Arc; + +// Create a TableSchema with partition columns +let file_schema = Arc::new(Schema::new(vec![ + Field::new("user_id", DataType::Int64, false), + Field::new("amount", DataType::Float64, false), +])); + +let partition_cols = vec![ + Arc::new(Field::new("date", DataType::Utf8, false)), + Arc::new(Field::new("region", DataType::Utf8, false)), +]; + +let table_schema = TableSchema::new(file_schema, partition_cols); + +// Access different schema representations +let file_schema_ref = table_schema.file_schema(); // Schema without partition columns +let full_schema = table_schema.table_schema(); // Complete schema with partition columns +let partition_cols_ref = table_schema.table_partition_cols(); // Just the partition columns +``` + +### `AggregateUDFImpl::is_ordered_set_aggregate` has been renamed to `AggregateUDFImpl::supports_within_group_clause` + +This method has been renamed to better reflect the actual impact it has for aggregate UDF implementations. +The accompanying `AggregateUDF::is_ordered_set_aggregate` has also been renamed to `AggregateUDF::supports_within_group_clause`. +No functionality has been changed with regards to this method; it still refers only to permitting use of `WITHIN GROUP` +SQL syntax for the aggregate function. diff --git a/docs/source/library-user-guide/upgrading/52.0.0.md b/docs/source/library-user-guide/upgrading/52.0.0.md new file mode 100644 index 000000000000..4c659b6118fe --- /dev/null +++ b/docs/source/library-user-guide/upgrading/52.0.0.md @@ -0,0 +1,669 @@ + + +# Upgrade Guides + +## DataFusion 52.0.0 + +### Changes to DFSchema API + +To permit more efficient planning, several methods on `DFSchema` have been +changed to return references to the underlying [`&FieldRef`] rather than +[`&Field`]. This allows planners to more cheaply copy the references via +`Arc::clone` rather than cloning the entire `Field` structure. + +You may need to change code to use `Arc::clone` instead of `.as_ref().clone()` +directly on the `Field`. For example: + +```diff +- let field = df_schema.field("my_column").as_ref().clone(); ++ let field = Arc::clone(df_schema.field("my_column")); +``` + +### ListingTableProvider now caches `LIST` commands + +In prior versions, `ListingTableProvider` would issue `LIST` commands to +the underlying object store each time it needed to list files for a query. +To improve performance, `ListingTableProvider` now caches the results of +`LIST` commands for the lifetime of the `ListingTableProvider` instance or +until a cache entry expires. + +Note that by default the cache has no expiration time, so if files are added or removed +from the underlying object store, the `ListingTableProvider` will not see +those changes until the `ListingTableProvider` instance is dropped and recreated. + +You can configure the maximum cache size and cache entry expiration time via configuration options: + +- `datafusion.runtime.list_files_cache_limit` - Limits the size of the cache in bytes +- `datafusion.runtime.list_files_cache_ttl` - Limits the TTL (time-to-live) of an entry in seconds + +Detailed configuration information can be found in the [DataFusion Runtime +Configuration](https://datafusion.apache.org/user-guide/configs.html#runtime-configuration-settings) user's guide. + +Caching can be disabled by setting the limit to 0: + +```sql +SET datafusion.runtime.list_files_cache_limit TO "0K"; +``` + +Note that the internal API has changed to use a trait `ListFilesCache` instead of a type alias. + +### `newlines_in_values` moved from `FileScanConfig` to `CsvOptions` + +The CSV-specific `newlines_in_values` configuration option has been moved from `FileScanConfig` to `CsvOptions`, as it only applies to CSV file parsing. + +**Who is affected:** + +- Users who set `newlines_in_values` via `FileScanConfigBuilder::with_newlines_in_values()` + +**Migration guide:** + +Set `newlines_in_values` in `CsvOptions` instead of on `FileScanConfigBuilder`: + +**Before:** + +```rust,ignore +let source = Arc::new(CsvSource::new(file_schema.clone())); +let config = FileScanConfigBuilder::new(object_store_url, source) + .with_newlines_in_values(true) + .build(); +``` + +**After:** + +```rust,ignore +let options = CsvOptions { + newlines_in_values: Some(true), + ..Default::default() +}; +let source = Arc::new(CsvSource::new(file_schema.clone()) + .with_csv_options(options)); +let config = FileScanConfigBuilder::new(object_store_url, source) + .build(); +``` + +### Removal of `pyarrow` feature + +The `pyarrow` feature flag has been removed. This feature has been migrated to +the `datafusion-python` repository since version `44.0.0`. + +### Refactoring of `FileSource` constructors and `FileScanConfigBuilder` to accept schemas upfront + +The way schemas are passed to file sources and scan configurations has been significantly refactored. File sources now require the schema (including partition columns) to be provided at construction time, and `FileScanConfigBuilder` no longer takes a separate schema parameter. + +**Who is affected:** + +- Users who create `FileScanConfig` or file sources (`ParquetSource`, `CsvSource`, `JsonSource`, `AvroSource`) directly +- Users who implement custom `FileFormat` implementations + +**Key changes:** + +1. **FileSource constructors now require TableSchema**: All built-in file sources now take the schema in their constructor: + + ```diff + - let source = ParquetSource::default(); + + let source = ParquetSource::new(table_schema); + ``` + +2. **FileScanConfigBuilder no longer takes schema as a parameter**: The schema is now passed via the FileSource: + + ```diff + - FileScanConfigBuilder::new(url, schema, source) + + FileScanConfigBuilder::new(url, source) + ``` + +3. **Partition columns are now part of TableSchema**: The `with_table_partition_cols()` method has been removed from `FileScanConfigBuilder`. Partition columns are now passed as part of the `TableSchema` to the FileSource constructor: + + ```diff + + let table_schema = TableSchema::new( + + file_schema, + + vec![Arc::new(Field::new("date", DataType::Utf8, false))], + + ); + + let source = ParquetSource::new(table_schema); + let config = FileScanConfigBuilder::new(url, source) + - .with_table_partition_cols(vec![Field::new("date", DataType::Utf8, false)]) + .with_file(partitioned_file) + .build(); + ``` + +4. **FileFormat::file_source() now takes TableSchema parameter**: Custom `FileFormat` implementations must be updated: + ```diff + impl FileFormat for MyFileFormat { + - fn file_source(&self) -> Arc { + + fn file_source(&self, table_schema: TableSchema) -> Arc { + - Arc::new(MyFileSource::default()) + + Arc::new(MyFileSource::new(table_schema)) + } + } + ``` + +**Migration examples:** + +For Parquet files: + +```diff +- let source = Arc::new(ParquetSource::default()); +- let config = FileScanConfigBuilder::new(url, schema, source) ++ let table_schema = TableSchema::new(schema, vec![]); ++ let source = Arc::new(ParquetSource::new(table_schema)); ++ let config = FileScanConfigBuilder::new(url, source) + .with_file(partitioned_file) + .build(); +``` + +For CSV files with partition columns: + +```diff +- let source = Arc::new(CsvSource::new(true, b',', b'"')); +- let config = FileScanConfigBuilder::new(url, file_schema, source) +- .with_table_partition_cols(vec![Field::new("year", DataType::Int32, false)]) ++ let options = CsvOptions { ++ has_header: Some(true), ++ delimiter: b',', ++ quote: b'"', ++ ..Default::default() ++ }; ++ let table_schema = TableSchema::new( ++ file_schema, ++ vec![Arc::new(Field::new("year", DataType::Int32, false))], ++ ); ++ let source = Arc::new(CsvSource::new(table_schema).with_csv_options(options)); ++ let config = FileScanConfigBuilder::new(url, source) + .build(); +``` + +### Adaptive filter representation in Parquet filter pushdown + +As of Arrow 57.1.0, DataFusion uses a new adaptive filter strategy when +evaluating pushed down filters for Parquet files. This new strategy improves +performance for certain types of queries where the results of filtering are +more efficiently represented with a bitmask rather than a selection. +See [arrow-rs #5523] for more details. + +This change only applies to the built-in Parquet data source with filter-pushdown enabled ( +which is [not yet the default behavior]). + +You can disable the new behavior by setting the +`datafusion.execution.parquet.force_filter_selections` [configuration setting] to true. + +```sql +> set datafusion.execution.parquet.force_filter_selections = true; +``` + +[arrow-rs #5523]: https://github.com/apache/arrow-rs/issues/5523 +[configuration setting]: https://datafusion.apache.org/user-guide/configs.html +[not yet the default behavior]: https://github.com/apache/datafusion/issues/3463 + +### Statistics handling moved from `FileSource` to `FileScanConfig` + +Statistics are now managed directly by `FileScanConfig` instead of being delegated to `FileSource` implementations. This simplifies the `FileSource` trait and provides more consistent statistics handling across all file formats. + +**Who is affected:** + +- Users who have implemented custom `FileSource` implementations + +**Breaking changes:** + +Two methods have been removed from the `FileSource` trait: + +- `with_statistics(&self, statistics: Statistics) -> Arc` +- `statistics(&self) -> Result` + +**Migration guide:** + +If you have a custom `FileSource` implementation, you need to: + +1. Remove the `with_statistics` method implementation +2. Remove the `statistics` method implementation +3. Remove any internal state that was storing statistics + +**Before:** + +```rust,ignore +#[derive(Clone)] +struct MyCustomSource { + table_schema: TableSchema, + projected_statistics: Option, + // other fields... +} + +impl FileSource for MyCustomSource { + fn with_statistics(&self, statistics: Statistics) -> Arc { + Arc::new(Self { + table_schema: self.table_schema.clone(), + projected_statistics: Some(statistics), + // other fields... + }) + } + + fn statistics(&self) -> Result { + Ok(self.projected_statistics.clone().unwrap_or_else(|| + Statistics::new_unknown(self.table_schema.file_schema()) + )) + } + + // other methods... +} +``` + +**After:** + +```rust,ignore +#[derive(Clone)] +struct MyCustomSource { + table_schema: TableSchema, + // projected_statistics field removed + // other fields... +} + +impl FileSource for MyCustomSource { + // with_statistics method removed + // statistics method removed + + // other methods... +} +``` + +**Accessing statistics:** + +Statistics are now accessed through `FileScanConfig` instead of `FileSource`: + +```diff +- let stats = config.file_source.statistics()?; ++ let stats = config.statistics(); +``` + +Note that `FileScanConfig::statistics()` automatically marks statistics as inexact when filters are present, ensuring correctness when filters are pushed down. + +### Partition column handling moved out of `PhysicalExprAdapter` + +Partition column replacement is now a separate preprocessing step performed before expression rewriting via `PhysicalExprAdapter`. This change provides better separation of concerns and makes the adapter more focused on schema differences rather than partition value substitution. + +**Who is affected:** + +- Users who have custom implementations of `PhysicalExprAdapterFactory` that handle partition columns +- Users who directly use the `FilePruner` API + +**Breaking changes:** + +1. `FilePruner::try_new()` signature changed: the `partition_fields` parameter has been removed since partition column handling is now done separately +2. Partition column replacement must now be done via `replace_columns_with_literals()` before expressions are passed to the adapter + +**Migration guide:** + +If you have code that creates a `FilePruner` with partition fields: + +**Before:** + +```rust,ignore +use datafusion_pruning::FilePruner; + +let pruner = FilePruner::try_new( + predicate, + file_schema, + partition_fields, // This parameter is removed + file_stats, +)?; +``` + +**After:** + +```rust,ignore +use datafusion_pruning::FilePruner; + +// Partition fields are no longer needed +let pruner = FilePruner::try_new( + predicate, + file_schema, + file_stats, +)?; +``` + +If you have custom code that relies on `PhysicalExprAdapter` to handle partition columns, you must now call `replace_columns_with_literals()` separately: + +**Before:** + +```rust,ignore +// Adapter handled partition column replacement internally +let adapted_expr = adapter.rewrite(expr)?; +``` + +**After:** + +```rust,ignore +use datafusion_physical_expr_adapter::replace_columns_with_literals; + +// Replace partition columns first +let expr_with_literals = replace_columns_with_literals(expr, &partition_values)?; +// Then apply the adapter +let adapted_expr = adapter.rewrite(expr_with_literals)?; +``` + +### `build_row_filter` signature simplified + +The `build_row_filter` function in `datafusion-datasource-parquet` has been simplified to take a single schema parameter instead of two. +The expectation is now that the filter has been adapted to the physical file schema (the arrow representation of the parquet file's schema) before being passed to this function +using a `PhysicalExprAdapter` for example. + +**Who is affected:** + +- Users who call `build_row_filter` directly + +**Breaking changes:** + +The function signature changed from: + +```rust,ignore +pub fn build_row_filter( + expr: &Arc, + physical_file_schema: &SchemaRef, + predicate_file_schema: &SchemaRef, // removed + metadata: &ParquetMetaData, + reorder_predicates: bool, + file_metrics: &ParquetFileMetrics, +) -> Result> +``` + +To: + +```rust,ignore +pub fn build_row_filter( + expr: &Arc, + file_schema: &SchemaRef, + metadata: &ParquetMetaData, + reorder_predicates: bool, + file_metrics: &ParquetFileMetrics, +) -> Result> +``` + +**Migration guide:** + +Remove the duplicate schema parameter from your call: + +```diff +- build_row_filter(&predicate, &file_schema, &file_schema, metadata, reorder, metrics) ++ build_row_filter(&predicate, &file_schema, metadata, reorder, metrics) +``` + +### Planner now requires explicit opt-in for WITHIN GROUP syntax + +The SQL planner now enforces the aggregate UDF contract more strictly: the +`WITHIN GROUP (ORDER BY ...)` syntax is accepted only if the aggregate UDAF +explicitly advertises support by returning `true` from +`AggregateUDFImpl::supports_within_group_clause()`. + +Previously the planner forwarded a `WITHIN GROUP` clause to order-sensitive +aggregates even when they did not implement ordered-set semantics, which could +cause queries such as `SUM(x) WITHIN GROUP (ORDER BY x)` to plan successfully. +This behavior was too permissive and has been changed to match PostgreSQL and +the documented semantics. + +Migration: If your UDAF intentionally implements ordered-set semantics and +wants to accept the `WITHIN GROUP` SQL syntax, update your implementation to +return `true` from `supports_within_group_clause()` and handle the ordering +semantics in your accumulator implementation. If your UDAF is merely +order-sensitive (but not an ordered-set aggregate), do not advertise +`supports_within_group_clause()` and clients should use alternative function +signatures (for example, explicit ordering as a function argument) instead. + +### `AggregateUDFImpl::supports_null_handling_clause` now defaults to `false` + +This method specifies whether an aggregate function allows `IGNORE NULLS`/`RESPECT NULLS` +during SQL parsing, with the implication it respects these configs during computation. + +Most DataFusion aggregate functions silently ignored this syntax in prior versions +as they did not make use of it and it was permitted by default. We change this so +only the few functions which do respect this clause (e.g. `array_agg`, `first_value`, +`last_value`) need to implement it. + +Custom user defined aggregate functions will also error if this syntax is used, +unless they explicitly declare support by overriding the method. + +For example, SQL parsing will now fail for queries such as this: + +```sql +SELECT median(c1) IGNORE NULLS FROM table +``` + +Instead of silently succeeding. + +### API change for `CacheAccessor` trait + +The remove API no longer requires a mutable instance + +### FFI crate updates + +Many of the structs in the `datafusion-ffi` crate have been updated to allow easier +conversion to the underlying trait types they represent. This simplifies some code +paths, but also provides an additional improvement in cases where library code goes +through a round trip via the foreign function interface. + +To update your code, suppose you have a `FFI_SchemaProvider` called `ffi_provider` +and you wish to use this as a `SchemaProvider`. In the old approach you would do +something like: + +```rust,ignore + let foreign_provider: ForeignSchemaProvider = ffi_provider.into(); + let foreign_provider = Arc::new(foreign_provider) as Arc; +``` + +This code should now be written as: + +```rust,ignore + let foreign_provider: Arc = ffi_provider.into(); + let foreign_provider = foreign_provider as Arc; +``` + +For the case of user defined functions, the updates are similar but you +may need to change the way you call the creation of the `ScalarUDF`. +Aggregate and window functions follow the same pattern. + +Previously you may write: + +```rust,ignore + let foreign_udf: ForeignScalarUDF = ffi_udf.try_into()?; + let foreign_udf: ScalarUDF = foreign_udf.into(); +``` + +Instead this should now be: + +```rust,ignore + let foreign_udf: Arc = ffi_udf.into(); + let foreign_udf = ScalarUDF::new_from_shared_impl(foreign_udf); +``` + +When creating any of the following structs, we now require the user to +provide a `TaskContextProvider` and optionally a `LogicalExtensionCodec`: + +- `FFI_CatalogListProvider` +- `FFI_CatalogProvider` +- `FFI_SchemaProvider` +- `FFI_TableProvider` +- `FFI_TableFunction` + +Each of these structs has a `new()` and a `new_with_ffi_codec()` method for +instantiation. For example, when you previously would write + +```rust,ignore + let table = Arc::new(MyTableProvider::new()); + let ffi_table = FFI_TableProvider::new(table, None); +``` + +Now you will need to provide a `TaskContextProvider`. The most common +implementation of this trait is `SessionContext`. + +```rust,ignore + let ctx = Arc::new(SessionContext::default()); + let table = Arc::new(MyTableProvider::new()); + let ffi_table = FFI_TableProvider::new(table, None, ctx, None); +``` + +The alternative function to create these structures may be more convenient +if you are doing many of these operations. A `FFI_LogicalExtensionCodec` will +store the `TaskContextProvider` as well. + +```rust,ignore + let codec = Arc::new(DefaultLogicalExtensionCodec {}); + let ctx = Arc::new(SessionContext::default()); + let ffi_codec = FFI_LogicalExtensionCodec::new(codec, None, ctx); + let table = Arc::new(MyTableProvider::new()); + let ffi_table = FFI_TableProvider::new_with_ffi_codec(table, None, ffi_codec); +``` + +Additional information about the usage of the `TaskContextProvider` can be +found in the crate README. + +Additionally, the FFI structure for Scalar UDF's no longer contains a +`return_type` call. This code was not used since the `ForeignScalarUDF` +struct implements the `return_field_from_args` instead. + +### Projection handling moved from FileScanConfig to FileSource + +Projection handling has been moved from `FileScanConfig` into `FileSource` implementations. This enables format-specific projection pushdown (e.g., Parquet can push down struct field access, Vortex can push down computed expressions into un-decoded data). + +**Who is affected:** + +- Users who have implemented custom `FileSource` implementations +- Users who use `FileScanConfigBuilder::with_projection_indices` directly + +**Breaking changes:** + +1. **`FileSource::with_projection` replaced with `try_pushdown_projection`:** + + The `with_projection(&self, config: &FileScanConfig) -> Arc` method has been removed and replaced with `try_pushdown_projection(&self, projection: &ProjectionExprs) -> Result>>`. + +2. **`FileScanConfig.projection_exprs` field removed:** + + Projections are now stored in the `FileSource` directly, not in `FileScanConfig`. + Various public helper methods that access projection information have been removed from `FileScanConfig`. + +3. **`FileScanConfigBuilder::with_projection_indices` now returns `Result`:** + + This method can now fail if the projection pushdown fails. + +4. **`FileSource::create_file_opener` now returns `Result>`:** + + Previously returned `Arc` directly. + Any `FileSource` implementation that may fail to create a `FileOpener` should now return an appropriate error. + +5. **`DataSource::try_swapping_with_projection` signature changed:** + + Parameter changed from `&[ProjectionExpr]` to `&ProjectionExprs`. + +**Migration guide:** + +If you have a custom `FileSource` implementation: + +**Before:** + +```rust,ignore +impl FileSource for MyCustomSource { + fn with_projection(&self, config: &FileScanConfig) -> Arc { + // Apply projection from config + Arc::new(Self { /* ... */ }) + } + + fn create_file_opener( + &self, + object_store: Arc, + base_config: &FileScanConfig, + partition: usize, + ) -> Arc { + Arc::new(MyOpener { /* ... */ }) + } +} +``` + +**After:** + +```rust,ignore +impl FileSource for MyCustomSource { + fn try_pushdown_projection( + &self, + projection: &ProjectionExprs, + ) -> Result>> { + // Return None if projection cannot be pushed down + // Return Some(new_source) with projection applied if it can + Ok(Some(Arc::new(Self { + projection: Some(projection.clone()), + /* ... */ + }))) + } + + fn projection(&self) -> Option<&ProjectionExprs> { + self.projection.as_ref() + } + + fn create_file_opener( + &self, + object_store: Arc, + base_config: &FileScanConfig, + partition: usize, + ) -> Result> { + Ok(Arc::new(MyOpener { /* ... */ })) + } +} +``` + +We recommend you look at [#18627](https://github.com/apache/datafusion/pull/18627) +that introduced these changes for more examples for how this was handled for the various built in file sources. + +We have added [`SplitProjection`](https://docs.rs/datafusion-datasource/latest/datafusion_datasource/projection/struct.SplitProjection.html) and [`ProjectionOpener`](https://docs.rs/datafusion-datasource/latest/datafusion_datasource/projection/struct.ProjectionOpener.html) helpers to make it easier to handle projections in your `FileSource` implementations. + +For file sources that can only handle simple column selections (not computed expressions), use the `SplitProjection` and `ProjectionOpener` helpers to split the projection into pushdownable and non-pushdownable parts: + +```rust,ignore +use datafusion_datasource::projection::{SplitProjection, ProjectionOpener}; + +// In try_pushdown_projection: +let split = SplitProjection::new(projection, self.table_schema())?; +// Use split.file_projection() for what to push down to the file format +// The ProjectionOpener wrapper will handle the rest +``` + +**For `FileScanConfigBuilder` users:** + +```diff +let config = FileScanConfigBuilder::new(url, source) +- .with_projection_indices(Some(vec![0, 2, 3])) ++ .with_projection_indices(Some(vec![0, 2, 3]))? + .build(); +``` + +### `SchemaAdapter` and `SchemaAdapterFactory` completely removed + +Following the deprecation announced in [DataFusion 49.0.0](49.0.0.md#deprecating-schemaadapterfactory-and-schemaadapter), `SchemaAdapterFactory` has been fully removed from Parquet scanning. This applies to both: + +The following symbols have been deprecated and will be removed in the next release: + +- `SchemaAdapter` trait +- `SchemaAdapterFactory` trait +- `SchemaMapper` trait +- `SchemaMapping` struct +- `DefaultSchemaAdapterFactory` struct + +These types were previously used to adapt record batch schemas during file reading. +This functionality has been replaced by `PhysicalExprAdapterFactory`, which rewrites expressions at planning time rather than transforming batches at runtime. +If you were using a custom `SchemaAdapterFactory` for schema adaptation (e.g., default column values, type coercion), you should now implement `PhysicalExprAdapterFactory` instead. +See the [default column values example](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/custom_data_source/default_column_values.rs) for how to implement a custom `PhysicalExprAdapterFactory`. + +**Migration guide:** + +If you implemented a custom `SchemaAdapterFactory`, migrate to `PhysicalExprAdapterFactory`. +See the [default column values example](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/custom_data_source/default_column_values.rs) for a complete implementation. diff --git a/docs/source/library-user-guide/upgrading/53.0.0.md b/docs/source/library-user-guide/upgrading/53.0.0.md new file mode 100644 index 000000000000..ef5f5743f5ea --- /dev/null +++ b/docs/source/library-user-guide/upgrading/53.0.0.md @@ -0,0 +1,474 @@ + + +# Upgrade Guides + +## DataFusion 53.0.0 + +**Note:** DataFusion `53.0.0` has not been released yet. The information provided +*in this section pertains to features and changes that have already been merged +*to the main branch and are awaiting release in this version. See [#19692] for +\*more details. + +[#19692]: https://github.com/apache/datafusion/issues/19692 + +### Upgrade arrow/parquet to 58.0.0 and object_store to 0.13.0 + +DataFusion 53.0.0 uses `arrow` and `parquet` 58.0.0, and `object_store` 0.13.0. +This may require updates to your Cargo.toml if you have direct dependencies on +these crates. + +See the [Arrow 58.0.0 release notes] and the [object_store 0.13.0 upgrade guide] for details on breaking changes in those versions. + +[arrow 58.0.0 release notes]: https://github.com/apache/arrow-rs/releases/tag/58.0.0 +[object_store 0.13.0 upgrade guide]: https://github.com/apache/arrow-rs/releases/tag/58.0.0 + +### `ExecutionPlan::properties` now returns `&Arc` + +Now `ExecutionPlan::properties()` returns `&Arc` instead of a +reference. This make it possible to cheaply clone properties and reuse them across multiple +`ExecutionPlans`. It also makes it possible to optimize [`ExecutionPlan::with_new_children`] +to reuse properties when the children plans have not changed, which can significantly reduce +planning time for complex queries. + +[`ExecutionPlan::with_new_children`](https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html#tymethod.with_new_children) + +To migrate, in all `ExecutionPlan` implementations, you will likely need to wrap +stored `PlanProperties` in an `Arc`: + +```diff +- cache: PlanProperties, ++ cache: Arc, + +... + +- fn properties(&self) -> &PlanProperties { ++ fn properties(&self) -> &Arc { + &self.cache + } +``` + +To improve performance of `with_new_children` for custom `ExecutionPlan` +implementations, you can use the new macro: `check_if_same_properties`. For it +to work, you need to implement the function: +`with_new_children_and_same_properties` with semantics identical to +`with_new_children`, but operating under the assumption that the properties of +the children plans have not changed. + +An example of supporting this optimization for `ProjectionExec`: + +```diff + impl ProjectionExec { ++ fn with_new_children_and_same_properties( ++ &self, ++ mut children: Vec>, ++ ) -> Self { ++ Self { ++ input: children.swap_remove(0), ++ metrics: ExecutionPlanMetricsSet::new(), ++ ..Self::clone(self) ++ } ++ } + } + + impl ExecutionPlan for ProjectionExec { + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { ++ check_if_same_properties!(self, children); + ProjectionExec::try_new( + self.projector.projection().into_iter().cloned(), + children.swap_remove(0), + ) + .map(|p| Arc::new(p) as _) + } + } +``` + +### `PlannerContext` outer query schema API now uses a stack + +`PlannerContext` no longer stores a single `outer_query_schema`. It now tracks a +stack of outer relation schemas so nested subqueries can access non-adjacent +outer relations. + +**Before:** + +```rust,ignore +let old_outer_query_schema = + planner_context.set_outer_query_schema(Some(input_schema.clone().into())); +let sub_plan = self.query_to_plan(subquery, planner_context)?; +planner_context.set_outer_query_schema(old_outer_query_schema); +``` + +**After:** + +```rust,ignore +planner_context.append_outer_query_schema(input_schema.clone().into()); +let sub_plan = self.query_to_plan(subquery, planner_context)?; +planner_context.pop_outer_query_schema(); +``` + +### `FileSinkConfig` adds `file_output_mode` + +`FileSinkConfig` now includes a `file_output_mode: FileOutputMode` field to control +single-file vs directory output behavior. Any code constructing `FileSinkConfig` via struct +literals must initialize this field. + +The `FileOutputMode` enum has three variants: + +- `Automatic` (default): Infer output mode from the URL (extension/trailing `/` heuristic) +- `SingleFile`: Write to a single file at the exact output path +- `Directory`: Write to a directory with generated filenames + +**Before:** + +```rust,ignore +FileSinkConfig { + // ... + file_extension: "parquet".into(), +} +``` + +**After:** + +```rust,ignore +use datafusion_datasource::file_sink_config::FileOutputMode; + +FileSinkConfig { + // ... + file_extension: "parquet".into(), + file_output_mode: FileOutputMode::Automatic, +} +``` + +### `SimplifyInfo` trait removed, `SimplifyContext` now uses builder-style API + +The `SimplifyInfo` trait has been removed and replaced with the concrete `SimplifyContext` struct. This simplifies the expression simplification API and removes the need for trait objects. + +**Who is affected:** + +- Users who implemented custom `SimplifyInfo` implementations +- Users who implemented `ScalarUDFImpl::simplify()` for custom scalar functions +- Users who directly use `SimplifyContext` or `ExprSimplifier` + +**Breaking changes:** + +1. The `SimplifyInfo` trait has been removed entirely +2. `SimplifyContext` no longer takes `&ExecutionProps` - it now uses a builder-style API with direct fields +3. `ScalarUDFImpl::simplify()` now takes `&SimplifyContext` instead of `&dyn SimplifyInfo` +4. Time-dependent function simplification (e.g., `now()`) is now optional - if `query_execution_start_time` is `None`, these functions won't be simplified + +**Migration guide:** + +If you implemented a custom `SimplifyInfo`: + +**Before:** + +```rust,ignore +impl SimplifyInfo for MySimplifyInfo { + fn is_boolean_type(&self, expr: &Expr) -> Result { ... } + fn nullable(&self, expr: &Expr) -> Result { ... } + fn execution_props(&self) -> &ExecutionProps { ... } + fn get_data_type(&self, expr: &Expr) -> Result { ... } +} +``` + +**After:** + +Use `SimplifyContext` directly with the builder-style API: + +```rust,ignore +let context = SimplifyContext::default() + .with_schema(schema) + .with_config_options(config_options) + .with_query_execution_start_time(Some(Utc::now())); // or use .with_current_time() +``` + +If you implemented `ScalarUDFImpl::simplify()`: + +**Before:** + +```rust,ignore +fn simplify( + &self, + args: Vec, + info: &dyn SimplifyInfo, +) -> Result { + let now_ts = info.execution_props().query_execution_start_time; + // ... +} +``` + +**After:** + +```rust,ignore +fn simplify( + &self, + args: Vec, + info: &SimplifyContext, +) -> Result { + // query_execution_start_time is now Option> + // Return Original if time is not set (simplification skipped) + let Some(now_ts) = info.query_execution_start_time() else { + return Ok(ExprSimplifyResult::Original(args)); + }; + // ... +} +``` + +If you created `SimplifyContext` from `ExecutionProps`: + +**Before:** + +```rust,ignore +let props = ExecutionProps::new(); +let context = SimplifyContext::new(&props).with_schema(schema); +``` + +**After:** + +```rust,ignore +let context = SimplifyContext::default() + .with_schema(schema) + .with_config_options(config_options) + .with_current_time(); // Sets query_execution_start_time to Utc::now() +``` + +See [`SimplifyContext` documentation](https://docs.rs/datafusion-expr/latest/datafusion_expr/simplify/struct.SimplifyContext.html) for more details. + +### Struct Casting Now Requires Field Name Overlap + +DataFusion's struct casting mechanism previously allowed casting between structs with differing field names if the field counts matched. This "positional fallback" behavior could silently misalign fields and cause data corruption. + +**Breaking Change:** + +Starting with DataFusion 53.0.0, struct casts now require **at least one overlapping field name** between the source and target structs. Casts without field name overlap are rejected at plan time with a clear error message. + +**Who is affected:** + +- Applications that cast between structs with no overlapping field names +- Queries that rely on positional struct field mapping (e.g., casting `struct(x, y)` to `struct(a, b)` based solely on position) +- Code that constructs or transforms struct columns programmatically + +**Migration guide:** + +If you encounter an error like: + +```text +Cannot cast struct with 2 fields to 2 fields because there is no field name overlap +``` + +You must explicitly rename or map fields to ensure at least one field name matches. Here are common patterns: + +**Example 1: Source and target field names already match (Name-based casting)** + +**Success case (field names align):** + +```sql +-- source_col has schema: STRUCT +-- Casting to the same field names succeeds (no-op or type validation only) +SELECT CAST(source_col AS STRUCT) FROM table1; +``` + +**Example 2: Source and target field names differ (Migration scenario)** + +**What fails now (no field name overlap):** + +```sql +-- source_col has schema: STRUCT +-- This FAILS because there is no field name overlap: +-- ❌ SELECT CAST(source_col AS STRUCT) FROM table1; +-- Error: Cannot cast struct with 2 fields to 2 fields because there is no field name overlap +``` + +**Migration options (must align names):** + +**Option A: Use struct constructor for explicit field mapping** + +```sql +-- source_col has schema: STRUCT +-- Use STRUCT_CONSTRUCT with explicit field names +SELECT STRUCT_CONSTRUCT( + 'x', source_col.a, + 'y', source_col.b +) AS renamed_struct FROM table1; +``` + +**Option B: Rename in the cast target to match source names** + +```sql +-- source_col has schema: STRUCT +-- Cast to target with matching field names +SELECT CAST(source_col AS STRUCT) FROM table1; +``` + +**Example 3: Using struct constructors in Rust API** + +If you need to map fields programmatically, build the target struct explicitly: + +```rust,ignore +// Build the target struct with explicit field names +let target_struct_type = DataType::Struct(vec![ + FieldRef::new("x", DataType::Int32), + FieldRef::new("y", DataType::Utf8), +]); + +// Use struct constructors rather than casting for field mapping +// This makes the field mapping explicit and unambiguous +// Use struct builders or row constructors that preserve your mapping logic +``` + +**Why this change:** + +1. **Safety:** Field names are now the primary contract for struct compatibility +2. **Explicitness:** Prevents silent data misalignment caused by positional assumptions +3. **Consistency:** Matches DuckDB's behavior and aligns with other SQL engines that enforce name-based matching +4. **Debuggability:** Errors now appear at plan time rather than as silent data corruption + +See [Issue #19841](https://github.com/apache/datafusion/issues/19841) and [PR #19955](https://github.com/apache/datafusion/pull/19955) for more details. + +### `FilterExec` builder methods deprecated + +The following methods on `FilterExec` have been deprecated in favor of using `FilterExecBuilder`: + +- `with_projection()` +- `with_batch_size()` + +**Who is affected:** + +- Users who create `FilterExec` instances and use these methods to configure them + +**Migration guide:** + +Use `FilterExecBuilder` instead of chaining method calls on `FilterExec`: + +**Before:** + +```rust,ignore +let filter = FilterExec::try_new(predicate, input)? + .with_projection(Some(vec![0, 2]))? + .with_batch_size(8192)?; +``` + +**After:** + +```rust,ignore +let filter = FilterExecBuilder::new(predicate, input) + .with_projection(Some(vec![0, 2])) + .with_batch_size(8192) + .build()?; +``` + +The builder pattern is more efficient as it computes properties once during `build()` rather than recomputing them for each method call. + +Note: `with_default_selectivity()` is not deprecated as it simply updates a field value and does not require the overhead of the builder pattern. + +### Protobuf conversion trait added + +A new trait, `PhysicalProtoConverterExtension`, has been added to the `datafusion-proto` +crate. This is used for controlling the process of conversion of physical plans and +expressions to and from their protobuf equivalents. The methods for conversion now +require an additional parameter. + +The primary APIs for interacting with this crate have not been modified, so most users +should not need to make any changes. If you do require this trait, you can use the +`DefaultPhysicalProtoConverter` implementation. + +For example, to convert a sort expression protobuf node you can make the following +updates: + +**Before:** + +```rust,ignore +let sort_expr = parse_physical_sort_expr( + sort_proto, + ctx, + input_schema, + codec, +); +``` + +**After:** + +```rust,ignore +let converter = DefaultPhysicalProtoConverter {}; +let sort_expr = parse_physical_sort_expr( + sort_proto, + ctx, + input_schema, + codec, + &converter +); +``` + +Similarly to convert from a physical sort expression into a protobuf node: + +**Before:** + +```rust,ignore +let sort_proto = serialize_physical_sort_expr( + sort_expr, + codec, +); +``` + +**After:** + +```rust,ignore +let converter = DefaultPhysicalProtoConverter {}; +let sort_proto = serialize_physical_sort_expr( + sort_expr, + codec, + &converter, +); +``` + +### `generate_series` and `range` table functions changed + +The `generate_series` and `range` table functions now return an empty set when the interval is invalid, instead of an error. +This behavior is consistent with systems like PostgreSQL. + +Before: + +```sql +> select * from generate_series(0, -1); +Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series + +> select * from range(0, -1); +Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series +``` + +Now: + +```sql +> select * from generate_series(0, -1); ++-------+ +| value | ++-------+ ++-------+ +0 row(s) fetched. + +> select * from range(0, -1); ++-------+ +| value | ++-------+ ++-------+ +0 row(s) fetched. +``` diff --git a/docs/source/library-user-guide/upgrading/index.rst b/docs/source/library-user-guide/upgrading/index.rst new file mode 100644 index 000000000000..16bb33b7592a --- /dev/null +++ b/docs/source/library-user-guide/upgrading/index.rst @@ -0,0 +1,32 @@ +.. Licensed to the Apache Software Foundation (ASF) under one +.. or more contributor license agreements. See the NOTICE file +.. distributed with this work for additional information +.. regarding copyright ownership. The ASF licenses this file +.. to you under the Apache License, Version 2.0 (the +.. "License"); you may not use this file except in compliance +.. with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, +.. software distributed under the License is distributed on an +.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +.. KIND, either express or implied. See the License for the +.. specific language governing permissions and limitations +.. under the License. + +Upgrade Guides +============== + +.. toctree:: + :maxdepth: 1 + + DataFusion 53.0.0 <53.0.0> + DataFusion 52.0.0 <52.0.0> + DataFusion 51.0.0 <51.0.0> + DataFusion 50.0.0 <50.0.0> + DataFusion 49.0.0 <49.0.0> + DataFusion 48.0.1 <48.0.1> + DataFusion 48.0.0 <48.0.0> + DataFusion 47.0.0 <47.0.0> + DataFusion 46.0.0 <46.0.0> diff --git a/docs/source/user-guide/arrow-introduction.md b/docs/source/user-guide/arrow-introduction.md index 89662a0c29c5..5a225782adfd 100644 --- a/docs/source/user-guide/arrow-introduction.md +++ b/docs/source/user-guide/arrow-introduction.md @@ -220,14 +220,15 @@ When working with Arrow and RecordBatches, watch out for these common issues: - [Schema](https://docs.rs/arrow-schema/latest/arrow_schema/struct.Schema.html) - Describes the structure of a RecordBatch (column names and types) [apache arrow]: https://arrow.apache.org/docs/index.html +[arrow-rs]: https://github.com/apache/arrow-rs [`arc`]: https://doc.rust-lang.org/std/sync/struct.Arc.html [`arrayref`]: https://docs.rs/arrow-array/latest/arrow_array/array/type.ArrayRef.html [`cast`]: https://docs.rs/arrow/latest/arrow/compute/fn.cast.html [`field`]: https://docs.rs/arrow-schema/latest/arrow_schema/struct.Field.html [`schema`]: https://docs.rs/arrow-schema/latest/arrow_schema/struct.Schema.html [`datatype`]: https://docs.rs/arrow-schema/latest/arrow_schema/enum.DataType.html -[`int32array`]: https://docs.rs/arrow-array/latest/arrow_array/array/struct.Int32Array.html -[`stringarray`]: https://docs.rs/arrow-array/latest/arrow_array/array/struct.StringArray.html +[`int32array`]: https://docs.rs/arrow/latest/arrow/array/type.Int32Array.html +[`stringarray`]: https://docs.rs/arrow/latest/arrow/array/type.StringArray.html [`int32`]: https://docs.rs/arrow-schema/latest/arrow_schema/enum.DataType.html#variant.Int32 [`int64`]: https://docs.rs/arrow-schema/latest/arrow_schema/enum.DataType.html#variant.Int64 [extension points]: ../library-user-guide/extensions.md @@ -241,8 +242,8 @@ When working with Arrow and RecordBatches, watch out for these common issues: [`.show()`]: https://docs.rs/datafusion/latest/datafusion/dataframe/struct.DataFrame.html#method.show [`memtable`]: https://docs.rs/datafusion/latest/datafusion/datasource/struct.MemTable.html [`sessioncontext`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html -[`csvreadoptions`]: https://docs.rs/datafusion/latest/datafusion/execution/options/struct.CsvReadOptions.html -[`parquetreadoptions`]: https://docs.rs/datafusion/latest/datafusion/execution/options/struct.ParquetReadOptions.html +[`csvreadoptions`]: https://docs.rs/datafusion/latest/datafusion/datasource/file_format/options/struct.CsvReadOptions.html +[`parquetreadoptions`]: https://docs.rs/datafusion/latest/datafusion/datasource/file_format/options/struct.ParquetReadOptions.html [`recordbatch`]: https://docs.rs/arrow-array/latest/arrow_array/struct.RecordBatch.html [`read_csv`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.read_csv [`read_parquet`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.read_parquet diff --git a/docs/source/user-guide/cli/functions.md b/docs/source/user-guide/cli/functions.md index f3b0163534c4..ea353d5c8dcc 100644 --- a/docs/source/user-guide/cli/functions.md +++ b/docs/source/user-guide/cli/functions.md @@ -170,5 +170,55 @@ The columns of the returned table are: | table_size_bytes | Utf8 | Size of the table, in bytes | | statistics_size_bytes | UInt64 | Size of the cached statistics in memory | +## `list_files_cache` + +The `list_files_cache` function shows information about the `ListFilesCache` that is used by the [`ListingTable`] implementation in DataFusion. When creating a [`ListingTable`], DataFusion lists the files in the table's location and caches results in the `ListFilesCache`. Subsequent queries against the same table can reuse this cached information instead of re-listing the files. Cache entries are scoped to tables. + +You can inspect the cache by querying the `list_files_cache` function. For example, + +```sql +> set datafusion.runtime.list_files_cache_ttl = "30s"; +> create external table overturemaps +stored as parquet +location 's3://overturemaps-us-west-2/release/2025-12-17.0/theme=base/type=infrastructure'; +0 row(s) fetched. +> select table, path, metadata_size_bytes, expires_in, unnest(metadata_list)['file_size_bytes'] as file_size_bytes, unnest(metadata_list)['e_tag'] as e_tag from list_files_cache() limit 10; ++--------------+-----------------------------------------------------+---------------------+-----------------------------------+-----------------+---------------------------------------+ +| table | path | metadata_size_bytes | expires_in | file_size_bytes | e_tag | ++--------------+-----------------------------------------------------+---------------------+-----------------------------------+-----------------+---------------------------------------+ +| overturemaps | release/2025-12-17.0/theme=base/type=infrastructure | 2750 | 0 days 0 hours 0 mins 25.264 secs | 999055952 | "35fc8fbe8400960b54c66fbb408c48e8-60" | +| overturemaps | release/2025-12-17.0/theme=base/type=infrastructure | 2750 | 0 days 0 hours 0 mins 25.264 secs | 975592768 | "8a16e10b722681cdc00242564b502965-59" | +| overturemaps | release/2025-12-17.0/theme=base/type=infrastructure | 2750 | 0 days 0 hours 0 mins 25.264 secs | 1082925747 | "24cd13ddb5e0e438952d2499f5dabe06-65" | +| overturemaps | release/2025-12-17.0/theme=base/type=infrastructure | 2750 | 0 days 0 hours 0 mins 25.264 secs | 1008425557 | "37663e31c7c64d4ef355882bcd47e361-61" | +| overturemaps | release/2025-12-17.0/theme=base/type=infrastructure | 2750 | 0 days 0 hours 0 mins 25.264 secs | 1065561905 | "4e7c50d2d1b3c5ed7b82b4898f5ac332-64" | +| overturemaps | release/2025-12-17.0/theme=base/type=infrastructure | 2750 | 0 days 0 hours 0 mins 25.264 secs | 1045655427 | "8fff7e6a72d375eba668727c55d4f103-63" | +| overturemaps | release/2025-12-17.0/theme=base/type=infrastructure | 2750 | 0 days 0 hours 0 mins 25.264 secs | 1086822683 | "b67167d8022d778936c330a52a5f1922-65" | +| overturemaps | release/2025-12-17.0/theme=base/type=infrastructure | 2750 | 0 days 0 hours 0 mins 25.264 secs | 1016732378 | "6d70857a0473ed9ed3fc6e149814168b-61" | +| overturemaps | release/2025-12-17.0/theme=base/type=infrastructure | 2750 | 0 days 0 hours 0 mins 25.264 secs | 991363784 | "c9cafb42fcbb413f851691c895dd7c2b-60" | +| overturemaps | release/2025-12-17.0/theme=base/type=infrastructure | 2750 | 0 days 0 hours 0 mins 25.264 secs | 1032469715 | "7540252d0d67158297a67038a3365e0f-62" | ++--------------+-----------------------------------------------------+---------------------+-----------------------------------+-----------------+---------------------------------------+ +``` + +The columns of the returned table are: +| column_name | data_type | Description | +| ------------------- | ------------ | ----------------------------------------------------------------------------------------- | +| table | Utf8 | Name of the table | +| path | Utf8 | File path relative to the object store / filesystem root | +| metadata_size_bytes | UInt64 | Size of the cached metadata in memory (not its thrift encoded form) | +| expires_in | Duration(ms) | Last modified time of the file | +| metadata_list | List(Struct) | List of metadatas, one for each file under the path. | + +A metadata struct in the metadata_list contains the following fields: + +```text +{ + "file_path": "release/2025-12-17.0/theme=base/type=infrastructure/part-00000-d556e455-e0c5-4940-b367-daff3287a952-c000.zstd.parquet", + "file_modified": "2025-12-17T22:20:29", + "file_size_bytes": 999055952, + "e_tag": "35fc8fbe8400960b54c66fbb408c48e8-60", + "version": null +} +``` + [`listingtable`]: https://docs.rs/datafusion/latest/datafusion/datasource/listing/struct.ListingTable.html [entity tag]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/ETag diff --git a/docs/source/user-guide/concepts-readings-events.md b/docs/source/user-guide/concepts-readings-events.md index ad444ef91c47..3946ca7b16f6 100644 --- a/docs/source/user-guide/concepts-readings-events.md +++ b/docs/source/user-guide/concepts-readings-events.md @@ -21,7 +21,7 @@ ## 🧭 Background Concepts -- **2024-06-13**: [2024 ACM SIGMOD International Conference on Management of Data: Apache Arrow DataFusion: A Fast, Embeddable, Modular Analytic Query Engine](https://dl.acm.org/doi/10.1145/3626246.3653368) - [Download](http://andrew.nerdnetworks.org/other/SIGMOD-2024-lamb.pdf), [Talk](https://youtu.be/-DpKcPfnNms), [Slides](https://docs.google.com/presentation/d/1gqcxSNLGVwaqN0_yJtCbNm19-w5pqPuktII5_EDA6_k/edit#slide=id.p), [Recording ](https://youtu.be/-DpKcPfnNms) +- **2024-06-13**: [2024 ACM SIGMOD International Conference on Management of Data: Apache Arrow DataFusion: A Fast, Embeddable, Modular Analytic Query Engine](https://dl.acm.org/doi/10.1145/3626246.3653368) - [Download](https://andrew.nerdnetworks.org/pdf/SIGMOD-2024-lamb.pdf), [Talk](https://youtu.be/-DpKcPfnNms), [Slides](https://docs.google.com/presentation/d/1gqcxSNLGVwaqN0_yJtCbNm19-w5pqPuktII5_EDA6_k/edit#slide=id.p), [Recording ](https://youtu.be/-DpKcPfnNms) - **2024-06-07**: [Video: SIGMOD 2024 Practice: Apache Arrow DataFusion A Fast, Embeddable, Modular Analytic Query Engine](https://www.youtube.com/watch?v=-DpKcPfnNms&t=5s) - [Slides](https://docs.google.com/presentation/d/1gqcxSNLGVwaqN0_yJtCbNm19-w5pqPuktII5_EDA6_k/edit#slide=id.p) @@ -37,6 +37,34 @@ This is a list of DataFusion related blog posts, articles, and other resources. Please open a PR to add any new resources you create or find +- **2026-01-12** [Blog: Extending SQL in DataFusion: from ->> to TABLESAMPLE](https://datafusion.apache.org/blog/2026/01/12/extending-sql) + +- **2025-12-15** [Blog: Optimizing Repartitions in DataFusion: How I Went From Database Noob to Core Contribution](https://datafusion.apache.org/blog/2025/12/15/avoid-consecutive-repartitions) + +- **2025-09-21** [Blog: Implementing User Defined Types and Custom Metadata in DataFusion](https://datafusion.apache.org/blog/2025/09/21/custom-types-using-metadata) + +- **2025-09-10** [Blog: Dynamic Filters: Passing Information Between Operators During Execution for 25x Faster Queries](https://datafusion.apache.org/blog/2025/09/10/dynamic-filters) + +- **2025-08-15** [Blog: Using External Indexes, Metadata Stores, Catalogs and Caches to Accelerate Queries on Apache Parquet](https://datafusion.apache.org/blog/2025/08/15/external-parquet-indexes) + +- **2025-07-14** [Blog: Embedding User-Defined Indexes in Apache Parquet Files](https://datafusion.apache.org/blog/2025/07/14/user-defined-parquet-indexes) + +- **2025-06-30** [Blog: Using Rust async for Query Execution and Cancelling Long-Running Queries](https://datafusion.apache.org/blog/2025/06/30/cancellation) + +- **2025-06-15** [Blog: Optimizing SQL (and DataFrames) in DataFusion, Part 1: Query Optimization Overview](https://datafusion.apache.org/blog/2025/06/15/optimizing-sql-dataframes-part-one) + +- **2025-06-15** [Blog: Optimizing SQL (and DataFrames) in DataFusion, Part 2: Optimizers in Apache DataFusion](https://datafusion.apache.org/blog/2025/06/15/optimizing-sql-dataframes-part-two) + +- **2025-04-19** [Blog: User defined Window Functions in DataFusion](https://datafusion.apache.org/blog/2025/04/19/user-defined-window-functions) + +- **2025-04-10** [Blog: tpchgen-rs World's fastest open source TPC-H data generator, written in Rust](https://datafusion.apache.org/blog/2025/04/10/fastest-tpch-generator) + +- **2025-03-11** [Blog: Using Ordering for Better Plans in Apache DataFusion](https://datafusion.apache.org/blog/2025/03/11/ordering-analysis) + +- **2024-05-07** [Blog: Announcing Apache Arrow DataFusion is now Apache DataFusion](https://datafusion.apache.org/blog/2024/05/07/datafusion-tlp) + +- **2024-03-06** [Blog: Announcing Apache Arrow DataFusion Comet](https://datafusion.apache.org/blog/2024/03/06/comet-donation) + - **2025-03-21** [Blog: Efficient Filter Pushdown in Parquet](https://datafusion.apache.org/blog/2025/03/21/parquet-pushdown/) - **2025-03-20** [Blog: Parquet Pruning in DataFusion: Read Only What Matters](https://datafusion.apache.org/blog/2025/03/20/parquet-pruning/) @@ -59,16 +87,14 @@ This is a list of DataFusion related blog posts, articles, and other resources. - **2024-10-29** [Video: MiDAS Seminar Fall 2024 on "Apache DataFusion" by Andrew Lamb](https://www.youtube.com/watch?v=CpnxuBwHbUc) -- **2024-10-27** [Blog: Caching in DataFusion: Don't read twice](https://blog.haoxp.xyz/posts/caching-datafusion) +- **2024-10-27** [Blog: Caching in DataFusion: Don't read twice](https://blog.xiangpeng.systems/posts/caching-datafusion/) -- **2024-10-24** [Blog: Parquet pruning in DataFusion: Read no more than you need](https://blog.haoxp.xyz/posts/parquet-to-arrow/) +- **2024-10-24** [Blog: Parquet pruning in DataFusion: Read no more than you need](https://blog.xiangpeng.systems/posts/parquet-to-arrow/) - **2024-09-13** [Blog: Using StringView / German Style Strings to make Queries Faster: Part 2 - String Operations](https://www.influxdata.com/blog/faster-queries-with-stringview-part-two-influxdb/) | [Reposted on DataFusion Blog](https://datafusion.apache.org/blog/2024/09/13/string-view-german-style-strings-part-2/) - **2024-09-13** [Blog: Using StringView / German Style Strings to Make Queries Faster: Part 1- Reading Parquet](https://www.influxdata.com/blog/faster-queries-with-stringview-part-one-influxdb/) | [Reposted on Datafusion Blog](https://datafusion.apache.org/blog/2024/09/13/string-view-german-style-strings-part-1/) -- **2024-10-16** [Blog: Candle Image Segmentation](https://www.letsql.com/posts/candle-image-segmentation/) - - **2024-09-23 → 2024-12-02** [Talks: Carnegie Mellon University: Database Building Blocks Seminar Series - Fall 2024](https://db.cs.cmu.edu/seminar2024/) - **2024-11-12** [Video: Building InfluxDB 3.0 with the FDAP Stack: Apache Flight, DataFusion, Arrow and Parquet (Paul Dix)](https://www.youtube.com/watch?v=AGS4GNGDK_4) diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index c9222afe8ceb..1245e59d477a 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -74,6 +74,8 @@ The following configuration settings are available: | datafusion.catalog.has_header | true | Default value for `format.has_header` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. | | datafusion.catalog.newlines_in_values | false | Specifies whether newlines in (quoted) CSV values are supported. This is the default value for `format.newlines_in_values` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. Parsing newlines in quoted values may be affected by execution behaviour such as parallel file scanning. Setting this to `true` ensures that newlines in values are parsed successfully, which may reduce performance. | | datafusion.execution.batch_size | 8192 | Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption | +| datafusion.execution.perfect_hash_join_small_build_threshold | 1024 | A perfect hash join (see `HashJoinExec` for more details) will be considered if the range of keys (max - min) on the build side is < this threshold. This provides a fast path for joins with very small key ranges, bypassing the density check. Currently only supports cases where build_side.num_rows() < u32::MAX. Support for build_side.num_rows() >= u32::MAX will be added in the future. | +| datafusion.execution.perfect_hash_join_min_key_density | 0.15 | The minimum required density of join keys on the build side to consider a perfect hash join (see `HashJoinExec` for more details). Density is calculated as: `(number of rows) / (max_key - min_key + 1)`. A perfect hash join may be used if the actual key density > this value. Currently only supports cases where build_side.num_rows() < u32::MAX. Support for build_side.num_rows() >= u32::MAX will be added in the future. | | datafusion.execution.coalesce_batches | true | When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting | | datafusion.execution.collect_statistics | true | Should DataFusion collect statistics when first creating a table. Has no effect after the table is created. Applies to the default `ListingTableProvider` in DataFusion. Defaults to true. | | datafusion.execution.target_partitions | 0 | Number of partitions for query execution. Increasing partitions can increase concurrency. Defaults to the number of CPU cores on the system | @@ -91,15 +93,15 @@ The following configuration settings are available: | datafusion.execution.parquet.bloom_filter_on_read | true | (reading) Use any available bloom filters when reading parquet files | | datafusion.execution.parquet.max_predicate_cache_size | NULL | (reading) The maximum predicate cache size, in bytes. When `pushdown_filters` is enabled, sets the maximum memory used to cache the results of predicate evaluation between filter evaluation and output generation. Decreasing this value will reduce memory usage, but may increase IO and CPU usage. None means use the default parquet reader setting. 0 means no caching. | | datafusion.execution.parquet.data_pagesize_limit | 1048576 | (writing) Sets best effort maximum size of data page in bytes | -| datafusion.execution.parquet.write_batch_size | 1024 | (writing) Sets write_batch_size in bytes | +| datafusion.execution.parquet.write_batch_size | 1024 | (writing) Sets write_batch_size in rows | | datafusion.execution.parquet.writer_version | 1.0 | (writing) Sets parquet writer version valid values are "1.0" and "2.0" | | datafusion.execution.parquet.skip_arrow_metadata | false | (writing) Skip encoding the embedded arrow metadata in the KV_meta This is analogous to the `ArrowWriterOptions::with_skip_arrow_metadata`. Refer to | -| datafusion.execution.parquet.compression | zstd(3) | (writing) Sets default parquet compression codec. Valid values are: uncompressed, snappy, gzip(level), lzo, brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting Note that this default setting is not the same as the default parquet writer setting. | +| datafusion.execution.parquet.compression | zstd(3) | (writing) Sets default parquet compression codec. Valid values are: uncompressed, snappy, gzip(level), brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting Note that this default setting is not the same as the default parquet writer setting. | | datafusion.execution.parquet.dictionary_enabled | true | (writing) Sets if dictionary encoding is enabled. If NULL, uses default parquet writer setting | | datafusion.execution.parquet.dictionary_page_size_limit | 1048576 | (writing) Sets best effort maximum dictionary page size, in bytes | | datafusion.execution.parquet.statistics_enabled | page | (writing) Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting | | datafusion.execution.parquet.max_row_group_size | 1048576 | (writing) Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. | -| datafusion.execution.parquet.created_by | datafusion version 51.0.0 | (writing) Sets "created by" property | +| datafusion.execution.parquet.created_by | datafusion version 53.1.0 | (writing) Sets "created by" property | | datafusion.execution.parquet.column_index_truncate_length | 64 | (writing) Sets column index truncate length | | datafusion.execution.parquet.statistics_truncate_length | 64 | (writing) Sets statistics truncate length. If NULL, uses default parquet writer setting | | datafusion.execution.parquet.data_page_row_count_limit | 20000 | (writing) Sets best effort maximum number of rows in data page | @@ -163,6 +165,7 @@ The following configuration settings are available: | datafusion.optimizer.prefer_existing_union | false | When set to true, the optimizer will not attempt to convert Union to Interleave | | datafusion.optimizer.expand_views_at_output | false | When set to true, if the returned type is a view type then the output will be coerced to a non-view. Coerces `Utf8View` to `LargeUtf8`, and `BinaryView` to `LargeBinary`. | | datafusion.optimizer.enable_sort_pushdown | true | Enable sort pushdown optimization. When enabled, attempts to push sort requirements down to data sources that can natively handle them (e.g., by reversing file/row group read order). Returns **inexact ordering**: Sort operator is kept for correctness, but optimized input enables early termination for TopK queries (ORDER BY ... LIMIT N), providing significant speedup. Memory: No additional overhead (only changes read order). Future: Will add option to detect perfectly sorted data and eliminate Sort completely. Default: true | +| datafusion.optimizer.enable_leaf_expression_pushdown | true | When set to true, the optimizer will extract leaf expressions (such as `get_field`) from filter/sort/join nodes into projections closer to the leaf table scans, and push those projections down towards the leaf nodes. | | datafusion.explain.logical_plan_only | false | When set to true, the explain statement will only print logical plans | | datafusion.explain.physical_plan_only | false | When set to true, the explain statement will only print physical plans | | datafusion.explain.show_statistics | false | When set to true, the explain statement will print operator statistics for physical plans | diff --git a/docs/source/user-guide/crate-configuration.md b/docs/source/user-guide/crate-configuration.md index 83a46b50c004..2acb2140efcb 100644 --- a/docs/source/user-guide/crate-configuration.md +++ b/docs/source/user-guide/crate-configuration.md @@ -24,6 +24,7 @@ your Rust project. The [Configuration Settings] section lists options that control additional aspects DataFusion's runtime behavior. [configuration settings]: configs.md +[support for adding dependencies]: https://doc.rust-lang.org/cargo/reference/specifying-dependencies.html#specifying-dependencies ## Using the nightly DataFusion builds @@ -155,7 +156,7 @@ By default, Datafusion returns errors as a plain text message. You can enable mo such as backtraces by enabling the `backtrace` feature to your `Cargo.toml` file like this: ```toml -datafusion = { version = "31.0.0", features = ["backtrace"]} +datafusion = { version = "53.0.0", features = ["backtrace"]} ``` Set environment [variables](https://doc.rust-lang.org/std/backtrace/index.html#environment-variables) diff --git a/docs/source/user-guide/example-usage.md b/docs/source/user-guide/example-usage.md index 6108315f398a..fd755715eec9 100644 --- a/docs/source/user-guide/example-usage.md +++ b/docs/source/user-guide/example-usage.md @@ -29,7 +29,7 @@ Find latest available Datafusion version on [DataFusion's crates.io] page. Add the dependency to your `Cargo.toml` file: ```toml -datafusion = "latest_version" +datafusion = "53.0.0" tokio = { version = "1.0", features = ["rt-multi-thread"] } ``` @@ -103,8 +103,8 @@ exported by DataFusion, for example: use datafusion::arrow::datatypes::Schema; ``` -For example, [DataFusion `25.0.0` dependencies] require `arrow` -`39.0.0`. If instead you used `arrow` `40.0.0` in your project you may +For example, [DataFusion `26.0.0` dependencies] require `arrow` +`40.0.0`. If instead you used `arrow` `41.0.0` in your project you may see errors such as: ```text diff --git a/docs/source/user-guide/explain-usage.md b/docs/source/user-guide/explain-usage.md index 5a1184539c03..c047659e9940 100644 --- a/docs/source/user-guide/explain-usage.md +++ b/docs/source/user-guide/explain-usage.md @@ -226,8 +226,10 @@ Again, reading from bottom up: When predicate pushdown is enabled, `DataSourceExec` with `ParquetSource` gains the following metrics: - `page_index_rows_pruned`: number of rows evaluated by page index filters. The metric reports both how many rows were considered in total and how many matched (were not pruned). +- `page_index_pages_pruned`: number of pages evaluated by page index filters. The metric reports both how many pages were considered in total and how many matched (were not pruned). - `row_groups_pruned_bloom_filter`: number of row groups evaluated by Bloom Filters, reporting both total checked groups and groups that matched. - `row_groups_pruned_statistics`: number of row groups evaluated by row-group statistics (min/max), reporting both total checked groups and groups that matched. +- `limit_pruned_row_groups`: number of row groups pruned by the limit. - `pushdown_rows_matched`: rows that were tested by any of the above filters, and passed all of them. - `pushdown_rows_pruned`: rows that were tested by any of the above filters, and did not pass at least one of them. - `predicate_evaluation_errors`: number of times evaluating the filter expression failed (expected to be zero in normal operation) diff --git a/docs/source/user-guide/introduction.md b/docs/source/user-guide/introduction.md index 66076e6b73ff..9ad42a2a1015 100644 --- a/docs/source/user-guide/introduction.md +++ b/docs/source/user-guide/introduction.md @@ -147,13 +147,14 @@ Here are some less active projects that used DataFusion: - [Flock] - [Tensorbase] +If you know of another project, please submit a PR to add a link! + [ballista]: https://github.com/apache/datafusion-ballista [auron]: https://github.com/apache/auron [cloudfuse buzz]: https://github.com/cloudfuse-io/buzz-rust [cnosdb]: https://github.com/cnosdb/cnosdb [cube store]: https://github.com/cube-js/cube.js/tree/master/rust [dask sql]: https://github.com/dask-contrib/dask-sql -[datafusion-tui]: https://github.com/datafusion-contrib/datafusion-tui [delta-rs]: https://github.com/delta-io/delta-rs [edb postgres lakehouse]: https://www.enterprisedb.com/products/analytics [exon]: https://github.com/wheretrue/exon @@ -172,7 +173,7 @@ Here are some less active projects that used DataFusion: [synnada]: https://synnada.ai/ [tensorbase]: https://github.com/tensorbase/tensorbase [vegafusion]: https://vegafusion.io/ -[vortex]: https://vortex.dev/ "if you know of another project, please submit a PR to add a link!" +[vortex]: https://vortex.dev/ ## Integrations and Extensions diff --git a/docs/source/user-guide/sql/data_types.md b/docs/source/user-guide/sql/data_types.md index 02edb6371ce3..502193df41a6 100644 --- a/docs/source/user-guide/sql/data_types.md +++ b/docs/source/user-guide/sql/data_types.md @@ -25,6 +25,11 @@ execution. The SQL types from are mapped to [Arrow data types](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) according to the following table. This mapping occurs when defining the schema in a `CREATE EXTERNAL TABLE` command or when performing a SQL `CAST` operation. +For background on extension types and custom metadata, see the +[Implementing User Defined Types and Custom Metadata in DataFusion] blog. + +[implementing user defined types and custom metadata in datafusion]: https://datafusion.apache.org/blog/2025/09/21/custom-types-using-metadata + You can see the corresponding Arrow type for any SQL expression using the `arrow_typeof` function. For example: @@ -64,27 +69,32 @@ select arrow_cast(now(), 'Timestamp(Second, None)') as "now()"; | SQL DataType | Arrow DataType | | ------------ | -------------- | -| `CHAR` | `Utf8` | -| `VARCHAR` | `Utf8` | -| `TEXT` | `Utf8` | -| `STRING` | `Utf8` | +| `CHAR` | `Utf8View` | +| `VARCHAR` | `Utf8View` | +| `TEXT` | `Utf8View` | +| `STRING` | `Utf8View` | + +By default, string types are mapped to `Utf8View`. This can be configured using the `datafusion.sql_parser.map_string_types_to_utf8view` setting. When set to `false`, string types are mapped to `Utf8` instead. ## Numeric Types -| SQL DataType | Arrow DataType | -| ------------------------------------ | :----------------------------- | -| `TINYINT` | `Int8` | -| `SMALLINT` | `Int16` | -| `INT` or `INTEGER` | `Int32` | -| `BIGINT` | `Int64` | -| `TINYINT UNSIGNED` | `UInt8` | -| `SMALLINT UNSIGNED` | `UInt16` | -| `INT UNSIGNED` or `INTEGER UNSIGNED` | `UInt32` | -| `BIGINT UNSIGNED` | `UInt64` | -| `FLOAT` | `Float32` | -| `REAL` | `Float32` | -| `DOUBLE` | `Float64` | -| `DECIMAL(precision, scale)` | `Decimal128(precision, scale)` | +| SQL DataType | Arrow DataType | +| ------------------------------------------------ | :----------------------------- | +| `TINYINT` | `Int8` | +| `SMALLINT` | `Int16` | +| `INT` or `INTEGER` | `Int32` | +| `BIGINT` | `Int64` | +| `TINYINT UNSIGNED` | `UInt8` | +| `SMALLINT UNSIGNED` | `UInt16` | +| `INT UNSIGNED` or `INTEGER UNSIGNED` | `UInt32` | +| `BIGINT UNSIGNED` | `UInt64` | +| `FLOAT` | `Float32` | +| `REAL` | `Float32` | +| `DOUBLE` | `Float64` | +| `DECIMAL(precision, scale)` where precision ≤ 38 | `Decimal128(precision, scale)` | +| `DECIMAL(precision, scale)` where precision > 38 | `Decimal256(precision, scale)` | + +The maximum supported precision for `DECIMAL` types is 76. ## Date/Time Types @@ -126,42 +136,3 @@ You can create binary literals using a hex string literal such as | `ENUM` | _Not yet supported_ | | `SET` | _Not yet supported_ | | `DATETIME` | _Not yet supported_ | - -## Supported Arrow Types - -The following types are supported by the `arrow_typeof` function: - -| Arrow Type | -| ----------------------------------------------------------- | -| `Null` | -| `Boolean` | -| `Int8` | -| `Int16` | -| `Int32` | -| `Int64` | -| `UInt8` | -| `UInt16` | -| `UInt32` | -| `UInt64` | -| `Float16` | -| `Float32` | -| `Float64` | -| `Utf8` | -| `LargeUtf8` | -| `Binary` | -| `Timestamp(Second, None)` | -| `Timestamp(Millisecond, None)` | -| `Timestamp(Microsecond, None)` | -| `Timestamp(Nanosecond, None)` | -| `Time32` | -| `Time64` | -| `Duration(Second)` | -| `Duration(Millisecond)` | -| `Duration(Microsecond)` | -| `Duration(Nanosecond)` | -| `Interval(YearMonth)` | -| `Interval(DayTime)` | -| `Interval(MonthDayNano)` | -| `FixedSizeBinary()` (e.g. `FixedSizeBinary(16)`) | -| `Decimal128(, )` e.g. `Decimal128(3, 10)` | -| `Decimal256(, )` e.g. `Decimal256(3, 10)` | diff --git a/docs/source/user-guide/sql/format_options.md b/docs/source/user-guide/sql/format_options.md index d349bc1c98c7..338508031413 100644 --- a/docs/source/user-guide/sql/format_options.md +++ b/docs/source/user-guide/sql/format_options.md @@ -132,38 +132,38 @@ OPTIONS('DELIMITER' '|', 'HAS_HEADER' 'true', 'NEWLINES_IN_VALUES' 'true'); The following options are available when reading or writing Parquet files. If any unsupported option is specified, an error will be raised and the query will fail. If a column-specific option is specified for a column that does not exist, the option will be ignored without error. -| Option | Can be Column Specific? | Description | OPTIONS Key | Default Value | -| ------------------------------------------ | ----------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------- | ------------------------ | -| COMPRESSION | Yes | Sets the internal Parquet **compression codec** for data pages, optionally including the compression level. Applies globally if set without `::col`, or specifically to a column if set using `'compression::column_name'`. Valid values: `uncompressed`, `snappy`, `gzip(level)`, `lzo`, `brotli(level)`, `lz4`, `zstd(level)`, `lz4_raw`. | `'compression'` or `'compression::col'` | zstd(3) | -| ENCODING | Yes | Sets the **encoding** scheme for data pages. Valid values: `plain`, `plain_dictionary`, `rle`, `bit_packed`, `delta_binary_packed`, `delta_length_byte_array`, `delta_byte_array`, `rle_dictionary`, `byte_stream_split`. Use key `'encoding'` or `'encoding::col'` in OPTIONS. | `'encoding'` or `'encoding::col'` | None | -| DICTIONARY_ENABLED | Yes | Sets whether dictionary encoding should be enabled globally or for a specific column. | `'dictionary_enabled'` or `'dictionary_enabled::col'` | true | -| STATISTICS_ENABLED | Yes | Sets the level of statistics to write (`none`, `chunk`, `page`). | `'statistics_enabled'` or `'statistics_enabled::col'` | page | -| BLOOM_FILTER_ENABLED | Yes | Sets whether a bloom filter should be written for a specific column. | `'bloom_filter_enabled::column_name'` | None | -| BLOOM_FILTER_FPP | Yes | Sets bloom filter false positive probability (global or per column). | `'bloom_filter_fpp'` or `'bloom_filter_fpp::col'` | None | -| BLOOM_FILTER_NDV | Yes | Sets bloom filter number of distinct values (global or per column). | `'bloom_filter_ndv'` or `'bloom_filter_ndv::col'` | None | -| MAX_ROW_GROUP_SIZE | No | Sets the maximum number of rows per row group. Larger groups require more memory but can improve compression and scan efficiency. | `'max_row_group_size'` | 1048576 | -| ENABLE_PAGE_INDEX | No | If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce I/O and decoding. | `'enable_page_index'` | true | -| PRUNING | No | If true, enables row group pruning based on min/max statistics. | `'pruning'` | true | -| SKIP_METADATA | No | If true, skips optional embedded metadata in the file schema. | `'skip_metadata'` | true | -| METADATA_SIZE_HINT | No | Sets the size hint (in bytes) for fetching Parquet file metadata. | `'metadata_size_hint'` | None | -| PUSHDOWN_FILTERS | No | If true, enables filter pushdown during Parquet decoding. | `'pushdown_filters'` | false | -| REORDER_FILTERS | No | If true, enables heuristic reordering of filters during Parquet decoding. | `'reorder_filters'` | false | -| SCHEMA_FORCE_VIEW_TYPES | No | If true, reads Utf8/Binary columns as view types. | `'schema_force_view_types'` | true | -| BINARY_AS_STRING | No | If true, reads Binary columns as strings. | `'binary_as_string'` | false | -| DATA_PAGESIZE_LIMIT | No | Sets best effort maximum size of data page in bytes. | `'data_pagesize_limit'` | 1048576 | -| DATA_PAGE_ROW_COUNT_LIMIT | No | Sets best effort maximum number of rows in data page. | `'data_page_row_count_limit'` | 20000 | -| DICTIONARY_PAGE_SIZE_LIMIT | No | Sets best effort maximum dictionary page size, in bytes. | `'dictionary_page_size_limit'` | 1048576 | -| WRITE_BATCH_SIZE | No | Sets write_batch_size in bytes. | `'write_batch_size'` | 1024 | -| WRITER_VERSION | No | Sets the Parquet writer version (`1.0` or `2.0`). | `'writer_version'` | 1.0 | -| SKIP_ARROW_METADATA | No | If true, skips writing Arrow schema information into the Parquet file metadata. | `'skip_arrow_metadata'` | false | -| CREATED_BY | No | Sets the "created by" string in the Parquet file metadata. | `'created_by'` | datafusion version X.Y.Z | -| COLUMN_INDEX_TRUNCATE_LENGTH | No | Sets the length (in bytes) to truncate min/max values in column indexes. | `'column_index_truncate_length'` | 64 | -| STATISTICS_TRUNCATE_LENGTH | No | Sets statistics truncate length. | `'statistics_truncate_length'` | None | -| BLOOM_FILTER_ON_WRITE | No | Sets whether bloom filters should be written for all columns by default (can be overridden per column). | `'bloom_filter_on_write'` | false | -| ALLOW_SINGLE_FILE_PARALLELISM | No | Enables parallel serialization of columns in a single file. | `'allow_single_file_parallelism'` | true | -| MAXIMUM_PARALLEL_ROW_GROUP_WRITERS | No | Maximum number of parallel row group writers. | `'maximum_parallel_row_group_writers'` | 1 | -| MAXIMUM_BUFFERED_RECORD_BATCHES_PER_STREAM | No | Maximum number of buffered record batches per stream. | `'maximum_buffered_record_batches_per_stream'` | 2 | -| KEY_VALUE_METADATA | No (Key is specific) | Adds custom key-value pairs to the file metadata. Use the format `'metadata::your_key_name' 'your_value'`. Multiple entries allowed. | `'metadata::key_name'` | None | +| Option | Can be Column Specific? | Description | OPTIONS Key | Default Value | +| ------------------------------------------ | ----------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----------------------------------------------------- | ------------------------ | +| COMPRESSION | Yes | Sets the internal Parquet **compression codec** for data pages, optionally including the compression level. Applies globally if set without `::col`, or specifically to a column if set using `'compression::column_name'`. Valid values: `uncompressed`, `snappy`, `gzip(level)`, `brotli(level)`, `lz4`, `zstd(level)`, `lz4_raw`. | `'compression'` or `'compression::col'` | zstd(3) | +| ENCODING | Yes | Sets the **encoding** scheme for data pages. Valid values: `plain`, `plain_dictionary`, `rle`, `bit_packed`, `delta_binary_packed`, `delta_length_byte_array`, `delta_byte_array`, `rle_dictionary`, `byte_stream_split`. Use key `'encoding'` or `'encoding::col'` in OPTIONS. | `'encoding'` or `'encoding::col'` | None | +| DICTIONARY_ENABLED | Yes | Sets whether dictionary encoding should be enabled globally or for a specific column. | `'dictionary_enabled'` or `'dictionary_enabled::col'` | true | +| STATISTICS_ENABLED | Yes | Sets the level of statistics to write (`none`, `chunk`, `page`). | `'statistics_enabled'` or `'statistics_enabled::col'` | page | +| BLOOM_FILTER_ENABLED | Yes | Sets whether a bloom filter should be written for a specific column. | `'bloom_filter_enabled::column_name'` | None | +| BLOOM_FILTER_FPP | Yes | Sets bloom filter false positive probability (global or per column). | `'bloom_filter_fpp'` or `'bloom_filter_fpp::col'` | None | +| BLOOM_FILTER_NDV | Yes | Sets bloom filter number of distinct values (global or per column). | `'bloom_filter_ndv'` or `'bloom_filter_ndv::col'` | None | +| MAX_ROW_GROUP_SIZE | No | Sets the maximum number of rows per row group. Larger groups require more memory but can improve compression and scan efficiency. | `'max_row_group_size'` | 1048576 | +| ENABLE_PAGE_INDEX | No | If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce I/O and decoding. | `'enable_page_index'` | true | +| PRUNING | No | If true, enables row group pruning based on min/max statistics. | `'pruning'` | true | +| SKIP_METADATA | No | If true, skips optional embedded metadata in the file schema. | `'skip_metadata'` | true | +| METADATA_SIZE_HINT | No | Sets the size hint (in bytes) for fetching Parquet file metadata. | `'metadata_size_hint'` | None | +| PUSHDOWN_FILTERS | No | If true, enables filter pushdown during Parquet decoding. | `'pushdown_filters'` | false | +| REORDER_FILTERS | No | If true, enables heuristic reordering of filters during Parquet decoding. | `'reorder_filters'` | false | +| SCHEMA_FORCE_VIEW_TYPES | No | If true, reads Utf8/Binary columns as view types. | `'schema_force_view_types'` | true | +| BINARY_AS_STRING | No | If true, reads Binary columns as strings. | `'binary_as_string'` | false | +| DATA_PAGESIZE_LIMIT | No | Sets best effort maximum size of data page in bytes. | `'data_pagesize_limit'` | 1048576 | +| DATA_PAGE_ROW_COUNT_LIMIT | No | Sets best effort maximum number of rows in data page. | `'data_page_row_count_limit'` | 20000 | +| DICTIONARY_PAGE_SIZE_LIMIT | No | Sets best effort maximum dictionary page size, in bytes. | `'dictionary_page_size_limit'` | 1048576 | +| WRITE_BATCH_SIZE | No | Sets write_batch_size in rows. | `'write_batch_size'` | 1024 | +| WRITER_VERSION | No | Sets the Parquet writer version (`1.0` or `2.0`). | `'writer_version'` | 1.0 | +| SKIP_ARROW_METADATA | No | If true, skips writing Arrow schema information into the Parquet file metadata. | `'skip_arrow_metadata'` | false | +| CREATED_BY | No | Sets the "created by" string in the Parquet file metadata. | `'created_by'` | datafusion version X.Y.Z | +| COLUMN_INDEX_TRUNCATE_LENGTH | No | Sets the length (in bytes) to truncate min/max values in column indexes. | `'column_index_truncate_length'` | 64 | +| STATISTICS_TRUNCATE_LENGTH | No | Sets statistics truncate length. | `'statistics_truncate_length'` | None | +| BLOOM_FILTER_ON_WRITE | No | Sets whether bloom filters should be written for all columns by default (can be overridden per column). | `'bloom_filter_on_write'` | false | +| ALLOW_SINGLE_FILE_PARALLELISM | No | Enables parallel serialization of columns in a single file. | `'allow_single_file_parallelism'` | true | +| MAXIMUM_PARALLEL_ROW_GROUP_WRITERS | No | Maximum number of parallel row group writers. | `'maximum_parallel_row_group_writers'` | 1 | +| MAXIMUM_BUFFERED_RECORD_BATCHES_PER_STREAM | No | Maximum number of buffered record batches per stream. | `'maximum_buffered_record_batches_per_stream'` | 2 | +| KEY_VALUE_METADATA | No (Key is specific) | Adds custom key-value pairs to the file metadata. Use the format `'metadata::your_key_name' 'your_value'`. Multiple entries allowed. | `'metadata::key_name'` | None | **Example:** diff --git a/docs/source/user-guide/sql/index.rst b/docs/source/user-guide/sql/index.rst index a13d40334b63..f1fef45f705a 100644 --- a/docs/source/user-guide/sql/index.rst +++ b/docs/source/user-guide/sql/index.rst @@ -22,6 +22,7 @@ SQL Reference :maxdepth: 2 data_types + struct_coercion select subqueries ddl diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index fe1ed1cab6bd..254151c2c20e 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1225,7 +1225,7 @@ bit_length(str) ### `btrim` -Trims the specified trim string from the start and end of a string. If no trim string is provided, all whitespace is removed from the start and end of the input string. +Trims the specified trim string from the start and end of a string. If no trim string is provided, all spaces are removed from the start and end of the input string. ```sql btrim(str[, trim_str]) @@ -1234,7 +1234,7 @@ btrim(str[, trim_str]) #### Arguments - **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **trim_str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. _Default is whitespace characters._ +- **trim_str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. _Default is a space._ #### Example @@ -1592,7 +1592,7 @@ lpad(str, n[, padding_str]) #### Arguments - **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **n**: String length to pad to. +- **n**: String length to pad to. If the input string is longer than this length, it is truncated (on the right). - **padding_str**: Optional string expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._ #### Example @@ -1612,7 +1612,7 @@ lpad(str, n[, padding_str]) ### `ltrim` -Trims the specified trim string from the beginning of a string. If no trim string is provided, all whitespace is removed from the start of the input string. +Trims the specified trim string from the beginning of a string. If no trim string is provided, spaces are removed from the start of the input string. ```sql ltrim(str[, trim_str]) @@ -1621,7 +1621,7 @@ ltrim(str[, trim_str]) #### Arguments - **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **trim_str**: String expression to trim from the beginning of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._ +- **trim_str**: String expression to trim from the beginning of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is a space._ #### Example @@ -1820,7 +1820,7 @@ rpad(str, n[, padding_str]) #### Arguments - **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **n**: String length to pad to. +- **n**: String length to pad to. If the input string is longer than this length, it is truncated. - **padding_str**: String expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._ #### Example @@ -1840,7 +1840,7 @@ rpad(str, n[, padding_str]) ### `rtrim` -Trims the specified trim string from the end of a string. If no trim string is provided, all whitespace is removed from the end of the input string. +Trims the specified trim string from the end of a string. If no trim string is provided, all spaces are removed from the end of the input string. ```sql rtrim(str[, trim_str]) @@ -1849,7 +1849,7 @@ rtrim(str[, trim_str]) #### Arguments - **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **trim_str**: String expression to trim from the end of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._ +- **trim_str**: String expression to trim from the end of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is a space._ #### Example @@ -1891,7 +1891,7 @@ split_part(str, delimiter, pos) - **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. - **delimiter**: String or character to split on. -- **pos**: Position of the part to return. +- **pos**: Position of the part to return (counting from 1). Negative values count backward from the end of the string. #### Example @@ -2068,17 +2068,17 @@ to_hex(int) ### `translate` -Translates characters in a string to specified translation characters. +Performs character-wise substitution based on a mapping. ```sql -translate(str, chars, translation) +translate(str, from, to) ``` #### Arguments - **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **chars**: Characters to translate. -- **translation**: Translation characters. Translation characters replace only characters at the same position in the **chars** string. +- **from**: The characters to be replaced. +- **to**: The characters to replace them with. Each character in **from** that is found in **str** is replaced by the character at the same index in **to**. Any characters in **from** that don't have a corresponding character in **to** are removed. If a character appears more than once in **from**, the first occurrence determines the mapping. #### Example @@ -2175,7 +2175,7 @@ encode(expression, format) #### Arguments - **expression**: Expression containing string or binary data -- **format**: Supported formats are: `base64`, `hex` +- **format**: Supported formats are: `base64`, `base64pad`, `hex` **Related functions**: @@ -2519,6 +2519,7 @@ date_part(part, expression) - **part**: Part of the date to return. The following date parts are supported: - year + - isoyear (ISO 8601 week-numbering year) - quarter (emits value in inclusive range [1, 4] based on which quartile of the year the date is in) - month - week (week of the year) @@ -2531,7 +2532,7 @@ date_part(part, expression) - nanosecond - dow (day of the week where Sunday is 0) - doy (day of the year) - - epoch (seconds since Unix epoch) + - epoch (seconds since Unix epoch for timestamps/dates, total seconds for intervals) - isodow (day of the week where Monday is 0) - **expression**: Time expression to operate on. Can be a constant, column, or function. @@ -2548,7 +2549,7 @@ extract(field FROM source) ### `date_trunc` -Truncates a timestamp value to a specified precision. +Truncates a timestamp or time value to a specified precision. ```sql date_trunc(precision, expression) @@ -2558,6 +2559,8 @@ date_trunc(precision, expression) - **precision**: Time precision to truncate to. The following precisions are supported: + For Timestamp types: + - year / YEAR - quarter / QUARTER - month / MONTH @@ -2569,7 +2572,15 @@ date_trunc(precision, expression) - millisecond / MILLISECOND - microsecond / MICROSECOND -- **expression**: Time expression to operate on. Can be a constant, column, or function. + For Time types (hour, minute, second, millisecond, microsecond only): + + - hour / HOUR + - minute / MINUTE + - second / SECOND + - millisecond / MILLISECOND + - microsecond / MICROSECOND + +- **expression**: Timestamp or time expression to operate on. Can be a constant, column, or function. #### Aliases @@ -2869,7 +2880,8 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000`) in the session time zone. Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') -if no [Chrono formats] are provided. Strings that parse without a time zone are treated as if they are in the +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Strings that parse without a time zone are treated as if they are in the session time zone, or UTC if no session time zone is set. Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). @@ -2917,7 +2929,8 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000`) in the session time zone. Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') -if no [Chrono formats] are provided. Strings that parse without a time zone are treated as if they are in the +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Strings that parse without a time zone are treated as if they are in the session time zone, or UTC if no session time zone is set. Integers, unsigned integers, and doubles are interpreted as microseconds since the unix epoch (`1970-01-01T00:00:00Z`). @@ -2960,7 +2973,8 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000`) in the session time zone. Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') -if no [Chrono formats] are provided. Strings that parse without a time zone are treated as if they are in the +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Strings that parse without a time zone are treated as if they are in the session time zone, or UTC if no session time zone is set. Integers, unsigned integers, and doubles are interpreted as milliseconds since the unix epoch (`1970-01-01T00:00:00Z`). @@ -3003,7 +3017,8 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000000`) in the session time zone. Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') -if no [Chrono formats] are provided. Strings that parse without a time zone are treated as if they are in the +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Strings that parse without a time zone are treated as if they are in the session time zone. Integers, unsigned integers, and doubles are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`). The session time zone can be set using the statement `SET TIMEZONE = 'desired time zone'`. @@ -3045,7 +3060,8 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo Converts a value to a timestamp (`YYYY-MM-DDT00:00:00`) in the session time zone. Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') -if no [Chrono formats] are provided. Strings that parse without a time zone are treated as if they are in the +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Strings that parse without a time zone are treated as if they are in the session time zone, or UTC if no session time zone is set. Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). @@ -3086,7 +3102,11 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo ### `to_unixtime` -Converts a value to seconds since the unix epoch (`1970-01-01T00:00:00`). Supports strings, dates, timestamps, integer, unsigned integer, and float types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. Integers, unsigned integers, and floats are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00`). +Converts a value to seconds since the unix epoch (`1970-01-01T00:00:00`). +Supports strings, dates, timestamps, integer, unsigned integer, and float types as input. +Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Integers, unsigned integers, and floats are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00`). ```sql to_unixtime(expression[, ..., format_n]) @@ -3163,6 +3183,7 @@ _Alias of [current_date](#current_date)._ - [array_to_string](#array_to_string) - [array_union](#array_union) - [arrays_overlap](#arrays_overlap) +- [arrays_zip](#arrays_zip) - [cardinality](#cardinality) - [empty](#empty) - [flatten](#flatten) @@ -3208,6 +3229,7 @@ _Alias of [current_date](#current_date)._ - [list_sort](#list_sort) - [list_to_string](#list_to_string) - [list_union](#list_union) +- [list_zip](#list_zip) - [make_array](#make_array) - [make_list](#make_list) - [range](#range) @@ -3523,16 +3545,16 @@ array_has_all(array, sub-array) ### `array_has_any` -Returns true if any elements exist in both arrays. +Returns true if the arrays have any elements in common. ```sql -array_has_any(array, sub-array) +array_has_any(array1, array2) ``` #### Arguments -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **sub-array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **array1**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **array2**: Array expression. Can be a constant, column, or function, and any combination of array operators. #### Example @@ -3754,7 +3776,7 @@ array_pop_front(array) ### `array_position` -Returns the position of the first occurrence of the specified element in the array, or NULL if not found. +Returns the position of the first occurrence of the specified element in the array, or NULL if not found. Comparisons are done using `IS DISTINCT FROM` semantics, so NULL is considered to match NULL. ```sql array_position(array, element) @@ -3764,7 +3786,7 @@ array_position(array, element, index) #### Arguments - **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **element**: Element to search for position in the array. +- **element**: Element to search for in the array. - **index**: Index at which to start searching (1-indexed). #### Example @@ -4190,7 +4212,7 @@ array_to_string(array, delimiter[, null_string]) - **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. - **delimiter**: Array element separator. -- **null_string**: Optional. String to replace null values in the array. If not provided, nulls will be handled by default behavior. +- **null_string**: Optional. String to use for null values in the output. If not provided, nulls will be omitted. #### Example @@ -4211,7 +4233,7 @@ array_to_string(array, delimiter[, null_string]) ### `array_union` -Returns an array of elements that are present in both arrays (all elements from both arrays) with out duplicates. +Returns an array of elements that are present in both arrays (all elements from both arrays) without duplicates. ```sql array_union(array1, array2) @@ -4247,6 +4269,41 @@ array_union(array1, array2) _Alias of [array_has_any](#array_has_any)._ +### `arrays_zip` + +Returns an array of structs created by combining the elements of each input array at the same index. If the arrays have different lengths, shorter arrays are padded with NULLs. + +```sql +arrays_zip(array1, array2[, ..., array_n]) +``` + +#### Arguments + +- **array1**: First array expression. +- **array2**: Second array expression. +- **array_n**: Subsequent array expressions. + +#### Example + +```sql +> select arrays_zip([1, 2, 3], ['a', 'b', 'c']); ++---------------------------------------------------+ +| arrays_zip([1, 2, 3], ['a', 'b', 'c']) | ++---------------------------------------------------+ +| [{c0: 1, c1: a}, {c0: 2, c1: b}, {c0: 3, c1: c}] | ++---------------------------------------------------+ +> select arrays_zip([1, 2], [3, 4, 5]); ++---------------------------------------------------+ +| arrays_zip([1, 2], [3, 4, 5]) | ++---------------------------------------------------+ +| [{c0: 1, c1: 3}, {c0: 2, c1: 4}, {c0: , c1: 5}] | ++---------------------------------------------------+ +``` + +#### Aliases + +- list_zip + ### `cardinality` Returns the total number of elements in the array. @@ -4516,6 +4573,10 @@ _Alias of [array_to_string](#array_to_string)._ _Alias of [array_union](#array_union)._ +### `list_zip` + +_Alias of [arrays_zip](#arrays_zip)._ + ### `make_array` Returns an array using the specified input expressions. diff --git a/docs/source/user-guide/sql/struct_coercion.md b/docs/source/user-guide/sql/struct_coercion.md new file mode 100644 index 000000000000..d2a32fcee265 --- /dev/null +++ b/docs/source/user-guide/sql/struct_coercion.md @@ -0,0 +1,354 @@ + + +# Struct Type Coercion and Field Mapping + +DataFusion uses **name-based field mapping** when coercing struct types across different operations. This document explains how struct coercion works, when it applies, and how to handle NULL fields. + +## Overview: Name-Based vs Positional Mapping + +When combining structs from different sources (e.g., in UNION, array construction, or JOINs), DataFusion matches struct fields by **name** rather than by **position**. This provides more robust and predictable behavior compared to positional matching. + +### Example: Field Reordering is Handled Transparently + +```sql +-- These two structs have the same fields in different order +SELECT [{a: 1, b: 2}, {b: 3, a: 4}]; + +-- Result: Field names matched, values unified +-- [{"a": 1, "b": 2}, {"a": 4, "b": 3}] +``` + +## Coercion Paths Using Name-Based Matching + +The following query operations use name-based field mapping for struct coercion: + +### 1. Array Literal Construction + +When creating array literals with struct elements that have different field orders: + +```sql +-- Structs with reordered fields in array literal +SELECT [{x: 1, y: 2}, {y: 3, x: 4}]; + +-- Unified type: List(Struct("x": Int32, "y": Int32)) +-- Values: [{"x": 1, "y": 2}, {"x": 4, "y": 3}] +``` + +**When it applies:** + +- Array literals with struct elements: `[{...}, {...}]` +- Nested arrays with structs: `[[{x: 1}, {x: 2}]]` + +### 2. Array Construction from Columns + +When constructing arrays from table columns with different struct schemas: + +```sql +CREATE TABLE t_left (s struct(x int, y int)) AS VALUES ({x: 1, y: 2}); +CREATE TABLE t_right (s struct(y int, x int)) AS VALUES ({y: 3, x: 4}); + +-- Dynamically constructs unified array schema +SELECT [t_left.s, t_right.s] FROM t_left JOIN t_right; + +-- Result: [{"x": 1, "y": 2}, {"x": 4, "y": 3}] +``` + +**When it applies:** + +- Array construction with column references: `[col1, col2]` +- Array construction in joins with matching field names + +### 3. UNION Operations + +When combining query results with different struct field orders: + +```sql +SELECT {a: 1, b: 2} as s +UNION ALL +SELECT {b: 3, a: 4} as s; + +-- Result: {"a": 1, "b": 2} and {"a": 4, "b": 3} +``` + +**When it applies:** + +- UNION ALL with structs: field names matched across branches +- UNION (deduplicated) with structs + +### 4. Common Table Expressions (CTEs) + +When multiple CTEs produce structs with different field orders that are combined: + +```sql +WITH + t1 AS (SELECT {a: 1, b: 2} as s), + t2 AS (SELECT {b: 3, a: 4} as s) +SELECT s FROM t1 +UNION ALL +SELECT s FROM t2; + +-- Result: Field names matched across CTEs +``` + +### 5. VALUES Clauses + +When creating tables or temporary results with struct values in different field orders: + +```sql +CREATE TABLE t AS VALUES ({a: 1, b: 2}), ({b: 3, a: 4}); + +-- Table schema unified: struct(a: int, b: int) +-- Values: {a: 1, b: 2} and {a: 4, b: 3} +``` + +### 6. JOIN Operations + +When joining tables where the JOIN condition involves structs with different field orders: + +```sql +CREATE TABLE orders (customer struct(name varchar, id int)); +CREATE TABLE customers (info struct(id int, name varchar)); + +-- Join matches struct fields by name +SELECT * FROM orders +JOIN customers ON orders.customer = customers.info; +``` + +### 7. Aggregate Functions + +When collecting structs with different field orders using aggregate functions like `array_agg`: + +```sql +SELECT array_agg(s) FROM ( + SELECT {x: 1, y: 2} as s + UNION ALL + SELECT {y: 3, x: 4} as s +) t +GROUP BY category; + +-- Result: Array of structs with unified field order +``` + +### 8. Window Functions + +When using window functions with struct expressions having different field orders: + +```sql +SELECT + id, + row_number() over (partition by s order by id) as rn +FROM ( + SELECT {category: 1, value: 10} as s, 1 as id + UNION ALL + SELECT {value: 20, category: 1} as s, 2 as id +); + +-- Fields matched by name in PARTITION BY clause +``` + +## NULL Handling for Missing Fields + +When structs have different field sets, missing fields are filled with **NULL** values during coercion. + +### Example: Partial Field Overlap + +```sql +-- Struct in first position has fields: a, b +-- Struct in second position has fields: b, c +-- Unified schema includes all fields: a, b, c + +SELECT [ + CAST({a: 1, b: 2} AS STRUCT(a INT, b INT, c INT)), + CAST({b: 3, c: 4} AS STRUCT(a INT, b INT, c INT)) +]; + +-- Result: +-- [ +-- {"a": 1, "b": 2, "c": NULL}, +-- {"a": NULL, "b": 3, "c": 4} +-- ] +``` + +### Limitations + +**Field count must match exactly.** If structs have different numbers of fields and their field names don't completely overlap, the query will fail: + +```sql +-- This fails because field sets don't match: +-- t_left has {x, y} but t_right has {x, y, z} +SELECT [t_left.s, t_right.s] FROM t_left JOIN t_right; +-- Error: Cannot coerce struct with mismatched field counts +``` + +**Workaround: Use explicit CAST** + +To handle partial field overlap, explicitly cast structs to a unified schema: + +```sql +SELECT [ + CAST(t_left.s AS STRUCT(x INT, y INT, z INT)), + CAST(t_right.s AS STRUCT(x INT, y INT, z INT)) +] FROM t_left JOIN t_right; +``` + +## Migration Guide: From Positional to Name-Based Matching + +If you have existing code that relied on **positional** struct field matching, you may need to update it. + +### Example: Query That Changes Behavior + +**Old behavior (positional):** + +```sql +-- These would have been positionally mapped (left-to-right) +SELECT [{x: 1, y: 2}, {y: 3, x: 4}]; +-- Old result (positional): [{"x": 1, "y": 2}, {"y": 3, "x": 4}] +``` + +**New behavior (name-based):** + +```sql +-- Now uses name-based matching +SELECT [{x: 1, y: 2}, {y: 3, x: 4}]; +-- New result (by name): [{"x": 1, "y": 2}, {"x": 4, "y": 3}] +``` + +### Migration Steps + +1. **Review struct operations** - Look for queries that combine structs from different sources +2. **Check field names** - Verify that field names match as expected (not positions) +3. **Test with new coercion** - Run queries and verify the results match your expectations +4. **Handle field reordering** - If you need specific field orders, use explicit CAST operations + +### Using Explicit CAST for Compatibility + +If you need precise control over struct field order and types, use explicit `CAST`: + +```sql +-- Guarantee specific field order and types +SELECT CAST({b: 3, a: 4} AS STRUCT(a INT, b INT)); +-- Result: {"a": 4, "b": 3} +``` + +## Best Practices + +### 1. Be Explicit with Schema Definitions + +When joining or combining structs, define target schemas explicitly: + +```sql +-- Good: explicit schema definition +SELECT CAST(data AS STRUCT(id INT, name VARCHAR, active BOOLEAN)) +FROM external_source; +``` + +### 2. Use Named Struct Constructors + +Prefer named struct constructors for clarity: + +```sql +-- Good: field names are explicit +SELECT named_struct('id', 1, 'name', 'Alice', 'active', true); + +-- Or using struct literal syntax +SELECT {id: 1, name: 'Alice', active: true}; +``` + +### 3. Test Field Mappings + +Always verify that field mappings work as expected: + +```sql +-- Use arrow_typeof to verify unified schema +SELECT arrow_typeof([{x: 1, y: 2}, {y: 3, x: 4}]); +-- Result: List(Struct("x": Int32, "y": Int32)) +``` + +### 4. Handle Partial Field Overlap Explicitly + +When combining structs with partial field overlap, use explicit CAST: + +```sql +-- Instead of relying on implicit coercion +SELECT [ + CAST(left_struct AS STRUCT(x INT, y INT, z INT)), + CAST(right_struct AS STRUCT(x INT, y INT, z INT)) +]; +``` + +### 5. Document Struct Schemas + +In complex queries, document the expected struct schemas: + +```sql +-- Expected schema: {customer_id: INT, name: VARCHAR, age: INT} +SELECT { + customer_id: c.id, + name: c.name, + age: c.age +} as customer_info +FROM customers c; +``` + +## Error Messages and Troubleshooting + +### "Cannot coerce struct with different field counts" + +**Cause:** Trying to combine structs with different numbers of fields. + +**Solution:** + +```sql +-- Use explicit CAST to handle missing fields +SELECT [ + CAST(struct1 AS STRUCT(a INT, b INT, c INT)), + CAST(struct2 AS STRUCT(a INT, b INT, c INT)) +]; +``` + +### "Field X not found in struct" + +**Cause:** Referencing a field name that doesn't exist in the struct. + +**Solution:** + +```sql +-- Verify field names match exactly (case-sensitive) +SELECT s['field_name'] FROM my_table; -- Use bracket notation for access +-- Or use get_field function +SELECT get_field(s, 'field_name') FROM my_table; +``` + +### Unexpected NULL values after coercion + +**Cause:** Struct coercion added NULL for missing fields. + +**Solution:** Check that all structs have the required fields, or explicitly handle NULLs: + +```sql +SELECT COALESCE(s['field'], default_value) FROM my_table; +``` + +## Related Functions + +- `arrow_typeof()` - Returns the Arrow type of an expression +- `struct()` / `named_struct()` - Creates struct values +- `get_field()` - Extracts field values from structs +- `CAST()` - Explicitly casts structs to specific schemas diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000000..6fc7705d8536 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,2 @@ +[tool.uv.workspace] +members = ["benchmarks", "dev", "docs"] diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 4e3ea12e2f28..f351f58a7117 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -19,5 +19,5 @@ # to compile this workspace and run CI jobs. [toolchain] -channel = "1.92.0" +channel = "1.93.0" components = ["rustfmt", "clippy"] diff --git a/test-utils/src/data_gen.rs b/test-utils/src/data_gen.rs index 2228010b28dd..bb8fdad5a0f8 100644 --- a/test-utils/src/data_gen.rs +++ b/test-utils/src/data_gen.rs @@ -129,7 +129,7 @@ impl BatchBuilder { } } - #[allow(clippy::too_many_arguments)] + #[expect(clippy::too_many_arguments)] fn append_row( &mut self, rng: &mut StdRng, diff --git a/uv.lock b/uv.lock new file mode 100644 index 000000000000..f44c83e5f539 --- /dev/null +++ b/uv.lock @@ -0,0 +1,1149 @@ +version = 1 +revision = 3 +requires-python = ">=3.11" +resolution-markers = [ + "python_full_version >= '3.12'", + "python_full_version < '3.12'", +] + +[manifest] +members = [ + "datafusion-benchmarks", + "datafusion-dev", + "datafusion-docs", +] + +[[package]] +name = "accessible-pygments" +version = "0.0.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bc/c1/bbac6a50d02774f91572938964c582fff4270eee73ab822a4aeea4d8b11b/accessible_pygments-0.0.5.tar.gz", hash = "sha256:40918d3e6a2b619ad424cb91e556bd3bd8865443d9f22f1dcdf79e33c8046872", size = 1377899, upload-time = "2024-05-10T11:23:10.216Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8d/3f/95338030883d8c8b91223b4e21744b04d11b161a3ef117295d8241f50ab4/accessible_pygments-0.0.5-py3-none-any.whl", hash = "sha256:88ae3211e68a1d0b011504b2ffc1691feafce124b845bd072ab6f9f66f34d4b7", size = 1395903, upload-time = "2024-05-10T11:23:08.421Z" }, +] + +[[package]] +name = "alabaster" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a6/f8/d9c74d0daf3f742840fd818d69cfae176fa332022fd44e3469487d5a9420/alabaster-1.0.0.tar.gz", hash = "sha256:c00dca57bca26fa62a6d7d0a9fcce65f3e026e9bfe33e9c538fd3fbb2144fd9e", size = 24210, upload-time = "2024-07-26T18:15:03.762Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/b3/6b4067be973ae96ba0d615946e314c5ae35f9f993eca561b356540bb0c2b/alabaster-1.0.0-py3-none-any.whl", hash = "sha256:fc6786402dc3fcb2de3cabd5fe455a2db534b371124f1f21de8731783dec828b", size = 13929, upload-time = "2024-07-26T18:15:02.05Z" }, +] + +[[package]] +name = "annotated-doc" +version = "0.0.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/57/ba/046ceea27344560984e26a590f90bc7f4a75b06701f653222458922b558c/annotated_doc-0.0.4.tar.gz", hash = "sha256:fbcda96e87e9c92ad167c2e53839e57503ecfda18804ea28102353485033faa4", size = 7288, upload-time = "2025-11-10T22:07:42.062Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl", hash = "sha256:571ac1dc6991c450b25a9c2d84a3705e2ae7a53467b5d111c24fa8baabbed320", size = 5303, upload-time = "2025-11-10T22:07:40.673Z" }, +] + +[[package]] +name = "babel" +version = "2.18.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7d/b2/51899539b6ceeeb420d40ed3cd4b7a40519404f9baf3d4ac99dc413a834b/babel-2.18.0.tar.gz", hash = "sha256:b80b99a14bd085fcacfa15c9165f651fbb3406e66cc603abf11c5750937c992d", size = 9959554, upload-time = "2026-02-01T12:30:56.078Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/f5/21d2de20e8b8b0408f0681956ca2c69f1320a3848ac50e6e7f39c6159675/babel-2.18.0-py3-none-any.whl", hash = "sha256:e2b422b277c2b9a9630c1d7903c2a00d0830c409c59ac8cae9081c92f1aeba35", size = 10196845, upload-time = "2026-02-01T12:30:53.445Z" }, +] + +[[package]] +name = "beautifulsoup4" +version = "4.14.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "soupsieve" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c3/b0/1c6a16426d389813b48d95e26898aff79abbde42ad353958ad95cc8c9b21/beautifulsoup4-4.14.3.tar.gz", hash = "sha256:6292b1c5186d356bba669ef9f7f051757099565ad9ada5dd630bd9de5fa7fb86", size = 627737, upload-time = "2025-11-30T15:08:26.084Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1a/39/47f9197bdd44df24d67ac8893641e16f386c984a0619ef2ee4c51fbbc019/beautifulsoup4-4.14.3-py3-none-any.whl", hash = "sha256:0918bfe44902e6ad8d57732ba310582e98da931428d231a5ecb9e7c703a735bb", size = 107721, upload-time = "2025-11-30T15:08:24.087Z" }, +] + +[[package]] +name = "certifi" +version = "2026.1.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e0/2d/a891ca51311197f6ad14a7ef42e2399f36cf2f9bd44752b3dc4eab60fdc5/certifi-2026.1.4.tar.gz", hash = "sha256:ac726dd470482006e014ad384921ed6438c457018f4b3d204aea4281258b2120", size = 154268, upload-time = "2026-01-04T02:42:41.825Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e6/ad/3cc14f097111b4de0040c83a525973216457bbeeb63739ef1ed275c1c021/certifi-2026.1.4-py3-none-any.whl", hash = "sha256:9943707519e4add1115f44c2bc244f782c0249876bf51b6599fee1ffbedd685c", size = 152900, upload-time = "2026-01-04T02:42:40.15Z" }, +] + +[[package]] +name = "cffi" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pycparser", marker = "implementation_name != 'PyPy'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/eb/56/b1ba7935a17738ae8453301356628e8147c79dbb825bcbc73dc7401f9846/cffi-2.0.0.tar.gz", hash = "sha256:44d1b5909021139fe36001ae048dbdde8214afa20200eda0f64c068cac5d5529", size = 523588, upload-time = "2025-09-08T23:24:04.541Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/12/4a/3dfd5f7850cbf0d06dc84ba9aa00db766b52ca38d8b86e3a38314d52498c/cffi-2.0.0-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:b4c854ef3adc177950a8dfc81a86f5115d2abd545751a304c5bcf2c2c7283cfe", size = 184344, upload-time = "2025-09-08T23:22:26.456Z" }, + { url = "https://files.pythonhosted.org/packages/4f/8b/f0e4c441227ba756aafbe78f117485b25bb26b1c059d01f137fa6d14896b/cffi-2.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2de9a304e27f7596cd03d16f1b7c72219bd944e99cc52b84d0145aefb07cbd3c", size = 180560, upload-time = "2025-09-08T23:22:28.197Z" }, + { url = "https://files.pythonhosted.org/packages/b1/b7/1200d354378ef52ec227395d95c2576330fd22a869f7a70e88e1447eb234/cffi-2.0.0-cp311-cp311-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:baf5215e0ab74c16e2dd324e8ec067ef59e41125d3eade2b863d294fd5035c92", size = 209613, upload-time = "2025-09-08T23:22:29.475Z" }, + { url = "https://files.pythonhosted.org/packages/b8/56/6033f5e86e8cc9bb629f0077ba71679508bdf54a9a5e112a3c0b91870332/cffi-2.0.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:730cacb21e1bdff3ce90babf007d0a0917cc3e6492f336c2f0134101e0944f93", size = 216476, upload-time = "2025-09-08T23:22:31.063Z" }, + { url = "https://files.pythonhosted.org/packages/dc/7f/55fecd70f7ece178db2f26128ec41430d8720f2d12ca97bf8f0a628207d5/cffi-2.0.0-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:6824f87845e3396029f3820c206e459ccc91760e8fa24422f8b0c3d1731cbec5", size = 203374, upload-time = "2025-09-08T23:22:32.507Z" }, + { url = "https://files.pythonhosted.org/packages/84/ef/a7b77c8bdc0f77adc3b46888f1ad54be8f3b7821697a7b89126e829e676a/cffi-2.0.0-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:9de40a7b0323d889cf8d23d1ef214f565ab154443c42737dfe52ff82cf857664", size = 202597, upload-time = "2025-09-08T23:22:34.132Z" }, + { url = "https://files.pythonhosted.org/packages/d7/91/500d892b2bf36529a75b77958edfcd5ad8e2ce4064ce2ecfeab2125d72d1/cffi-2.0.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:8941aaadaf67246224cee8c3803777eed332a19d909b47e29c9842ef1e79ac26", size = 215574, upload-time = "2025-09-08T23:22:35.443Z" }, + { url = "https://files.pythonhosted.org/packages/44/64/58f6255b62b101093d5df22dcb752596066c7e89dd725e0afaed242a61be/cffi-2.0.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:a05d0c237b3349096d3981b727493e22147f934b20f6f125a3eba8f994bec4a9", size = 218971, upload-time = "2025-09-08T23:22:36.805Z" }, + { url = "https://files.pythonhosted.org/packages/ab/49/fa72cebe2fd8a55fbe14956f9970fe8eb1ac59e5df042f603ef7c8ba0adc/cffi-2.0.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:94698a9c5f91f9d138526b48fe26a199609544591f859c870d477351dc7b2414", size = 211972, upload-time = "2025-09-08T23:22:38.436Z" }, + { url = "https://files.pythonhosted.org/packages/0b/28/dd0967a76aab36731b6ebfe64dec4e981aff7e0608f60c2d46b46982607d/cffi-2.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:5fed36fccc0612a53f1d4d9a816b50a36702c28a2aa880cb8a122b3466638743", size = 217078, upload-time = "2025-09-08T23:22:39.776Z" }, + { url = "https://files.pythonhosted.org/packages/2b/c0/015b25184413d7ab0a410775fdb4a50fca20f5589b5dab1dbbfa3baad8ce/cffi-2.0.0-cp311-cp311-win32.whl", hash = "sha256:c649e3a33450ec82378822b3dad03cc228b8f5963c0c12fc3b1e0ab940f768a5", size = 172076, upload-time = "2025-09-08T23:22:40.95Z" }, + { url = "https://files.pythonhosted.org/packages/ae/8f/dc5531155e7070361eb1b7e4c1a9d896d0cb21c49f807a6c03fd63fc877e/cffi-2.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:66f011380d0e49ed280c789fbd08ff0d40968ee7b665575489afa95c98196ab5", size = 182820, upload-time = "2025-09-08T23:22:42.463Z" }, + { url = "https://files.pythonhosted.org/packages/95/5c/1b493356429f9aecfd56bc171285a4c4ac8697f76e9bbbbb105e537853a1/cffi-2.0.0-cp311-cp311-win_arm64.whl", hash = "sha256:c6638687455baf640e37344fe26d37c404db8b80d037c3d29f58fe8d1c3b194d", size = 177635, upload-time = "2025-09-08T23:22:43.623Z" }, + { url = "https://files.pythonhosted.org/packages/ea/47/4f61023ea636104d4f16ab488e268b93008c3d0bb76893b1b31db1f96802/cffi-2.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:6d02d6655b0e54f54c4ef0b94eb6be0607b70853c45ce98bd278dc7de718be5d", size = 185271, upload-time = "2025-09-08T23:22:44.795Z" }, + { url = "https://files.pythonhosted.org/packages/df/a2/781b623f57358e360d62cdd7a8c681f074a71d445418a776eef0aadb4ab4/cffi-2.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8eca2a813c1cb7ad4fb74d368c2ffbbb4789d377ee5bb8df98373c2cc0dee76c", size = 181048, upload-time = "2025-09-08T23:22:45.938Z" }, + { url = "https://files.pythonhosted.org/packages/ff/df/a4f0fbd47331ceeba3d37c2e51e9dfc9722498becbeec2bd8bc856c9538a/cffi-2.0.0-cp312-cp312-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:21d1152871b019407d8ac3985f6775c079416c282e431a4da6afe7aefd2bccbe", size = 212529, upload-time = "2025-09-08T23:22:47.349Z" }, + { url = "https://files.pythonhosted.org/packages/d5/72/12b5f8d3865bf0f87cf1404d8c374e7487dcf097a1c91c436e72e6badd83/cffi-2.0.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b21e08af67b8a103c71a250401c78d5e0893beff75e28c53c98f4de42f774062", size = 220097, upload-time = "2025-09-08T23:22:48.677Z" }, + { url = "https://files.pythonhosted.org/packages/c2/95/7a135d52a50dfa7c882ab0ac17e8dc11cec9d55d2c18dda414c051c5e69e/cffi-2.0.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:1e3a615586f05fc4065a8b22b8152f0c1b00cdbc60596d187c2a74f9e3036e4e", size = 207983, upload-time = "2025-09-08T23:22:50.06Z" }, + { url = "https://files.pythonhosted.org/packages/3a/c8/15cb9ada8895957ea171c62dc78ff3e99159ee7adb13c0123c001a2546c1/cffi-2.0.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:81afed14892743bbe14dacb9e36d9e0e504cd204e0b165062c488942b9718037", size = 206519, upload-time = "2025-09-08T23:22:51.364Z" }, + { url = "https://files.pythonhosted.org/packages/78/2d/7fa73dfa841b5ac06c7b8855cfc18622132e365f5b81d02230333ff26e9e/cffi-2.0.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3e17ed538242334bf70832644a32a7aae3d83b57567f9fd60a26257e992b79ba", size = 219572, upload-time = "2025-09-08T23:22:52.902Z" }, + { url = "https://files.pythonhosted.org/packages/07/e0/267e57e387b4ca276b90f0434ff88b2c2241ad72b16d31836adddfd6031b/cffi-2.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3925dd22fa2b7699ed2617149842d2e6adde22b262fcbfada50e3d195e4b3a94", size = 222963, upload-time = "2025-09-08T23:22:54.518Z" }, + { url = "https://files.pythonhosted.org/packages/b6/75/1f2747525e06f53efbd878f4d03bac5b859cbc11c633d0fb81432d98a795/cffi-2.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2c8f814d84194c9ea681642fd164267891702542f028a15fc97d4674b6206187", size = 221361, upload-time = "2025-09-08T23:22:55.867Z" }, + { url = "https://files.pythonhosted.org/packages/7b/2b/2b6435f76bfeb6bbf055596976da087377ede68df465419d192acf00c437/cffi-2.0.0-cp312-cp312-win32.whl", hash = "sha256:da902562c3e9c550df360bfa53c035b2f241fed6d9aef119048073680ace4a18", size = 172932, upload-time = "2025-09-08T23:22:57.188Z" }, + { url = "https://files.pythonhosted.org/packages/f8/ed/13bd4418627013bec4ed6e54283b1959cf6db888048c7cf4b4c3b5b36002/cffi-2.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:da68248800ad6320861f129cd9c1bf96ca849a2771a59e0344e88681905916f5", size = 183557, upload-time = "2025-09-08T23:22:58.351Z" }, + { url = "https://files.pythonhosted.org/packages/95/31/9f7f93ad2f8eff1dbc1c3656d7ca5bfd8fb52c9d786b4dcf19b2d02217fa/cffi-2.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:4671d9dd5ec934cb9a73e7ee9676f9362aba54f7f34910956b84d727b0d73fb6", size = 177762, upload-time = "2025-09-08T23:22:59.668Z" }, + { url = "https://files.pythonhosted.org/packages/4b/8d/a0a47a0c9e413a658623d014e91e74a50cdd2c423f7ccfd44086ef767f90/cffi-2.0.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:00bdf7acc5f795150faa6957054fbbca2439db2f775ce831222b66f192f03beb", size = 185230, upload-time = "2025-09-08T23:23:00.879Z" }, + { url = "https://files.pythonhosted.org/packages/4a/d2/a6c0296814556c68ee32009d9c2ad4f85f2707cdecfd7727951ec228005d/cffi-2.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:45d5e886156860dc35862657e1494b9bae8dfa63bf56796f2fb56e1679fc0bca", size = 181043, upload-time = "2025-09-08T23:23:02.231Z" }, + { url = "https://files.pythonhosted.org/packages/b0/1e/d22cc63332bd59b06481ceaac49d6c507598642e2230f201649058a7e704/cffi-2.0.0-cp313-cp313-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:07b271772c100085dd28b74fa0cd81c8fb1a3ba18b21e03d7c27f3436a10606b", size = 212446, upload-time = "2025-09-08T23:23:03.472Z" }, + { url = "https://files.pythonhosted.org/packages/a9/f5/a2c23eb03b61a0b8747f211eb716446c826ad66818ddc7810cc2cc19b3f2/cffi-2.0.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d48a880098c96020b02d5a1f7d9251308510ce8858940e6fa99ece33f610838b", size = 220101, upload-time = "2025-09-08T23:23:04.792Z" }, + { url = "https://files.pythonhosted.org/packages/f2/7f/e6647792fc5850d634695bc0e6ab4111ae88e89981d35ac269956605feba/cffi-2.0.0-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:f93fd8e5c8c0a4aa1f424d6173f14a892044054871c771f8566e4008eaa359d2", size = 207948, upload-time = "2025-09-08T23:23:06.127Z" }, + { url = "https://files.pythonhosted.org/packages/cb/1e/a5a1bd6f1fb30f22573f76533de12a00bf274abcdc55c8edab639078abb6/cffi-2.0.0-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:dd4f05f54a52fb558f1ba9f528228066954fee3ebe629fc1660d874d040ae5a3", size = 206422, upload-time = "2025-09-08T23:23:07.753Z" }, + { url = "https://files.pythonhosted.org/packages/98/df/0a1755e750013a2081e863e7cd37e0cdd02664372c754e5560099eb7aa44/cffi-2.0.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c8d3b5532fc71b7a77c09192b4a5a200ea992702734a2e9279a37f2478236f26", size = 219499, upload-time = "2025-09-08T23:23:09.648Z" }, + { url = "https://files.pythonhosted.org/packages/50/e1/a969e687fcf9ea58e6e2a928ad5e2dd88cc12f6f0ab477e9971f2309b57c/cffi-2.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:d9b29c1f0ae438d5ee9acb31cadee00a58c46cc9c0b2f9038c6b0b3470877a8c", size = 222928, upload-time = "2025-09-08T23:23:10.928Z" }, + { url = "https://files.pythonhosted.org/packages/36/54/0362578dd2c9e557a28ac77698ed67323ed5b9775ca9d3fe73fe191bb5d8/cffi-2.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6d50360be4546678fc1b79ffe7a66265e28667840010348dd69a314145807a1b", size = 221302, upload-time = "2025-09-08T23:23:12.42Z" }, + { url = "https://files.pythonhosted.org/packages/eb/6d/bf9bda840d5f1dfdbf0feca87fbdb64a918a69bca42cfa0ba7b137c48cb8/cffi-2.0.0-cp313-cp313-win32.whl", hash = "sha256:74a03b9698e198d47562765773b4a8309919089150a0bb17d829ad7b44b60d27", size = 172909, upload-time = "2025-09-08T23:23:14.32Z" }, + { url = "https://files.pythonhosted.org/packages/37/18/6519e1ee6f5a1e579e04b9ddb6f1676c17368a7aba48299c3759bbc3c8b3/cffi-2.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:19f705ada2530c1167abacb171925dd886168931e0a7b78f5bffcae5c6b5be75", size = 183402, upload-time = "2025-09-08T23:23:15.535Z" }, + { url = "https://files.pythonhosted.org/packages/cb/0e/02ceeec9a7d6ee63bb596121c2c8e9b3a9e150936f4fbef6ca1943e6137c/cffi-2.0.0-cp313-cp313-win_arm64.whl", hash = "sha256:256f80b80ca3853f90c21b23ee78cd008713787b1b1e93eae9f3d6a7134abd91", size = 177780, upload-time = "2025-09-08T23:23:16.761Z" }, + { url = "https://files.pythonhosted.org/packages/92/c4/3ce07396253a83250ee98564f8d7e9789fab8e58858f35d07a9a2c78de9f/cffi-2.0.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:fc33c5141b55ed366cfaad382df24fe7dcbc686de5be719b207bb248e3053dc5", size = 185320, upload-time = "2025-09-08T23:23:18.087Z" }, + { url = "https://files.pythonhosted.org/packages/59/dd/27e9fa567a23931c838c6b02d0764611c62290062a6d4e8ff7863daf9730/cffi-2.0.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:c654de545946e0db659b3400168c9ad31b5d29593291482c43e3564effbcee13", size = 181487, upload-time = "2025-09-08T23:23:19.622Z" }, + { url = "https://files.pythonhosted.org/packages/d6/43/0e822876f87ea8a4ef95442c3d766a06a51fc5298823f884ef87aaad168c/cffi-2.0.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:24b6f81f1983e6df8db3adc38562c83f7d4a0c36162885ec7f7b77c7dcbec97b", size = 220049, upload-time = "2025-09-08T23:23:20.853Z" }, + { url = "https://files.pythonhosted.org/packages/b4/89/76799151d9c2d2d1ead63c2429da9ea9d7aac304603de0c6e8764e6e8e70/cffi-2.0.0-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:12873ca6cb9b0f0d3a0da705d6086fe911591737a59f28b7936bdfed27c0d47c", size = 207793, upload-time = "2025-09-08T23:23:22.08Z" }, + { url = "https://files.pythonhosted.org/packages/bb/dd/3465b14bb9e24ee24cb88c9e3730f6de63111fffe513492bf8c808a3547e/cffi-2.0.0-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:d9b97165e8aed9272a6bb17c01e3cc5871a594a446ebedc996e2397a1c1ea8ef", size = 206300, upload-time = "2025-09-08T23:23:23.314Z" }, + { url = "https://files.pythonhosted.org/packages/47/d9/d83e293854571c877a92da46fdec39158f8d7e68da75bf73581225d28e90/cffi-2.0.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:afb8db5439b81cf9c9d0c80404b60c3cc9c3add93e114dcae767f1477cb53775", size = 219244, upload-time = "2025-09-08T23:23:24.541Z" }, + { url = "https://files.pythonhosted.org/packages/2b/0f/1f177e3683aead2bb00f7679a16451d302c436b5cbf2505f0ea8146ef59e/cffi-2.0.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:737fe7d37e1a1bffe70bd5754ea763a62a066dc5913ca57e957824b72a85e205", size = 222828, upload-time = "2025-09-08T23:23:26.143Z" }, + { url = "https://files.pythonhosted.org/packages/c6/0f/cafacebd4b040e3119dcb32fed8bdef8dfe94da653155f9d0b9dc660166e/cffi-2.0.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:38100abb9d1b1435bc4cc340bb4489635dc2f0da7456590877030c9b3d40b0c1", size = 220926, upload-time = "2025-09-08T23:23:27.873Z" }, + { url = "https://files.pythonhosted.org/packages/3e/aa/df335faa45b395396fcbc03de2dfcab242cd61a9900e914fe682a59170b1/cffi-2.0.0-cp314-cp314-win32.whl", hash = "sha256:087067fa8953339c723661eda6b54bc98c5625757ea62e95eb4898ad5e776e9f", size = 175328, upload-time = "2025-09-08T23:23:44.61Z" }, + { url = "https://files.pythonhosted.org/packages/bb/92/882c2d30831744296ce713f0feb4c1cd30f346ef747b530b5318715cc367/cffi-2.0.0-cp314-cp314-win_amd64.whl", hash = "sha256:203a48d1fb583fc7d78a4c6655692963b860a417c0528492a6bc21f1aaefab25", size = 185650, upload-time = "2025-09-08T23:23:45.848Z" }, + { url = "https://files.pythonhosted.org/packages/9f/2c/98ece204b9d35a7366b5b2c6539c350313ca13932143e79dc133ba757104/cffi-2.0.0-cp314-cp314-win_arm64.whl", hash = "sha256:dbd5c7a25a7cb98f5ca55d258b103a2054f859a46ae11aaf23134f9cc0d356ad", size = 180687, upload-time = "2025-09-08T23:23:47.105Z" }, + { url = "https://files.pythonhosted.org/packages/3e/61/c768e4d548bfa607abcda77423448df8c471f25dbe64fb2ef6d555eae006/cffi-2.0.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:9a67fc9e8eb39039280526379fb3a70023d77caec1852002b4da7e8b270c4dd9", size = 188773, upload-time = "2025-09-08T23:23:29.347Z" }, + { url = "https://files.pythonhosted.org/packages/2c/ea/5f76bce7cf6fcd0ab1a1058b5af899bfbef198bea4d5686da88471ea0336/cffi-2.0.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:7a66c7204d8869299919db4d5069a82f1561581af12b11b3c9f48c584eb8743d", size = 185013, upload-time = "2025-09-08T23:23:30.63Z" }, + { url = "https://files.pythonhosted.org/packages/be/b4/c56878d0d1755cf9caa54ba71e5d049479c52f9e4afc230f06822162ab2f/cffi-2.0.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7cc09976e8b56f8cebd752f7113ad07752461f48a58cbba644139015ac24954c", size = 221593, upload-time = "2025-09-08T23:23:31.91Z" }, + { url = "https://files.pythonhosted.org/packages/e0/0d/eb704606dfe8033e7128df5e90fee946bbcb64a04fcdaa97321309004000/cffi-2.0.0-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:92b68146a71df78564e4ef48af17551a5ddd142e5190cdf2c5624d0c3ff5b2e8", size = 209354, upload-time = "2025-09-08T23:23:33.214Z" }, + { url = "https://files.pythonhosted.org/packages/d8/19/3c435d727b368ca475fb8742ab97c9cb13a0de600ce86f62eab7fa3eea60/cffi-2.0.0-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:b1e74d11748e7e98e2f426ab176d4ed720a64412b6a15054378afdb71e0f37dc", size = 208480, upload-time = "2025-09-08T23:23:34.495Z" }, + { url = "https://files.pythonhosted.org/packages/d0/44/681604464ed9541673e486521497406fadcc15b5217c3e326b061696899a/cffi-2.0.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:28a3a209b96630bca57cce802da70c266eb08c6e97e5afd61a75611ee6c64592", size = 221584, upload-time = "2025-09-08T23:23:36.096Z" }, + { url = "https://files.pythonhosted.org/packages/25/8e/342a504ff018a2825d395d44d63a767dd8ebc927ebda557fecdaca3ac33a/cffi-2.0.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:7553fb2090d71822f02c629afe6042c299edf91ba1bf94951165613553984512", size = 224443, upload-time = "2025-09-08T23:23:37.328Z" }, + { url = "https://files.pythonhosted.org/packages/e1/5e/b666bacbbc60fbf415ba9988324a132c9a7a0448a9a8f125074671c0f2c3/cffi-2.0.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:6c6c373cfc5c83a975506110d17457138c8c63016b563cc9ed6e056a82f13ce4", size = 223437, upload-time = "2025-09-08T23:23:38.945Z" }, + { url = "https://files.pythonhosted.org/packages/a0/1d/ec1a60bd1a10daa292d3cd6bb0b359a81607154fb8165f3ec95fe003b85c/cffi-2.0.0-cp314-cp314t-win32.whl", hash = "sha256:1fc9ea04857caf665289b7a75923f2c6ed559b8298a1b8c49e59f7dd95c8481e", size = 180487, upload-time = "2025-09-08T23:23:40.423Z" }, + { url = "https://files.pythonhosted.org/packages/bf/41/4c1168c74fac325c0c8156f04b6749c8b6a8f405bbf91413ba088359f60d/cffi-2.0.0-cp314-cp314t-win_amd64.whl", hash = "sha256:d68b6cef7827e8641e8ef16f4494edda8b36104d79773a334beaa1e3521430f6", size = 191726, upload-time = "2025-09-08T23:23:41.742Z" }, + { url = "https://files.pythonhosted.org/packages/ae/3a/dbeec9d1ee0844c679f6bb5d6ad4e9f198b1224f4e7a32825f47f6192b0c/cffi-2.0.0-cp314-cp314t-win_arm64.whl", hash = "sha256:0a1527a803f0a659de1af2e1fd700213caba79377e27e4693648c2923da066f9", size = 184195, upload-time = "2025-09-08T23:23:43.004Z" }, +] + +[[package]] +name = "charset-normalizer" +version = "3.4.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/13/69/33ddede1939fdd074bce5434295f38fae7136463422fe4fd3e0e89b98062/charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a", size = 129418, upload-time = "2025-10-14T04:42:32.879Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ed/27/c6491ff4954e58a10f69ad90aca8a1b6fe9c5d3c6f380907af3c37435b59/charset_normalizer-3.4.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6e1fcf0720908f200cd21aa4e6750a48ff6ce4afe7ff5a79a90d5ed8a08296f8", size = 206988, upload-time = "2025-10-14T04:40:33.79Z" }, + { url = "https://files.pythonhosted.org/packages/94/59/2e87300fe67ab820b5428580a53cad894272dbb97f38a7a814a2a1ac1011/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5f819d5fe9234f9f82d75bdfa9aef3a3d72c4d24a6e57aeaebba32a704553aa0", size = 147324, upload-time = "2025-10-14T04:40:34.961Z" }, + { url = "https://files.pythonhosted.org/packages/07/fb/0cf61dc84b2b088391830f6274cb57c82e4da8bbc2efeac8c025edb88772/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:a59cb51917aa591b1c4e6a43c132f0cdc3c76dbad6155df4e28ee626cc77a0a3", size = 142742, upload-time = "2025-10-14T04:40:36.105Z" }, + { url = "https://files.pythonhosted.org/packages/62/8b/171935adf2312cd745d290ed93cf16cf0dfe320863ab7cbeeae1dcd6535f/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:8ef3c867360f88ac904fd3f5e1f902f13307af9052646963ee08ff4f131adafc", size = 160863, upload-time = "2025-10-14T04:40:37.188Z" }, + { url = "https://files.pythonhosted.org/packages/09/73/ad875b192bda14f2173bfc1bc9a55e009808484a4b256748d931b6948442/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d9e45d7faa48ee908174d8fe84854479ef838fc6a705c9315372eacbc2f02897", size = 157837, upload-time = "2025-10-14T04:40:38.435Z" }, + { url = "https://files.pythonhosted.org/packages/6d/fc/de9cce525b2c5b94b47c70a4b4fb19f871b24995c728e957ee68ab1671ea/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:840c25fb618a231545cbab0564a799f101b63b9901f2569faecd6b222ac72381", size = 151550, upload-time = "2025-10-14T04:40:40.053Z" }, + { url = "https://files.pythonhosted.org/packages/55/c2/43edd615fdfba8c6f2dfbd459b25a6b3b551f24ea21981e23fb768503ce1/charset_normalizer-3.4.4-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ca5862d5b3928c4940729dacc329aa9102900382fea192fc5e52eb69d6093815", size = 149162, upload-time = "2025-10-14T04:40:41.163Z" }, + { url = "https://files.pythonhosted.org/packages/03/86/bde4ad8b4d0e9429a4e82c1e8f5c659993a9a863ad62c7df05cf7b678d75/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d9c7f57c3d666a53421049053eaacdd14bbd0a528e2186fcb2e672effd053bb0", size = 150019, upload-time = "2025-10-14T04:40:42.276Z" }, + { url = "https://files.pythonhosted.org/packages/1f/86/a151eb2af293a7e7bac3a739b81072585ce36ccfb4493039f49f1d3cae8c/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:277e970e750505ed74c832b4bf75dac7476262ee2a013f5574dd49075879e161", size = 143310, upload-time = "2025-10-14T04:40:43.439Z" }, + { url = "https://files.pythonhosted.org/packages/b5/fe/43dae6144a7e07b87478fdfc4dbe9efd5defb0e7ec29f5f58a55aeef7bf7/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:31fd66405eaf47bb62e8cd575dc621c56c668f27d46a61d975a249930dd5e2a4", size = 162022, upload-time = "2025-10-14T04:40:44.547Z" }, + { url = "https://files.pythonhosted.org/packages/80/e6/7aab83774f5d2bca81f42ac58d04caf44f0cc2b65fc6db2b3b2e8a05f3b3/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:0d3d8f15c07f86e9ff82319b3d9ef6f4bf907608f53fe9d92b28ea9ae3d1fd89", size = 149383, upload-time = "2025-10-14T04:40:46.018Z" }, + { url = "https://files.pythonhosted.org/packages/4f/e8/b289173b4edae05c0dde07f69f8db476a0b511eac556dfe0d6bda3c43384/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:9f7fcd74d410a36883701fafa2482a6af2ff5ba96b9a620e9e0721e28ead5569", size = 159098, upload-time = "2025-10-14T04:40:47.081Z" }, + { url = "https://files.pythonhosted.org/packages/d8/df/fe699727754cae3f8478493c7f45f777b17c3ef0600e28abfec8619eb49c/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ebf3e58c7ec8a8bed6d66a75d7fb37b55e5015b03ceae72a8e7c74495551e224", size = 152991, upload-time = "2025-10-14T04:40:48.246Z" }, + { url = "https://files.pythonhosted.org/packages/1a/86/584869fe4ddb6ffa3bd9f491b87a01568797fb9bd8933f557dba9771beaf/charset_normalizer-3.4.4-cp311-cp311-win32.whl", hash = "sha256:eecbc200c7fd5ddb9a7f16c7decb07b566c29fa2161a16cf67b8d068bd21690a", size = 99456, upload-time = "2025-10-14T04:40:49.376Z" }, + { url = "https://files.pythonhosted.org/packages/65/f6/62fdd5feb60530f50f7e38b4f6a1d5203f4d16ff4f9f0952962c044e919a/charset_normalizer-3.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:5ae497466c7901d54b639cf42d5b8c1b6a4fead55215500d2f486d34db48d016", size = 106978, upload-time = "2025-10-14T04:40:50.844Z" }, + { url = "https://files.pythonhosted.org/packages/7a/9d/0710916e6c82948b3be62d9d398cb4fcf4e97b56d6a6aeccd66c4b2f2bd5/charset_normalizer-3.4.4-cp311-cp311-win_arm64.whl", hash = "sha256:65e2befcd84bc6f37095f5961e68a6f077bf44946771354a28ad434c2cce0ae1", size = 99969, upload-time = "2025-10-14T04:40:52.272Z" }, + { url = "https://files.pythonhosted.org/packages/f3/85/1637cd4af66fa687396e757dec650f28025f2a2f5a5531a3208dc0ec43f2/charset_normalizer-3.4.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:0a98e6759f854bd25a58a73fa88833fba3b7c491169f86ce1180c948ab3fd394", size = 208425, upload-time = "2025-10-14T04:40:53.353Z" }, + { url = "https://files.pythonhosted.org/packages/9d/6a/04130023fef2a0d9c62d0bae2649b69f7b7d8d24ea5536feef50551029df/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b5b290ccc2a263e8d185130284f8501e3e36c5e02750fc6b6bdeb2e9e96f1e25", size = 148162, upload-time = "2025-10-14T04:40:54.558Z" }, + { url = "https://files.pythonhosted.org/packages/78/29/62328d79aa60da22c9e0b9a66539feae06ca0f5a4171ac4f7dc285b83688/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:74bb723680f9f7a6234dcf67aea57e708ec1fbdf5699fb91dfd6f511b0a320ef", size = 144558, upload-time = "2025-10-14T04:40:55.677Z" }, + { url = "https://files.pythonhosted.org/packages/86/bb/b32194a4bf15b88403537c2e120b817c61cd4ecffa9b6876e941c3ee38fe/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f1e34719c6ed0b92f418c7c780480b26b5d9c50349e9a9af7d76bf757530350d", size = 161497, upload-time = "2025-10-14T04:40:57.217Z" }, + { url = "https://files.pythonhosted.org/packages/19/89/a54c82b253d5b9b111dc74aca196ba5ccfcca8242d0fb64146d4d3183ff1/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2437418e20515acec67d86e12bf70056a33abdacb5cb1655042f6538d6b085a8", size = 159240, upload-time = "2025-10-14T04:40:58.358Z" }, + { url = "https://files.pythonhosted.org/packages/c0/10/d20b513afe03acc89ec33948320a5544d31f21b05368436d580dec4e234d/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:11d694519d7f29d6cd09f6ac70028dba10f92f6cdd059096db198c283794ac86", size = 153471, upload-time = "2025-10-14T04:40:59.468Z" }, + { url = "https://files.pythonhosted.org/packages/61/fa/fbf177b55bdd727010f9c0a3c49eefa1d10f960e5f09d1d887bf93c2e698/charset_normalizer-3.4.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ac1c4a689edcc530fc9d9aa11f5774b9e2f33f9a0c6a57864e90908f5208d30a", size = 150864, upload-time = "2025-10-14T04:41:00.623Z" }, + { url = "https://files.pythonhosted.org/packages/05/12/9fbc6a4d39c0198adeebbde20b619790e9236557ca59fc40e0e3cebe6f40/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:21d142cc6c0ec30d2efee5068ca36c128a30b0f2c53c1c07bd78cb6bc1d3be5f", size = 150647, upload-time = "2025-10-14T04:41:01.754Z" }, + { url = "https://files.pythonhosted.org/packages/ad/1f/6a9a593d52e3e8c5d2b167daf8c6b968808efb57ef4c210acb907c365bc4/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:5dbe56a36425d26d6cfb40ce79c314a2e4dd6211d51d6d2191c00bed34f354cc", size = 145110, upload-time = "2025-10-14T04:41:03.231Z" }, + { url = "https://files.pythonhosted.org/packages/30/42/9a52c609e72471b0fc54386dc63c3781a387bb4fe61c20231a4ebcd58bdd/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:5bfbb1b9acf3334612667b61bd3002196fe2a1eb4dd74d247e0f2a4d50ec9bbf", size = 162839, upload-time = "2025-10-14T04:41:04.715Z" }, + { url = "https://files.pythonhosted.org/packages/c4/5b/c0682bbf9f11597073052628ddd38344a3d673fda35a36773f7d19344b23/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:d055ec1e26e441f6187acf818b73564e6e6282709e9bcb5b63f5b23068356a15", size = 150667, upload-time = "2025-10-14T04:41:05.827Z" }, + { url = "https://files.pythonhosted.org/packages/e4/24/a41afeab6f990cf2daf6cb8c67419b63b48cf518e4f56022230840c9bfb2/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:af2d8c67d8e573d6de5bc30cdb27e9b95e49115cd9baad5ddbd1a6207aaa82a9", size = 160535, upload-time = "2025-10-14T04:41:06.938Z" }, + { url = "https://files.pythonhosted.org/packages/2a/e5/6a4ce77ed243c4a50a1fecca6aaaab419628c818a49434be428fe24c9957/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:780236ac706e66881f3b7f2f32dfe90507a09e67d1d454c762cf642e6e1586e0", size = 154816, upload-time = "2025-10-14T04:41:08.101Z" }, + { url = "https://files.pythonhosted.org/packages/a8/ef/89297262b8092b312d29cdb2517cb1237e51db8ecef2e9af5edbe7b683b1/charset_normalizer-3.4.4-cp312-cp312-win32.whl", hash = "sha256:5833d2c39d8896e4e19b689ffc198f08ea58116bee26dea51e362ecc7cd3ed26", size = 99694, upload-time = "2025-10-14T04:41:09.23Z" }, + { url = "https://files.pythonhosted.org/packages/3d/2d/1e5ed9dd3b3803994c155cd9aacb60c82c331bad84daf75bcb9c91b3295e/charset_normalizer-3.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:a79cfe37875f822425b89a82333404539ae63dbdddf97f84dcbc3d339aae9525", size = 107131, upload-time = "2025-10-14T04:41:10.467Z" }, + { url = "https://files.pythonhosted.org/packages/d0/d9/0ed4c7098a861482a7b6a95603edce4c0d9db2311af23da1fb2b75ec26fc/charset_normalizer-3.4.4-cp312-cp312-win_arm64.whl", hash = "sha256:376bec83a63b8021bb5c8ea75e21c4ccb86e7e45ca4eb81146091b56599b80c3", size = 100390, upload-time = "2025-10-14T04:41:11.915Z" }, + { url = "https://files.pythonhosted.org/packages/97/45/4b3a1239bbacd321068ea6e7ac28875b03ab8bc0aa0966452db17cd36714/charset_normalizer-3.4.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:e1f185f86a6f3403aa2420e815904c67b2f9ebc443f045edd0de921108345794", size = 208091, upload-time = "2025-10-14T04:41:13.346Z" }, + { url = "https://files.pythonhosted.org/packages/7d/62/73a6d7450829655a35bb88a88fca7d736f9882a27eacdca2c6d505b57e2e/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6b39f987ae8ccdf0d2642338faf2abb1862340facc796048b604ef14919e55ed", size = 147936, upload-time = "2025-10-14T04:41:14.461Z" }, + { url = "https://files.pythonhosted.org/packages/89/c5/adb8c8b3d6625bef6d88b251bbb0d95f8205831b987631ab0c8bb5d937c2/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3162d5d8ce1bb98dd51af660f2121c55d0fa541b46dff7bb9b9f86ea1d87de72", size = 144180, upload-time = "2025-10-14T04:41:15.588Z" }, + { url = "https://files.pythonhosted.org/packages/91/ed/9706e4070682d1cc219050b6048bfd293ccf67b3d4f5a4f39207453d4b99/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:81d5eb2a312700f4ecaa977a8235b634ce853200e828fbadf3a9c50bab278328", size = 161346, upload-time = "2025-10-14T04:41:16.738Z" }, + { url = "https://files.pythonhosted.org/packages/d5/0d/031f0d95e4972901a2f6f09ef055751805ff541511dc1252ba3ca1f80cf5/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5bd2293095d766545ec1a8f612559f6b40abc0eb18bb2f5d1171872d34036ede", size = 158874, upload-time = "2025-10-14T04:41:17.923Z" }, + { url = "https://files.pythonhosted.org/packages/f5/83/6ab5883f57c9c801ce5e5677242328aa45592be8a00644310a008d04f922/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a8a8b89589086a25749f471e6a900d3f662d1d3b6e2e59dcecf787b1cc3a1894", size = 153076, upload-time = "2025-10-14T04:41:19.106Z" }, + { url = "https://files.pythonhosted.org/packages/75/1e/5ff781ddf5260e387d6419959ee89ef13878229732732ee73cdae01800f2/charset_normalizer-3.4.4-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:bc7637e2f80d8530ee4a78e878bce464f70087ce73cf7c1caf142416923b98f1", size = 150601, upload-time = "2025-10-14T04:41:20.245Z" }, + { url = "https://files.pythonhosted.org/packages/d7/57/71be810965493d3510a6ca79b90c19e48696fb1ff964da319334b12677f0/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f8bf04158c6b607d747e93949aa60618b61312fe647a6369f88ce2ff16043490", size = 150376, upload-time = "2025-10-14T04:41:21.398Z" }, + { url = "https://files.pythonhosted.org/packages/e5/d5/c3d057a78c181d007014feb7e9f2e65905a6c4ef182c0ddf0de2924edd65/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:554af85e960429cf30784dd47447d5125aaa3b99a6f0683589dbd27e2f45da44", size = 144825, upload-time = "2025-10-14T04:41:22.583Z" }, + { url = "https://files.pythonhosted.org/packages/e6/8c/d0406294828d4976f275ffbe66f00266c4b3136b7506941d87c00cab5272/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:74018750915ee7ad843a774364e13a3db91682f26142baddf775342c3f5b1133", size = 162583, upload-time = "2025-10-14T04:41:23.754Z" }, + { url = "https://files.pythonhosted.org/packages/d7/24/e2aa1f18c8f15c4c0e932d9287b8609dd30ad56dbe41d926bd846e22fb8d/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:c0463276121fdee9c49b98908b3a89c39be45d86d1dbaa22957e38f6321d4ce3", size = 150366, upload-time = "2025-10-14T04:41:25.27Z" }, + { url = "https://files.pythonhosted.org/packages/e4/5b/1e6160c7739aad1e2df054300cc618b06bf784a7a164b0f238360721ab86/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:362d61fd13843997c1c446760ef36f240cf81d3ebf74ac62652aebaf7838561e", size = 160300, upload-time = "2025-10-14T04:41:26.725Z" }, + { url = "https://files.pythonhosted.org/packages/7a/10/f882167cd207fbdd743e55534d5d9620e095089d176d55cb22d5322f2afd/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9a26f18905b8dd5d685d6d07b0cdf98a79f3c7a918906af7cc143ea2e164c8bc", size = 154465, upload-time = "2025-10-14T04:41:28.322Z" }, + { url = "https://files.pythonhosted.org/packages/89/66/c7a9e1b7429be72123441bfdbaf2bc13faab3f90b933f664db506dea5915/charset_normalizer-3.4.4-cp313-cp313-win32.whl", hash = "sha256:9b35f4c90079ff2e2edc5b26c0c77925e5d2d255c42c74fdb70fb49b172726ac", size = 99404, upload-time = "2025-10-14T04:41:29.95Z" }, + { url = "https://files.pythonhosted.org/packages/c4/26/b9924fa27db384bdcd97ab83b4f0a8058d96ad9626ead570674d5e737d90/charset_normalizer-3.4.4-cp313-cp313-win_amd64.whl", hash = "sha256:b435cba5f4f750aa6c0a0d92c541fb79f69a387c91e61f1795227e4ed9cece14", size = 107092, upload-time = "2025-10-14T04:41:31.188Z" }, + { url = "https://files.pythonhosted.org/packages/af/8f/3ed4bfa0c0c72a7ca17f0380cd9e4dd842b09f664e780c13cff1dcf2ef1b/charset_normalizer-3.4.4-cp313-cp313-win_arm64.whl", hash = "sha256:542d2cee80be6f80247095cc36c418f7bddd14f4a6de45af91dfad36d817bba2", size = 100408, upload-time = "2025-10-14T04:41:32.624Z" }, + { url = "https://files.pythonhosted.org/packages/2a/35/7051599bd493e62411d6ede36fd5af83a38f37c4767b92884df7301db25d/charset_normalizer-3.4.4-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:da3326d9e65ef63a817ecbcc0df6e94463713b754fe293eaa03da99befb9a5bd", size = 207746, upload-time = "2025-10-14T04:41:33.773Z" }, + { url = "https://files.pythonhosted.org/packages/10/9a/97c8d48ef10d6cd4fcead2415523221624bf58bcf68a802721a6bc807c8f/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8af65f14dc14a79b924524b1e7fffe304517b2bff5a58bf64f30b98bbc5079eb", size = 147889, upload-time = "2025-10-14T04:41:34.897Z" }, + { url = "https://files.pythonhosted.org/packages/10/bf/979224a919a1b606c82bd2c5fa49b5c6d5727aa47b4312bb27b1734f53cd/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:74664978bb272435107de04e36db5a9735e78232b85b77d45cfb38f758efd33e", size = 143641, upload-time = "2025-10-14T04:41:36.116Z" }, + { url = "https://files.pythonhosted.org/packages/ba/33/0ad65587441fc730dc7bd90e9716b30b4702dc7b617e6ba4997dc8651495/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:752944c7ffbfdd10c074dc58ec2d5a8a4cd9493b314d367c14d24c17684ddd14", size = 160779, upload-time = "2025-10-14T04:41:37.229Z" }, + { url = "https://files.pythonhosted.org/packages/67/ed/331d6b249259ee71ddea93f6f2f0a56cfebd46938bde6fcc6f7b9a3d0e09/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d1f13550535ad8cff21b8d757a3257963e951d96e20ec82ab44bc64aeb62a191", size = 159035, upload-time = "2025-10-14T04:41:38.368Z" }, + { url = "https://files.pythonhosted.org/packages/67/ff/f6b948ca32e4f2a4576aa129d8bed61f2e0543bf9f5f2b7fc3758ed005c9/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ecaae4149d99b1c9e7b88bb03e3221956f68fd6d50be2ef061b2381b61d20838", size = 152542, upload-time = "2025-10-14T04:41:39.862Z" }, + { url = "https://files.pythonhosted.org/packages/16/85/276033dcbcc369eb176594de22728541a925b2632f9716428c851b149e83/charset_normalizer-3.4.4-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:cb6254dc36b47a990e59e1068afacdcd02958bdcce30bb50cc1700a8b9d624a6", size = 149524, upload-time = "2025-10-14T04:41:41.319Z" }, + { url = "https://files.pythonhosted.org/packages/9e/f2/6a2a1f722b6aba37050e626530a46a68f74e63683947a8acff92569f979a/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:c8ae8a0f02f57a6e61203a31428fa1d677cbe50c93622b4149d5c0f319c1d19e", size = 150395, upload-time = "2025-10-14T04:41:42.539Z" }, + { url = "https://files.pythonhosted.org/packages/60/bb/2186cb2f2bbaea6338cad15ce23a67f9b0672929744381e28b0592676824/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:47cc91b2f4dd2833fddaedd2893006b0106129d4b94fdb6af1f4ce5a9965577c", size = 143680, upload-time = "2025-10-14T04:41:43.661Z" }, + { url = "https://files.pythonhosted.org/packages/7d/a5/bf6f13b772fbb2a90360eb620d52ed8f796f3c5caee8398c3b2eb7b1c60d/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:82004af6c302b5d3ab2cfc4cc5f29db16123b1a8417f2e25f9066f91d4411090", size = 162045, upload-time = "2025-10-14T04:41:44.821Z" }, + { url = "https://files.pythonhosted.org/packages/df/c5/d1be898bf0dc3ef9030c3825e5d3b83f2c528d207d246cbabe245966808d/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:2b7d8f6c26245217bd2ad053761201e9f9680f8ce52f0fcd8d0755aeae5b2152", size = 149687, upload-time = "2025-10-14T04:41:46.442Z" }, + { url = "https://files.pythonhosted.org/packages/a5/42/90c1f7b9341eef50c8a1cb3f098ac43b0508413f33affd762855f67a410e/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:799a7a5e4fb2d5898c60b640fd4981d6a25f1c11790935a44ce38c54e985f828", size = 160014, upload-time = "2025-10-14T04:41:47.631Z" }, + { url = "https://files.pythonhosted.org/packages/76/be/4d3ee471e8145d12795ab655ece37baed0929462a86e72372fd25859047c/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:99ae2cffebb06e6c22bdc25801d7b30f503cc87dbd283479e7b606f70aff57ec", size = 154044, upload-time = "2025-10-14T04:41:48.81Z" }, + { url = "https://files.pythonhosted.org/packages/b0/6f/8f7af07237c34a1defe7defc565a9bc1807762f672c0fde711a4b22bf9c0/charset_normalizer-3.4.4-cp314-cp314-win32.whl", hash = "sha256:f9d332f8c2a2fcbffe1378594431458ddbef721c1769d78e2cbc06280d8155f9", size = 99940, upload-time = "2025-10-14T04:41:49.946Z" }, + { url = "https://files.pythonhosted.org/packages/4b/51/8ade005e5ca5b0d80fb4aff72a3775b325bdc3d27408c8113811a7cbe640/charset_normalizer-3.4.4-cp314-cp314-win_amd64.whl", hash = "sha256:8a6562c3700cce886c5be75ade4a5db4214fda19fede41d9792d100288d8f94c", size = 107104, upload-time = "2025-10-14T04:41:51.051Z" }, + { url = "https://files.pythonhosted.org/packages/da/5f/6b8f83a55bb8278772c5ae54a577f3099025f9ade59d0136ac24a0df4bde/charset_normalizer-3.4.4-cp314-cp314-win_arm64.whl", hash = "sha256:de00632ca48df9daf77a2c65a484531649261ec9f25489917f09e455cb09ddb2", size = 100743, upload-time = "2025-10-14T04:41:52.122Z" }, + { url = "https://files.pythonhosted.org/packages/0a/4c/925909008ed5a988ccbb72dcc897407e5d6d3bd72410d69e051fc0c14647/charset_normalizer-3.4.4-py3-none-any.whl", hash = "sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f", size = 53402, upload-time = "2025-10-14T04:42:31.76Z" }, +] + +[[package]] +name = "click" +version = "8.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3d/fa/656b739db8587d7b5dfa22e22ed02566950fbfbcdc20311993483657a5c0/click-8.3.1.tar.gz", hash = "sha256:12ff4785d337a1bb490bb7e9c2b1ee5da3112e94a8622f26a6c77f5d2fc6842a", size = 295065, upload-time = "2025-11-15T20:45:42.706Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/98/78/01c019cdb5d6498122777c1a43056ebb3ebfeef2076d9d026bfe15583b2b/click-8.3.1-py3-none-any.whl", hash = "sha256:981153a64e25f12d547d3426c367a4857371575ee7ad18df2a6183ab0545b2a6", size = 108274, upload-time = "2025-11-15T20:45:41.139Z" }, +] + +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, +] + +[[package]] +name = "cryptography" +version = "46.0.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi", marker = "platform_python_implementation != 'PyPy'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/60/04/ee2a9e8542e4fa2773b81771ff8349ff19cdd56b7258a0cc442639052edb/cryptography-46.0.5.tar.gz", hash = "sha256:abace499247268e3757271b2f1e244b36b06f8515cf27c4d49468fc9eb16e93d", size = 750064, upload-time = "2026-02-10T19:18:38.255Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f7/81/b0bb27f2ba931a65409c6b8a8b358a7f03c0e46eceacddff55f7c84b1f3b/cryptography-46.0.5-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:351695ada9ea9618b3500b490ad54c739860883df6c1f555e088eaf25b1bbaad", size = 7176289, upload-time = "2026-02-10T19:17:08.274Z" }, + { url = "https://files.pythonhosted.org/packages/ff/9e/6b4397a3e3d15123de3b1806ef342522393d50736c13b20ec4c9ea6693a6/cryptography-46.0.5-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c18ff11e86df2e28854939acde2d003f7984f721eba450b56a200ad90eeb0e6b", size = 4275637, upload-time = "2026-02-10T19:17:10.53Z" }, + { url = "https://files.pythonhosted.org/packages/63/e7/471ab61099a3920b0c77852ea3f0ea611c9702f651600397ac567848b897/cryptography-46.0.5-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d7e3d356b8cd4ea5aff04f129d5f66ebdc7b6f8eae802b93739ed520c47c79b", size = 4424742, upload-time = "2026-02-10T19:17:12.388Z" }, + { url = "https://files.pythonhosted.org/packages/37/53/a18500f270342d66bf7e4d9f091114e31e5ee9e7375a5aba2e85a91e0044/cryptography-46.0.5-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:50bfb6925eff619c9c023b967d5b77a54e04256c4281b0e21336a130cd7fc263", size = 4277528, upload-time = "2026-02-10T19:17:13.853Z" }, + { url = "https://files.pythonhosted.org/packages/22/29/c2e812ebc38c57b40e7c583895e73c8c5adb4d1e4a0cc4c5a4fdab2b1acc/cryptography-46.0.5-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:803812e111e75d1aa73690d2facc295eaefd4439be1023fefc4995eaea2af90d", size = 4947993, upload-time = "2026-02-10T19:17:15.618Z" }, + { url = "https://files.pythonhosted.org/packages/6b/e7/237155ae19a9023de7e30ec64e5d99a9431a567407ac21170a046d22a5a3/cryptography-46.0.5-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3ee190460e2fbe447175cda91b88b84ae8322a104fc27766ad09428754a618ed", size = 4456855, upload-time = "2026-02-10T19:17:17.221Z" }, + { url = "https://files.pythonhosted.org/packages/2d/87/fc628a7ad85b81206738abbd213b07702bcbdada1dd43f72236ef3cffbb5/cryptography-46.0.5-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:f145bba11b878005c496e93e257c1e88f154d278d2638e6450d17e0f31e558d2", size = 3984635, upload-time = "2026-02-10T19:17:18.792Z" }, + { url = "https://files.pythonhosted.org/packages/84/29/65b55622bde135aedf4565dc509d99b560ee4095e56989e815f8fd2aa910/cryptography-46.0.5-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:e9251e3be159d1020c4030bd2e5f84d6a43fe54b6c19c12f51cde9542a2817b2", size = 4277038, upload-time = "2026-02-10T19:17:20.256Z" }, + { url = "https://files.pythonhosted.org/packages/bc/36/45e76c68d7311432741faf1fbf7fac8a196a0a735ca21f504c75d37e2558/cryptography-46.0.5-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:47fb8a66058b80e509c47118ef8a75d14c455e81ac369050f20ba0d23e77fee0", size = 4912181, upload-time = "2026-02-10T19:17:21.825Z" }, + { url = "https://files.pythonhosted.org/packages/6d/1a/c1ba8fead184d6e3d5afcf03d569acac5ad063f3ac9fb7258af158f7e378/cryptography-46.0.5-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:4c3341037c136030cb46e4b1e17b7418ea4cbd9dd207e4a6f3b2b24e0d4ac731", size = 4456482, upload-time = "2026-02-10T19:17:25.133Z" }, + { url = "https://files.pythonhosted.org/packages/f9/e5/3fb22e37f66827ced3b902cf895e6a6bc1d095b5b26be26bd13c441fdf19/cryptography-46.0.5-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:890bcb4abd5a2d3f852196437129eb3667d62630333aacc13dfd470fad3aaa82", size = 4405497, upload-time = "2026-02-10T19:17:26.66Z" }, + { url = "https://files.pythonhosted.org/packages/1a/df/9d58bb32b1121a8a2f27383fabae4d63080c7ca60b9b5c88be742be04ee7/cryptography-46.0.5-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:80a8d7bfdf38f87ca30a5391c0c9ce4ed2926918e017c29ddf643d0ed2778ea1", size = 4667819, upload-time = "2026-02-10T19:17:28.569Z" }, + { url = "https://files.pythonhosted.org/packages/ea/ed/325d2a490c5e94038cdb0117da9397ece1f11201f425c4e9c57fe5b9f08b/cryptography-46.0.5-cp311-abi3-win32.whl", hash = "sha256:60ee7e19e95104d4c03871d7d7dfb3d22ef8a9b9c6778c94e1c8fcc8365afd48", size = 3028230, upload-time = "2026-02-10T19:17:30.518Z" }, + { url = "https://files.pythonhosted.org/packages/e9/5a/ac0f49e48063ab4255d9e3b79f5def51697fce1a95ea1370f03dc9db76f6/cryptography-46.0.5-cp311-abi3-win_amd64.whl", hash = "sha256:38946c54b16c885c72c4f59846be9743d699eee2b69b6988e0a00a01f46a61a4", size = 3480909, upload-time = "2026-02-10T19:17:32.083Z" }, + { url = "https://files.pythonhosted.org/packages/00/13/3d278bfa7a15a96b9dc22db5a12ad1e48a9eb3d40e1827ef66a5df75d0d0/cryptography-46.0.5-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:94a76daa32eb78d61339aff7952ea819b1734b46f73646a07decb40e5b3448e2", size = 7119287, upload-time = "2026-02-10T19:17:33.801Z" }, + { url = "https://files.pythonhosted.org/packages/67/c8/581a6702e14f0898a0848105cbefd20c058099e2c2d22ef4e476dfec75d7/cryptography-46.0.5-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5be7bf2fb40769e05739dd0046e7b26f9d4670badc7b032d6ce4db64dddc0678", size = 4265728, upload-time = "2026-02-10T19:17:35.569Z" }, + { url = "https://files.pythonhosted.org/packages/dd/4a/ba1a65ce8fc65435e5a849558379896c957870dd64fecea97b1ad5f46a37/cryptography-46.0.5-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fe346b143ff9685e40192a4960938545c699054ba11d4f9029f94751e3f71d87", size = 4408287, upload-time = "2026-02-10T19:17:36.938Z" }, + { url = "https://files.pythonhosted.org/packages/f8/67/8ffdbf7b65ed1ac224d1c2df3943553766914a8ca718747ee3871da6107e/cryptography-46.0.5-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:c69fd885df7d089548a42d5ec05be26050ebcd2283d89b3d30676eb32ff87dee", size = 4270291, upload-time = "2026-02-10T19:17:38.748Z" }, + { url = "https://files.pythonhosted.org/packages/f8/e5/f52377ee93bc2f2bba55a41a886fd208c15276ffbd2569f2ddc89d50e2c5/cryptography-46.0.5-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:8293f3dea7fc929ef7240796ba231413afa7b68ce38fd21da2995549f5961981", size = 4927539, upload-time = "2026-02-10T19:17:40.241Z" }, + { url = "https://files.pythonhosted.org/packages/3b/02/cfe39181b02419bbbbcf3abdd16c1c5c8541f03ca8bda240debc467d5a12/cryptography-46.0.5-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:1abfdb89b41c3be0365328a410baa9df3ff8a9110fb75e7b52e66803ddabc9a9", size = 4442199, upload-time = "2026-02-10T19:17:41.789Z" }, + { url = "https://files.pythonhosted.org/packages/c0/96/2fcaeb4873e536cf71421a388a6c11b5bc846e986b2b069c79363dc1648e/cryptography-46.0.5-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:d66e421495fdb797610a08f43b05269e0a5ea7f5e652a89bfd5a7d3c1dee3648", size = 3960131, upload-time = "2026-02-10T19:17:43.379Z" }, + { url = "https://files.pythonhosted.org/packages/d8/d2/b27631f401ddd644e94c5cf33c9a4069f72011821cf3dc7309546b0642a0/cryptography-46.0.5-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:4e817a8920bfbcff8940ecfd60f23d01836408242b30f1a708d93198393a80b4", size = 4270072, upload-time = "2026-02-10T19:17:45.481Z" }, + { url = "https://files.pythonhosted.org/packages/f4/a7/60d32b0370dae0b4ebe55ffa10e8599a2a59935b5ece1b9f06edb73abdeb/cryptography-46.0.5-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:68f68d13f2e1cb95163fa3b4db4bf9a159a418f5f6e7242564fc75fcae667fd0", size = 4892170, upload-time = "2026-02-10T19:17:46.997Z" }, + { url = "https://files.pythonhosted.org/packages/d2/b9/cf73ddf8ef1164330eb0b199a589103c363afa0cf794218c24d524a58eab/cryptography-46.0.5-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:a3d1fae9863299076f05cb8a778c467578262fae09f9dc0ee9b12eb4268ce663", size = 4441741, upload-time = "2026-02-10T19:17:48.661Z" }, + { url = "https://files.pythonhosted.org/packages/5f/eb/eee00b28c84c726fe8fa0158c65afe312d9c3b78d9d01daf700f1f6e37ff/cryptography-46.0.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:c4143987a42a2397f2fc3b4d7e3a7d313fbe684f67ff443999e803dd75a76826", size = 4396728, upload-time = "2026-02-10T19:17:50.058Z" }, + { url = "https://files.pythonhosted.org/packages/65/f4/6bc1a9ed5aef7145045114b75b77c2a8261b4d38717bd8dea111a63c3442/cryptography-46.0.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:7d731d4b107030987fd61a7f8ab512b25b53cef8f233a97379ede116f30eb67d", size = 4652001, upload-time = "2026-02-10T19:17:51.54Z" }, + { url = "https://files.pythonhosted.org/packages/86/ef/5d00ef966ddd71ac2e6951d278884a84a40ffbd88948ef0e294b214ae9e4/cryptography-46.0.5-cp314-cp314t-win32.whl", hash = "sha256:c3bcce8521d785d510b2aad26ae2c966092b7daa8f45dd8f44734a104dc0bc1a", size = 3003637, upload-time = "2026-02-10T19:17:52.997Z" }, + { url = "https://files.pythonhosted.org/packages/b7/57/f3f4160123da6d098db78350fdfd9705057aad21de7388eacb2401dceab9/cryptography-46.0.5-cp314-cp314t-win_amd64.whl", hash = "sha256:4d8ae8659ab18c65ced284993c2265910f6c9e650189d4e3f68445ef82a810e4", size = 3469487, upload-time = "2026-02-10T19:17:54.549Z" }, + { url = "https://files.pythonhosted.org/packages/e2/fa/a66aa722105ad6a458bebd64086ca2b72cdd361fed31763d20390f6f1389/cryptography-46.0.5-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:4108d4c09fbbf2789d0c926eb4152ae1760d5a2d97612b92d508d96c861e4d31", size = 7170514, upload-time = "2026-02-10T19:17:56.267Z" }, + { url = "https://files.pythonhosted.org/packages/0f/04/c85bdeab78c8bc77b701bf0d9bdcf514c044e18a46dcff330df5448631b0/cryptography-46.0.5-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7d1f30a86d2757199cb2d56e48cce14deddf1f9c95f1ef1b64ee91ea43fe2e18", size = 4275349, upload-time = "2026-02-10T19:17:58.419Z" }, + { url = "https://files.pythonhosted.org/packages/5c/32/9b87132a2f91ee7f5223b091dc963055503e9b442c98fc0b8a5ca765fab0/cryptography-46.0.5-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:039917b0dc418bb9f6edce8a906572d69e74bd330b0b3fea4f79dab7f8ddd235", size = 4420667, upload-time = "2026-02-10T19:18:00.619Z" }, + { url = "https://files.pythonhosted.org/packages/a1/a6/a7cb7010bec4b7c5692ca6f024150371b295ee1c108bdc1c400e4c44562b/cryptography-46.0.5-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:ba2a27ff02f48193fc4daeadf8ad2590516fa3d0adeeb34336b96f7fa64c1e3a", size = 4276980, upload-time = "2026-02-10T19:18:02.379Z" }, + { url = "https://files.pythonhosted.org/packages/8e/7c/c4f45e0eeff9b91e3f12dbd0e165fcf2a38847288fcfd889deea99fb7b6d/cryptography-46.0.5-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:61aa400dce22cb001a98014f647dc21cda08f7915ceb95df0c9eaf84b4b6af76", size = 4939143, upload-time = "2026-02-10T19:18:03.964Z" }, + { url = "https://files.pythonhosted.org/packages/37/19/e1b8f964a834eddb44fa1b9a9976f4e414cbb7aa62809b6760c8803d22d1/cryptography-46.0.5-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3ce58ba46e1bc2aac4f7d9290223cead56743fa6ab94a5d53292ffaac6a91614", size = 4453674, upload-time = "2026-02-10T19:18:05.588Z" }, + { url = "https://files.pythonhosted.org/packages/db/ed/db15d3956f65264ca204625597c410d420e26530c4e2943e05a0d2f24d51/cryptography-46.0.5-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:420d0e909050490d04359e7fdb5ed7e667ca5c3c402b809ae2563d7e66a92229", size = 3978801, upload-time = "2026-02-10T19:18:07.167Z" }, + { url = "https://files.pythonhosted.org/packages/41/e2/df40a31d82df0a70a0daf69791f91dbb70e47644c58581d654879b382d11/cryptography-46.0.5-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:582f5fcd2afa31622f317f80426a027f30dc792e9c80ffee87b993200ea115f1", size = 4276755, upload-time = "2026-02-10T19:18:09.813Z" }, + { url = "https://files.pythonhosted.org/packages/33/45/726809d1176959f4a896b86907b98ff4391a8aa29c0aaaf9450a8a10630e/cryptography-46.0.5-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:bfd56bb4b37ed4f330b82402f6f435845a5f5648edf1ad497da51a8452d5d62d", size = 4901539, upload-time = "2026-02-10T19:18:11.263Z" }, + { url = "https://files.pythonhosted.org/packages/99/0f/a3076874e9c88ecb2ecc31382f6e7c21b428ede6f55aafa1aa272613e3cd/cryptography-46.0.5-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:a3d507bb6a513ca96ba84443226af944b0f7f47dcc9a399d110cd6146481d24c", size = 4452794, upload-time = "2026-02-10T19:18:12.914Z" }, + { url = "https://files.pythonhosted.org/packages/02/ef/ffeb542d3683d24194a38f66ca17c0a4b8bf10631feef44a7ef64e631b1a/cryptography-46.0.5-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9f16fbdf4da055efb21c22d81b89f155f02ba420558db21288b3d0035bafd5f4", size = 4404160, upload-time = "2026-02-10T19:18:14.375Z" }, + { url = "https://files.pythonhosted.org/packages/96/93/682d2b43c1d5f1406ed048f377c0fc9fc8f7b0447a478d5c65ab3d3a66eb/cryptography-46.0.5-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:ced80795227d70549a411a4ab66e8ce307899fad2220ce5ab2f296e687eacde9", size = 4667123, upload-time = "2026-02-10T19:18:15.886Z" }, + { url = "https://files.pythonhosted.org/packages/45/2d/9c5f2926cb5300a8eefc3f4f0b3f3df39db7f7ce40c8365444c49363cbda/cryptography-46.0.5-cp38-abi3-win32.whl", hash = "sha256:02f547fce831f5096c9a567fd41bc12ca8f11df260959ecc7c3202555cc47a72", size = 3010220, upload-time = "2026-02-10T19:18:17.361Z" }, + { url = "https://files.pythonhosted.org/packages/48/ef/0c2f4a8e31018a986949d34a01115dd057bf536905dca38897bacd21fac3/cryptography-46.0.5-cp38-abi3-win_amd64.whl", hash = "sha256:556e106ee01aa13484ce9b0239bca667be5004efb0aabbed28d353df86445595", size = 3467050, upload-time = "2026-02-10T19:18:18.899Z" }, + { url = "https://files.pythonhosted.org/packages/eb/dd/2d9fdb07cebdf3d51179730afb7d5e576153c6744c3ff8fded23030c204e/cryptography-46.0.5-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:3b4995dc971c9fb83c25aa44cf45f02ba86f71ee600d81091c2f0cbae116b06c", size = 3476964, upload-time = "2026-02-10T19:18:20.687Z" }, + { url = "https://files.pythonhosted.org/packages/e9/6f/6cc6cc9955caa6eaf83660b0da2b077c7fe8ff9950a3c5e45d605038d439/cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:bc84e875994c3b445871ea7181d424588171efec3e185dced958dad9e001950a", size = 4218321, upload-time = "2026-02-10T19:18:22.349Z" }, + { url = "https://files.pythonhosted.org/packages/3e/5d/c4da701939eeee699566a6c1367427ab91a8b7088cc2328c09dbee940415/cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:2ae6971afd6246710480e3f15824ed3029a60fc16991db250034efd0b9fb4356", size = 4381786, upload-time = "2026-02-10T19:18:24.529Z" }, + { url = "https://files.pythonhosted.org/packages/ac/97/a538654732974a94ff96c1db621fa464f455c02d4bb7d2652f4edc21d600/cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:d861ee9e76ace6cf36a6a89b959ec08e7bc2493ee39d07ffe5acb23ef46d27da", size = 4217990, upload-time = "2026-02-10T19:18:25.957Z" }, + { url = "https://files.pythonhosted.org/packages/ae/11/7e500d2dd3ba891197b9efd2da5454b74336d64a7cc419aa7327ab74e5f6/cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:2b7a67c9cd56372f3249b39699f2ad479f6991e62ea15800973b956f4b73e257", size = 4381252, upload-time = "2026-02-10T19:18:27.496Z" }, + { url = "https://files.pythonhosted.org/packages/bc/58/6b3d24e6b9bc474a2dcdee65dfd1f008867015408a271562e4b690561a4d/cryptography-46.0.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:8456928655f856c6e1533ff59d5be76578a7157224dbd9ce6872f25055ab9ab7", size = 3407605, upload-time = "2026-02-10T19:18:29.233Z" }, +] + +[[package]] +name = "datafusion-benchmarks" +version = "0.1.0" +source = { virtual = "benchmarks" } +dependencies = [ + { name = "falsa" }, + { name = "rich" }, + { name = "typing-extensions" }, +] + +[package.metadata] +requires-dist = [ + { name = "falsa" }, + { name = "rich" }, + { name = "typing-extensions" }, +] + +[[package]] +name = "datafusion-dev" +version = "0.1.0" +source = { virtual = "dev" } +dependencies = [ + { name = "pygithub" }, + { name = "requests" }, + { name = "tomlkit" }, +] + +[package.metadata] +requires-dist = [ + { name = "pygithub" }, + { name = "requests" }, + { name = "tomlkit" }, +] + +[[package]] +name = "datafusion-docs" +version = "0.1.0" +source = { virtual = "docs" } +dependencies = [ + { name = "jinja2" }, + { name = "maturin" }, + { name = "myst-parser" }, + { name = "pydata-sphinx-theme" }, + { name = "setuptools" }, + { name = "sphinx", version = "9.0.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.12'" }, + { name = "sphinx", version = "9.1.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, + { name = "sphinx-reredirects" }, +] + +[package.metadata] +requires-dist = [ + { name = "jinja2", specifier = ">=3.1,<4" }, + { name = "maturin", specifier = ">=1.11,<2" }, + { name = "myst-parser", specifier = ">=5,<6" }, + { name = "pydata-sphinx-theme", specifier = ">=0.16,<1" }, + { name = "setuptools", specifier = ">=82,<83" }, + { name = "sphinx", specifier = ">=9,<10" }, + { name = "sphinx-reredirects", specifier = ">=1.1,<2" }, +] + +[[package]] +name = "docutils" +version = "0.22.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ae/b6/03bb70946330e88ffec97aefd3ea75ba575cb2e762061e0e62a213befee8/docutils-0.22.4.tar.gz", hash = "sha256:4db53b1fde9abecbb74d91230d32ab626d94f6badfc575d6db9194a49df29968", size = 2291750, upload-time = "2025-12-18T19:00:26.443Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/02/10/5da547df7a391dcde17f59520a231527b8571e6f46fc8efb02ccb370ab12/docutils-0.22.4-py3-none-any.whl", hash = "sha256:d0013f540772d1420576855455d050a2180186c91c15779301ac2ccb3eeb68de", size = 633196, upload-time = "2025-12-18T19:00:18.077Z" }, +] + +[[package]] +name = "falsa" +version = "0.0.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "pyarrow" }, + { name = "typer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/36/65/0f51f3509cfe4f8cc5b9b1a7ba614a5c0ca0b7ada7a2f8de4275ddc5d979/falsa-0.0.6.tar.gz", hash = "sha256:1b037941886755a73a77f3c80ecb661ee4732085bd68947c0ec788f77b487b32", size = 524238, upload-time = "2025-09-20T07:35:15.162Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/46/61/9fb4f242b37ecf4b706703cdc1c8ca0e8333edab42172340d27680c19c86/falsa-0.0.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:048d6b23fe7d2457761a406c667110904634685bac4816732455ee0c4f38ad0b", size = 437619, upload-time = "2025-09-20T07:33:31.806Z" }, + { url = "https://files.pythonhosted.org/packages/4d/cd/efb9c57f94d339a9dc7cf3ae555fa7dabcdf9c4c5d18bd1cf464b93e5457/falsa-0.0.6-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:85d96e0a0c481f50023ff5aa18b4dd663cdad7b778d2f98ca7d21e3fa132eef3", size = 435477, upload-time = "2025-09-20T07:33:43.118Z" }, + { url = "https://files.pythonhosted.org/packages/17/85/814e049f046f25611be25352959be8a9a711ef384b46cba7c0797fe03882/falsa-0.0.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e44ecdff3361e4ecbfc67b84dc0ed04e3f73d37b20ebfb435c8d1ebca7b85bb9", size = 652226, upload-time = "2025-09-20T07:33:54.515Z" }, + { url = "https://files.pythonhosted.org/packages/ee/a3/0a064fedccc3462ea413c87d15b35da854878b300d432bd79a3404b4de36/falsa-0.0.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5dc08fbb6833ead8bf63106837615236e259dd05fc4d1dd4b1b91b949ba632e2", size = 476290, upload-time = "2025-09-20T07:34:05.171Z" }, + { url = "https://files.pythonhosted.org/packages/46/38/d7f9182a505439d893c9741acf12a9daa04ea2ae9c9afff01a65fc5619ef/falsa-0.0.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b57b6ef70842776c5698498d04c1c38602b255083ee6822fe6d8a67aa32b3260", size = 598436, upload-time = "2025-09-20T07:34:26.207Z" }, + { url = "https://files.pythonhosted.org/packages/61/03/6199cc9011e8e708bef3e0420009b4e93be517f642184ee1f564b33b16d5/falsa-0.0.6-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9305aabafdf1be131b157d97ba7e105da115eef0e02af73f4716bcae64a18041", size = 461327, upload-time = "2025-09-20T07:34:16.337Z" }, + { url = "https://files.pythonhosted.org/packages/85/58/8d72300acf63c671f4ed8fcf6d74312581e6ad72d530676ec4a8c30e2b06/falsa-0.0.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:a0ffaf1c24296b16320b11116420d221b4678f1c4942ecf88599b33b094e78c7", size = 616922, upload-time = "2025-09-20T07:34:34.73Z" }, + { url = "https://files.pythonhosted.org/packages/31/09/da0a47ef5f56d3b9466f24b0451d6f326c6637da383b3b95b07ccd7be7c3/falsa-0.0.6-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:828f151c6737ed4d9051edbf695738e4d758815c316b58fa18166e0ab3d1fea7", size = 699657, upload-time = "2025-09-20T07:34:45.774Z" }, + { url = "https://files.pythonhosted.org/packages/4b/98/bc733bc0d88fb975577b530dca848cfcfbae20010af1884822d18fed634e/falsa-0.0.6-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:938f5170282f699638e0c7a941cc80235bd5ca8a8c5a19b65615aa0dc6fbf3f8", size = 632823, upload-time = "2025-09-20T07:34:56.436Z" }, + { url = "https://files.pythonhosted.org/packages/42/8e/eb5a164f44dddf674c6c248da8d4f241dc8d2bf1fcff4db74bc00f9c0036/falsa-0.0.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:56e500c635ad608fe3cf7d2634bd6e3d736aa432dfe00498af14e470eb354254", size = 605256, upload-time = "2025-09-20T07:35:06.564Z" }, + { url = "https://files.pythonhosted.org/packages/fd/20/3d74be0cc90d3d6d4edea625c5e57efa404a388428506c54f11cbd8413f0/falsa-0.0.6-cp311-cp311-win32.whl", hash = "sha256:fe0ff809e7246d1b06e03662c3a84f2e10d252590f62e06d0f937d498cda24d8", size = 253058, upload-time = "2025-09-20T07:35:21.813Z" }, + { url = "https://files.pythonhosted.org/packages/a3/f4/95c01bd3fda06fbe711e69252ba99a99484a701ca426481556cb362a7121/falsa-0.0.6-cp311-cp311-win_amd64.whl", hash = "sha256:050bc5eb7cbd1c0c6551851af0d3ef6a6db1794123c49718bdf2472103facf65", size = 276389, upload-time = "2025-09-20T07:35:17.047Z" }, + { url = "https://files.pythonhosted.org/packages/1c/f7/bce7df04f3ea86c88e6b2b82bd4cfce3d50b0057b68ae98fb1703730ad3e/falsa-0.0.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f2a17bf26161fd5fdde8db3bcb0f290bbcad679ae231842d53bfebd506130faf", size = 436615, upload-time = "2025-09-20T07:33:32.811Z" }, + { url = "https://files.pythonhosted.org/packages/a4/34/e42d33525910f37b165ba765a8548eca8079ee94ec4ca4001a3f13e7eab1/falsa-0.0.6-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c685c34779b33e8db9d13517931d3ea6df785756fea26b7ac11a49059c1375ca", size = 435130, upload-time = "2025-09-20T07:33:44.498Z" }, + { url = "https://files.pythonhosted.org/packages/53/dc/212f5b3b7e7a99a3867af1d49745e393d79610aa4c2218c72b6a4c9e9312/falsa-0.0.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6682631faa42ad303730872db6dce7b809da94842546fbd15431ebabba2b99bc", size = 651373, upload-time = "2025-09-20T07:33:55.721Z" }, + { url = "https://files.pythonhosted.org/packages/d7/e5/076c350bd7f6887463f28d7c49d97abb738daaeab356da5c5793720d32ba/falsa-0.0.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cf5d69cce8670b8d8617daa0a874e5bcb0a3409d368bfb044354b0db9404ff72", size = 475126, upload-time = "2025-09-20T07:34:06.562Z" }, + { url = "https://files.pythonhosted.org/packages/4b/3c/44d9e23b01da33b094bd4ee4cdae4f667a1cf0e123413981d16509660609/falsa-0.0.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:818ed089f8088ff9d170f366ad3df07c1458581d864ec3153b48be5bf06fc6f3", size = 597193, upload-time = "2025-09-20T07:34:27.531Z" }, + { url = "https://files.pythonhosted.org/packages/11/aa/70afcfbb1d76ccf275d7fb1cb6ee99720039a11b9d66ed23219f6cd4209a/falsa-0.0.6-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0e48df7acf762af490fcc3bfe9baeaeec82d151669e111c7630b37d38707bf73", size = 460932, upload-time = "2025-09-20T07:34:17.351Z" }, + { url = "https://files.pythonhosted.org/packages/8c/54/bd69faa0989fbbdf61793dedff7d953cd3832580ef35398f9f5a43443b29/falsa-0.0.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:eee10e87d74efe7a089db0a58c8cb6e02082b80618c8be70c75816e818d0194a", size = 616017, upload-time = "2025-09-20T07:34:36.222Z" }, + { url = "https://files.pythonhosted.org/packages/26/29/06a92316c7799337a40c7e3d8737827ea3590b1bdc66fb8341c720d96e8e/falsa-0.0.6-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:a97cc63f77f635e9ec738584565edf933d31078e94825788c236864488e7b062", size = 698946, upload-time = "2025-09-20T07:34:47.185Z" }, + { url = "https://files.pythonhosted.org/packages/df/14/5081e53d8e2927f86af70007e7d424a8bc3992527f87db78d8f21541e89c/falsa-0.0.6-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:4a3494b7c352e506c64c708b64e85afcb593419d541dbadf38405dc0fbc02f61", size = 632186, upload-time = "2025-09-20T07:34:57.499Z" }, + { url = "https://files.pythonhosted.org/packages/b9/fe/8d691ed9f2159726828cbe0765c579c032d35eb647ccfeb6ab10ffaa2f48/falsa-0.0.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:852d57713f169043d9ecbdb2ae6b8a93e87de68aa790e800f487fa61dfed1729", size = 603671, upload-time = "2025-09-20T07:35:07.65Z" }, + { url = "https://files.pythonhosted.org/packages/e7/70/425e1ad3b447a86c4f246433020d6c5ff359f278120e57e08e4b0b91cd16/falsa-0.0.6-cp312-cp312-win_amd64.whl", hash = "sha256:ea831bfdcbca03c2ca220dc61b2a8de14526af9a9a6a014f275299aace25f5c5", size = 275829, upload-time = "2025-09-20T07:35:18.074Z" }, + { url = "https://files.pythonhosted.org/packages/0a/8f/fb2e90057ae3f69b89f188c83dc4b930b34e6ecf89d7e5b7d99ae07e6b52/falsa-0.0.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c7bb9884f8cf468e1de57f0fa59532ed99c8bfd41999cf85e57e78a9fb8fd0ca", size = 436591, upload-time = "2025-09-20T07:33:34.336Z" }, + { url = "https://files.pythonhosted.org/packages/5e/c2/57e1b88757e637865fb2390560f927fd9eb60e793d82bbcf18d411b36104/falsa-0.0.6-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bc80e361b29d19d5739a6cb1ace1e00765f139e1d065c70693a644f7c4375089", size = 434955, upload-time = "2025-09-20T07:33:45.802Z" }, + { url = "https://files.pythonhosted.org/packages/ed/29/79585d31bce867fa083d2ca11bb469a3530077407ea2549046d6e496df24/falsa-0.0.6-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ca667084eb89a07893c373bbe05492235482a214b23b13da39626d71c9028ce7", size = 650688, upload-time = "2025-09-20T07:33:56.767Z" }, + { url = "https://files.pythonhosted.org/packages/fa/50/cda029ec50341601c283b040748172ba9cacc0a16880e93e4cb6239a715e/falsa-0.0.6-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d4811ab6aa0b2a155180aac6b3800ae5ea800bf422bddf8fb11daa509908c793", size = 475074, upload-time = "2025-09-20T07:34:07.88Z" }, + { url = "https://files.pythonhosted.org/packages/7c/62/1272b0c50203d0be2df3253e237f1ddbadce1642117d9dab4fb658fd241a/falsa-0.0.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d9cd0e7075aa22daaa970ca113502c51d1e0d89cf3322be116213099f61aa5fe", size = 597359, upload-time = "2025-09-20T07:34:28.566Z" }, + { url = "https://files.pythonhosted.org/packages/72/c9/4cc472d2e734bd4788ff5ce43825aaeba4715fc70f4900f2bfd6099b809e/falsa-0.0.6-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:37882088385512187511311d56a26226d45fd4f53dad081e50fdb07f587e0201", size = 461025, upload-time = "2025-09-20T07:34:18.436Z" }, + { url = "https://files.pythonhosted.org/packages/b0/a3/32206b72a42c06d771cd18b1211321d2fa413695e4cc9616b72d80708252/falsa-0.0.6-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:22f0c8dd927e857480c83b4db1e4209021e0a301efb8e76b2d3a91ad747b3768", size = 616183, upload-time = "2025-09-20T07:34:37.526Z" }, + { url = "https://files.pythonhosted.org/packages/54/57/244227fd859a5173938501a17bd2ec81c09ce25a60472dceb1f54dbb529b/falsa-0.0.6-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:fd06795b6873926a507f685eb147a06fb6c7282789ceb550558c42325bcbc637", size = 698951, upload-time = "2025-09-20T07:34:48.241Z" }, + { url = "https://files.pythonhosted.org/packages/41/6f/57d82f555f288ea9106b7a7ffb1978d27f8ffc1bf52753b8c2c4298acc00/falsa-0.0.6-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:5b6fd2c5cc4bbcae5b1a28f533705eb95ba0e220c8b70c67c830e86309477fb5", size = 632175, upload-time = "2025-09-20T07:34:58.664Z" }, + { url = "https://files.pythonhosted.org/packages/e9/a4/64c6c7dfe0e73948ead7e19217e38116853fa49512ee91dfdf41e8f799ca/falsa-0.0.6-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ea73bd1b098198b0cabd94eec7952de37051024b26805a30906ed350d3b474a8", size = 604022, upload-time = "2025-09-20T07:35:08.71Z" }, + { url = "https://files.pythonhosted.org/packages/a7/e2/42d9b92f09671cacc629a000d08656fe4f0da4ec818f4841fa700a0651f0/falsa-0.0.6-cp313-cp313-win_amd64.whl", hash = "sha256:80908855b7e8144add3d5f9b1ff7ef58d2fc574a6e8f7ac755437a178058d2ac", size = 275625, upload-time = "2025-09-20T07:35:19.664Z" }, + { url = "https://files.pythonhosted.org/packages/90/9e/304d3ce465ca33055ed22560e7694dd8418f200d1c6eaca16236aa24035e/falsa-0.0.6-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f6177b18bb6e61f333cca5c73d1c60a809a688937090130f8baeea4363366b9e", size = 436505, upload-time = "2025-09-20T07:33:35.655Z" }, + { url = "https://files.pythonhosted.org/packages/f9/e8/0f51c6562ee4e0c572e3cac4c9ea338678a15e349351474e4f298184f8c0/falsa-0.0.6-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5df6bedb01acf73134f565b0352493b981aa3ea84d09fd4e8d6f2c618042a1f3", size = 433993, upload-time = "2025-09-20T07:33:47.056Z" }, + { url = "https://files.pythonhosted.org/packages/46/6e/7a0a4acfc0bf397fd6f3c749040287c75e6fc9677d32ec20bca8e06ae4e0/falsa-0.0.6-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:731acd74b9b41e9bca388176c7e7be6ea48b5ba136f149f41bdfaaaaa53a40e4", size = 649979, upload-time = "2025-09-20T07:33:57.991Z" }, + { url = "https://files.pythonhosted.org/packages/f3/2a/19d66b0b38232d6230ed163e9c24c55683f38348930e25c7e36188b9e7a1/falsa-0.0.6-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d7aa02f407b473fe81a5e94d3cbaa5ba34e243da35593fbfb1b71351093eac8", size = 474443, upload-time = "2025-09-20T07:34:08.949Z" }, + { url = "https://files.pythonhosted.org/packages/e1/df/80bea42472af460b2b18c3bb547ae5eaf55bea9eff63f5abf266dca51b5a/falsa-0.0.6-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:f0214f94434924e03308b48a81ddf246d0c8c9e1e4b323184bb417fe81df190e", size = 615972, upload-time = "2025-09-20T07:34:38.639Z" }, + { url = "https://files.pythonhosted.org/packages/2c/6d/449f03ad7b5c31f7cac1fc7177419a67d0c53b7733c83034772ca491b697/falsa-0.0.6-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:4e2982b9ef053fedca216f6abeb5d7325d73f4df24540dd9a0fe8463a9c80abd", size = 698052, upload-time = "2025-09-20T07:34:49.336Z" }, + { url = "https://files.pythonhosted.org/packages/34/6f/723bed02c00e9b3741a2b8fdbbca1afb7ba3fc2ad398be85cd477408f611/falsa-0.0.6-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:4953ae9f87aefed8a3936562dbab20dd6b3a6cdadf32f009ef552e9e5df96a56", size = 631684, upload-time = "2025-09-20T07:34:59.715Z" }, + { url = "https://files.pythonhosted.org/packages/54/70/a8a0bda4afa93bd602ce05efe3f615f25e2145880e5abb0f8138312fcaed/falsa-0.0.6-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:fcf31b451835037ccdf6b9adb9353d99981178d6e96601b6b023fbac1db74342", size = 604314, upload-time = "2025-09-20T07:35:09.78Z" }, + { url = "https://files.pythonhosted.org/packages/8b/47/6e1a6a2cf730e7cf5b2a5159066590a5151867b0cf1c913386285b39d52c/falsa-0.0.6-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de0bd27c505b47c8870463ef9376e52e72d54a7f3bb7b393e6a0f5fe8227c95e", size = 597105, upload-time = "2025-09-20T07:34:29.668Z" }, + { url = "https://files.pythonhosted.org/packages/a0/a0/3d697341c44c238e635af6f4ccc87d1150edbb5374c67e6f7c86c9818336/falsa-0.0.6-cp314-cp314-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:bf8f50d6f8f65009ae5b986f4220dd823cb22d704221e29ca91a06dd0c178599", size = 461233, upload-time = "2025-09-20T07:34:19.704Z" }, + { url = "https://files.pythonhosted.org/packages/8a/a6/a59e8d6f27c049a0955f3b7d7a229633213f485b0175d6a348fc66047bdd/falsa-0.0.6-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b8714397240eeb05f490b8e2c1ca6592edb2e6c5e6652baaf1d29ea4bd2c4a6", size = 438116, upload-time = "2025-09-20T07:33:39.668Z" }, + { url = "https://files.pythonhosted.org/packages/b3/e8/27f367c60dd662e009dd2945c1fdbc74fad277c6b668d02ee004ba41e2ee/falsa-0.0.6-pp311-pypy311_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:47a610301a11f1b53c12092d97b5dff80e576b1534883e62a02d019bc759d06f", size = 436210, upload-time = "2025-09-20T07:33:50.477Z" }, + { url = "https://files.pythonhosted.org/packages/cb/a4/6163320b1130da9333f851633a6f7b726ea42974bafc6db333fc3c0a69e0/falsa-0.0.6-pp311-pypy311_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:13c98c49225232016dfd8bdd0e5f2e10649f9d0388fde9b1020b04d7409c9078", size = 651561, upload-time = "2025-09-20T07:34:01.522Z" }, + { url = "https://files.pythonhosted.org/packages/f1/5d/f06f625cb2e9af5769f0f755154469e9a280b9ce6bedfff15564bce9483a/falsa-0.0.6-pp311-pypy311_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e8d63db146847709114032382c4cdaf7274654781d3a56732eb5e622350654f2", size = 476530, upload-time = "2025-09-20T07:34:12.248Z" }, + { url = "https://files.pythonhosted.org/packages/54/cb/81fd6f2d542ef1833485d95f766c29bf5a9bf73213d4c6dad8b2c4541327/falsa-0.0.6-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b552b1525300b14abd2400dc692cfb79de6813cec725deca03aaf251ca94111", size = 598516, upload-time = "2025-09-20T07:34:31.807Z" }, + { url = "https://files.pythonhosted.org/packages/97/33/07809af6ff17d1fc3e059ea1a73a76cc5593661832cf0c91498be9bc8172/falsa-0.0.6-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:535f9d6cc9a745d7aed0b108f8447de1780e548fc30fbeb0d360f8403ed86b6e", size = 461808, upload-time = "2025-09-20T07:34:24.119Z" }, + { url = "https://files.pythonhosted.org/packages/f7/6a/0b4a3903f7c8ed15e2f5c8b4d226e0cf214f7f32dca1b74a8064f6d27c47/falsa-0.0.6-pp311-pypy311_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:04109d8e1c58cd8d87d513546fa945db4b5883e1ddc29a1dc14b9bb999991d6d", size = 617349, upload-time = "2025-09-20T07:34:42.168Z" }, + { url = "https://files.pythonhosted.org/packages/08/cc/3a7d98bd4f8569c9ec683d358379b6167e19911007263fcc45e4f414f407/falsa-0.0.6-pp311-pypy311_pp73-musllinux_1_2_armv7l.whl", hash = "sha256:9623ada575625e65245488ec6ef7cf09e40e134245c5ab8a440267338212f73e", size = 700202, upload-time = "2025-09-20T07:34:52.724Z" }, + { url = "https://files.pythonhosted.org/packages/dc/5c/88e1a1d2c29b83e0c5da30960815f830dd79694c474f6b7ae2eb716a8e65/falsa-0.0.6-pp311-pypy311_pp73-musllinux_1_2_i686.whl", hash = "sha256:9a8e8cd40e0389f56c2fb41bd0a0c2472c2365265b78966c7f187aaf3409558a", size = 633105, upload-time = "2025-09-20T07:35:03.315Z" }, + { url = "https://files.pythonhosted.org/packages/37/03/94f5e53369796b3e93c3d942d6c010f3215957330a697a2c715fe93f2ac6/falsa-0.0.6-pp311-pypy311_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:afaadf6ac8599bbf2e42f54bccda76e9f0218f6d6429085186d38d243c6b28da", size = 605690, upload-time = "2025-09-20T07:35:13.015Z" }, +] + +[[package]] +name = "idna" +version = "3.11" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6f/6d/0703ccc57f3a7233505399edb88de3cbd678da106337b9fcde432b65ed60/idna-3.11.tar.gz", hash = "sha256:795dafcc9c04ed0c1fb032c2aa73654d8e8c5023a7df64a53f39190ada629902", size = 194582, upload-time = "2025-10-12T14:55:20.501Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" }, +] + +[[package]] +name = "imagesize" +version = "1.4.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a7/84/62473fb57d61e31fef6e36d64a179c8781605429fd927b5dd608c997be31/imagesize-1.4.1.tar.gz", hash = "sha256:69150444affb9cb0d5cc5a92b3676f0b2fb7cd9ae39e947a5e11a36b4497cd4a", size = 1280026, upload-time = "2022-07-01T12:21:05.687Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ff/62/85c4c919272577931d407be5ba5d71c20f0b616d31a0befe0ae45bb79abd/imagesize-1.4.1-py2.py3-none-any.whl", hash = "sha256:0d8d18d08f840c19d0ee7ca1fd82490fdc3729b7ac93f49870406ddde8ef8d8b", size = 8769, upload-time = "2022-07-01T12:21:02.467Z" }, +] + +[[package]] +name = "jinja2" +version = "3.1.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/df/bf/f7da0350254c0ed7c72f3e33cef02e048281fec7ecec5f032d4aac52226b/jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d", size = 245115, upload-time = "2025-03-05T20:05:02.478Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, +] + +[[package]] +name = "markdown-it-py" +version = "4.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mdurl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5b/f5/4ec618ed16cc4f8fb3b701563655a69816155e79e24a17b651541804721d/markdown_it_py-4.0.0.tar.gz", hash = "sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3", size = 73070, upload-time = "2025-08-11T12:57:52.854Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321, upload-time = "2025-08-11T12:57:51.923Z" }, +] + +[[package]] +name = "markupsafe" +version = "3.0.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7e/99/7690b6d4034fffd95959cbe0c02de8deb3098cc577c67bb6a24fe5d7caa7/markupsafe-3.0.3.tar.gz", hash = "sha256:722695808f4b6457b320fdc131280796bdceb04ab50fe1795cd540799ebe1698", size = 80313, upload-time = "2025-09-27T18:37:40.426Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/08/db/fefacb2136439fc8dd20e797950e749aa1f4997ed584c62cfb8ef7c2be0e/markupsafe-3.0.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1cc7ea17a6824959616c525620e387f6dd30fec8cb44f649e31712db02123dad", size = 11631, upload-time = "2025-09-27T18:36:18.185Z" }, + { url = "https://files.pythonhosted.org/packages/e1/2e/5898933336b61975ce9dc04decbc0a7f2fee78c30353c5efba7f2d6ff27a/markupsafe-3.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4bd4cd07944443f5a265608cc6aab442e4f74dff8088b0dfc8238647b8f6ae9a", size = 12058, upload-time = "2025-09-27T18:36:19.444Z" }, + { url = "https://files.pythonhosted.org/packages/1d/09/adf2df3699d87d1d8184038df46a9c80d78c0148492323f4693df54e17bb/markupsafe-3.0.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6b5420a1d9450023228968e7e6a9ce57f65d148ab56d2313fcd589eee96a7a50", size = 24287, upload-time = "2025-09-27T18:36:20.768Z" }, + { url = "https://files.pythonhosted.org/packages/30/ac/0273f6fcb5f42e314c6d8cd99effae6a5354604d461b8d392b5ec9530a54/markupsafe-3.0.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0bf2a864d67e76e5c9a34dc26ec616a66b9888e25e7b9460e1c76d3293bd9dbf", size = 22940, upload-time = "2025-09-27T18:36:22.249Z" }, + { url = "https://files.pythonhosted.org/packages/19/ae/31c1be199ef767124c042c6c3e904da327a2f7f0cd63a0337e1eca2967a8/markupsafe-3.0.3-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:bc51efed119bc9cfdf792cdeaa4d67e8f6fcccab66ed4bfdd6bde3e59bfcbb2f", size = 21887, upload-time = "2025-09-27T18:36:23.535Z" }, + { url = "https://files.pythonhosted.org/packages/b2/76/7edcab99d5349a4532a459e1fe64f0b0467a3365056ae550d3bcf3f79e1e/markupsafe-3.0.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:068f375c472b3e7acbe2d5318dea141359e6900156b5b2ba06a30b169086b91a", size = 23692, upload-time = "2025-09-27T18:36:24.823Z" }, + { url = "https://files.pythonhosted.org/packages/a4/28/6e74cdd26d7514849143d69f0bf2399f929c37dc2b31e6829fd2045b2765/markupsafe-3.0.3-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:7be7b61bb172e1ed687f1754f8e7484f1c8019780f6f6b0786e76bb01c2ae115", size = 21471, upload-time = "2025-09-27T18:36:25.95Z" }, + { url = "https://files.pythonhosted.org/packages/62/7e/a145f36a5c2945673e590850a6f8014318d5577ed7e5920a4b3448e0865d/markupsafe-3.0.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f9e130248f4462aaa8e2552d547f36ddadbeaa573879158d721bbd33dfe4743a", size = 22923, upload-time = "2025-09-27T18:36:27.109Z" }, + { url = "https://files.pythonhosted.org/packages/0f/62/d9c46a7f5c9adbeeeda52f5b8d802e1094e9717705a645efc71b0913a0a8/markupsafe-3.0.3-cp311-cp311-win32.whl", hash = "sha256:0db14f5dafddbb6d9208827849fad01f1a2609380add406671a26386cdf15a19", size = 14572, upload-time = "2025-09-27T18:36:28.045Z" }, + { url = "https://files.pythonhosted.org/packages/83/8a/4414c03d3f891739326e1783338e48fb49781cc915b2e0ee052aa490d586/markupsafe-3.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:de8a88e63464af587c950061a5e6a67d3632e36df62b986892331d4620a35c01", size = 15077, upload-time = "2025-09-27T18:36:29.025Z" }, + { url = "https://files.pythonhosted.org/packages/35/73/893072b42e6862f319b5207adc9ae06070f095b358655f077f69a35601f0/markupsafe-3.0.3-cp311-cp311-win_arm64.whl", hash = "sha256:3b562dd9e9ea93f13d53989d23a7e775fdfd1066c33494ff43f5418bc8c58a5c", size = 13876, upload-time = "2025-09-27T18:36:29.954Z" }, + { url = "https://files.pythonhosted.org/packages/5a/72/147da192e38635ada20e0a2e1a51cf8823d2119ce8883f7053879c2199b5/markupsafe-3.0.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d53197da72cc091b024dd97249dfc7794d6a56530370992a5e1a08983ad9230e", size = 11615, upload-time = "2025-09-27T18:36:30.854Z" }, + { url = "https://files.pythonhosted.org/packages/9a/81/7e4e08678a1f98521201c3079f77db69fb552acd56067661f8c2f534a718/markupsafe-3.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1872df69a4de6aead3491198eaf13810b565bdbeec3ae2dc8780f14458ec73ce", size = 12020, upload-time = "2025-09-27T18:36:31.971Z" }, + { url = "https://files.pythonhosted.org/packages/1e/2c/799f4742efc39633a1b54a92eec4082e4f815314869865d876824c257c1e/markupsafe-3.0.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3a7e8ae81ae39e62a41ec302f972ba6ae23a5c5396c8e60113e9066ef893da0d", size = 24332, upload-time = "2025-09-27T18:36:32.813Z" }, + { url = "https://files.pythonhosted.org/packages/3c/2e/8d0c2ab90a8c1d9a24f0399058ab8519a3279d1bd4289511d74e909f060e/markupsafe-3.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d6dd0be5b5b189d31db7cda48b91d7e0a9795f31430b7f271219ab30f1d3ac9d", size = 22947, upload-time = "2025-09-27T18:36:33.86Z" }, + { url = "https://files.pythonhosted.org/packages/2c/54/887f3092a85238093a0b2154bd629c89444f395618842e8b0c41783898ea/markupsafe-3.0.3-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:94c6f0bb423f739146aec64595853541634bde58b2135f27f61c1ffd1cd4d16a", size = 21962, upload-time = "2025-09-27T18:36:35.099Z" }, + { url = "https://files.pythonhosted.org/packages/c9/2f/336b8c7b6f4a4d95e91119dc8521402461b74a485558d8f238a68312f11c/markupsafe-3.0.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:be8813b57049a7dc738189df53d69395eba14fb99345e0a5994914a3864c8a4b", size = 23760, upload-time = "2025-09-27T18:36:36.001Z" }, + { url = "https://files.pythonhosted.org/packages/32/43/67935f2b7e4982ffb50a4d169b724d74b62a3964bc1a9a527f5ac4f1ee2b/markupsafe-3.0.3-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:83891d0e9fb81a825d9a6d61e3f07550ca70a076484292a70fde82c4b807286f", size = 21529, upload-time = "2025-09-27T18:36:36.906Z" }, + { url = "https://files.pythonhosted.org/packages/89/e0/4486f11e51bbba8b0c041098859e869e304d1c261e59244baa3d295d47b7/markupsafe-3.0.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:77f0643abe7495da77fb436f50f8dab76dbc6e5fd25d39589a0f1fe6548bfa2b", size = 23015, upload-time = "2025-09-27T18:36:37.868Z" }, + { url = "https://files.pythonhosted.org/packages/2f/e1/78ee7a023dac597a5825441ebd17170785a9dab23de95d2c7508ade94e0e/markupsafe-3.0.3-cp312-cp312-win32.whl", hash = "sha256:d88b440e37a16e651bda4c7c2b930eb586fd15ca7406cb39e211fcff3bf3017d", size = 14540, upload-time = "2025-09-27T18:36:38.761Z" }, + { url = "https://files.pythonhosted.org/packages/aa/5b/bec5aa9bbbb2c946ca2733ef9c4ca91c91b6a24580193e891b5f7dbe8e1e/markupsafe-3.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:26a5784ded40c9e318cfc2bdb30fe164bdb8665ded9cd64d500a34fb42067b1c", size = 15105, upload-time = "2025-09-27T18:36:39.701Z" }, + { url = "https://files.pythonhosted.org/packages/e5/f1/216fc1bbfd74011693a4fd837e7026152e89c4bcf3e77b6692fba9923123/markupsafe-3.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:35add3b638a5d900e807944a078b51922212fb3dedb01633a8defc4b01a3c85f", size = 13906, upload-time = "2025-09-27T18:36:40.689Z" }, + { url = "https://files.pythonhosted.org/packages/38/2f/907b9c7bbba283e68f20259574b13d005c121a0fa4c175f9bed27c4597ff/markupsafe-3.0.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e1cf1972137e83c5d4c136c43ced9ac51d0e124706ee1c8aa8532c1287fa8795", size = 11622, upload-time = "2025-09-27T18:36:41.777Z" }, + { url = "https://files.pythonhosted.org/packages/9c/d9/5f7756922cdd676869eca1c4e3c0cd0df60ed30199ffd775e319089cb3ed/markupsafe-3.0.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:116bb52f642a37c115f517494ea5feb03889e04df47eeff5b130b1808ce7c219", size = 12029, upload-time = "2025-09-27T18:36:43.257Z" }, + { url = "https://files.pythonhosted.org/packages/00/07/575a68c754943058c78f30db02ee03a64b3c638586fba6a6dd56830b30a3/markupsafe-3.0.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:133a43e73a802c5562be9bbcd03d090aa5a1fe899db609c29e8c8d815c5f6de6", size = 24374, upload-time = "2025-09-27T18:36:44.508Z" }, + { url = "https://files.pythonhosted.org/packages/a9/21/9b05698b46f218fc0e118e1f8168395c65c8a2c750ae2bab54fc4bd4e0e8/markupsafe-3.0.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ccfcd093f13f0f0b7fdd0f198b90053bf7b2f02a3927a30e63f3ccc9df56b676", size = 22980, upload-time = "2025-09-27T18:36:45.385Z" }, + { url = "https://files.pythonhosted.org/packages/7f/71/544260864f893f18b6827315b988c146b559391e6e7e8f7252839b1b846a/markupsafe-3.0.3-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:509fa21c6deb7a7a273d629cf5ec029bc209d1a51178615ddf718f5918992ab9", size = 21990, upload-time = "2025-09-27T18:36:46.916Z" }, + { url = "https://files.pythonhosted.org/packages/c2/28/b50fc2f74d1ad761af2f5dcce7492648b983d00a65b8c0e0cb457c82ebbe/markupsafe-3.0.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a4afe79fb3de0b7097d81da19090f4df4f8d3a2b3adaa8764138aac2e44f3af1", size = 23784, upload-time = "2025-09-27T18:36:47.884Z" }, + { url = "https://files.pythonhosted.org/packages/ed/76/104b2aa106a208da8b17a2fb72e033a5a9d7073c68f7e508b94916ed47a9/markupsafe-3.0.3-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:795e7751525cae078558e679d646ae45574b47ed6e7771863fcc079a6171a0fc", size = 21588, upload-time = "2025-09-27T18:36:48.82Z" }, + { url = "https://files.pythonhosted.org/packages/b5/99/16a5eb2d140087ebd97180d95249b00a03aa87e29cc224056274f2e45fd6/markupsafe-3.0.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:8485f406a96febb5140bfeca44a73e3ce5116b2501ac54fe953e488fb1d03b12", size = 23041, upload-time = "2025-09-27T18:36:49.797Z" }, + { url = "https://files.pythonhosted.org/packages/19/bc/e7140ed90c5d61d77cea142eed9f9c303f4c4806f60a1044c13e3f1471d0/markupsafe-3.0.3-cp313-cp313-win32.whl", hash = "sha256:bdd37121970bfd8be76c5fb069c7751683bdf373db1ed6c010162b2a130248ed", size = 14543, upload-time = "2025-09-27T18:36:51.584Z" }, + { url = "https://files.pythonhosted.org/packages/05/73/c4abe620b841b6b791f2edc248f556900667a5a1cf023a6646967ae98335/markupsafe-3.0.3-cp313-cp313-win_amd64.whl", hash = "sha256:9a1abfdc021a164803f4d485104931fb8f8c1efd55bc6b748d2f5774e78b62c5", size = 15113, upload-time = "2025-09-27T18:36:52.537Z" }, + { url = "https://files.pythonhosted.org/packages/f0/3a/fa34a0f7cfef23cf9500d68cb7c32dd64ffd58a12b09225fb03dd37d5b80/markupsafe-3.0.3-cp313-cp313-win_arm64.whl", hash = "sha256:7e68f88e5b8799aa49c85cd116c932a1ac15caaa3f5db09087854d218359e485", size = 13911, upload-time = "2025-09-27T18:36:53.513Z" }, + { url = "https://files.pythonhosted.org/packages/e4/d7/e05cd7efe43a88a17a37b3ae96e79a19e846f3f456fe79c57ca61356ef01/markupsafe-3.0.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:218551f6df4868a8d527e3062d0fb968682fe92054e89978594c28e642c43a73", size = 11658, upload-time = "2025-09-27T18:36:54.819Z" }, + { url = "https://files.pythonhosted.org/packages/99/9e/e412117548182ce2148bdeacdda3bb494260c0b0184360fe0d56389b523b/markupsafe-3.0.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:3524b778fe5cfb3452a09d31e7b5adefeea8c5be1d43c4f810ba09f2ceb29d37", size = 12066, upload-time = "2025-09-27T18:36:55.714Z" }, + { url = "https://files.pythonhosted.org/packages/bc/e6/fa0ffcda717ef64a5108eaa7b4f5ed28d56122c9a6d70ab8b72f9f715c80/markupsafe-3.0.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4e885a3d1efa2eadc93c894a21770e4bc67899e3543680313b09f139e149ab19", size = 25639, upload-time = "2025-09-27T18:36:56.908Z" }, + { url = "https://files.pythonhosted.org/packages/96/ec/2102e881fe9d25fc16cb4b25d5f5cde50970967ffa5dddafdb771237062d/markupsafe-3.0.3-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8709b08f4a89aa7586de0aadc8da56180242ee0ada3999749b183aa23df95025", size = 23569, upload-time = "2025-09-27T18:36:57.913Z" }, + { url = "https://files.pythonhosted.org/packages/4b/30/6f2fce1f1f205fc9323255b216ca8a235b15860c34b6798f810f05828e32/markupsafe-3.0.3-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:b8512a91625c9b3da6f127803b166b629725e68af71f8184ae7e7d54686a56d6", size = 23284, upload-time = "2025-09-27T18:36:58.833Z" }, + { url = "https://files.pythonhosted.org/packages/58/47/4a0ccea4ab9f5dcb6f79c0236d954acb382202721e704223a8aafa38b5c8/markupsafe-3.0.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:9b79b7a16f7fedff2495d684f2b59b0457c3b493778c9eed31111be64d58279f", size = 24801, upload-time = "2025-09-27T18:36:59.739Z" }, + { url = "https://files.pythonhosted.org/packages/6a/70/3780e9b72180b6fecb83a4814d84c3bf4b4ae4bf0b19c27196104149734c/markupsafe-3.0.3-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:12c63dfb4a98206f045aa9563db46507995f7ef6d83b2f68eda65c307c6829eb", size = 22769, upload-time = "2025-09-27T18:37:00.719Z" }, + { url = "https://files.pythonhosted.org/packages/98/c5/c03c7f4125180fc215220c035beac6b9cb684bc7a067c84fc69414d315f5/markupsafe-3.0.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:8f71bc33915be5186016f675cd83a1e08523649b0e33efdb898db577ef5bb009", size = 23642, upload-time = "2025-09-27T18:37:01.673Z" }, + { url = "https://files.pythonhosted.org/packages/80/d6/2d1b89f6ca4bff1036499b1e29a1d02d282259f3681540e16563f27ebc23/markupsafe-3.0.3-cp313-cp313t-win32.whl", hash = "sha256:69c0b73548bc525c8cb9a251cddf1931d1db4d2258e9599c28c07ef3580ef354", size = 14612, upload-time = "2025-09-27T18:37:02.639Z" }, + { url = "https://files.pythonhosted.org/packages/2b/98/e48a4bfba0a0ffcf9925fe2d69240bfaa19c6f7507b8cd09c70684a53c1e/markupsafe-3.0.3-cp313-cp313t-win_amd64.whl", hash = "sha256:1b4b79e8ebf6b55351f0d91fe80f893b4743f104bff22e90697db1590e47a218", size = 15200, upload-time = "2025-09-27T18:37:03.582Z" }, + { url = "https://files.pythonhosted.org/packages/0e/72/e3cc540f351f316e9ed0f092757459afbc595824ca724cbc5a5d4263713f/markupsafe-3.0.3-cp313-cp313t-win_arm64.whl", hash = "sha256:ad2cf8aa28b8c020ab2fc8287b0f823d0a7d8630784c31e9ee5edea20f406287", size = 13973, upload-time = "2025-09-27T18:37:04.929Z" }, + { url = "https://files.pythonhosted.org/packages/33/8a/8e42d4838cd89b7dde187011e97fe6c3af66d8c044997d2183fbd6d31352/markupsafe-3.0.3-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:eaa9599de571d72e2daf60164784109f19978b327a3910d3e9de8c97b5b70cfe", size = 11619, upload-time = "2025-09-27T18:37:06.342Z" }, + { url = "https://files.pythonhosted.org/packages/b5/64/7660f8a4a8e53c924d0fa05dc3a55c9cee10bbd82b11c5afb27d44b096ce/markupsafe-3.0.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:c47a551199eb8eb2121d4f0f15ae0f923d31350ab9280078d1e5f12b249e0026", size = 12029, upload-time = "2025-09-27T18:37:07.213Z" }, + { url = "https://files.pythonhosted.org/packages/da/ef/e648bfd021127bef5fa12e1720ffed0c6cbb8310c8d9bea7266337ff06de/markupsafe-3.0.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f34c41761022dd093b4b6896d4810782ffbabe30f2d443ff5f083e0cbbb8c737", size = 24408, upload-time = "2025-09-27T18:37:09.572Z" }, + { url = "https://files.pythonhosted.org/packages/41/3c/a36c2450754618e62008bf7435ccb0f88053e07592e6028a34776213d877/markupsafe-3.0.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:457a69a9577064c05a97c41f4e65148652db078a3a509039e64d3467b9e7ef97", size = 23005, upload-time = "2025-09-27T18:37:10.58Z" }, + { url = "https://files.pythonhosted.org/packages/bc/20/b7fdf89a8456b099837cd1dc21974632a02a999ec9bf7ca3e490aacd98e7/markupsafe-3.0.3-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:e8afc3f2ccfa24215f8cb28dcf43f0113ac3c37c2f0f0806d8c70e4228c5cf4d", size = 22048, upload-time = "2025-09-27T18:37:11.547Z" }, + { url = "https://files.pythonhosted.org/packages/9a/a7/591f592afdc734f47db08a75793a55d7fbcc6902a723ae4cfbab61010cc5/markupsafe-3.0.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:ec15a59cf5af7be74194f7ab02d0f59a62bdcf1a537677ce67a2537c9b87fcda", size = 23821, upload-time = "2025-09-27T18:37:12.48Z" }, + { url = "https://files.pythonhosted.org/packages/7d/33/45b24e4f44195b26521bc6f1a82197118f74df348556594bd2262bda1038/markupsafe-3.0.3-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:0eb9ff8191e8498cca014656ae6b8d61f39da5f95b488805da4bb029cccbfbaf", size = 21606, upload-time = "2025-09-27T18:37:13.485Z" }, + { url = "https://files.pythonhosted.org/packages/ff/0e/53dfaca23a69fbfbbf17a4b64072090e70717344c52eaaaa9c5ddff1e5f0/markupsafe-3.0.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:2713baf880df847f2bece4230d4d094280f4e67b1e813eec43b4c0e144a34ffe", size = 23043, upload-time = "2025-09-27T18:37:14.408Z" }, + { url = "https://files.pythonhosted.org/packages/46/11/f333a06fc16236d5238bfe74daccbca41459dcd8d1fa952e8fbd5dccfb70/markupsafe-3.0.3-cp314-cp314-win32.whl", hash = "sha256:729586769a26dbceff69f7a7dbbf59ab6572b99d94576a5592625d5b411576b9", size = 14747, upload-time = "2025-09-27T18:37:15.36Z" }, + { url = "https://files.pythonhosted.org/packages/28/52/182836104b33b444e400b14f797212f720cbc9ed6ba34c800639d154e821/markupsafe-3.0.3-cp314-cp314-win_amd64.whl", hash = "sha256:bdc919ead48f234740ad807933cdf545180bfbe9342c2bb451556db2ed958581", size = 15341, upload-time = "2025-09-27T18:37:16.496Z" }, + { url = "https://files.pythonhosted.org/packages/6f/18/acf23e91bd94fd7b3031558b1f013adfa21a8e407a3fdb32745538730382/markupsafe-3.0.3-cp314-cp314-win_arm64.whl", hash = "sha256:5a7d5dc5140555cf21a6fefbdbf8723f06fcd2f63ef108f2854de715e4422cb4", size = 14073, upload-time = "2025-09-27T18:37:17.476Z" }, + { url = "https://files.pythonhosted.org/packages/3c/f0/57689aa4076e1b43b15fdfa646b04653969d50cf30c32a102762be2485da/markupsafe-3.0.3-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:1353ef0c1b138e1907ae78e2f6c63ff67501122006b0f9abad68fda5f4ffc6ab", size = 11661, upload-time = "2025-09-27T18:37:18.453Z" }, + { url = "https://files.pythonhosted.org/packages/89/c3/2e67a7ca217c6912985ec766c6393b636fb0c2344443ff9d91404dc4c79f/markupsafe-3.0.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:1085e7fbddd3be5f89cc898938f42c0b3c711fdcb37d75221de2666af647c175", size = 12069, upload-time = "2025-09-27T18:37:19.332Z" }, + { url = "https://files.pythonhosted.org/packages/f0/00/be561dce4e6ca66b15276e184ce4b8aec61fe83662cce2f7d72bd3249d28/markupsafe-3.0.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1b52b4fb9df4eb9ae465f8d0c228a00624de2334f216f178a995ccdcf82c4634", size = 25670, upload-time = "2025-09-27T18:37:20.245Z" }, + { url = "https://files.pythonhosted.org/packages/50/09/c419f6f5a92e5fadde27efd190eca90f05e1261b10dbd8cbcb39cd8ea1dc/markupsafe-3.0.3-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fed51ac40f757d41b7c48425901843666a6677e3e8eb0abcff09e4ba6e664f50", size = 23598, upload-time = "2025-09-27T18:37:21.177Z" }, + { url = "https://files.pythonhosted.org/packages/22/44/a0681611106e0b2921b3033fc19bc53323e0b50bc70cffdd19f7d679bb66/markupsafe-3.0.3-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:f190daf01f13c72eac4efd5c430a8de82489d9cff23c364c3ea822545032993e", size = 23261, upload-time = "2025-09-27T18:37:22.167Z" }, + { url = "https://files.pythonhosted.org/packages/5f/57/1b0b3f100259dc9fffe780cfb60d4be71375510e435efec3d116b6436d43/markupsafe-3.0.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:e56b7d45a839a697b5eb268c82a71bd8c7f6c94d6fd50c3d577fa39a9f1409f5", size = 24835, upload-time = "2025-09-27T18:37:23.296Z" }, + { url = "https://files.pythonhosted.org/packages/26/6a/4bf6d0c97c4920f1597cc14dd720705eca0bf7c787aebc6bb4d1bead5388/markupsafe-3.0.3-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:f3e98bb3798ead92273dc0e5fd0f31ade220f59a266ffd8a4f6065e0a3ce0523", size = 22733, upload-time = "2025-09-27T18:37:24.237Z" }, + { url = "https://files.pythonhosted.org/packages/14/c7/ca723101509b518797fedc2fdf79ba57f886b4aca8a7d31857ba3ee8281f/markupsafe-3.0.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:5678211cb9333a6468fb8d8be0305520aa073f50d17f089b5b4b477ea6e67fdc", size = 23672, upload-time = "2025-09-27T18:37:25.271Z" }, + { url = "https://files.pythonhosted.org/packages/fb/df/5bd7a48c256faecd1d36edc13133e51397e41b73bb77e1a69deab746ebac/markupsafe-3.0.3-cp314-cp314t-win32.whl", hash = "sha256:915c04ba3851909ce68ccc2b8e2cd691618c4dc4c4232fb7982bca3f41fd8c3d", size = 14819, upload-time = "2025-09-27T18:37:26.285Z" }, + { url = "https://files.pythonhosted.org/packages/1a/8a/0402ba61a2f16038b48b39bccca271134be00c5c9f0f623208399333c448/markupsafe-3.0.3-cp314-cp314t-win_amd64.whl", hash = "sha256:4faffd047e07c38848ce017e8725090413cd80cbc23d86e55c587bf979e579c9", size = 15426, upload-time = "2025-09-27T18:37:27.316Z" }, + { url = "https://files.pythonhosted.org/packages/70/bc/6f1c2f612465f5fa89b95bead1f44dcb607670fd42891d8fdcd5d039f4f4/markupsafe-3.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:32001d6a8fc98c8cb5c947787c5d08b0a50663d139f1305bac5885d98d9b40fa", size = 14146, upload-time = "2025-09-27T18:37:28.327Z" }, +] + +[[package]] +name = "maturin" +version = "1.12.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ae/13/aeff8a21835ed0e40c329c286750fcdcdcbf231f1a5cb327378666c5def6/maturin-1.12.2.tar.gz", hash = "sha256:d6253079f53dbb692395a13abddc0f2d3d96af32f8c0b32e2912849713c55794", size = 257279, upload-time = "2026-02-16T13:56:20.221Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8a/9d/4811e1fcaa346a0b9fad6aee0ac0eec9eb376a24fe27c66d5d4fe975586e/maturin-1.12.2-py3-none-linux_armv6l.whl", hash = "sha256:0ed31b6a392928ad23645a470edc4f3814b952a416e41f8e5daac42d7bfbabc6", size = 9653200, upload-time = "2026-02-16T13:56:16.216Z" }, + { url = "https://files.pythonhosted.org/packages/69/db/74d582af74c32bbda12e4d7e153b389884409a1c5cd31edc9d3194d515f7/maturin-1.12.2-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:f1c2e4ee43bf286b052091a3b2356a157978985837c7aed42354deb2947a4006", size = 18870087, upload-time = "2026-02-16T13:56:18.463Z" }, + { url = "https://files.pythonhosted.org/packages/9d/6f/71be226c6780387f032c0b4ab791c390c7162ed62f93a11e600f9266dafd/maturin-1.12.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:04c9c4f9c9f904f007cbfcd4640c406e53f19d04c220f5940d1537edb914d325", size = 9762083, upload-time = "2026-02-16T13:56:27.853Z" }, + { url = "https://files.pythonhosted.org/packages/6a/cc/989dce6140227277b4184aab248d07fe67fa11f95411ccf90e272542287d/maturin-1.12.2-py3-none-manylinux_2_12_i686.manylinux2010_i686.musllinux_1_1_i686.whl", hash = "sha256:4bdc486b9ab80d8b50143ecc9a1924b890866fe95be150dd9a59fa22a6b37238", size = 9710711, upload-time = "2026-02-16T13:56:21.364Z" }, + { url = "https://files.pythonhosted.org/packages/a1/e8/02bb64f7150013d8af3ca622944e22f550beb312b6d5cf8760dc2896cce8/maturin-1.12.2-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.musllinux_1_1_x86_64.whl", hash = "sha256:134e895578258a693ba1d55b166c2ba96e9f51067e106b8a74d422432653d45b", size = 10205015, upload-time = "2026-02-16T13:56:07.994Z" }, + { url = "https://files.pythonhosted.org/packages/84/81/b603a74bef68fabd402d1e54f43560213ea69c3c01467610d0256eea013b/maturin-1.12.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.musllinux_1_1_aarch64.whl", hash = "sha256:39665d622dcc950ab17b9569e8cab84a4d64eea6a18b540a8b49e00c0f7dda02", size = 9536887, upload-time = "2026-02-16T13:56:25.658Z" }, + { url = "https://files.pythonhosted.org/packages/70/a5/387c7bced34f7fd8d08d399c6b1ac3d94d7ca50c9f87db9e1bc0dd8c8d08/maturin-1.12.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.musllinux_1_1_armv7l.whl", hash = "sha256:ca3b20bcc3aff115c9eaf97340e78bff58829ea1efa16764940dd0d858dcf6af", size = 9487394, upload-time = "2026-02-16T13:56:29.875Z" }, + { url = "https://files.pythonhosted.org/packages/6d/30/d5ae812c54a70d5d3a5b67b073e92d1d14d36675242e2d00e6a175fa6117/maturin-1.12.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.musllinux_1_1_ppc64le.whl", hash = "sha256:d1617989b4a5dc543fea6d23c28b2f07fadb2c726ff00fe959538ee71a301384", size = 12577754, upload-time = "2026-02-16T13:56:31.902Z" }, + { url = "https://files.pythonhosted.org/packages/84/f4/7baac2fa5324ccdc3f888ff5f6a793f3eb5a7805d89bc17a8bacbe9fc566/maturin-1.12.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6af778e7ee048612e55a1255488db7678741bea2ba881e66a19712f59f2534cb", size = 10375409, upload-time = "2026-02-16T13:56:23.316Z" }, + { url = "https://files.pythonhosted.org/packages/6f/ed/5680efbb1becb4f47da3ada8ea4eb6844d2fd91ae558287e1dd0871cb603/maturin-1.12.2-py3-none-manylinux_2_31_riscv64.musllinux_1_1_riscv64.whl", hash = "sha256:72aad9efe09a6392de9930f2bea80bfcc36fd98e18caa621f512571179c02d41", size = 10010584, upload-time = "2026-02-16T13:56:10.357Z" }, + { url = "https://files.pythonhosted.org/packages/86/20/7e27e07dd2270b707dd0124256cd46bef7c8832476b0aefa2ecd74835365/maturin-1.12.2-py3-none-win32.whl", hash = "sha256:9763d277e143409cf0ce309eb1a493fc4e1e75777364d67ccac39a161b51b5b0", size = 8483122, upload-time = "2026-02-16T13:56:12.606Z" }, + { url = "https://files.pythonhosted.org/packages/3b/6e/9cc0e19c9a336fbc1b9664c1a7955caa6d8fd510c0047ace9be66a33704a/maturin-1.12.2-py3-none-win_amd64.whl", hash = "sha256:c06d218931985035d7ab4d0211ba96027e1bc7e4b01a87c8c4e30a57790403ec", size = 9825577, upload-time = "2026-02-16T13:56:34.193Z" }, + { url = "https://files.pythonhosted.org/packages/2e/67/07ea2c991ca1a55c6b08cd821710736276af7a3e160e1f869ea5c41c78c3/maturin-1.12.2-py3-none-win_arm64.whl", hash = "sha256:a882cc80c241b1e2c27bd1acd713b09e9ac9266a3159cc1e34e8c7b77f049bba", size = 8522702, upload-time = "2026-02-16T13:56:14.42Z" }, +] + +[[package]] +name = "mdit-py-plugins" +version = "0.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b2/fd/a756d36c0bfba5f6e39a1cdbdbfdd448dc02692467d83816dff4592a1ebc/mdit_py_plugins-0.5.0.tar.gz", hash = "sha256:f4918cb50119f50446560513a8e311d574ff6aaed72606ddae6d35716fe809c6", size = 44655, upload-time = "2025-08-11T07:25:49.083Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fb/86/dd6e5db36df29e76c7a7699123569a4a18c1623ce68d826ed96c62643cae/mdit_py_plugins-0.5.0-py3-none-any.whl", hash = "sha256:07a08422fc1936a5d26d146759e9155ea466e842f5ab2f7d2266dd084c8dab1f", size = 57205, upload-time = "2025-08-11T07:25:47.597Z" }, +] + +[[package]] +name = "mdurl" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, +] + +[[package]] +name = "myst-parser" +version = "5.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "docutils" }, + { name = "jinja2" }, + { name = "markdown-it-py" }, + { name = "mdit-py-plugins" }, + { name = "pyyaml" }, + { name = "sphinx", version = "9.0.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.12'" }, + { name = "sphinx", version = "9.1.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/33/fa/7b45eef11b7971f0beb29d27b7bfe0d747d063aa29e170d9edd004733c8a/myst_parser-5.0.0.tar.gz", hash = "sha256:f6f231452c56e8baa662cc352c548158f6a16fcbd6e3800fc594978002b94f3a", size = 98535, upload-time = "2026-01-15T09:08:18.036Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d3/ac/686789b9145413f1a61878c407210e41bfdb097976864e0913078b24098c/myst_parser-5.0.0-py3-none-any.whl", hash = "sha256:ab31e516024918296e169139072b81592336f2fef55b8986aa31c9f04b5f7211", size = 84533, upload-time = "2026-01-15T09:08:16.788Z" }, +] + +[[package]] +name = "numpy" +version = "2.4.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/57/fd/0005efbd0af48e55eb3c7208af93f2862d4b1a56cd78e84309a2d959208d/numpy-2.4.2.tar.gz", hash = "sha256:659a6107e31a83c4e33f763942275fd278b21d095094044eb35569e86a21ddae", size = 20723651, upload-time = "2026-01-31T23:13:10.135Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d3/44/71852273146957899753e69986246d6a176061ea183407e95418c2aa4d9a/numpy-2.4.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e7e88598032542bd49af7c4747541422884219056c268823ef6e5e89851c8825", size = 16955478, upload-time = "2026-01-31T23:10:25.623Z" }, + { url = "https://files.pythonhosted.org/packages/74/41/5d17d4058bd0cd96bcbd4d9ff0fb2e21f52702aab9a72e4a594efa18692f/numpy-2.4.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7edc794af8b36ca37ef5fcb5e0d128c7e0595c7b96a2318d1badb6fcd8ee86b1", size = 14965467, upload-time = "2026-01-31T23:10:28.186Z" }, + { url = "https://files.pythonhosted.org/packages/49/48/fb1ce8136c19452ed15f033f8aee91d5defe515094e330ce368a0647846f/numpy-2.4.2-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:6e9f61981ace1360e42737e2bae58b27bf28a1b27e781721047d84bd754d32e7", size = 5475172, upload-time = "2026-01-31T23:10:30.848Z" }, + { url = "https://files.pythonhosted.org/packages/40/a9/3feb49f17bbd1300dd2570432961f5c8a4ffeff1db6f02c7273bd020a4c9/numpy-2.4.2-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:cb7bbb88aa74908950d979eeaa24dbdf1a865e3c7e45ff0121d8f70387b55f73", size = 6805145, upload-time = "2026-01-31T23:10:32.352Z" }, + { url = "https://files.pythonhosted.org/packages/3f/39/fdf35cbd6d6e2fcad42fcf85ac04a85a0d0fbfbf34b30721c98d602fd70a/numpy-2.4.2-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4f069069931240b3fc703f1e23df63443dbd6390614c8c44a87d96cd0ec81eb1", size = 15966084, upload-time = "2026-01-31T23:10:34.502Z" }, + { url = "https://files.pythonhosted.org/packages/1b/46/6fa4ea94f1ddf969b2ee941290cca6f1bfac92b53c76ae5f44afe17ceb69/numpy-2.4.2-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c02ef4401a506fb60b411467ad501e1429a3487abca4664871d9ae0b46c8ba32", size = 16899477, upload-time = "2026-01-31T23:10:37.075Z" }, + { url = "https://files.pythonhosted.org/packages/09/a1/2a424e162b1a14a5bd860a464ab4e07513916a64ab1683fae262f735ccd2/numpy-2.4.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2653de5c24910e49c2b106499803124dde62a5a1fe0eedeaecf4309a5f639390", size = 17323429, upload-time = "2026-01-31T23:10:39.704Z" }, + { url = "https://files.pythonhosted.org/packages/ce/a2/73014149ff250628df72c58204822ac01d768697913881aacf839ff78680/numpy-2.4.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1ae241bbfc6ae276f94a170b14785e561cb5e7f626b6688cf076af4110887413", size = 18635109, upload-time = "2026-01-31T23:10:41.924Z" }, + { url = "https://files.pythonhosted.org/packages/6c/0c/73e8be2f1accd56df74abc1c5e18527822067dced5ec0861b5bb882c2ce0/numpy-2.4.2-cp311-cp311-win32.whl", hash = "sha256:df1b10187212b198dd45fa943d8985a3c8cf854aed4923796e0e019e113a1bda", size = 6237915, upload-time = "2026-01-31T23:10:45.26Z" }, + { url = "https://files.pythonhosted.org/packages/76/ae/e0265e0163cf127c24c3969d29f1c4c64551a1e375d95a13d32eab25d364/numpy-2.4.2-cp311-cp311-win_amd64.whl", hash = "sha256:b9c618d56a29c9cb1c4da979e9899be7578d2e0b3c24d52079c166324c9e8695", size = 12607972, upload-time = "2026-01-31T23:10:47.021Z" }, + { url = "https://files.pythonhosted.org/packages/29/a5/c43029af9b8014d6ea157f192652c50042e8911f4300f8f6ed3336bf437f/numpy-2.4.2-cp311-cp311-win_arm64.whl", hash = "sha256:47c5a6ed21d9452b10227e5e8a0e1c22979811cad7dcc19d8e3e2fb8fa03f1a3", size = 10485763, upload-time = "2026-01-31T23:10:50.087Z" }, + { url = "https://files.pythonhosted.org/packages/51/6e/6f394c9c77668153e14d4da83bcc247beb5952f6ead7699a1a2992613bea/numpy-2.4.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:21982668592194c609de53ba4933a7471880ccbaadcc52352694a59ecc860b3a", size = 16667963, upload-time = "2026-01-31T23:10:52.147Z" }, + { url = "https://files.pythonhosted.org/packages/1f/f8/55483431f2b2fd015ae6ed4fe62288823ce908437ed49db5a03d15151678/numpy-2.4.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40397bda92382fcec844066efb11f13e1c9a3e2a8e8f318fb72ed8b6db9f60f1", size = 14693571, upload-time = "2026-01-31T23:10:54.789Z" }, + { url = "https://files.pythonhosted.org/packages/2f/20/18026832b1845cdc82248208dd929ca14c9d8f2bac391f67440707fff27c/numpy-2.4.2-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:b3a24467af63c67829bfaa61eecf18d5432d4f11992688537be59ecd6ad32f5e", size = 5203469, upload-time = "2026-01-31T23:10:57.343Z" }, + { url = "https://files.pythonhosted.org/packages/7d/33/2eb97c8a77daaba34eaa3fa7241a14ac5f51c46a6bd5911361b644c4a1e2/numpy-2.4.2-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:805cc8de9fd6e7a22da5aed858e0ab16be5a4db6c873dde1d7451c541553aa27", size = 6550820, upload-time = "2026-01-31T23:10:59.429Z" }, + { url = "https://files.pythonhosted.org/packages/b1/91/b97fdfd12dc75b02c44e26c6638241cc004d4079a0321a69c62f51470c4c/numpy-2.4.2-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6d82351358ffbcdcd7b686b90742a9b86632d6c1c051016484fa0b326a0a1548", size = 15663067, upload-time = "2026-01-31T23:11:01.291Z" }, + { url = "https://files.pythonhosted.org/packages/f5/c6/a18e59f3f0b8071cc85cbc8d80cd02d68aa9710170b2553a117203d46936/numpy-2.4.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9e35d3e0144137d9fdae62912e869136164534d64a169f86438bc9561b6ad49f", size = 16619782, upload-time = "2026-01-31T23:11:03.669Z" }, + { url = "https://files.pythonhosted.org/packages/b7/83/9751502164601a79e18847309f5ceec0b1446d7b6aa12305759b72cf98b2/numpy-2.4.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:adb6ed2ad29b9e15321d167d152ee909ec73395901b70936f029c3bc6d7f4460", size = 17013128, upload-time = "2026-01-31T23:11:05.913Z" }, + { url = "https://files.pythonhosted.org/packages/61/c4/c4066322256ec740acc1c8923a10047818691d2f8aec254798f3dd90f5f2/numpy-2.4.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:8906e71fd8afcb76580404e2a950caef2685df3d2a57fe82a86ac8d33cc007ba", size = 18345324, upload-time = "2026-01-31T23:11:08.248Z" }, + { url = "https://files.pythonhosted.org/packages/ab/af/6157aa6da728fa4525a755bfad486ae7e3f76d4c1864138003eb84328497/numpy-2.4.2-cp312-cp312-win32.whl", hash = "sha256:ec055f6dae239a6299cace477b479cca2fc125c5675482daf1dd886933a1076f", size = 5960282, upload-time = "2026-01-31T23:11:10.497Z" }, + { url = "https://files.pythonhosted.org/packages/92/0f/7ceaaeaacb40567071e94dbf2c9480c0ae453d5bb4f52bea3892c39dc83c/numpy-2.4.2-cp312-cp312-win_amd64.whl", hash = "sha256:209fae046e62d0ce6435fcfe3b1a10537e858249b3d9b05829e2a05218296a85", size = 12314210, upload-time = "2026-01-31T23:11:12.176Z" }, + { url = "https://files.pythonhosted.org/packages/2f/a3/56c5c604fae6dd40fa2ed3040d005fca97e91bd320d232ac9931d77ba13c/numpy-2.4.2-cp312-cp312-win_arm64.whl", hash = "sha256:fbde1b0c6e81d56f5dccd95dd4a711d9b95df1ae4009a60887e56b27e8d903fa", size = 10220171, upload-time = "2026-01-31T23:11:14.684Z" }, + { url = "https://files.pythonhosted.org/packages/a1/22/815b9fe25d1d7ae7d492152adbc7226d3eff731dffc38fe970589fcaaa38/numpy-2.4.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:25f2059807faea4b077a2b6837391b5d830864b3543627f381821c646f31a63c", size = 16663696, upload-time = "2026-01-31T23:11:17.516Z" }, + { url = "https://files.pythonhosted.org/packages/09/f0/817d03a03f93ba9c6c8993de509277d84e69f9453601915e4a69554102a1/numpy-2.4.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:bd3a7a9f5847d2fb8c2c6d1c862fa109c31a9abeca1a3c2bd5a64572955b2979", size = 14688322, upload-time = "2026-01-31T23:11:19.883Z" }, + { url = "https://files.pythonhosted.org/packages/da/b4/f805ab79293c728b9a99438775ce51885fd4f31b76178767cfc718701a39/numpy-2.4.2-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:8e4549f8a3c6d13d55041925e912bfd834285ef1dd64d6bc7d542583355e2e98", size = 5198157, upload-time = "2026-01-31T23:11:22.375Z" }, + { url = "https://files.pythonhosted.org/packages/74/09/826e4289844eccdcd64aac27d13b0fd3f32039915dd5b9ba01baae1f436c/numpy-2.4.2-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:aea4f66ff44dfddf8c2cffd66ba6538c5ec67d389285292fe428cb2c738c8aef", size = 6546330, upload-time = "2026-01-31T23:11:23.958Z" }, + { url = "https://files.pythonhosted.org/packages/19/fb/cbfdbfa3057a10aea5422c558ac57538e6acc87ec1669e666d32ac198da7/numpy-2.4.2-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c3cd545784805de05aafe1dde61752ea49a359ccba9760c1e5d1c88a93bbf2b7", size = 15660968, upload-time = "2026-01-31T23:11:25.713Z" }, + { url = "https://files.pythonhosted.org/packages/04/dc/46066ce18d01645541f0186877377b9371b8fa8017fa8262002b4ef22612/numpy-2.4.2-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d0d9b7c93578baafcbc5f0b83eaf17b79d345c6f36917ba0c67f45226911d499", size = 16607311, upload-time = "2026-01-31T23:11:28.117Z" }, + { url = "https://files.pythonhosted.org/packages/14/d9/4b5adfc39a43fa6bf918c6d544bc60c05236cc2f6339847fc5b35e6cb5b0/numpy-2.4.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f74f0f7779cc7ae07d1810aab8ac6b1464c3eafb9e283a40da7309d5e6e48fbb", size = 17012850, upload-time = "2026-01-31T23:11:30.888Z" }, + { url = "https://files.pythonhosted.org/packages/b7/20/adb6e6adde6d0130046e6fdfb7675cc62bc2f6b7b02239a09eb58435753d/numpy-2.4.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:c7ac672d699bf36275c035e16b65539931347d68b70667d28984c9fb34e07fa7", size = 18334210, upload-time = "2026-01-31T23:11:33.214Z" }, + { url = "https://files.pythonhosted.org/packages/78/0e/0a73b3dff26803a8c02baa76398015ea2a5434d9b8265a7898a6028c1591/numpy-2.4.2-cp313-cp313-win32.whl", hash = "sha256:8e9afaeb0beff068b4d9cd20d322ba0ee1cecfb0b08db145e4ab4dd44a6b5110", size = 5958199, upload-time = "2026-01-31T23:11:35.385Z" }, + { url = "https://files.pythonhosted.org/packages/43/bc/6352f343522fcb2c04dbaf94cb30cca6fd32c1a750c06ad6231b4293708c/numpy-2.4.2-cp313-cp313-win_amd64.whl", hash = "sha256:7df2de1e4fba69a51c06c28f5a3de36731eb9639feb8e1cf7e4a7b0daf4cf622", size = 12310848, upload-time = "2026-01-31T23:11:38.001Z" }, + { url = "https://files.pythonhosted.org/packages/6e/8d/6da186483e308da5da1cc6918ce913dcfe14ffde98e710bfeff2a6158d4e/numpy-2.4.2-cp313-cp313-win_arm64.whl", hash = "sha256:0fece1d1f0a89c16b03442eae5c56dc0be0c7883b5d388e0c03f53019a4bfd71", size = 10221082, upload-time = "2026-01-31T23:11:40.392Z" }, + { url = "https://files.pythonhosted.org/packages/25/a1/9510aa43555b44781968935c7548a8926274f815de42ad3997e9e83680dd/numpy-2.4.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:5633c0da313330fd20c484c78cdd3f9b175b55e1a766c4a174230c6b70ad8262", size = 14815866, upload-time = "2026-01-31T23:11:42.495Z" }, + { url = "https://files.pythonhosted.org/packages/36/30/6bbb5e76631a5ae46e7923dd16ca9d3f1c93cfa8d4ed79a129814a9d8db3/numpy-2.4.2-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:d9f64d786b3b1dd742c946c42d15b07497ed14af1a1f3ce840cce27daa0ce913", size = 5325631, upload-time = "2026-01-31T23:11:44.7Z" }, + { url = "https://files.pythonhosted.org/packages/46/00/3a490938800c1923b567b3a15cd17896e68052e2145d8662aaf3e1ffc58f/numpy-2.4.2-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:b21041e8cb6a1eb5312dd1d2f80a94d91efffb7a06b70597d44f1bd2dfc315ab", size = 6646254, upload-time = "2026-01-31T23:11:46.341Z" }, + { url = "https://files.pythonhosted.org/packages/d3/e9/fac0890149898a9b609caa5af7455a948b544746e4b8fe7c212c8edd71f8/numpy-2.4.2-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:00ab83c56211a1d7c07c25e3217ea6695e50a3e2f255053686b081dc0b091a82", size = 15720138, upload-time = "2026-01-31T23:11:48.082Z" }, + { url = "https://files.pythonhosted.org/packages/ea/5c/08887c54e68e1e28df53709f1893ce92932cc6f01f7c3d4dc952f61ffd4e/numpy-2.4.2-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2fb882da679409066b4603579619341c6d6898fc83a8995199d5249f986e8e8f", size = 16655398, upload-time = "2026-01-31T23:11:50.293Z" }, + { url = "https://files.pythonhosted.org/packages/4d/89/253db0fa0e66e9129c745e4ef25631dc37d5f1314dad2b53e907b8538e6d/numpy-2.4.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:66cb9422236317f9d44b67b4d18f44efe6e9c7f8794ac0462978513359461554", size = 17079064, upload-time = "2026-01-31T23:11:52.927Z" }, + { url = "https://files.pythonhosted.org/packages/2a/d5/cbade46ce97c59c6c3da525e8d95b7abe8a42974a1dc5c1d489c10433e88/numpy-2.4.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:0f01dcf33e73d80bd8dc0f20a71303abbafa26a19e23f6b68d1aa9990af90257", size = 18379680, upload-time = "2026-01-31T23:11:55.22Z" }, + { url = "https://files.pythonhosted.org/packages/40/62/48f99ae172a4b63d981babe683685030e8a3df4f246c893ea5c6ef99f018/numpy-2.4.2-cp313-cp313t-win32.whl", hash = "sha256:52b913ec40ff7ae845687b0b34d8d93b60cb66dcee06996dd5c99f2fc9328657", size = 6082433, upload-time = "2026-01-31T23:11:58.096Z" }, + { url = "https://files.pythonhosted.org/packages/07/38/e054a61cfe48ad9f1ed0d188e78b7e26859d0b60ef21cd9de4897cdb5326/numpy-2.4.2-cp313-cp313t-win_amd64.whl", hash = "sha256:5eea80d908b2c1f91486eb95b3fb6fab187e569ec9752ab7d9333d2e66bf2d6b", size = 12451181, upload-time = "2026-01-31T23:11:59.782Z" }, + { url = "https://files.pythonhosted.org/packages/6e/a4/a05c3a6418575e185dd84d0b9680b6bb2e2dc3e4202f036b7b4e22d6e9dc/numpy-2.4.2-cp313-cp313t-win_arm64.whl", hash = "sha256:fd49860271d52127d61197bb50b64f58454e9f578cb4b2c001a6de8b1f50b0b1", size = 10290756, upload-time = "2026-01-31T23:12:02.438Z" }, + { url = "https://files.pythonhosted.org/packages/18/88/b7df6050bf18fdcfb7046286c6535cabbdd2064a3440fca3f069d319c16e/numpy-2.4.2-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:444be170853f1f9d528428eceb55f12918e4fda5d8805480f36a002f1415e09b", size = 16663092, upload-time = "2026-01-31T23:12:04.521Z" }, + { url = "https://files.pythonhosted.org/packages/25/7a/1fee4329abc705a469a4afe6e69b1ef7e915117747886327104a8493a955/numpy-2.4.2-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:d1240d50adff70c2a88217698ca844723068533f3f5c5fa6ee2e3220e3bdb000", size = 14698770, upload-time = "2026-01-31T23:12:06.96Z" }, + { url = "https://files.pythonhosted.org/packages/fb/0b/f9e49ba6c923678ad5bc38181c08ac5e53b7a5754dbca8e581aa1a56b1ff/numpy-2.4.2-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:7cdde6de52fb6664b00b056341265441192d1291c130e99183ec0d4b110ff8b1", size = 5208562, upload-time = "2026-01-31T23:12:09.632Z" }, + { url = "https://files.pythonhosted.org/packages/7d/12/d7de8f6f53f9bb76997e5e4c069eda2051e3fe134e9181671c4391677bb2/numpy-2.4.2-cp314-cp314-macosx_14_0_x86_64.whl", hash = "sha256:cda077c2e5b780200b6b3e09d0b42205a3d1c68f30c6dceb90401c13bff8fe74", size = 6543710, upload-time = "2026-01-31T23:12:11.969Z" }, + { url = "https://files.pythonhosted.org/packages/09/63/c66418c2e0268a31a4cf8a8b512685748200f8e8e8ec6c507ce14e773529/numpy-2.4.2-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d30291931c915b2ab5717c2974bb95ee891a1cf22ebc16a8006bd59cd210d40a", size = 15677205, upload-time = "2026-01-31T23:12:14.33Z" }, + { url = "https://files.pythonhosted.org/packages/5d/6c/7f237821c9642fb2a04d2f1e88b4295677144ca93285fd76eff3bcba858d/numpy-2.4.2-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bba37bc29d4d85761deed3954a1bc62be7cf462b9510b51d367b769a8c8df325", size = 16611738, upload-time = "2026-01-31T23:12:16.525Z" }, + { url = "https://files.pythonhosted.org/packages/c2/a7/39c4cdda9f019b609b5c473899d87abff092fc908cfe4d1ecb2fcff453b0/numpy-2.4.2-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:b2f0073ed0868db1dcd86e052d37279eef185b9c8db5bf61f30f46adac63c909", size = 17028888, upload-time = "2026-01-31T23:12:19.306Z" }, + { url = "https://files.pythonhosted.org/packages/da/b3/e84bb64bdfea967cc10950d71090ec2d84b49bc691df0025dddb7c26e8e3/numpy-2.4.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:7f54844851cdb630ceb623dcec4db3240d1ac13d4990532446761baede94996a", size = 18339556, upload-time = "2026-01-31T23:12:21.816Z" }, + { url = "https://files.pythonhosted.org/packages/88/f5/954a291bc1192a27081706862ac62bb5920fbecfbaa302f64682aa90beed/numpy-2.4.2-cp314-cp314-win32.whl", hash = "sha256:12e26134a0331d8dbd9351620f037ec470b7c75929cb8a1537f6bfe411152a1a", size = 6006899, upload-time = "2026-01-31T23:12:24.14Z" }, + { url = "https://files.pythonhosted.org/packages/05/cb/eff72a91b2efdd1bc98b3b8759f6a1654aa87612fc86e3d87d6fe4f948c4/numpy-2.4.2-cp314-cp314-win_amd64.whl", hash = "sha256:068cdb2d0d644cdb45670810894f6a0600797a69c05f1ac478e8d31670b8ee75", size = 12443072, upload-time = "2026-01-31T23:12:26.33Z" }, + { url = "https://files.pythonhosted.org/packages/37/75/62726948db36a56428fce4ba80a115716dc4fad6a3a4352487f8bb950966/numpy-2.4.2-cp314-cp314-win_arm64.whl", hash = "sha256:6ed0be1ee58eef41231a5c943d7d1375f093142702d5723ca2eb07db9b934b05", size = 10494886, upload-time = "2026-01-31T23:12:28.488Z" }, + { url = "https://files.pythonhosted.org/packages/36/2f/ee93744f1e0661dc267e4b21940870cabfae187c092e1433b77b09b50ac4/numpy-2.4.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:98f16a80e917003a12c0580f97b5f875853ebc33e2eaa4bccfc8201ac6869308", size = 14818567, upload-time = "2026-01-31T23:12:30.709Z" }, + { url = "https://files.pythonhosted.org/packages/a7/24/6535212add7d76ff938d8bdc654f53f88d35cddedf807a599e180dcb8e66/numpy-2.4.2-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:20abd069b9cda45874498b245c8015b18ace6de8546bf50dfa8cea1696ed06ef", size = 5328372, upload-time = "2026-01-31T23:12:32.962Z" }, + { url = "https://files.pythonhosted.org/packages/5e/9d/c48f0a035725f925634bf6b8994253b43f2047f6778a54147d7e213bc5a7/numpy-2.4.2-cp314-cp314t-macosx_14_0_x86_64.whl", hash = "sha256:e98c97502435b53741540a5717a6749ac2ada901056c7db951d33e11c885cc7d", size = 6649306, upload-time = "2026-01-31T23:12:34.797Z" }, + { url = "https://files.pythonhosted.org/packages/81/05/7c73a9574cd4a53a25907bad38b59ac83919c0ddc8234ec157f344d57d9a/numpy-2.4.2-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:da6cad4e82cb893db4b69105c604d805e0c3ce11501a55b5e9f9083b47d2ffe8", size = 15722394, upload-time = "2026-01-31T23:12:36.565Z" }, + { url = "https://files.pythonhosted.org/packages/35/fa/4de10089f21fc7d18442c4a767ab156b25c2a6eaf187c0db6d9ecdaeb43f/numpy-2.4.2-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9e4424677ce4b47fe73c8b5556d876571f7c6945d264201180db2dc34f676ab5", size = 16653343, upload-time = "2026-01-31T23:12:39.188Z" }, + { url = "https://files.pythonhosted.org/packages/b8/f9/d33e4ffc857f3763a57aa85650f2e82486832d7492280ac21ba9efda80da/numpy-2.4.2-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:2b8f157c8a6f20eb657e240f8985cc135598b2b46985c5bccbde7616dc9c6b1e", size = 17078045, upload-time = "2026-01-31T23:12:42.041Z" }, + { url = "https://files.pythonhosted.org/packages/c8/b8/54bdb43b6225badbea6389fa038c4ef868c44f5890f95dd530a218706da3/numpy-2.4.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:5daf6f3914a733336dab21a05cdec343144600e964d2fcdabaac0c0269874b2a", size = 18380024, upload-time = "2026-01-31T23:12:44.331Z" }, + { url = "https://files.pythonhosted.org/packages/a5/55/6e1a61ded7af8df04016d81b5b02daa59f2ea9252ee0397cb9f631efe9e5/numpy-2.4.2-cp314-cp314t-win32.whl", hash = "sha256:8c50dd1fc8826f5b26a5ee4d77ca55d88a895f4e4819c7ecc2a9f5905047a443", size = 6153937, upload-time = "2026-01-31T23:12:47.229Z" }, + { url = "https://files.pythonhosted.org/packages/45/aa/fa6118d1ed6d776b0983f3ceac9b1a5558e80df9365b1c3aa6d42bf9eee4/numpy-2.4.2-cp314-cp314t-win_amd64.whl", hash = "sha256:fcf92bee92742edd401ba41135185866f7026c502617f422eb432cfeca4fe236", size = 12631844, upload-time = "2026-01-31T23:12:48.997Z" }, + { url = "https://files.pythonhosted.org/packages/32/0a/2ec5deea6dcd158f254a7b372fb09cfba5719419c8d66343bab35237b3fb/numpy-2.4.2-cp314-cp314t-win_arm64.whl", hash = "sha256:1f92f53998a17265194018d1cc321b2e96e900ca52d54c7c77837b71b9465181", size = 10565379, upload-time = "2026-01-31T23:12:51.345Z" }, + { url = "https://files.pythonhosted.org/packages/f4/f8/50e14d36d915ef64d8f8bc4a087fc8264d82c785eda6711f80ab7e620335/numpy-2.4.2-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:89f7268c009bc492f506abd6f5265defa7cb3f7487dc21d357c3d290add45082", size = 16833179, upload-time = "2026-01-31T23:12:53.5Z" }, + { url = "https://files.pythonhosted.org/packages/17/17/809b5cad63812058a8189e91a1e2d55a5a18fd04611dbad244e8aeae465c/numpy-2.4.2-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:e6dee3bb76aa4009d5a912180bf5b2de012532998d094acee25d9cb8dee3e44a", size = 14889755, upload-time = "2026-01-31T23:12:55.933Z" }, + { url = "https://files.pythonhosted.org/packages/3e/ea/181b9bcf7627fc8371720316c24db888dcb9829b1c0270abf3d288b2e29b/numpy-2.4.2-pp311-pypy311_pp73-macosx_14_0_arm64.whl", hash = "sha256:cd2bd2bbed13e213d6b55dc1d035a4f91748a7d3edc9480c13898b0353708920", size = 5399500, upload-time = "2026-01-31T23:12:58.671Z" }, + { url = "https://files.pythonhosted.org/packages/33/9f/413adf3fc955541ff5536b78fcf0754680b3c6d95103230252a2c9408d23/numpy-2.4.2-pp311-pypy311_pp73-macosx_14_0_x86_64.whl", hash = "sha256:cf28c0c1d4c4bf00f509fa7eb02c58d7caf221b50b467bcb0d9bbf1584d5c821", size = 6714252, upload-time = "2026-01-31T23:13:00.518Z" }, + { url = "https://files.pythonhosted.org/packages/91/da/643aad274e29ccbdf42ecd94dafe524b81c87bcb56b83872d54827f10543/numpy-2.4.2-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e04ae107ac591763a47398bb45b568fc38f02dbc4aa44c063f67a131f99346cb", size = 15797142, upload-time = "2026-01-31T23:13:02.219Z" }, + { url = "https://files.pythonhosted.org/packages/66/27/965b8525e9cb5dc16481b30a1b3c21e50c7ebf6e9dbd48d0c4d0d5089c7e/numpy-2.4.2-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:602f65afdef699cda27ec0b9224ae5dc43e328f4c24c689deaf77133dbee74d0", size = 16727979, upload-time = "2026-01-31T23:13:04.62Z" }, + { url = "https://files.pythonhosted.org/packages/de/e5/b7d20451657664b07986c2f6e3be564433f5dcaf3482d68eaecd79afaf03/numpy-2.4.2-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:be71bf1edb48ebbbf7f6337b5bfd2f895d1902f6335a5830b20141fc126ffba0", size = 12502577, upload-time = "2026-01-31T23:13:07.08Z" }, +] + +[[package]] +name = "packaging" +version = "26.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/65/ee/299d360cdc32edc7d2cf530f3accf79c4fca01e96ffc950d8a52213bd8e4/packaging-26.0.tar.gz", hash = "sha256:00243ae351a257117b6a241061796684b084ed1c516a08c48a3f7e147a9d80b4", size = 143416, upload-time = "2026-01-21T20:50:39.064Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/b9/c538f279a4e237a006a2c98387d081e9eb060d203d8ed34467cc0f0b9b53/packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529", size = 74366, upload-time = "2026-01-21T20:50:37.788Z" }, +] + +[[package]] +name = "pyarrow" +version = "23.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/88/22/134986a4cc224d593c1afde5494d18ff629393d74cc2eddb176669f234a4/pyarrow-23.0.1.tar.gz", hash = "sha256:b8c5873e33440b2bc2f4a79d2b47017a89c5a24116c055625e6f2ee50523f019", size = 1167336, upload-time = "2026-02-16T10:14:12.39Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b0/41/8e6b6ef7e225d4ceead8459427a52afdc23379768f54dd3566014d7618c1/pyarrow-23.0.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:6f0147ee9e0386f519c952cc670eb4a8b05caa594eeffe01af0e25f699e4e9bb", size = 34302230, upload-time = "2026-02-16T10:09:03.859Z" }, + { url = "https://files.pythonhosted.org/packages/bf/4a/1472c00392f521fea03ae93408bf445cc7bfa1ab81683faf9bc188e36629/pyarrow-23.0.1-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:0ae6e17c828455b6265d590100c295193f93cc5675eb0af59e49dbd00d2de350", size = 35850050, upload-time = "2026-02-16T10:09:11.877Z" }, + { url = "https://files.pythonhosted.org/packages/0c/b2/bd1f2f05ded56af7f54d702c8364c9c43cd6abb91b0e9933f3d77b4f4132/pyarrow-23.0.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:fed7020203e9ef273360b9e45be52a2a47d3103caf156a30ace5247ffb51bdbd", size = 44491918, upload-time = "2026-02-16T10:09:18.144Z" }, + { url = "https://files.pythonhosted.org/packages/0b/62/96459ef5b67957eac38a90f541d1c28833d1b367f014a482cb63f3b7cd2d/pyarrow-23.0.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:26d50dee49d741ac0e82185033488d28d35be4d763ae6f321f97d1140eb7a0e9", size = 47562811, upload-time = "2026-02-16T10:09:25.792Z" }, + { url = "https://files.pythonhosted.org/packages/7d/94/1170e235add1f5f45a954e26cd0e906e7e74e23392dcb560de471f7366ec/pyarrow-23.0.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:3c30143b17161310f151f4a2bcfe41b5ff744238c1039338779424e38579d701", size = 48183766, upload-time = "2026-02-16T10:09:34.645Z" }, + { url = "https://files.pythonhosted.org/packages/0e/2d/39a42af4570377b99774cdb47f63ee6c7da7616bd55b3d5001aa18edfe4f/pyarrow-23.0.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:db2190fa79c80a23fdd29fef4b8992893f024ae7c17d2f5f4db7171fa30c2c78", size = 50607669, upload-time = "2026-02-16T10:09:44.153Z" }, + { url = "https://files.pythonhosted.org/packages/00/ca/db94101c187f3df742133ac837e93b1f269ebdac49427f8310ee40b6a58f/pyarrow-23.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:f00f993a8179e0e1c9713bcc0baf6d6c01326a406a9c23495ec1ba9c9ebf2919", size = 27527698, upload-time = "2026-02-16T10:09:50.263Z" }, + { url = "https://files.pythonhosted.org/packages/9a/4b/4166bb5abbfe6f750fc60ad337c43ecf61340fa52ab386da6e8dbf9e63c4/pyarrow-23.0.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:f4b0dbfa124c0bb161f8b5ebb40f1a680b70279aa0c9901d44a2b5a20806039f", size = 34214575, upload-time = "2026-02-16T10:09:56.225Z" }, + { url = "https://files.pythonhosted.org/packages/e1/da/3f941e3734ac8088ea588b53e860baeddac8323ea40ce22e3d0baa865cc9/pyarrow-23.0.1-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:7707d2b6673f7de054e2e83d59f9e805939038eebe1763fe811ee8fa5c0cd1a7", size = 35832540, upload-time = "2026-02-16T10:10:03.428Z" }, + { url = "https://files.pythonhosted.org/packages/88/7c/3d841c366620e906d54430817531b877ba646310296df42ef697308c2705/pyarrow-23.0.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:86ff03fb9f1a320266e0de855dee4b17da6794c595d207f89bba40d16b5c78b9", size = 44470940, upload-time = "2026-02-16T10:10:10.704Z" }, + { url = "https://files.pythonhosted.org/packages/2c/a5/da83046273d990f256cb79796a190bbf7ec999269705ddc609403f8c6b06/pyarrow-23.0.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:813d99f31275919c383aab17f0f455a04f5a429c261cc411b1e9a8f5e4aaaa05", size = 47586063, upload-time = "2026-02-16T10:10:17.95Z" }, + { url = "https://files.pythonhosted.org/packages/5b/3c/b7d2ebcff47a514f47f9da1e74b7949138c58cfeb108cdd4ee62f43f0cf3/pyarrow-23.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:bf5842f960cddd2ef757d486041d57c96483efc295a8c4a0e20e704cbbf39c67", size = 48173045, upload-time = "2026-02-16T10:10:25.363Z" }, + { url = "https://files.pythonhosted.org/packages/43/b2/b40961262213beaba6acfc88698eb773dfce32ecdf34d19291db94c2bd73/pyarrow-23.0.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:564baf97c858ecc03ec01a41062e8f4698abc3e6e2acd79c01c2e97880a19730", size = 50621741, upload-time = "2026-02-16T10:10:33.477Z" }, + { url = "https://files.pythonhosted.org/packages/f6/70/1fdda42d65b28b078e93d75d371b2185a61da89dda4def8ba6ba41ebdeb4/pyarrow-23.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:07deae7783782ac7250989a7b2ecde9b3c343a643f82e8a4df03d93b633006f0", size = 27620678, upload-time = "2026-02-16T10:10:39.31Z" }, + { url = "https://files.pythonhosted.org/packages/47/10/2cbe4c6f0fb83d2de37249567373d64327a5e4d8db72f486db42875b08f6/pyarrow-23.0.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:6b8fda694640b00e8af3c824f99f789e836720aa8c9379fb435d4c4953a756b8", size = 34210066, upload-time = "2026-02-16T10:10:45.487Z" }, + { url = "https://files.pythonhosted.org/packages/cb/4f/679fa7e84dadbaca7a65f7cdba8d6c83febbd93ca12fa4adf40ba3b6362b/pyarrow-23.0.1-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:8ff51b1addc469b9444b7c6f3548e19dc931b172ab234e995a60aea9f6e6025f", size = 35825526, upload-time = "2026-02-16T10:10:52.266Z" }, + { url = "https://files.pythonhosted.org/packages/f9/63/d2747d930882c9d661e9398eefc54f15696547b8983aaaf11d4a2e8b5426/pyarrow-23.0.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:71c5be5cbf1e1cb6169d2a0980850bccb558ddc9b747b6206435313c47c37677", size = 44473279, upload-time = "2026-02-16T10:11:01.557Z" }, + { url = "https://files.pythonhosted.org/packages/b3/93/10a48b5e238de6d562a411af6467e71e7aedbc9b87f8d3a35f1560ae30fb/pyarrow-23.0.1-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:9b6f4f17b43bc39d56fec96e53fe89d94bac3eb134137964371b45352d40d0c2", size = 47585798, upload-time = "2026-02-16T10:11:09.401Z" }, + { url = "https://files.pythonhosted.org/packages/5c/20/476943001c54ef078dbf9542280e22741219a184a0632862bca4feccd666/pyarrow-23.0.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9fc13fc6c403d1337acab46a2c4346ca6c9dec5780c3c697cf8abfd5e19b6b37", size = 48179446, upload-time = "2026-02-16T10:11:17.781Z" }, + { url = "https://files.pythonhosted.org/packages/4b/b6/5dd0c47b335fcd8edba9bfab78ad961bd0fd55ebe53468cc393f45e0be60/pyarrow-23.0.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5c16ed4f53247fa3ffb12a14d236de4213a4415d127fe9cebed33d51671113e2", size = 50623972, upload-time = "2026-02-16T10:11:26.185Z" }, + { url = "https://files.pythonhosted.org/packages/d5/09/a532297c9591a727d67760e2e756b83905dd89adb365a7f6e9c72578bcc1/pyarrow-23.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:cecfb12ef629cf6be0b1887f9f86463b0dd3dc3195ae6224e74006be4736035a", size = 27540749, upload-time = "2026-02-16T10:12:23.297Z" }, + { url = "https://files.pythonhosted.org/packages/a5/8e/38749c4b1303e6ae76b3c80618f84861ae0c55dd3c2273842ea6f8258233/pyarrow-23.0.1-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:29f7f7419a0e30264ea261fdc0e5fe63ce5a6095003db2945d7cd78df391a7e1", size = 34471544, upload-time = "2026-02-16T10:11:32.535Z" }, + { url = "https://files.pythonhosted.org/packages/a3/73/f237b2bc8c669212f842bcfd842b04fc8d936bfc9d471630569132dc920d/pyarrow-23.0.1-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:33d648dc25b51fd8055c19e4261e813dfc4d2427f068bcecc8b53d01b81b0500", size = 35949911, upload-time = "2026-02-16T10:11:39.813Z" }, + { url = "https://files.pythonhosted.org/packages/0c/86/b912195eee0903b5611bf596833def7d146ab2d301afeb4b722c57ffc966/pyarrow-23.0.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:cd395abf8f91c673dd3589cadc8cc1ee4e8674fa61b2e923c8dd215d9c7d1f41", size = 44520337, upload-time = "2026-02-16T10:11:47.764Z" }, + { url = "https://files.pythonhosted.org/packages/69/c2/f2a717fb824f62d0be952ea724b4f6f9372a17eed6f704b5c9526f12f2f1/pyarrow-23.0.1-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:00be9576d970c31defb5c32eb72ef585bf600ef6d0a82d5eccaae96639cf9d07", size = 47548944, upload-time = "2026-02-16T10:11:56.607Z" }, + { url = "https://files.pythonhosted.org/packages/84/a7/90007d476b9f0dc308e3bc57b832d004f848fd6c0da601375d20d92d1519/pyarrow-23.0.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:c2139549494445609f35a5cda4eb94e2c9e4d704ce60a095b342f82460c73a83", size = 48236269, upload-time = "2026-02-16T10:12:04.47Z" }, + { url = "https://files.pythonhosted.org/packages/b0/3f/b16fab3e77709856eb6ac328ce35f57a6d4a18462c7ca5186ef31b45e0e0/pyarrow-23.0.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:7044b442f184d84e2351e5084600f0d7343d6117aabcbc1ac78eb1ae11eb4125", size = 50604794, upload-time = "2026-02-16T10:12:11.797Z" }, + { url = "https://files.pythonhosted.org/packages/e9/a1/22df0620a9fac31d68397a75465c344e83c3dfe521f7612aea33e27ab6c0/pyarrow-23.0.1-cp313-cp313t-win_amd64.whl", hash = "sha256:a35581e856a2fafa12f3f54fce4331862b1cfb0bef5758347a858a4aa9d6bae8", size = 27660642, upload-time = "2026-02-16T10:12:17.746Z" }, + { url = "https://files.pythonhosted.org/packages/8d/1b/6da9a89583ce7b23ac611f183ae4843cd3a6cf54f079549b0e8c14031e73/pyarrow-23.0.1-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:5df1161da23636a70838099d4aaa65142777185cc0cdba4037a18cee7d8db9ca", size = 34238755, upload-time = "2026-02-16T10:12:32.819Z" }, + { url = "https://files.pythonhosted.org/packages/ae/b5/d58a241fbe324dbaeb8df07be6af8752c846192d78d2272e551098f74e88/pyarrow-23.0.1-cp314-cp314-macosx_12_0_x86_64.whl", hash = "sha256:fa8e51cb04b9f8c9c5ace6bab63af9a1f88d35c0d6cbf53e8c17c098552285e1", size = 35847826, upload-time = "2026-02-16T10:12:38.949Z" }, + { url = "https://files.pythonhosted.org/packages/54/a5/8cbc83f04aba433ca7b331b38f39e000efd9f0c7ce47128670e737542996/pyarrow-23.0.1-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:0b95a3994f015be13c63148fef8832e8a23938128c185ee951c98908a696e0eb", size = 44536859, upload-time = "2026-02-16T10:12:45.467Z" }, + { url = "https://files.pythonhosted.org/packages/36/2e/c0f017c405fcdc252dbccafbe05e36b0d0eb1ea9a958f081e01c6972927f/pyarrow-23.0.1-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:4982d71350b1a6e5cfe1af742c53dfb759b11ce14141870d05d9e540d13bc5d1", size = 47614443, upload-time = "2026-02-16T10:12:55.525Z" }, + { url = "https://files.pythonhosted.org/packages/af/6b/2314a78057912f5627afa13ba43809d9d653e6630859618b0fd81a4e0759/pyarrow-23.0.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:c250248f1fe266db627921c89b47b7c06fee0489ad95b04d50353537d74d6886", size = 48232991, upload-time = "2026-02-16T10:13:04.729Z" }, + { url = "https://files.pythonhosted.org/packages/40/f2/1bcb1d3be3460832ef3370d621142216e15a2c7c62602a4ea19ec240dd64/pyarrow-23.0.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:5f4763b83c11c16e5f4c15601ba6dfa849e20723b46aa2617cb4bffe8768479f", size = 50645077, upload-time = "2026-02-16T10:13:14.147Z" }, + { url = "https://files.pythonhosted.org/packages/eb/3f/b1da7b61cd66566a4d4c8383d376c606d1c34a906c3f1cb35c479f59d1aa/pyarrow-23.0.1-cp314-cp314-win_amd64.whl", hash = "sha256:3a4c85ef66c134161987c17b147d6bffdca4566f9a4c1d81a0a01cdf08414ea5", size = 28234271, upload-time = "2026-02-16T10:14:09.397Z" }, + { url = "https://files.pythonhosted.org/packages/b5/78/07f67434e910a0f7323269be7bfbf58699bd0c1d080b18a1ab49ba943fe8/pyarrow-23.0.1-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:17cd28e906c18af486a499422740298c52d7c6795344ea5002a7720b4eadf16d", size = 34488692, upload-time = "2026-02-16T10:13:21.541Z" }, + { url = "https://files.pythonhosted.org/packages/50/76/34cf7ae93ece1f740a04910d9f7e80ba166b9b4ab9596a953e9e62b90fe1/pyarrow-23.0.1-cp314-cp314t-macosx_12_0_x86_64.whl", hash = "sha256:76e823d0e86b4fb5e1cf4a58d293036e678b5a4b03539be933d3b31f9406859f", size = 35964383, upload-time = "2026-02-16T10:13:28.63Z" }, + { url = "https://files.pythonhosted.org/packages/46/90/459b827238936d4244214be7c684e1b366a63f8c78c380807ae25ed92199/pyarrow-23.0.1-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:a62e1899e3078bf65943078b3ad2a6ddcacf2373bc06379aac61b1e548a75814", size = 44538119, upload-time = "2026-02-16T10:13:35.506Z" }, + { url = "https://files.pythonhosted.org/packages/28/a1/93a71ae5881e99d1f9de1d4554a87be37da11cd6b152239fb5bd924fdc64/pyarrow-23.0.1-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:df088e8f640c9fae3b1f495b3c64755c4e719091caf250f3a74d095ddf3c836d", size = 47571199, upload-time = "2026-02-16T10:13:42.504Z" }, + { url = "https://files.pythonhosted.org/packages/88/a3/d2c462d4ef313521eaf2eff04d204ac60775263f1fb08c374b543f79f610/pyarrow-23.0.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:46718a220d64677c93bc243af1d44b55998255427588e400677d7192671845c7", size = 48259435, upload-time = "2026-02-16T10:13:49.226Z" }, + { url = "https://files.pythonhosted.org/packages/cc/f1/11a544b8c3d38a759eb3fbb022039117fd633e9a7b19e4841cc3da091915/pyarrow-23.0.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:a09f3876e87f48bc2f13583ab551f0379e5dfb83210391e68ace404181a20690", size = 50629149, upload-time = "2026-02-16T10:13:57.238Z" }, + { url = "https://files.pythonhosted.org/packages/50/f2/c0e76a0b451ffdf0cf788932e182758eb7558953f4f27f1aff8e2518b653/pyarrow-23.0.1-cp314-cp314t-win_amd64.whl", hash = "sha256:527e8d899f14bd15b740cd5a54ad56b7f98044955373a17179d5956ddb93d9ce", size = 28365807, upload-time = "2026-02-16T10:14:03.892Z" }, +] + +[[package]] +name = "pycparser" +version = "3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1b/7d/92392ff7815c21062bea51aa7b87d45576f649f16458d78b7cf94b9ab2e6/pycparser-3.0.tar.gz", hash = "sha256:600f49d217304a5902ac3c37e1281c9fe94e4d0489de643a9504c5cdfdfc6b29", size = 103492, upload-time = "2026-01-21T14:26:51.89Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0c/c3/44f3fbbfa403ea2a7c779186dc20772604442dde72947e7d01069cbe98e3/pycparser-3.0-py3-none-any.whl", hash = "sha256:b727414169a36b7d524c1c3e31839a521725078d7b2ff038656844266160a992", size = 48172, upload-time = "2026-01-21T14:26:50.693Z" }, +] + +[[package]] +name = "pydata-sphinx-theme" +version = "0.16.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "accessible-pygments" }, + { name = "babel" }, + { name = "beautifulsoup4" }, + { name = "docutils" }, + { name = "pygments" }, + { name = "sphinx", version = "9.0.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.12'" }, + { name = "sphinx", version = "9.1.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/00/20/bb50f9de3a6de69e6abd6b087b52fa2418a0418b19597601605f855ad044/pydata_sphinx_theme-0.16.1.tar.gz", hash = "sha256:a08b7f0b7f70387219dc659bff0893a7554d5eb39b59d3b8ef37b8401b7642d7", size = 2412693, upload-time = "2024-12-17T10:53:39.537Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e2/0d/8ba33fa83a7dcde13eb3c1c2a0c1cc29950a048bfed6d9b0d8b6bd710b4c/pydata_sphinx_theme-0.16.1-py3-none-any.whl", hash = "sha256:225331e8ac4b32682c18fcac5a57a6f717c4e632cea5dd0e247b55155faeccde", size = 6723264, upload-time = "2024-12-17T10:53:35.645Z" }, +] + +[[package]] +name = "pygithub" +version = "2.8.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyjwt", extra = ["crypto"] }, + { name = "pynacl" }, + { name = "requests" }, + { name = "typing-extensions" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c1/74/e560bdeffea72ecb26cff27f0fad548bbff5ecc51d6a155311ea7f9e4c4c/pygithub-2.8.1.tar.gz", hash = "sha256:341b7c78521cb07324ff670afd1baa2bf5c286f8d9fd302c1798ba594a5400c9", size = 2246994, upload-time = "2025-09-02T17:41:54.674Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/07/ba/7049ce39f653f6140aac4beb53a5aaf08b4407b6a3019aae394c1c5244ff/pygithub-2.8.1-py3-none-any.whl", hash = "sha256:23a0a5bca93baef082e03411bf0ce27204c32be8bfa7abc92fe4a3e132936df0", size = 432709, upload-time = "2025-09-02T17:41:52.947Z" }, +] + +[[package]] +name = "pygments" +version = "2.19.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, +] + +[[package]] +name = "pyjwt" +version = "2.11.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5c/5a/b46fa56bf322901eee5b0454a34343cdbdae202cd421775a8ee4e42fd519/pyjwt-2.11.0.tar.gz", hash = "sha256:35f95c1f0fbe5d5ba6e43f00271c275f7a1a4db1dab27bf708073b75318ea623", size = 98019, upload-time = "2026-01-30T19:59:55.694Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6f/01/c26ce75ba460d5cd503da9e13b21a33804d38c2165dec7b716d06b13010c/pyjwt-2.11.0-py3-none-any.whl", hash = "sha256:94a6bde30eb5c8e04fee991062b534071fd1439ef58d2adc9ccb823e7bcd0469", size = 28224, upload-time = "2026-01-30T19:59:54.539Z" }, +] + +[package.optional-dependencies] +crypto = [ + { name = "cryptography" }, +] + +[[package]] +name = "pynacl" +version = "1.6.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi", marker = "platform_python_implementation != 'PyPy'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d9/9a/4019b524b03a13438637b11538c82781a5eda427394380381af8f04f467a/pynacl-1.6.2.tar.gz", hash = "sha256:018494d6d696ae03c7e656e5e74cdfd8ea1326962cc401bcf018f1ed8436811c", size = 3511692, upload-time = "2026-01-01T17:48:10.851Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4b/79/0e3c34dc3c4671f67d251c07aa8eb100916f250ee470df230b0ab89551b4/pynacl-1.6.2-cp314-cp314t-macosx_10_10_universal2.whl", hash = "sha256:622d7b07cc5c02c666795792931b50c91f3ce3c2649762efb1ef0d5684c81594", size = 390064, upload-time = "2026-01-01T17:31:57.264Z" }, + { url = "https://files.pythonhosted.org/packages/eb/1c/23a26e931736e13b16483795c8a6b2f641bf6a3d5238c22b070a5112722c/pynacl-1.6.2-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d071c6a9a4c94d79eb665db4ce5cedc537faf74f2355e4d502591d850d3913c0", size = 809370, upload-time = "2026-01-01T17:31:59.198Z" }, + { url = "https://files.pythonhosted.org/packages/87/74/8d4b718f8a22aea9e8dcc8b95deb76d4aae380e2f5b570cc70b5fd0a852d/pynacl-1.6.2-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fe9847ca47d287af41e82be1dd5e23023d3c31a951da134121ab02e42ac218c9", size = 1408304, upload-time = "2026-01-01T17:32:01.162Z" }, + { url = "https://files.pythonhosted.org/packages/fd/73/be4fdd3a6a87fe8a4553380c2b47fbd1f7f58292eb820902f5c8ac7de7b0/pynacl-1.6.2-cp314-cp314t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:04316d1fc625d860b6c162fff704eb8426b1a8bcd3abacea11142cbd99a6b574", size = 844871, upload-time = "2026-01-01T17:32:02.824Z" }, + { url = "https://files.pythonhosted.org/packages/55/ad/6efc57ab75ee4422e96b5f2697d51bbcf6cdcc091e66310df91fbdc144a8/pynacl-1.6.2-cp314-cp314t-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:44081faff368d6c5553ccf55322ef2819abb40e25afaec7e740f159f74813634", size = 1446356, upload-time = "2026-01-01T17:32:04.452Z" }, + { url = "https://files.pythonhosted.org/packages/78/b7/928ee9c4779caa0a915844311ab9fb5f99585621c5d6e4574538a17dca07/pynacl-1.6.2-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:a9f9932d8d2811ce1a8ffa79dcbdf3970e7355b5c8eb0c1a881a57e7f7d96e88", size = 826814, upload-time = "2026-01-01T17:32:06.078Z" }, + { url = "https://files.pythonhosted.org/packages/f7/a9/1bdba746a2be20f8809fee75c10e3159d75864ef69c6b0dd168fc60e485d/pynacl-1.6.2-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:bc4a36b28dd72fb4845e5d8f9760610588a96d5a51f01d84d8c6ff9849968c14", size = 1411742, upload-time = "2026-01-01T17:32:07.651Z" }, + { url = "https://files.pythonhosted.org/packages/f3/2f/5e7ea8d85f9f3ea5b6b87db1d8388daa3587eed181bdeb0306816fdbbe79/pynacl-1.6.2-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:3bffb6d0f6becacb6526f8f42adfb5efb26337056ee0831fb9a7044d1a964444", size = 801714, upload-time = "2026-01-01T17:32:09.558Z" }, + { url = "https://files.pythonhosted.org/packages/06/ea/43fe2f7eab5f200e40fb10d305bf6f87ea31b3bbc83443eac37cd34a9e1e/pynacl-1.6.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:2fef529ef3ee487ad8113d287a593fa26f48ee3620d92ecc6f1d09ea38e0709b", size = 1372257, upload-time = "2026-01-01T17:32:11.026Z" }, + { url = "https://files.pythonhosted.org/packages/4d/54/c9ea116412788629b1347e415f72195c25eb2f3809b2d3e7b25f5c79f13a/pynacl-1.6.2-cp314-cp314t-win32.whl", hash = "sha256:a84bf1c20339d06dc0c85d9aea9637a24f718f375d861b2668b2f9f96fa51145", size = 231319, upload-time = "2026-01-01T17:32:12.46Z" }, + { url = "https://files.pythonhosted.org/packages/ce/04/64e9d76646abac2dccf904fccba352a86e7d172647557f35b9fe2a5ee4a1/pynacl-1.6.2-cp314-cp314t-win_amd64.whl", hash = "sha256:320ef68a41c87547c91a8b58903c9caa641ab01e8512ce291085b5fe2fcb7590", size = 244044, upload-time = "2026-01-01T17:32:13.781Z" }, + { url = "https://files.pythonhosted.org/packages/33/33/7873dc161c6a06f43cda13dec67b6fe152cb2f982581151956fa5e5cdb47/pynacl-1.6.2-cp314-cp314t-win_arm64.whl", hash = "sha256:d29bfe37e20e015a7d8b23cfc8bd6aa7909c92a1b8f41ee416bbb3e79ef182b2", size = 188740, upload-time = "2026-01-01T17:32:15.083Z" }, + { url = "https://files.pythonhosted.org/packages/be/7b/4845bbf88e94586ec47a432da4e9107e3fc3ce37eb412b1398630a37f7dd/pynacl-1.6.2-cp38-abi3-macosx_10_10_universal2.whl", hash = "sha256:c949ea47e4206af7c8f604b8278093b674f7c79ed0d4719cc836902bf4517465", size = 388458, upload-time = "2026-01-01T17:32:16.829Z" }, + { url = "https://files.pythonhosted.org/packages/1e/b4/e927e0653ba63b02a4ca5b4d852a8d1d678afbf69b3dbf9c4d0785ac905c/pynacl-1.6.2-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8845c0631c0be43abdd865511c41eab235e0be69c81dc66a50911594198679b0", size = 800020, upload-time = "2026-01-01T17:32:18.34Z" }, + { url = "https://files.pythonhosted.org/packages/7f/81/d60984052df5c97b1d24365bc1e30024379b42c4edcd79d2436b1b9806f2/pynacl-1.6.2-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:22de65bb9010a725b0dac248f353bb072969c94fa8d6b1f34b87d7953cf7bbe4", size = 1399174, upload-time = "2026-01-01T17:32:20.239Z" }, + { url = "https://files.pythonhosted.org/packages/68/f7/322f2f9915c4ef27d140101dd0ed26b479f7e6f5f183590fd32dfc48c4d3/pynacl-1.6.2-cp38-abi3-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:46065496ab748469cdd999246d17e301b2c24ae2fdf739132e580a0e94c94a87", size = 835085, upload-time = "2026-01-01T17:32:22.24Z" }, + { url = "https://files.pythonhosted.org/packages/3e/d0/f301f83ac8dbe53442c5a43f6a39016f94f754d7a9815a875b65e218a307/pynacl-1.6.2-cp38-abi3-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8a66d6fb6ae7661c58995f9c6435bda2b1e68b54b598a6a10247bfcdadac996c", size = 1437614, upload-time = "2026-01-01T17:32:23.766Z" }, + { url = "https://files.pythonhosted.org/packages/c4/58/fc6e649762b029315325ace1a8c6be66125e42f67416d3dbd47b69563d61/pynacl-1.6.2-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:26bfcd00dcf2cf160f122186af731ae30ab120c18e8375684ec2670dccd28130", size = 818251, upload-time = "2026-01-01T17:32:25.69Z" }, + { url = "https://files.pythonhosted.org/packages/c9/a8/b917096b1accc9acd878819a49d3d84875731a41eb665f6ebc826b1af99e/pynacl-1.6.2-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:c8a231e36ec2cab018c4ad4358c386e36eede0319a0c41fed24f840b1dac59f6", size = 1402859, upload-time = "2026-01-01T17:32:27.215Z" }, + { url = "https://files.pythonhosted.org/packages/85/42/fe60b5f4473e12c72f977548e4028156f4d340b884c635ec6b063fe7e9a5/pynacl-1.6.2-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:68be3a09455743ff9505491220b64440ced8973fe930f270c8e07ccfa25b1f9e", size = 791926, upload-time = "2026-01-01T17:32:29.314Z" }, + { url = "https://files.pythonhosted.org/packages/fa/f9/e40e318c604259301cc091a2a63f237d9e7b424c4851cafaea4ea7c4834e/pynacl-1.6.2-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:8b097553b380236d51ed11356c953bf8ce36a29a3e596e934ecabe76c985a577", size = 1363101, upload-time = "2026-01-01T17:32:31.263Z" }, + { url = "https://files.pythonhosted.org/packages/48/47/e761c254f410c023a469284a9bc210933e18588ca87706ae93002c05114c/pynacl-1.6.2-cp38-abi3-win32.whl", hash = "sha256:5811c72b473b2f38f7e2a3dc4f8642e3a3e9b5e7317266e4ced1fba85cae41aa", size = 227421, upload-time = "2026-01-01T17:32:33.076Z" }, + { url = "https://files.pythonhosted.org/packages/41/ad/334600e8cacc7d86587fe5f565480fde569dfb487389c8e1be56ac21d8ac/pynacl-1.6.2-cp38-abi3-win_amd64.whl", hash = "sha256:62985f233210dee6548c223301b6c25440852e13d59a8b81490203c3227c5ba0", size = 239754, upload-time = "2026-01-01T17:32:34.557Z" }, + { url = "https://files.pythonhosted.org/packages/29/7d/5945b5af29534641820d3bd7b00962abbbdfee84ec7e19f0d5b3175f9a31/pynacl-1.6.2-cp38-abi3-win_arm64.whl", hash = "sha256:834a43af110f743a754448463e8fd61259cd4ab5bbedcf70f9dabad1d28a394c", size = 184801, upload-time = "2026-01-01T17:32:36.309Z" }, +] + +[[package]] +name = "pyyaml" +version = "6.0.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/05/8e/961c0007c59b8dd7729d542c61a4d537767a59645b82a0b521206e1e25c2/pyyaml-6.0.3.tar.gz", hash = "sha256:d76623373421df22fb4cf8817020cbb7ef15c725b9d5e45f17e189bfc384190f", size = 130960, upload-time = "2025-09-25T21:33:16.546Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6d/16/a95b6757765b7b031c9374925bb718d55e0a9ba8a1b6a12d25962ea44347/pyyaml-6.0.3-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:44edc647873928551a01e7a563d7452ccdebee747728c1080d881d68af7b997e", size = 185826, upload-time = "2025-09-25T21:31:58.655Z" }, + { url = "https://files.pythonhosted.org/packages/16/19/13de8e4377ed53079ee996e1ab0a9c33ec2faf808a4647b7b4c0d46dd239/pyyaml-6.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:652cb6edd41e718550aad172851962662ff2681490a8a711af6a4d288dd96824", size = 175577, upload-time = "2025-09-25T21:32:00.088Z" }, + { url = "https://files.pythonhosted.org/packages/0c/62/d2eb46264d4b157dae1275b573017abec435397aa59cbcdab6fc978a8af4/pyyaml-6.0.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:10892704fc220243f5305762e276552a0395f7beb4dbf9b14ec8fd43b57f126c", size = 775556, upload-time = "2025-09-25T21:32:01.31Z" }, + { url = "https://files.pythonhosted.org/packages/10/cb/16c3f2cf3266edd25aaa00d6c4350381c8b012ed6f5276675b9eba8d9ff4/pyyaml-6.0.3-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:850774a7879607d3a6f50d36d04f00ee69e7fc816450e5f7e58d7f17f1ae5c00", size = 882114, upload-time = "2025-09-25T21:32:03.376Z" }, + { url = "https://files.pythonhosted.org/packages/71/60/917329f640924b18ff085ab889a11c763e0b573da888e8404ff486657602/pyyaml-6.0.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b8bb0864c5a28024fac8a632c443c87c5aa6f215c0b126c449ae1a150412f31d", size = 806638, upload-time = "2025-09-25T21:32:04.553Z" }, + { url = "https://files.pythonhosted.org/packages/dd/6f/529b0f316a9fd167281a6c3826b5583e6192dba792dd55e3203d3f8e655a/pyyaml-6.0.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1d37d57ad971609cf3c53ba6a7e365e40660e3be0e5175fa9f2365a379d6095a", size = 767463, upload-time = "2025-09-25T21:32:06.152Z" }, + { url = "https://files.pythonhosted.org/packages/f2/6a/b627b4e0c1dd03718543519ffb2f1deea4a1e6d42fbab8021936a4d22589/pyyaml-6.0.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:37503bfbfc9d2c40b344d06b2199cf0e96e97957ab1c1b546fd4f87e53e5d3e4", size = 794986, upload-time = "2025-09-25T21:32:07.367Z" }, + { url = "https://files.pythonhosted.org/packages/45/91/47a6e1c42d9ee337c4839208f30d9f09caa9f720ec7582917b264defc875/pyyaml-6.0.3-cp311-cp311-win32.whl", hash = "sha256:8098f252adfa6c80ab48096053f512f2321f0b998f98150cea9bd23d83e1467b", size = 142543, upload-time = "2025-09-25T21:32:08.95Z" }, + { url = "https://files.pythonhosted.org/packages/da/e3/ea007450a105ae919a72393cb06f122f288ef60bba2dc64b26e2646fa315/pyyaml-6.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:9f3bfb4965eb874431221a3ff3fdcddc7e74e3b07799e0e84ca4a0f867d449bf", size = 158763, upload-time = "2025-09-25T21:32:09.96Z" }, + { url = "https://files.pythonhosted.org/packages/d1/33/422b98d2195232ca1826284a76852ad5a86fe23e31b009c9886b2d0fb8b2/pyyaml-6.0.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7f047e29dcae44602496db43be01ad42fc6f1cc0d8cd6c83d342306c32270196", size = 182063, upload-time = "2025-09-25T21:32:11.445Z" }, + { url = "https://files.pythonhosted.org/packages/89/a0/6cf41a19a1f2f3feab0e9c0b74134aa2ce6849093d5517a0c550fe37a648/pyyaml-6.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fc09d0aa354569bc501d4e787133afc08552722d3ab34836a80547331bb5d4a0", size = 173973, upload-time = "2025-09-25T21:32:12.492Z" }, + { url = "https://files.pythonhosted.org/packages/ed/23/7a778b6bd0b9a8039df8b1b1d80e2e2ad78aa04171592c8a5c43a56a6af4/pyyaml-6.0.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9149cad251584d5fb4981be1ecde53a1ca46c891a79788c0df828d2f166bda28", size = 775116, upload-time = "2025-09-25T21:32:13.652Z" }, + { url = "https://files.pythonhosted.org/packages/65/30/d7353c338e12baef4ecc1b09e877c1970bd3382789c159b4f89d6a70dc09/pyyaml-6.0.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5fdec68f91a0c6739b380c83b951e2c72ac0197ace422360e6d5a959d8d97b2c", size = 844011, upload-time = "2025-09-25T21:32:15.21Z" }, + { url = "https://files.pythonhosted.org/packages/8b/9d/b3589d3877982d4f2329302ef98a8026e7f4443c765c46cfecc8858c6b4b/pyyaml-6.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ba1cc08a7ccde2d2ec775841541641e4548226580ab850948cbfda66a1befcdc", size = 807870, upload-time = "2025-09-25T21:32:16.431Z" }, + { url = "https://files.pythonhosted.org/packages/05/c0/b3be26a015601b822b97d9149ff8cb5ead58c66f981e04fedf4e762f4bd4/pyyaml-6.0.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8dc52c23056b9ddd46818a57b78404882310fb473d63f17b07d5c40421e47f8e", size = 761089, upload-time = "2025-09-25T21:32:17.56Z" }, + { url = "https://files.pythonhosted.org/packages/be/8e/98435a21d1d4b46590d5459a22d88128103f8da4c2d4cb8f14f2a96504e1/pyyaml-6.0.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:41715c910c881bc081f1e8872880d3c650acf13dfa8214bad49ed4cede7c34ea", size = 790181, upload-time = "2025-09-25T21:32:18.834Z" }, + { url = "https://files.pythonhosted.org/packages/74/93/7baea19427dcfbe1e5a372d81473250b379f04b1bd3c4c5ff825e2327202/pyyaml-6.0.3-cp312-cp312-win32.whl", hash = "sha256:96b533f0e99f6579b3d4d4995707cf36df9100d67e0c8303a0c55b27b5f99bc5", size = 137658, upload-time = "2025-09-25T21:32:20.209Z" }, + { url = "https://files.pythonhosted.org/packages/86/bf/899e81e4cce32febab4fb42bb97dcdf66bc135272882d1987881a4b519e9/pyyaml-6.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:5fcd34e47f6e0b794d17de1b4ff496c00986e1c83f7ab2fb8fcfe9616ff7477b", size = 154003, upload-time = "2025-09-25T21:32:21.167Z" }, + { url = "https://files.pythonhosted.org/packages/1a/08/67bd04656199bbb51dbed1439b7f27601dfb576fb864099c7ef0c3e55531/pyyaml-6.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:64386e5e707d03a7e172c0701abfb7e10f0fb753ee1d773128192742712a98fd", size = 140344, upload-time = "2025-09-25T21:32:22.617Z" }, + { url = "https://files.pythonhosted.org/packages/d1/11/0fd08f8192109f7169db964b5707a2f1e8b745d4e239b784a5a1dd80d1db/pyyaml-6.0.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8da9669d359f02c0b91ccc01cac4a67f16afec0dac22c2ad09f46bee0697eba8", size = 181669, upload-time = "2025-09-25T21:32:23.673Z" }, + { url = "https://files.pythonhosted.org/packages/b1/16/95309993f1d3748cd644e02e38b75d50cbc0d9561d21f390a76242ce073f/pyyaml-6.0.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:2283a07e2c21a2aa78d9c4442724ec1eb15f5e42a723b99cb3d822d48f5f7ad1", size = 173252, upload-time = "2025-09-25T21:32:25.149Z" }, + { url = "https://files.pythonhosted.org/packages/50/31/b20f376d3f810b9b2371e72ef5adb33879b25edb7a6d072cb7ca0c486398/pyyaml-6.0.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ee2922902c45ae8ccada2c5b501ab86c36525b883eff4255313a253a3160861c", size = 767081, upload-time = "2025-09-25T21:32:26.575Z" }, + { url = "https://files.pythonhosted.org/packages/49/1e/a55ca81e949270d5d4432fbbd19dfea5321eda7c41a849d443dc92fd1ff7/pyyaml-6.0.3-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:a33284e20b78bd4a18c8c2282d549d10bc8408a2a7ff57653c0cf0b9be0afce5", size = 841159, upload-time = "2025-09-25T21:32:27.727Z" }, + { url = "https://files.pythonhosted.org/packages/74/27/e5b8f34d02d9995b80abcef563ea1f8b56d20134d8f4e5e81733b1feceb2/pyyaml-6.0.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0f29edc409a6392443abf94b9cf89ce99889a1dd5376d94316ae5145dfedd5d6", size = 801626, upload-time = "2025-09-25T21:32:28.878Z" }, + { url = "https://files.pythonhosted.org/packages/f9/11/ba845c23988798f40e52ba45f34849aa8a1f2d4af4b798588010792ebad6/pyyaml-6.0.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f7057c9a337546edc7973c0d3ba84ddcdf0daa14533c2065749c9075001090e6", size = 753613, upload-time = "2025-09-25T21:32:30.178Z" }, + { url = "https://files.pythonhosted.org/packages/3d/e0/7966e1a7bfc0a45bf0a7fb6b98ea03fc9b8d84fa7f2229e9659680b69ee3/pyyaml-6.0.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:eda16858a3cab07b80edaf74336ece1f986ba330fdb8ee0d6c0d68fe82bc96be", size = 794115, upload-time = "2025-09-25T21:32:31.353Z" }, + { url = "https://files.pythonhosted.org/packages/de/94/980b50a6531b3019e45ddeada0626d45fa85cbe22300844a7983285bed3b/pyyaml-6.0.3-cp313-cp313-win32.whl", hash = "sha256:d0eae10f8159e8fdad514efdc92d74fd8d682c933a6dd088030f3834bc8e6b26", size = 137427, upload-time = "2025-09-25T21:32:32.58Z" }, + { url = "https://files.pythonhosted.org/packages/97/c9/39d5b874e8b28845e4ec2202b5da735d0199dbe5b8fb85f91398814a9a46/pyyaml-6.0.3-cp313-cp313-win_amd64.whl", hash = "sha256:79005a0d97d5ddabfeeea4cf676af11e647e41d81c9a7722a193022accdb6b7c", size = 154090, upload-time = "2025-09-25T21:32:33.659Z" }, + { url = "https://files.pythonhosted.org/packages/73/e8/2bdf3ca2090f68bb3d75b44da7bbc71843b19c9f2b9cb9b0f4ab7a5a4329/pyyaml-6.0.3-cp313-cp313-win_arm64.whl", hash = "sha256:5498cd1645aa724a7c71c8f378eb29ebe23da2fc0d7a08071d89469bf1d2defb", size = 140246, upload-time = "2025-09-25T21:32:34.663Z" }, + { url = "https://files.pythonhosted.org/packages/9d/8c/f4bd7f6465179953d3ac9bc44ac1a8a3e6122cf8ada906b4f96c60172d43/pyyaml-6.0.3-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:8d1fab6bb153a416f9aeb4b8763bc0f22a5586065f86f7664fc23339fc1c1fac", size = 181814, upload-time = "2025-09-25T21:32:35.712Z" }, + { url = "https://files.pythonhosted.org/packages/bd/9c/4d95bb87eb2063d20db7b60faa3840c1b18025517ae857371c4dd55a6b3a/pyyaml-6.0.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:34d5fcd24b8445fadc33f9cf348c1047101756fd760b4dacb5c3e99755703310", size = 173809, upload-time = "2025-09-25T21:32:36.789Z" }, + { url = "https://files.pythonhosted.org/packages/92/b5/47e807c2623074914e29dabd16cbbdd4bf5e9b2db9f8090fa64411fc5382/pyyaml-6.0.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:501a031947e3a9025ed4405a168e6ef5ae3126c59f90ce0cd6f2bfc477be31b7", size = 766454, upload-time = "2025-09-25T21:32:37.966Z" }, + { url = "https://files.pythonhosted.org/packages/02/9e/e5e9b168be58564121efb3de6859c452fccde0ab093d8438905899a3a483/pyyaml-6.0.3-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b3bc83488de33889877a0f2543ade9f70c67d66d9ebb4ac959502e12de895788", size = 836355, upload-time = "2025-09-25T21:32:39.178Z" }, + { url = "https://files.pythonhosted.org/packages/88/f9/16491d7ed2a919954993e48aa941b200f38040928474c9e85ea9e64222c3/pyyaml-6.0.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c458b6d084f9b935061bc36216e8a69a7e293a2f1e68bf956dcd9e6cbcd143f5", size = 794175, upload-time = "2025-09-25T21:32:40.865Z" }, + { url = "https://files.pythonhosted.org/packages/dd/3f/5989debef34dc6397317802b527dbbafb2b4760878a53d4166579111411e/pyyaml-6.0.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:7c6610def4f163542a622a73fb39f534f8c101d690126992300bf3207eab9764", size = 755228, upload-time = "2025-09-25T21:32:42.084Z" }, + { url = "https://files.pythonhosted.org/packages/d7/ce/af88a49043cd2e265be63d083fc75b27b6ed062f5f9fd6cdc223ad62f03e/pyyaml-6.0.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:5190d403f121660ce8d1d2c1bb2ef1bd05b5f68533fc5c2ea899bd15f4399b35", size = 789194, upload-time = "2025-09-25T21:32:43.362Z" }, + { url = "https://files.pythonhosted.org/packages/23/20/bb6982b26a40bb43951265ba29d4c246ef0ff59c9fdcdf0ed04e0687de4d/pyyaml-6.0.3-cp314-cp314-win_amd64.whl", hash = "sha256:4a2e8cebe2ff6ab7d1050ecd59c25d4c8bd7e6f400f5f82b96557ac0abafd0ac", size = 156429, upload-time = "2025-09-25T21:32:57.844Z" }, + { url = "https://files.pythonhosted.org/packages/f4/f4/a4541072bb9422c8a883ab55255f918fa378ecf083f5b85e87fc2b4eda1b/pyyaml-6.0.3-cp314-cp314-win_arm64.whl", hash = "sha256:93dda82c9c22deb0a405ea4dc5f2d0cda384168e466364dec6255b293923b2f3", size = 143912, upload-time = "2025-09-25T21:32:59.247Z" }, + { url = "https://files.pythonhosted.org/packages/7c/f9/07dd09ae774e4616edf6cda684ee78f97777bdd15847253637a6f052a62f/pyyaml-6.0.3-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:02893d100e99e03eda1c8fd5c441d8c60103fd175728e23e431db1b589cf5ab3", size = 189108, upload-time = "2025-09-25T21:32:44.377Z" }, + { url = "https://files.pythonhosted.org/packages/4e/78/8d08c9fb7ce09ad8c38ad533c1191cf27f7ae1effe5bb9400a46d9437fcf/pyyaml-6.0.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:c1ff362665ae507275af2853520967820d9124984e0f7466736aea23d8611fba", size = 183641, upload-time = "2025-09-25T21:32:45.407Z" }, + { url = "https://files.pythonhosted.org/packages/7b/5b/3babb19104a46945cf816d047db2788bcaf8c94527a805610b0289a01c6b/pyyaml-6.0.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6adc77889b628398debc7b65c073bcb99c4a0237b248cacaf3fe8a557563ef6c", size = 831901, upload-time = "2025-09-25T21:32:48.83Z" }, + { url = "https://files.pythonhosted.org/packages/8b/cc/dff0684d8dc44da4d22a13f35f073d558c268780ce3c6ba1b87055bb0b87/pyyaml-6.0.3-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:a80cb027f6b349846a3bf6d73b5e95e782175e52f22108cfa17876aaeff93702", size = 861132, upload-time = "2025-09-25T21:32:50.149Z" }, + { url = "https://files.pythonhosted.org/packages/b1/5e/f77dc6b9036943e285ba76b49e118d9ea929885becb0a29ba8a7c75e29fe/pyyaml-6.0.3-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:00c4bdeba853cc34e7dd471f16b4114f4162dc03e6b7afcc2128711f0eca823c", size = 839261, upload-time = "2025-09-25T21:32:51.808Z" }, + { url = "https://files.pythonhosted.org/packages/ce/88/a9db1376aa2a228197c58b37302f284b5617f56a5d959fd1763fb1675ce6/pyyaml-6.0.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:66e1674c3ef6f541c35191caae2d429b967b99e02040f5ba928632d9a7f0f065", size = 805272, upload-time = "2025-09-25T21:32:52.941Z" }, + { url = "https://files.pythonhosted.org/packages/da/92/1446574745d74df0c92e6aa4a7b0b3130706a4142b2d1a5869f2eaa423c6/pyyaml-6.0.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:16249ee61e95f858e83976573de0f5b2893b3677ba71c9dd36b9cf8be9ac6d65", size = 829923, upload-time = "2025-09-25T21:32:54.537Z" }, + { url = "https://files.pythonhosted.org/packages/f0/7a/1c7270340330e575b92f397352af856a8c06f230aa3e76f86b39d01b416a/pyyaml-6.0.3-cp314-cp314t-win_amd64.whl", hash = "sha256:4ad1906908f2f5ae4e5a8ddfce73c320c2a1429ec52eafd27138b7f1cbe341c9", size = 174062, upload-time = "2025-09-25T21:32:55.767Z" }, + { url = "https://files.pythonhosted.org/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" }, +] + +[[package]] +name = "requests" +version = "2.32.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "charset-normalizer" }, + { name = "idna" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, +] + +[[package]] +name = "rich" +version = "14.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/74/99/a4cab2acbb884f80e558b0771e97e21e939c5dfb460f488d19df485e8298/rich-14.3.2.tar.gz", hash = "sha256:e712f11c1a562a11843306f5ed999475f09ac31ffb64281f73ab29ffdda8b3b8", size = 230143, upload-time = "2026-02-01T16:20:47.908Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/45/615f5babd880b4bd7d405cc0dc348234c5ffb6ed1ea33e152ede08b2072d/rich-14.3.2-py3-none-any.whl", hash = "sha256:08e67c3e90884651da3239ea668222d19bea7b589149d8014a21c633420dbb69", size = 309963, upload-time = "2026-02-01T16:20:46.078Z" }, +] + +[[package]] +name = "roman-numerals" +version = "4.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ae/f9/41dc953bbeb056c17d5f7a519f50fdf010bd0553be2d630bc69d1e022703/roman_numerals-4.1.0.tar.gz", hash = "sha256:1af8b147eb1405d5839e78aeb93131690495fe9da5c91856cb33ad55a7f1e5b2", size = 9077, upload-time = "2025-12-17T18:25:34.381Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/54/6f679c435d28e0a568d8e8a7c0a93a09010818634c3c3907fc98d8983770/roman_numerals-4.1.0-py3-none-any.whl", hash = "sha256:647ba99caddc2cc1e55a51e4360689115551bf4476d90e8162cf8c345fe233c7", size = 7676, upload-time = "2025-12-17T18:25:33.098Z" }, +] + +[[package]] +name = "setuptools" +version = "82.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/82/f3/748f4d6f65d1756b9ae577f329c951cda23fb900e4de9f70900ced962085/setuptools-82.0.0.tar.gz", hash = "sha256:22e0a2d69474c6ae4feb01951cb69d515ed23728cf96d05513d36e42b62b37cb", size = 1144893, upload-time = "2026-02-08T15:08:40.206Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e1/c6/76dc613121b793286a3f91621d7b75a2b493e0390ddca50f11993eadf192/setuptools-82.0.0-py3-none-any.whl", hash = "sha256:70b18734b607bd1da571d097d236cfcfacaf01de45717d59e6e04b96877532e0", size = 1003468, upload-time = "2026-02-08T15:08:38.723Z" }, +] + +[[package]] +name = "shellingham" +version = "1.5.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/58/15/8b3609fd3830ef7b27b655beb4b4e9c62313a4e8da8c676e142cc210d58e/shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de", size = 10310, upload-time = "2023-10-24T04:13:40.426Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755, upload-time = "2023-10-24T04:13:38.866Z" }, +] + +[[package]] +name = "snowballstemmer" +version = "3.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/75/a7/9810d872919697c9d01295633f5d574fb416d47e535f258272ca1f01f447/snowballstemmer-3.0.1.tar.gz", hash = "sha256:6d5eeeec8e9f84d4d56b847692bacf79bc2c8e90c7f80ca4444ff8b6f2e52895", size = 105575, upload-time = "2025-05-09T16:34:51.843Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/78/3565d011c61f5a43488987ee32b6f3f656e7f107ac2782dd57bdd7d91d9a/snowballstemmer-3.0.1-py3-none-any.whl", hash = "sha256:6cd7b3897da8d6c9ffb968a6781fa6532dce9c3618a4b127d920dab764a19064", size = 103274, upload-time = "2025-05-09T16:34:50.371Z" }, +] + +[[package]] +name = "soupsieve" +version = "2.8.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7b/ae/2d9c981590ed9999a0d91755b47fc74f74de286b0f5cee14c9269041e6c4/soupsieve-2.8.3.tar.gz", hash = "sha256:3267f1eeea4251fb42728b6dfb746edc9acaffc4a45b27e19450b676586e8349", size = 118627, upload-time = "2026-01-20T04:27:02.457Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/46/2c/1462b1d0a634697ae9e55b3cecdcb64788e8b7d63f54d923fcd0bb140aed/soupsieve-2.8.3-py3-none-any.whl", hash = "sha256:ed64f2ba4eebeab06cc4962affce381647455978ffc1e36bb79a545b91f45a95", size = 37016, upload-time = "2026-01-20T04:27:01.012Z" }, +] + +[[package]] +name = "sphinx" +version = "9.0.4" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.12'", +] +dependencies = [ + { name = "alabaster", marker = "python_full_version < '3.12'" }, + { name = "babel", marker = "python_full_version < '3.12'" }, + { name = "colorama", marker = "python_full_version < '3.12' and sys_platform == 'win32'" }, + { name = "docutils", marker = "python_full_version < '3.12'" }, + { name = "imagesize", marker = "python_full_version < '3.12'" }, + { name = "jinja2", marker = "python_full_version < '3.12'" }, + { name = "packaging", marker = "python_full_version < '3.12'" }, + { name = "pygments", marker = "python_full_version < '3.12'" }, + { name = "requests", marker = "python_full_version < '3.12'" }, + { name = "roman-numerals", marker = "python_full_version < '3.12'" }, + { name = "snowballstemmer", marker = "python_full_version < '3.12'" }, + { name = "sphinxcontrib-applehelp", marker = "python_full_version < '3.12'" }, + { name = "sphinxcontrib-devhelp", marker = "python_full_version < '3.12'" }, + { name = "sphinxcontrib-htmlhelp", marker = "python_full_version < '3.12'" }, + { name = "sphinxcontrib-jsmath", marker = "python_full_version < '3.12'" }, + { name = "sphinxcontrib-qthelp", marker = "python_full_version < '3.12'" }, + { name = "sphinxcontrib-serializinghtml", marker = "python_full_version < '3.12'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/50/a8c6ccc36d5eacdfd7913ddccd15a9cee03ecafc5ee2bc40e1f168d85022/sphinx-9.0.4.tar.gz", hash = "sha256:594ef59d042972abbc581d8baa577404abe4e6c3b04ef61bd7fc2acbd51f3fa3", size = 8710502, upload-time = "2025-12-04T07:45:27.343Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c6/3f/4bbd76424c393caead2e1eb89777f575dee5c8653e2d4b6afd7a564f5974/sphinx-9.0.4-py3-none-any.whl", hash = "sha256:5bebc595a5e943ea248b99c13814c1c5e10b3ece718976824ffa7959ff95fffb", size = 3917713, upload-time = "2025-12-04T07:45:24.944Z" }, +] + +[[package]] +name = "sphinx" +version = "9.1.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.12'", +] +dependencies = [ + { name = "alabaster", marker = "python_full_version >= '3.12'" }, + { name = "babel", marker = "python_full_version >= '3.12'" }, + { name = "colorama", marker = "python_full_version >= '3.12' and sys_platform == 'win32'" }, + { name = "docutils", marker = "python_full_version >= '3.12'" }, + { name = "imagesize", marker = "python_full_version >= '3.12'" }, + { name = "jinja2", marker = "python_full_version >= '3.12'" }, + { name = "packaging", marker = "python_full_version >= '3.12'" }, + { name = "pygments", marker = "python_full_version >= '3.12'" }, + { name = "requests", marker = "python_full_version >= '3.12'" }, + { name = "roman-numerals", marker = "python_full_version >= '3.12'" }, + { name = "snowballstemmer", marker = "python_full_version >= '3.12'" }, + { name = "sphinxcontrib-applehelp", marker = "python_full_version >= '3.12'" }, + { name = "sphinxcontrib-devhelp", marker = "python_full_version >= '3.12'" }, + { name = "sphinxcontrib-htmlhelp", marker = "python_full_version >= '3.12'" }, + { name = "sphinxcontrib-jsmath", marker = "python_full_version >= '3.12'" }, + { name = "sphinxcontrib-qthelp", marker = "python_full_version >= '3.12'" }, + { name = "sphinxcontrib-serializinghtml", marker = "python_full_version >= '3.12'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cd/bd/f08eb0f4eed5c83f1ba2a3bd18f7745a2b1525fad70660a1c00224ec468a/sphinx-9.1.0.tar.gz", hash = "sha256:7741722357dd75f8190766926071fed3bdc211c74dd2d7d4df5404da95930ddb", size = 8718324, upload-time = "2025-12-31T15:09:27.646Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/f7/b1884cb3188ab181fc81fa00c266699dab600f927a964df02ec3d5d1916a/sphinx-9.1.0-py3-none-any.whl", hash = "sha256:c84fdd4e782504495fe4f2c0b3413d6c2bf388589bb352d439b2a3bb99991978", size = 3921742, upload-time = "2025-12-31T15:09:25.561Z" }, +] + +[[package]] +name = "sphinx-reredirects" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "sphinx", version = "9.0.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.12'" }, + { name = "sphinx", version = "9.1.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1b/8d/0e39fe2740d7d71417edf9a6424aa80ca2c27c17fc21282cdc39f90d5a40/sphinx_reredirects-1.1.0.tar.gz", hash = "sha256:fb9b195335ab14b43f8273287d0c7eeb637ba6c56c66581c11b47202f6718b29", size = 614624, upload-time = "2025-12-22T08:28:02.792Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/51/81/b5dd07067f3daac6d23687ec737b2d593740671ebcd145830c8f92d381c5/sphinx_reredirects-1.1.0-py3-none-any.whl", hash = "sha256:4b5692273c72cd2d4d917f4c6f87d5919e4d6114a752d4be033f7f5f6310efd9", size = 6351, upload-time = "2025-12-22T08:27:59.724Z" }, +] + +[[package]] +name = "sphinxcontrib-applehelp" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ba/6e/b837e84a1a704953c62ef8776d45c3e8d759876b4a84fe14eba2859106fe/sphinxcontrib_applehelp-2.0.0.tar.gz", hash = "sha256:2f29ef331735ce958efa4734873f084941970894c6090408b079c61b2e1c06d1", size = 20053, upload-time = "2024-07-29T01:09:00.465Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5d/85/9ebeae2f76e9e77b952f4b274c27238156eae7979c5421fba91a28f4970d/sphinxcontrib_applehelp-2.0.0-py3-none-any.whl", hash = "sha256:4cd3f0ec4ac5dd9c17ec65e9ab272c9b867ea77425228e68ecf08d6b28ddbdb5", size = 119300, upload-time = "2024-07-29T01:08:58.99Z" }, +] + +[[package]] +name = "sphinxcontrib-devhelp" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f6/d2/5beee64d3e4e747f316bae86b55943f51e82bb86ecd325883ef65741e7da/sphinxcontrib_devhelp-2.0.0.tar.gz", hash = "sha256:411f5d96d445d1d73bb5d52133377b4248ec79db5c793ce7dbe59e074b4dd1ad", size = 12967, upload-time = "2024-07-29T01:09:23.417Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/35/7a/987e583882f985fe4d7323774889ec58049171828b58c2217e7f79cdf44e/sphinxcontrib_devhelp-2.0.0-py3-none-any.whl", hash = "sha256:aefb8b83854e4b0998877524d1029fd3e6879210422ee3780459e28a1f03a8a2", size = 82530, upload-time = "2024-07-29T01:09:21.945Z" }, +] + +[[package]] +name = "sphinxcontrib-htmlhelp" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/93/983afd9aa001e5201eab16b5a444ed5b9b0a7a010541e0ddfbbfd0b2470c/sphinxcontrib_htmlhelp-2.1.0.tar.gz", hash = "sha256:c9e2916ace8aad64cc13a0d233ee22317f2b9025b9cf3295249fa985cc7082e9", size = 22617, upload-time = "2024-07-29T01:09:37.889Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0a/7b/18a8c0bcec9182c05a0b3ec2a776bba4ead82750a55ff798e8d406dae604/sphinxcontrib_htmlhelp-2.1.0-py3-none-any.whl", hash = "sha256:166759820b47002d22914d64a075ce08f4c46818e17cfc9470a9786b759b19f8", size = 98705, upload-time = "2024-07-29T01:09:36.407Z" }, +] + +[[package]] +name = "sphinxcontrib-jsmath" +version = "1.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b2/e8/9ed3830aeed71f17c026a07a5097edcf44b692850ef215b161b8ad875729/sphinxcontrib-jsmath-1.0.1.tar.gz", hash = "sha256:a9925e4a4587247ed2191a22df5f6970656cb8ca2bd6284309578f2153e0c4b8", size = 5787, upload-time = "2019-01-21T16:10:16.347Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/42/4c8646762ee83602e3fb3fbe774c2fac12f317deb0b5dbeeedd2d3ba4b77/sphinxcontrib_jsmath-1.0.1-py2.py3-none-any.whl", hash = "sha256:2ec2eaebfb78f3f2078e73666b1415417a116cc848b72e5172e596c871103178", size = 5071, upload-time = "2019-01-21T16:10:14.333Z" }, +] + +[[package]] +name = "sphinxcontrib-qthelp" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/68/bc/9104308fc285eb3e0b31b67688235db556cd5b0ef31d96f30e45f2e51cae/sphinxcontrib_qthelp-2.0.0.tar.gz", hash = "sha256:4fe7d0ac8fc171045be623aba3e2a8f613f8682731f9153bb2e40ece16b9bbab", size = 17165, upload-time = "2024-07-29T01:09:56.435Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/27/83/859ecdd180cacc13b1f7e857abf8582a64552ea7a061057a6c716e790fce/sphinxcontrib_qthelp-2.0.0-py3-none-any.whl", hash = "sha256:b18a828cdba941ccd6ee8445dbe72ffa3ef8cbe7505d8cd1fa0d42d3f2d5f3eb", size = 88743, upload-time = "2024-07-29T01:09:54.885Z" }, +] + +[[package]] +name = "sphinxcontrib-serializinghtml" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3b/44/6716b257b0aa6bfd51a1b31665d1c205fb12cb5ad56de752dfa15657de2f/sphinxcontrib_serializinghtml-2.0.0.tar.gz", hash = "sha256:e9d912827f872c029017a53f0ef2180b327c3f7fd23c87229f7a8e8b70031d4d", size = 16080, upload-time = "2024-07-29T01:10:09.332Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/52/a7/d2782e4e3f77c8450f727ba74a8f12756d5ba823d81b941f1b04da9d033a/sphinxcontrib_serializinghtml-2.0.0-py3-none-any.whl", hash = "sha256:6e2cb0eef194e10c27ec0023bfeb25badbbb5868244cf5bc5bdc04e4464bf331", size = 92072, upload-time = "2024-07-29T01:10:08.203Z" }, +] + +[[package]] +name = "tomlkit" +version = "0.14.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c3/af/14b24e41977adb296d6bd1fb59402cf7d60ce364f90c890bd2ec65c43b5a/tomlkit-0.14.0.tar.gz", hash = "sha256:cf00efca415dbd57575befb1f6634c4f42d2d87dbba376128adb42c121b87064", size = 187167, upload-time = "2026-01-13T01:14:53.304Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b5/11/87d6d29fb5d237229d67973a6c9e06e048f01cf4994dee194ab0ea841814/tomlkit-0.14.0-py3-none-any.whl", hash = "sha256:592064ed85b40fa213469f81ac584f67a4f2992509a7c3ea2d632208623a3680", size = 39310, upload-time = "2026-01-13T01:14:51.965Z" }, +] + +[[package]] +name = "typer" +version = "0.23.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-doc" }, + { name = "click" }, + { name = "rich" }, + { name = "shellingham" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d3/ae/93d16574e66dfe4c2284ffdaca4b0320ade32858cb2cc586c8dd79f127c5/typer-0.23.2.tar.gz", hash = "sha256:a99706a08e54f1aef8bb6a8611503808188a4092808e86addff1828a208af0de", size = 120162, upload-time = "2026-02-16T18:52:40.354Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/2c/dee705c427875402200fe779eb8a3c00ccb349471172c41178336e9599cc/typer-0.23.2-py3-none-any.whl", hash = "sha256:e9c8dc380f82450b3c851a9b9d5a0edf95d1d6456ae70c517d8b06a50c7a9978", size = 56834, upload-time = "2026-02-16T18:52:39.308Z" }, +] + +[[package]] +name = "typing-extensions" +version = "4.15.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/94/1a15dd82efb362ac84269196e94cf00f187f7ed21c242792a923cdb1c61f/typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466", size = 109391, upload-time = "2025-08-25T13:49:26.313Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" }, +] + +[[package]] +name = "urllib3" +version = "2.6.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c7/24/5f1b3bdffd70275f6661c76461e25f024d5a38a46f04aaca912426a2b1d3/urllib3-2.6.3.tar.gz", hash = "sha256:1b62b6884944a57dbe321509ab94fd4d3b307075e0c2eae991ac71ee15ad38ed", size = 435556, upload-time = "2026-01-07T16:24:43.925Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/39/08/aaaad47bc4e9dc8c725e68f9d04865dbcb2052843ff09c97b08904852d84/urllib3-2.6.3-py3-none-any.whl", hash = "sha256:bf272323e553dfb2e87d9bfd225ca7b0f467b919d7bbd355436d3fd37cb0acd4", size = 131584, upload-time = "2026-01-07T16:24:42.685Z" }, +]