Skip to content

Commit

Permalink
ARROW-11341: [Python] [Gandiva] Add NULL/None checks to Gandiva build…
Browse files Browse the repository at this point in the history
…er functions (#9289)

If these functions were passed None as an argument, they caused segfaults. As an example:

```python
import pyarrow
import pyarrow.gandiva as gandiva

builder = gandiva.TreeExprBuilder()
field = pyarrow.field('whatever', type=pyarrow.date64())
date_col = builder.make_field(field)

func = builder.make_function('less_than_or_equal_to', [date_col, None], pyarrow.bool_())

condition = builder.make_condition(func)

# Will segfault on this line:
gandiva.make_filter(pyarrow.schema([field]), condition)
```

Lead-authored-by: Will Jones <willjones127@gmail.com>
Co-authored-by: Will Jones <will.jones@mscience.com>
Signed-off-by: David Li <li.davidm96@gmail.com>
  • Loading branch information
wjones127 and Will Jones committed Jul 13, 2022
1 parent 861f237 commit 03e80dc
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 15 deletions.
41 changes: 26 additions & 15 deletions python/pyarrow/gandiva.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ cdef class TreeExprBuilder(_Weakrefable):

return Node.create(r)

def make_expression(self, Node root_node, Field return_field):
def make_expression(self, Node root_node not None,
Field return_field not None):
cdef shared_ptr[CGandivaExpression] r = TreeExprBuilder_MakeExpression(
root_node.node, return_field.sp_field)
cdef Expression expression = Expression()
Expand All @@ -303,17 +304,19 @@ cdef class TreeExprBuilder(_Weakrefable):
cdef c_vector[shared_ptr[CNode]] c_children
cdef Node child
for child in children:
if child is None:
raise TypeError("Child nodes must not be None")
c_children.push_back(child.node)
cdef shared_ptr[CNode] r = TreeExprBuilder_MakeFunction(
name.encode(), c_children, return_type.sp_type)
return Node.create(r)

def make_field(self, Field field):
def make_field(self, Field field not None):
cdef shared_ptr[CNode] r = TreeExprBuilder_MakeField(field.sp_field)
return Node.create(r)

def make_if(self, Node condition, Node this_node,
Node else_node, DataType return_type):
def make_if(self, Node condition not None, Node this_node not None,
Node else_node not None, DataType return_type not None):
cdef shared_ptr[CNode] r = TreeExprBuilder_MakeIf(
condition.node, this_node.node, else_node.node,
return_type.sp_type)
Expand All @@ -323,6 +326,8 @@ cdef class TreeExprBuilder(_Weakrefable):
cdef c_vector[shared_ptr[CNode]] c_children
cdef Node child
for child in children:
if child is None:
raise TypeError("Child nodes must not be None")
c_children.push_back(child.node)
cdef shared_ptr[CNode] r = TreeExprBuilder_MakeAnd(c_children)
return Node.create(r)
Expand All @@ -331,11 +336,13 @@ cdef class TreeExprBuilder(_Weakrefable):
cdef c_vector[shared_ptr[CNode]] c_children
cdef Node child
for child in children:
if child is None:
raise TypeError("Child nodes must not be None")
c_children.push_back(child.node)
cdef shared_ptr[CNode] r = TreeExprBuilder_MakeOr(c_children)
return Node.create(r)

def _make_in_expression_int32(self, Node node, values):
def _make_in_expression_int32(self, Node node not None, values):
cdef shared_ptr[CNode] r
cdef c_unordered_set[int32_t] c_values
cdef int32_t v
Expand All @@ -344,7 +351,7 @@ cdef class TreeExprBuilder(_Weakrefable):
r = TreeExprBuilder_MakeInExpressionInt32(node.node, c_values)
return Node.create(r)

def _make_in_expression_int64(self, Node node, values):
def _make_in_expression_int64(self, Node node not None, values):
cdef shared_ptr[CNode] r
cdef c_unordered_set[int64_t] c_values
cdef int64_t v
Expand All @@ -353,7 +360,7 @@ cdef class TreeExprBuilder(_Weakrefable):
r = TreeExprBuilder_MakeInExpressionInt64(node.node, c_values)
return Node.create(r)

def _make_in_expression_time32(self, Node node, values):
def _make_in_expression_time32(self, Node node not None, values):
cdef shared_ptr[CNode] r
cdef c_unordered_set[int32_t] c_values
cdef int32_t v
Expand All @@ -362,7 +369,7 @@ cdef class TreeExprBuilder(_Weakrefable):
r = TreeExprBuilder_MakeInExpressionTime32(node.node, c_values)
return Node.create(r)

def _make_in_expression_time64(self, Node node, values):
def _make_in_expression_time64(self, Node node not None, values):
cdef shared_ptr[CNode] r
cdef c_unordered_set[int64_t] c_values
cdef int64_t v
Expand All @@ -371,7 +378,7 @@ cdef class TreeExprBuilder(_Weakrefable):
r = TreeExprBuilder_MakeInExpressionTime64(node.node, c_values)
return Node.create(r)

def _make_in_expression_date32(self, Node node, values):
def _make_in_expression_date32(self, Node node not None, values):
cdef shared_ptr[CNode] r
cdef c_unordered_set[int32_t] c_values
cdef int32_t v
Expand All @@ -380,7 +387,7 @@ cdef class TreeExprBuilder(_Weakrefable):
r = TreeExprBuilder_MakeInExpressionDate32(node.node, c_values)
return Node.create(r)

def _make_in_expression_date64(self, Node node, values):
def _make_in_expression_date64(self, Node node not None, values):
cdef shared_ptr[CNode] r
cdef c_unordered_set[int64_t] c_values
cdef int64_t v
Expand All @@ -389,7 +396,7 @@ cdef class TreeExprBuilder(_Weakrefable):
r = TreeExprBuilder_MakeInExpressionDate64(node.node, c_values)
return Node.create(r)

def _make_in_expression_timestamp(self, Node node, values):
def _make_in_expression_timestamp(self, Node node not None, values):
cdef shared_ptr[CNode] r
cdef c_unordered_set[int64_t] c_values
cdef int64_t v
Expand All @@ -398,7 +405,7 @@ cdef class TreeExprBuilder(_Weakrefable):
r = TreeExprBuilder_MakeInExpressionTimeStamp(node.node, c_values)
return Node.create(r)

def _make_in_expression_binary(self, Node node, values):
def _make_in_expression_binary(self, Node node not None, values):
cdef shared_ptr[CNode] r
cdef c_unordered_set[c_string] c_values
cdef c_string v
Expand All @@ -407,7 +414,7 @@ cdef class TreeExprBuilder(_Weakrefable):
r = TreeExprBuilder_MakeInExpressionString(node.node, c_values)
return Node.create(r)

def _make_in_expression_string(self, Node node, values):
def _make_in_expression_string(self, Node node not None, values):
cdef shared_ptr[CNode] r
cdef c_unordered_set[c_string] c_values
cdef c_string _v
Expand All @@ -417,7 +424,7 @@ cdef class TreeExprBuilder(_Weakrefable):
r = TreeExprBuilder_MakeInExpressionString(node.node, c_values)
return Node.create(r)

def make_in_expression(self, Node node, values, dtype):
def make_in_expression(self, Node node not None, values, dtype):
cdef DataType type = ensure_type(dtype)

if type.id == _Type_INT32:
Expand All @@ -441,7 +448,7 @@ cdef class TreeExprBuilder(_Weakrefable):
else:
raise TypeError("Data type " + str(dtype) + " not supported.")

def make_condition(self, Node condition):
def make_condition(self, Node condition not None):
cdef shared_ptr[CCondition] r = TreeExprBuilder_MakeCondition(
condition.node)
return Condition.create(r)
Expand Down Expand Up @@ -476,6 +483,8 @@ cpdef make_projector(Schema schema, children, MemoryPool pool,
shared_ptr[CProjector] result

for child in children:
if child is None:
raise TypeError("Expressions must not be None")
c_children.push_back(child.expression)

check_status(
Expand Down Expand Up @@ -505,6 +514,8 @@ cpdef make_filter(Schema schema, Condition condition):
Filter instance
"""
cdef shared_ptr[CFilter] result
if condition is None:
raise TypeError("Condition must not be None")
check_status(
Filter_Make(schema.sp_schema, condition.condition, &result))
return Filter.create(result)
Expand Down
41 changes: 41 additions & 0 deletions python/pyarrow/tests/test_gandiva.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,44 @@ def test_to_string():
field_y = builder.make_field(pa.field('y', pa.bool_()))
and_node = builder.make_and([func_node, field_y])
assert str(and_node) == 'bool not((bool) z) && (bool) y'


@pytest.mark.gandiva
def test_rejects_none():
import pyarrow.gandiva as gandiva

builder = gandiva.TreeExprBuilder()

field_x = pa.field('x', pa.int32())
schema = pa.schema([field_x])
literal_true = builder.make_literal(True, pa.bool_())

with pytest.raises(TypeError):
builder.make_field(None)

with pytest.raises(TypeError):
builder.make_if(literal_true, None, None, None)

with pytest.raises(TypeError):
builder.make_and([literal_true, None])

with pytest.raises(TypeError):
builder.make_or([None, literal_true])

with pytest.raises(TypeError):
builder.make_in_expression(None, [1, 2, 3], pa.int32())

with pytest.raises(TypeError):
builder.make_expression(None, field_x)

with pytest.raises(TypeError):
builder.make_condition(None)

with pytest.raises(TypeError):
builder.make_function('less_than', [literal_true, None], pa.bool_())

with pytest.raises(TypeError):
gandiva.make_projector(schema, [None])

with pytest.raises(TypeError):
gandiva.make_filter(schema, None)

0 comments on commit 03e80dc

Please sign in to comment.