Skip to content

Commit d6b9051

Browse files
paleolimbotpitrou
andauthored
GH-40066: [Python] Support requested_schema in __arrow_c_stream__() (#40070)
### Rationale for this change The `requested_schema` portion of the `__arrow_c_stream__()` protocol methods errored in all cases if passed an unequal schema. There was a note about figuring out how to check the cast before doing it and a comment in #40066 about how it should be done lazily. This PR (hopefully) solves both! ### What changes are included in this PR? - Added `arrow::py::CastingRecordBatchReader`, which wraps a `arrow::RecordBatchReader`, casting each batch as it is pulled. ### Are these changes tested? Yes. ### Are there any user-facing changes? Yes: the current approach adds `RecordBatchReader.cast()` as the way to access the casting reader. * Closes: #40066 * GitHub Issue: #40066 Lead-authored-by: Dewey Dunnington <dewey@fishandwhistle.net> Co-authored-by: Dewey Dunnington <dewey@voltrondata.com> Co-authored-by: Antoine Pitrou <pitrou@free.fr> Signed-off-by: Joris Van den Bossche <jorisvandenbossche@gmail.com>
1 parent 99c5412 commit d6b9051

File tree

8 files changed

+261
-13
lines changed

8 files changed

+261
-13
lines changed

python/pyarrow/includes/libarrow_python.pxd

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,14 @@ cdef extern from "arrow/python/ipc.h" namespace "arrow::py":
283283
object)
284284

285285

286+
cdef extern from "arrow/python/ipc.h" namespace "arrow::py" nogil:
287+
cdef cppclass CCastingRecordBatchReader" arrow::py::CastingRecordBatchReader" \
288+
(CRecordBatchReader):
289+
@staticmethod
290+
CResult[shared_ptr[CRecordBatchReader]] Make(shared_ptr[CRecordBatchReader],
291+
shared_ptr[CSchema])
292+
293+
286294
cdef extern from "arrow/python/extension_type.h" namespace "arrow::py":
287295
cdef cppclass CPyExtensionType \
288296
" arrow::py::PyExtensionType"(CExtensionType):

python/pyarrow/ipc.pxi

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -772,6 +772,38 @@ cdef class RecordBatchReader(_Weakrefable):
772772
def __exit__(self, exc_type, exc_val, exc_tb):
773773
self.close()
774774

