Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions dataframely/columns/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,17 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
# NOTE: We might want to add support for PostgreSQL's ARRAY type or use JSON in the future.
raise NotImplementedError("SQL column cannot have 'Array' type.")

def _pyarrow_dtype_of_shape(self, shape: Sequence[int]) -> pa.DataType:
def _pyarrow_field_of_shape(self, shape: Sequence[int]) -> pa.Field:
if shape:
size, *rest = shape
return pa.list_(self._pyarrow_dtype_of_shape(rest), size)
inner_type = self._pyarrow_field_of_shape(rest)
return pa.field("item", pa.list_(inner_type, size), nullable=True)
else:
return self.inner.pyarrow_dtype
return self.inner.pyarrow_field("item")

@property
def pyarrow_dtype(self) -> pa.DataType:
return self._pyarrow_dtype_of_shape(self.shape)
return self._pyarrow_field_of_shape(self.shape).type

def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
# Sample the inner elements in a flat series
Expand Down
2 changes: 1 addition & 1 deletion dataframely/columns/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
@property
def pyarrow_dtype(self) -> pa.DataType:
# NOTE: Polars uses `large_list`s by default.
return pa.large_list(self.inner.pyarrow_dtype)
return pa.large_list(self.inner.pyarrow_field("item"))

def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
# First, sample the number of items per list element
Expand Down
2 changes: 1 addition & 1 deletion dataframely/columns/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:

@property
def pyarrow_dtype(self) -> pa.DataType:
return pa.struct({name: col.pyarrow_dtype for name, col in self.inner.items()})
return pa.struct([col.pyarrow_field(name) for name, col in self.inner.items()])

def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
series = (
Expand Down
103 changes: 96 additions & 7 deletions tests/columns/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,10 @@ def test_equal_polars_schema_enum(categories: list[str]) -> None:

@pytest.mark.parametrize(
"inner",
[c() for c in ALL_COLUMN_TYPES]
+ [dy.List(t()) for t in ALL_COLUMN_TYPES]
+ [
dy.Array(t() if t == dy.Any else t(nullable=True), 1)
for t in NO_VALIDATION_COLUMN_TYPES
]
+ [dy.Struct({"a": t()}) for t in ALL_COLUMN_TYPES],
[_nullable(c) for c in ALL_COLUMN_TYPES]
+ [dy.List(_nullable(t), nullable=True) for t in ALL_COLUMN_TYPES]
+ [dy.Array(_nullable(t), 1, nullable=True) for t in NO_VALIDATION_COLUMN_TYPES]
+ [dy.Struct({"a": _nullable(t)}, nullable=True) for t in ALL_COLUMN_TYPES],
)
def test_equal_polars_schema_list(inner: Column) -> None:
schema = create_schema("test", {"a": dy.List(inner, nullable=True)})
Expand Down Expand Up @@ -161,6 +158,98 @@ def test_nullability_information_struct(inner: Column, nullable: bool) -> None:
assert ("not null" in str(schema.to_pyarrow_schema())) != nullable


@pytest.mark.parametrize("column_type", COLUMN_TYPES)
@pytest.mark.parametrize("inner_nullable", [True, False])
def test_inner_nullability_struct(
column_type: type[Column], inner_nullable: bool
) -> None:
inner = column_type(nullable=inner_nullable)
schema = create_schema("test", {"a": dy.Struct({"a": inner})})
pa_schema = schema.to_pyarrow_schema()
struct_field = pa_schema.field("a")
inner_field = struct_field.type[0]
assert inner_field.nullable == inner_nullable


@pytest.mark.parametrize("column_type", COLUMN_TYPES)
@pytest.mark.parametrize("inner_nullable", [True, False])
def test_inner_nullability_list(
column_type: type[Column], inner_nullable: bool
) -> None:
inner = column_type(nullable=inner_nullable)
schema = create_schema("test", {"a": dy.List(inner)})
pa_schema = schema.to_pyarrow_schema()
list_field = pa_schema.field("a")
inner_field = list_field.type.value_field
assert inner_field.nullable == inner_nullable


def test_nested_struct_in_list_preserves_nullability() -> None:
"""Test that nested struct fields in lists preserve nullability."""
schema = create_schema(
"test",
{
"a": dy.List(
dy.Struct(
{
"required": dy.String(nullable=False),
"optional": dy.String(nullable=True),
},
nullable=True,
),
nullable=True,
)
},
)
pa_schema = schema.to_pyarrow_schema()
list_field = pa_schema.field("a")
struct_type = list_field.type.value_field.type
assert not struct_type[0].nullable
assert struct_type[1].nullable


def test_nested_list_in_struct_preserves_nullability() -> None:
"""Test that nested list fields in structs preserve nullability."""
schema = create_schema(
"test",
{
"a": dy.Struct(
{"list_field": dy.List(dy.String(nullable=False), nullable=True)},
nullable=True,
)
},
)
pa_schema = schema.to_pyarrow_schema()
struct_field = pa_schema.field("a")
list_type = struct_field.type[0].type
assert not list_type.value_field.nullable


def test_deeply_nested_nullability() -> None:
schema = create_schema(
"test",
{
"a": dy.Struct(
{
"nested": dy.Struct(
{
"required": dy.String(nullable=False),
"optional": dy.String(nullable=True),
},
nullable=True,
),
},
nullable=True,
)
},
)
pa_schema = schema.to_pyarrow_schema()
outer_struct = pa_schema.field("a").type
inner_struct = outer_struct[0].type
assert not inner_struct[0].nullable # required field
assert inner_struct[1].nullable # optional field


def test_multiple_columns() -> None:
schema = create_schema(
"test", {"a": dy.Int32(nullable=False), "b": dy.Integer(nullable=True)}
Expand Down
Loading