Skip to content

Commit

Permalink
[Rust] Handle double equals in filter (lancedb#639)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: Chang She <759245+changhiskhan@users.noreply.github.com>
  • Loading branch information
eddyxu and changhiskhan committed Feb 25, 2023
1 parent 175bc47 commit cdbba23
Show file tree
Hide file tree
Showing 8 changed files with 226 additions and 26 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run_tests/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ runs:
working-directory: python
shell: bash
run: |
pip3 install $(ls target/wheels/pylance-*.whl) pytest
pip3 install $(ls target/wheels/pylance-*.whl) pytest duckdb
- name: Run python tests
shell: bash
working-directory: python
Expand Down
1 change: 1 addition & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,5 @@ build-backend = "maturin"
[project.optional-dependencies]
tests = [
"pytest",
"duckdb"
]
26 changes: 19 additions & 7 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ def scanner(
List of column names to be fetched.
All columns if None or unspecified.
filter : pa.compute.Expression or str
Not enabled just yet. Soon...
Expression or str that is a valid SQL where clause.
Currently only >, <, >=, <=, ==, !=, |, & are supported.
is_null, is_valid, ~, and others are not yet supported.
Specifying these will result in an expression parsing error
limit: int, default 0
Fetch up to this many rows. All rows if 0 or unspecified.
offset: int, default None
Expand All @@ -75,6 +78,12 @@ def scanner(
"nprobes": 1,
"refine_factor": 1
}
Notes
-----
For now, if BOTH filter and nearest is specified, then:
1. nearest is executed first.
2. The results are filtered afterwards.
"""
return (
ScannerBuilder(self)
Expand Down Expand Up @@ -109,7 +118,10 @@ def to_table(
List of column names to be fetched.
All columns if None or unspecified.
filter : pa.compute.Expression or str
Scan will return only the rows matching the filter.
Expression or str that is a valid SQL where clause.
Currently only >, <, >=, <=, ==, !=, |, & are supported.
is_null, is_valid, ~, and others are not yet supported.
Specifying these will result in an expression parsing error
limit: int, default 0
Fetch up to this many rows. All rows if 0 or unspecified.
offset: int, default None
Expand All @@ -125,7 +137,11 @@ def to_table(
"refine_factor": 1
}
See `scanner()` for more details.
Notes
-----
For now, if BOTH filter and nearest is specified, then:
1. nearest is executed first.
2. The results are filtered afterwards.
"""
return self.scanner(
columns=columns, filter=filter, limit=limit, offset=offset, nearest=nearest
Expand Down Expand Up @@ -384,10 +400,6 @@ def columns(self, cols: Optional[list[str]] = None) -> ScannerBuilder:
return self

def filter(self, filter: Union[str, pa.compute.Expression]) -> ScannerBuilder:
if filter is not None:
raise NotImplementedError(
"Allllmost ready. For now, please do `to_table().filter(...)`"
)
if isinstance(filter, pa.compute.Expression):
filter = str(filter)
self._filter = filter
Expand Down
137 changes: 137 additions & 0 deletions python/python/tests/test_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright 2023 Lance Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2023 Lance Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for predicate pushdown"""

import random
import string

import lance
import numpy as np
import pandas as pd
import pandas.testing as tm
import pyarrow as pa
import pyarrow.compute as pc
import pytest

from lance.vector import vec_to_table


def create_table(nrows=100):
intcol = pa.array(range(nrows))
floatcol = pa.array(np.arange(nrows) * 2/3, type=pa.float32())
arr = np.arange(nrows) < nrows / 2
structcol = pa.StructArray.from_arrays([pa.array(arr, type=pa.bool_())], names=["bool"])

def gen_str(n):
return ''.join(random.choices("abc"))

stringcol = pa.array([gen_str(2) for _ in range(nrows)])

tbl = pa.Table.from_arrays([
intcol, floatcol, structcol, stringcol
], names=["int", "float", "rec", "str"])
return tbl


@pytest.fixture()
def dataset(tmp_path):
tbl = create_table()
yield lance.write_dataset(tbl, tmp_path)


def test_simple_predicates(dataset):
predicates = [
pc.field("int") >= 50,
pc.field("int") == 50,
pc.field("int") != 50,
pc.field("float") < 90.0,
pc.field("float") > 90.0,
pc.field("float") <= 90.0,
pc.field("float") >= 90.0,
pc.field("str") != "aa",
pc.field("str") == "aa",
]
# test simple
for expr in predicates:
assert dataset.to_table(filter=expr) == dataset.to_table().filter(expr)


def test_compound(dataset):
predicates = [
pc.field("int") >= 50,
pc.field("float") < 90.0,
pc.field("str") == "aa",
]
# test compound
for expr in predicates:
for other_expr in predicates:
compound = expr & other_expr
assert dataset.to_table(filter=compound) == dataset.to_table().filter(compound)
compound = expr | other_expr
assert dataset.to_table(filter=compound) == dataset.to_table().filter(compound)


def create_table_for_duckdb(nvec=10000, ndim=768):
mat = np.random.randn(nvec, ndim)
price = (np.random.rand(nvec) + 1) * 100

def gen_str(n):
return "".join(random.choices("abc"))

meta = np.array([gen_str(1) for _ in range(nvec)])
tbl = (
vec_to_table(data=mat)
.append_column("price", pa.array(price))
.append_column("meta", pa.array(meta))
.append_column("id", pa.array(range(nvec)))
)
return tbl


def test_duckdb(tmp_path):
duckdb = pytest.importorskip("duckdb")
tbl = create_table_for_duckdb()
ds = lance.write_dataset(tbl, str(tmp_path))

actual = duckdb.query("SELECT id, meta, price FROM ds WHERE id==1000").to_df()
expected = duckdb.query("SELECT id, meta, price FROM ds").to_df()
expected = expected[expected.id == 1000].reset_index(drop=True)
tm.assert_frame_equal(actual, expected)

actual = duckdb.query("SELECT id, meta, price FROM ds WHERE id=1000").to_df()
expected = duckdb.query("SELECT id, meta, price FROM ds").to_df()
expected = expected[expected.id == 1000].reset_index(drop=True)
tm.assert_frame_equal(actual, expected)

actual = duckdb.query("SELECT id, meta, price FROM ds WHERE price>20.0 and price<=90").to_df()
expected = duckdb.query("SELECT id, meta, price FROM ds").to_df()
expected = expected[(expected.price > 20.0) & (expected.price <= 90)].reset_index(drop=True)
tm.assert_frame_equal(actual, expected)

actual = duckdb.query("SELECT id, meta, price FROM ds WHERE meta=='aa'").to_df()
expected = duckdb.query("SELECT id, meta, price FROM ds").to_df()
expected = expected[expected.meta == "aa"].reset_index(drop=True)
tm.assert_frame_equal(actual, expected)
2 changes: 1 addition & 1 deletion rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ uuid = { version = "1.2", features = ["v4"] }
path-absolutize = "3.0.14"
arrow = { version = "32.0.0", features = ["prettyprint"] }
num_cpus = "1.0"
sqlparser = "0.30.0"
sqlparser = { git = "https://github.com/eto-ai/sqlparser-rs.git", branch = "lei/double_eq" }
# TODO: use datafusion sub-modules to reduce build size?
datafusion = { version = "18.0.0", default-features = false }
faiss = { version = "0.11.0", features = ["gpu"], optional = true }
Expand Down
2 changes: 1 addition & 1 deletion rust/src/encodings/plain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ impl<'a> PlainDecoder<'a> {
// TODO: optimize boolean access
let start = indices.value(0) as usize;
let end = indices.value(indices.len() - 1) as usize;
let array = self.get(start..end).await?;
let array = self.get(start..end + 1).await?;
let array_byte_boundray = (start / 8 * 8) as u32;
let shifted_indices = subtract_scalar(indices, array_byte_boundray)?;
Ok(take(array.as_ref(), &shifted_indices, None)?)
Expand Down
74 changes: 60 additions & 14 deletions rust/src/io/exec/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ use datafusion::{
scalar::ScalarValue,
};
use sqlparser::{
ast::{BinaryOperator, Expr as SQLExpr, Ident, SetExpr, Statement, Value},
ast::{
BinaryOperator, Expr as SQLExpr, Function, FunctionArg, FunctionArgExpr, Ident, SetExpr,
Statement, Value,
},
dialect::GenericDialect,
parser::Parser,
};
Expand Down Expand Up @@ -100,17 +103,53 @@ impl Planner {
Value::EscapedStringLiteral(_) => todo!(),
Value::NationalStringLiteral(_) => todo!(),
Value::HexStringLiteral(_) => todo!(),
Value::DoubleQuotedString(_) => todo!(),
Value::DoubleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone()))),
Value::Boolean(v) => Expr::Literal(ScalarValue::Boolean(Some(*v))),
Value::Null => Expr::Literal(ScalarValue::Null),
Value::Placeholder(_) => todo!(),
Value::UnQuotedString(_) => todo!(),
Value::SingleQuotedByteStringLiteral(_) => todo!(),
Value::DoubleQuotedByteStringLiteral(_) => todo!(),
})
}

fn parse_function_args(&self, func_args: &FunctionArg) -> Result<Expr> {
match func_args {
FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => self.parse_sql_expr(expr),
_ => Err(Error::IO(format!(
"Unsuppoted function args: {:?}",
func_args
))),
}
}

fn parse_function(&self, func: &Function) -> Result<Expr> {
if func.name.to_string() == "is_valid" {
if func.args.len() != 1 {
return Err(Error::IO(format!(
"is_valid only support 1 args, got {}",
func.args.len()
)));
}
return Ok(Expr::IsNotNull(Box::new(
self.parse_function_args(&func.args[0])?,
)));
}
Err(Error::IO(format!(
"function '{}' is not supported",
func.name
)))
}

fn parse_sql_expr(&self, expr: &SQLExpr) -> Result<Expr> {
match expr {
SQLExpr::Identifier(id) => self.column(vec![id.clone()].as_slice()),
SQLExpr::Identifier(id) => {
if id.quote_style == Some('"') {
Ok(Expr::Literal(ScalarValue::Utf8(Some(id.value.clone()))))
} else {
self.column(vec![id.clone()].as_slice())
}
}
SQLExpr::CompoundIdentifier(ids) => self.column(ids.as_slice()),
SQLExpr::BinaryOp { left, op, right } => self.binary_expr(left, op, right),
SQLExpr::Value(value) => self.value(value),
Expand All @@ -133,6 +172,7 @@ impl Planner {
Ok(value_expr.in_list(list_exprs, *negated))
}
SQLExpr::Nested(inner) => self.parse_sql_expr(inner.as_ref()),
SQLExpr::Function(func) => self.parse_function(func),
_ => {
return Err(Error::IO(format!(
"Expression '{expr}' is not supported as filter in lance"
Expand Down Expand Up @@ -216,6 +256,7 @@ mod tests {
};
use arrow_schema::{DataType, Field, Schema};
use datafusion::logical_expr::{col, lit};
use datafusion::prelude::exp;

#[test]
fn test_parse_filter_simple() {
Expand All @@ -234,20 +275,25 @@ mod tests {

let planner = Planner::new(schema.clone());

let expected = col("i")
.gt(lit(3_i32))
.and(col("st.x").lt_eq(lit(5.0_f32)))
.and(
col("s")
.eq(lit("str-4"))
.or(col("s").in_list(vec![lit("str-4"), lit("str-5")], false)),
);

// double quotes
let expr = planner
.parse_filter("i > 3 AND st.x <= 5.0 AND (s == 'str-4' OR s in ('str-4', 'str-5'))")
.unwrap();
assert_eq!(expr, expected);

// single quote
let expr = planner
.parse_filter("i > 3 AND st.x <= 5.0 AND (s = 'str-4' OR s in ('str-4', 'str-5'))")
.unwrap();
assert_eq!(
expr,
col("i")
.gt(lit(3_i32))
.and(col("st.x").lt_eq(lit(5.0_f32)))
.and(
col("s")
.eq(lit("str-4"))
.or(col("s").in_list(vec![lit("str-4"), lit("str-5")], false))
)
);

let physical_expr = planner.create_physical_expr(&expr).unwrap();
println!("Physical expr: {:#?}", physical_expr);
Expand Down
8 changes: 6 additions & 2 deletions rust/src/io/exec/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,16 @@ impl LocalTake {
if take_schema.fields.is_empty() {
return Ok(batch);
};
let projection_schema = ArrowSchema::from(projection.as_ref());
if batch.num_rows() == 0 {
return Ok(RecordBatch::new_empty(Arc::new(projection_schema)));
}

let row_id_arr = batch.column_by_name(ROW_ID).unwrap();
let row_ids: &UInt64Array = as_primitive_array(row_id_arr);

let remaining_columns =
dataset.take_rows(row_ids.values(), &take_schema).await?;
let projection_schema = ArrowSchema::from(projection.as_ref());

let batch = batch
.merge(&remaining_columns)?
.project_by_schema(&projection_schema)?;
Expand Down

0 comments on commit cdbba23

Please sign in to comment.