775+
def cast(self, target_schema):
776+
"""
777+
Wrap this reader with one that casts each batch lazily as it is pulled.
778+
Currently only a safe cast to target_schema is implemented.
779+
780+
Parameters
781+
----------
782+
target_schema : Schema
783+
Schema to cast to, the names and order of fields must match.
784+
785+
Returns
786+
-------
787+
RecordBatchReader
788+
"""
789+
cdef:
790+
shared_ptr[CSchema] c_schema
791+
shared_ptr[CRecordBatchReader] c_reader
792+
RecordBatchReader out
793+
794+
if self.schema.names != target_schema.names:
795+
raise ValueError("Target schema's field names are not matching "
796+
f"the table's field names: {self.schema.names}, "
797+
f"{target_schema.names}")
798+
799+
c_schema = pyarrow_unwrap_schema(target_schema)
800+
c_reader = GetResultValue(CCastingRecordBatchReader.Make(
801+
self.reader, c_schema))
802+
803+
out = RecordBatchReader.__new__(RecordBatchReader)
804+
out.reader = c_reader
805+
return out
806+
775807
def _export_to_c(self, out_ptr):
776808
"""
777809
Export to a C ArrowArrayStream struct, given its pointer.
@@ -827,8 +859,6 @@ cdef class RecordBatchReader(_Weakrefable):
827859
The schema to which the stream should be casted, passed as a
828860
PyCapsule containing a C ArrowSchema representation of the
829861
requested schema.
830-
Currently, this is not supported and will raise a
831-
NotImplementedError if the schema doesn't match the current schema.
832862
833863
Returns
834864
-------
@@ -840,11 +870,8 @@ cdef class RecordBatchReader(_Weakrefable):
840870

841871
if requested_schema is not None:
842872
out_schema = Schema._import_from_c_capsule(requested_schema)
843-
# TODO: figure out a way to check if one schema is castable to
844-
# another. Once we have that, we can perform validation here and
845-
# if successful creating a wrapping reader that casts each batch.
846873
if self.schema != out_schema:
847-
raise NotImplementedError("Casting to requested_schema")
874+
return self.cast(out_schema).__arrow_c_stream__()
848875

849876
stream_capsule = alloc_c_stream(&c_stream)
850877

python/pyarrow/src/arrow/python/ipc.cc

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include <memory>
2121

22+
#include "arrow/compute/cast.h"
2223
#include "arrow/python/pyarrow.h"
2324

2425
namespace arrow {
@@ -63,5 +64,70 @@ Result<std::shared_ptr<RecordBatchReader>> PyRecordBatchReader::Make(
6364
return reader;
6465
}
6566

67+
CastingRecordBatchReader::CastingRecordBatchReader() = default;
68+
69+
Status CastingRecordBatchReader::Init(std::shared_ptr<RecordBatchReader> parent,
70+
std::shared_ptr<Schema> schema) {
71+
std::shared_ptr<Schema> src = parent->schema();
72+
73+
// The check for names has already been done in Python where it's easier to
74+
// generate a nice error message.
75+
int num_fields = schema->num_fields();
76+
if (src->num_fields() != num_fields) {
77+
return Status::Invalid("Number of fields not equal");
78+
}
79+
80+
// Ensure all columns can be cast before succeeding
81+
for (int i = 0; i < num_fields; i++) {
82+
if (!compute::CanCast(*src->field(i)->type(), *schema->field(i)->type())) {
83+
return Status::TypeError("Field ", i, " cannot be cast from ",
84+
src->field(i)->type()->ToString(), " to ",
85+
schema->field(i)->type()->ToString());
86+
}
87+
}
88+
89+
parent_ = std::move(parent);
90+
schema_ = std::move(schema);
91+
92+
return Status::OK();
93+
}
94+
95+
std::shared_ptr<Schema> CastingRecordBatchReader::schema() const { return schema_; }
96+
97+
Status CastingRecordBatchReader::ReadNext(std::shared_ptr<RecordBatch>* batch) {
98+
std::shared_ptr<RecordBatch> out;
99+
ARROW_RETURN_NOT_OK(parent_->ReadNext(&out));
100+
if (!out) {
101+
batch->reset();
102+
return Status::OK();
103+
}
104+
105+
auto num_columns = out->num_columns();
106+
auto options = compute::CastOptions::Safe();
107+
ArrayVector columns(num_columns);
108+
for (int i = 0; i < num_columns; i++) {
109+
const Array& src = *out->column(i);
110+
if (!schema_->field(i)->nullable() && src.null_count() > 0) {
111+
return Status::Invalid(
112+
"Can't cast array that contains nulls to non-nullable field at index ", i);
113+
}
114+
115+
ARROW_ASSIGN_OR_RAISE(columns[i],
116+
compute::Cast(src, schema_->field(i)->type(), options));
117+
}
118+
119+
*batch = RecordBatch::Make(schema_, out->num_rows(), std::move(columns));
120+
return Status::OK();
121+
}
122+
123+
Result<std::shared_ptr<RecordBatchReader>> CastingRecordBatchReader::Make(
124+
std::shared_ptr<RecordBatchReader> parent, std::shared_ptr<Schema> schema) {
125+
auto reader = std::shared_ptr<CastingRecordBatchReader>(new CastingRecordBatchReader());
126+
ARROW_RETURN_NOT_OK(reader->Init(parent, schema));
127+
return reader;
128+
}
129+
130+
Status CastingRecordBatchReader::Close() { return parent_->Close(); }
131+
66132
} // namespace py
67133
} // namespace arrow

python/pyarrow/src/arrow/python/ipc.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,25 @@ class ARROW_PYTHON_EXPORT PyRecordBatchReader : public RecordBatchReader {
4848
OwnedRefNoGIL iterator_;
4949
};
5050

51+
class ARROW_PYTHON_EXPORT CastingRecordBatchReader : public RecordBatchReader {
52+
public:
53+
std::shared_ptr<Schema> schema() const override;
54+
55+
Status ReadNext(std::shared_ptr<RecordBatch>* batch) override;
56+
57+
static Result<std::shared_ptr<RecordBatchReader>> Make(
58+
std::shared_ptr<RecordBatchReader> parent, std::shared_ptr<Schema> schema);
59+
60+
Status Close() override;
61+
62+
protected:
63+
CastingRecordBatchReader();
64+
65+
Status Init(std::shared_ptr<RecordBatchReader> parent, std::shared_ptr<Schema> schema);
66+
67+
std::shared_ptr<RecordBatchReader> parent_;
68+
std::shared_ptr<Schema> schema_;
69+
};
70+
5171
} // namespace py
5272
} // namespace arrow

python/pyarrow/table.pxi

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2742,6 +2742,29 @@ cdef class RecordBatch(_Tabular):
27422742

27432743
return pyarrow_wrap_batch(c_batch)
27442744

2745+
def cast(self, Schema target_schema, safe=None, options=None):
2746+
"""
2747+
Cast batch values to another schema.
2748+
2749+
Parameters
2750+
----------
2751+
target_schema : Schema
2752+
Schema to cast to, the names and order of fields must match.
2753+
safe : bool, default True
2754+
Check for overflows or other unsafe conversions.
2755+
options : CastOptions, default None
2756+
Additional checks pass by CastOptions
2757+
2758+
Returns
2759+
-------
2760+
RecordBatch
2761+
"""
2762+
# Wrap the more general Table cast implementation
2763+
tbl = Table.from_batches([self])
2764+
casted_tbl = tbl.cast(target_schema, safe=safe, options=options)
2765+
casted_batch, = casted_tbl.to_batches()
2766+
return casted_batch
2767+
27452768
def _to_pandas(self, options, **kwargs):
27462769
return Table.from_batches([self])._to_pandas(options, **kwargs)
27472770

python/pyarrow/tests/test_cffi.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -633,9 +633,8 @@ def test_roundtrip_reader_capsule(constructor):
633633

634634
obj = constructor(schema, batches)
635635

636-
# TODO: turn this to ValueError once we implement validation.
637636
bad_schema = pa.schema({'ints': pa.int32()})
638-
with pytest.raises(NotImplementedError):
637+
with pytest.raises(pa.lib.ArrowTypeError, match="Field 0 cannot be cast"):
639638
obj.__arrow_c_stream__(bad_schema.__arrow_c_schema__())
640639

641640
# Can work with matching schema
@@ -647,6 +646,21 @@ def test_roundtrip_reader_capsule(constructor):
647646
assert batch.equals(expected)
648647

649648

649+
def test_roundtrip_batch_reader_capsule_requested_schema():
650+
batch = make_batch()
651+
requested_schema = pa.schema([('ints', pa.list_(pa.int64()))])
652+
requested_capsule = requested_schema.__arrow_c_schema__()
653+
batch_as_requested = batch.cast(requested_schema)
654+
655+
capsule = batch.__arrow_c_stream__(requested_capsule)
656+
assert PyCapsule_IsValid(capsule, b"arrow_array_stream") == 1
657+
imported_reader = pa.RecordBatchReader._import_from_c_capsule(capsule)
658+
assert imported_reader.schema == requested_schema
659+
assert imported_reader.read_next_batch().equals(batch_as_requested)
660+
with pytest.raises(StopIteration):
661+
imported_reader.read_next_batch()
662+
663+
650664
def test_roundtrip_batch_reader_capsule():
651665
batch = make_batch()
652666

python/pyarrow/tests/test_ipc.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1226,10 +1226,15 @@ def __arrow_c_stream__(self, requested_schema=None):
12261226
reader = pa.RecordBatchReader.from_stream(wrapper, schema=data[0].schema)
12271227
assert reader.read_all() == expected
12281228

1229-
# If schema doesn't match, raises NotImplementedError
1230-
with pytest.raises(NotImplementedError):
1229+
# Passing a different but castable schema works
1230+
good_schema = pa.schema([pa.field("a", pa.int32())])
1231+
reader = pa.RecordBatchReader.from_stream(wrapper, schema=good_schema)
1232+
assert reader.read_all() == expected.cast(good_schema)
1233+
1234+
# If schema doesn't match, raises TypeError
1235+
with pytest.raises(pa.lib.ArrowTypeError, match='Field 0 cannot be cast'):
12311236
pa.RecordBatchReader.from_stream(
1232-
wrapper, schema=pa.schema([pa.field('a', pa.int32())])
1237+
wrapper, schema=pa.schema([pa.field('a', pa.list_(pa.int32()))])
12331238
)
12341239

12351240
# Proper type errors for wrong input
@@ -1238,3 +1243,60 @@ def __arrow_c_stream__(self, requested_schema=None):
12381243

12391244
with pytest.raises(TypeError):
12401245
pa.RecordBatchReader.from_stream(expected, schema=data[0])
1246+
1247+
1248+
def test_record_batch_reader_cast():
1249+
schema_src = pa.schema([pa.field('a', pa.int64())])
1250+
data = [
1251+
pa.record_batch([pa.array([1, 2, 3], type=pa.int64())], names=['a']),
1252+
pa.record_batch([pa.array([4, 5, 6], type=pa.int64())], names=['a']),
1253+
]
1254+
table_src = pa.Table.from_batches(data)
1255+
1256+
# Cast to same type should always work
1257+
reader = pa.RecordBatchReader.from_batches(schema_src, data)
1258+
assert reader.cast(schema_src).read_all() == table_src
1259+
1260+
# Check non-trivial cast
1261+
schema_dst = pa.schema([pa.field('a', pa.int32())])
1262+
reader = pa.RecordBatchReader.from_batches(schema_src, data)
1263+
assert reader.cast(schema_dst).read_all() == table_src.cast(schema_dst)
1264+
1265+
# Check error for field name/length mismatch
1266+
reader = pa.RecordBatchReader.from_batches(schema_src, data)
1267+
with pytest.raises(ValueError, match="Target schema's field names"):
1268+
reader.cast(pa.schema([]))
1269+
1270+
# Check error for impossible cast in call to .cast()
1271+
reader = pa.RecordBatchReader.from_batches(schema_src, data)
1272+
with pytest.raises(pa.lib.ArrowTypeError, match='Field 0 cannot be cast'):
1273+
reader.cast(pa.schema([pa.field('a', pa.list_(pa.int32()))]))
1274+
1275+
1276+
def test_record_batch_reader_cast_nulls():
1277+
schema_src = pa.schema([pa.field('a', pa.int64())])
1278+
data_with_nulls = [
1279+
pa.record_batch([pa.array([1, 2, None], type=pa.int64())], names=['a']),
1280+
]
1281+
data_without_nulls = [
1282+
pa.record_batch([pa.array([1, 2, 3], type=pa.int64())], names=['a']),
1283+
]
1284+
table_with_nulls = pa.Table.from_batches(data_with_nulls)
1285+
table_without_nulls = pa.Table.from_batches(data_without_nulls)
1286+
1287+
# Cast to nullable destination should work
1288+
reader = pa.RecordBatchReader.from_batches(schema_src, data_with_nulls)
1289+
schema_dst = pa.schema([pa.field('a', pa.int32())])
1290+
assert reader.cast(schema_dst).read_all() == table_with_nulls.cast(schema_dst)
1291+
1292+
# Cast to non-nullable destination should work if there are no nulls
1293+
reader = pa.RecordBatchReader.from_batches(schema_src, data_without_nulls)
1294+
schema_dst = pa.schema([pa.field('a', pa.int32(), nullable=False)])
1295+
assert reader.cast(schema_dst).read_all() == table_without_nulls.cast(schema_dst)
1296+
1297+
# Cast to non-nullable destination should error if there are nulls
1298+
# when the batch is pulled
1299+
reader = pa.RecordBatchReader.from_batches(schema_src, data_with_nulls)
1300+
casted_reader = reader.cast(schema_dst)
1301+
with pytest.raises(pa.lib.ArrowInvalid, match="Can't cast array"):
1302+
casted_reader.read_all()

python/pyarrow/tests/test_table.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -635,9 +635,18 @@ def __arrow_c_stream__(self, requested_schema=None):
635635
result = pa.table(wrapper, schema=data[0].schema)
636636
assert result == expected
637637

638+
# Passing a different schema will cast
639+
good_schema = pa.schema([pa.field('a', pa.int32())])
640+
result = pa.table(wrapper, schema=good_schema)
641+
assert result == expected.cast(good_schema)
642+
638643
# If schema doesn't match, raises NotImplementedError
639-
with pytest.raises(NotImplementedError):
640-
pa.table(wrapper, schema=pa.schema([pa.field('a', pa.int32())]))
644+
with pytest.raises(
645+
pa.lib.ArrowTypeError, match="Field 0 cannot be cast"
646+
):
647+
pa.table(
648+
wrapper, schema=pa.schema([pa.field('a', pa.list_(pa.int32()))])
649+
)
641650

642651

643652
def test_recordbatch_itercolumns():
@@ -2620,6 +2629,25 @@ def test_record_batch_sort():
26202629
assert sorted_rb_dict["c"] == ["foobar", "bar", "foo", "car"]
26212630

26222631

2632+
def test_record_batch_cast():
2633+
rb = pa.RecordBatch.from_arrays([
2634+
pa.array([None, 1]),
2635+
pa.array([False, True])
2636+
], names=["a", "b"])
2637+
new_schema = pa.schema([pa.field("a", "int64", nullable=True),
2638+
pa.field("b", "bool", nullable=False)])
2639+
2640+
assert rb.cast(new_schema).schema == new_schema
2641+
2642+
# Casting a nullable field to non-nullable is invalid
2643+
rb = pa.RecordBatch.from_arrays([
2644+
pa.array([None, 1]),
2645+
pa.array([None, True])
2646+
], names=["a", "b"])
2647+
with pytest.raises(ValueError):
2648+
rb.cast(new_schema)
2649+
2650+
26232651
@pytest.mark.parametrize("constructor", [pa.table, pa.record_batch])
26242652
def test_numpy_asarray(constructor):
26252653
table = constructor([[1, 2, 3], [4.0, 5.0, 6.0]], names=["a", "b"])

0 commit comments

Comments
 (0)