New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Rust bindings for Schema #650
Changes from 2 commits
02da455
398d659
1f41d9f
a017894
d677206
7584ad1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Iterator, TypeVar | ||
|
||
from daft.daft import PyField as _PyField | ||
from daft.daft import PySchema as _PySchema | ||
from daft.datatype import DataType | ||
from daft.expressions2 import Expression, col | ||
|
||
ExpressionType = TypeVar("ExpressionType", bound=Expression) | ||
|
||
|
||
class Field: | ||
_field: _PyField | ||
|
||
def __init__(self) -> None: | ||
raise NotImplementedError("We do not support creating a Field via __init__ ") | ||
|
||
@staticmethod | ||
def _from_pyfield(field: _PyField) -> Field: | ||
f = Field.__new__(Field) | ||
f._field = field | ||
return f | ||
|
||
@property | ||
def name(self): | ||
return self._field.name() | ||
|
||
@property | ||
def dtype(self) -> DataType: | ||
return DataType._from_pydatatype(self._field.dtype()) | ||
|
||
def __eq__(self, other: object) -> bool: | ||
if not isinstance(other, Field): | ||
return False | ||
return self._field.eq(other._field) | ||
|
||
|
||
class Schema: | ||
_schema: _PySchema | ||
|
||
def __init__(self) -> None: | ||
raise NotImplementedError("We do not support creating a Schema via __init__ ") | ||
|
||
@staticmethod | ||
def _from_pyschema(schema: _PySchema) -> Schema: | ||
s = Schema.__new__(Schema) | ||
s._schema = schema | ||
return s | ||
|
||
def __getitem__(self, key: str) -> Field: | ||
if key not in self._schema.names(): | ||
raise ValueError(f"{key} was not found in Schema of fields {self._schema.field_names()}") | ||
pyfield = self._schema[key] | ||
return Field._from_pyfield(pyfield) | ||
|
||
def __len__(self) -> int: | ||
return len(self._schema.names()) | ||
|
||
def column_names(self) -> list[str]: | ||
return list(self._schema.names()) | ||
|
||
def __iter__(self) -> Iterator[Field]: | ||
col_names = self.column_names() | ||
return (self[name] for name in col_names) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. return There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This actually returns a generator rather than a
I can do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nah its fine then |
||
|
||
def __eq__(self, other: object) -> bool: | ||
return isinstance(other, Schema) and self._schema.eq(other._schema) | ||
|
||
def to_name_set(self) -> set[str]: | ||
return set(self.column_names()) | ||
|
||
def __repr__(self) -> str: | ||
return repr([(field.name, field.dtype) for field in self]) | ||
|
||
def to_column_expressions(self) -> list[Expression]: | ||
return [col(f.name) for f in self] | ||
|
||
def union(self, other: Schema) -> Schema: | ||
if not isinstance(other, Schema): | ||
raise ValueError(f"Expected Schema, got other: {type(other)}") | ||
|
||
intersecting_names = self.to_name_set().intersection(other.to_name_set()) | ||
if intersecting_names: | ||
raise ValueError(f"Cannot union schemas with overlapping names: {intersecting_names}") | ||
|
||
return Schema._from_pyschema(self._schema.union(other._schema)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
use pyo3::prelude::*; | ||
|
||
use crate::datatypes; | ||
use crate::python::datatype; | ||
use crate::schema; | ||
|
||
#[pyclass] | ||
pub struct PySchema { | ||
pub schema: schema::SchemaRef, | ||
} | ||
|
||
#[pyclass] | ||
pub struct PyField { | ||
pub field: datatypes::Field, | ||
} | ||
|
||
#[pymethods] | ||
impl PySchema { | ||
pub fn __getitem__(&self, name: &str) -> PyResult<PyField> { | ||
Ok(self.schema.get_field(name)?.clone().into()) | ||
} | ||
|
||
pub fn names(&self) -> PyResult<Vec<String>> { | ||
Ok(self.schema.names()?) | ||
} | ||
|
||
pub fn union(&self, other: &PySchema) -> PyResult<PySchema> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this union method should be implemented in |
||
let fields: Vec<datatypes::Field> = self | ||
.schema | ||
.fields | ||
.values() | ||
.cloned() | ||
.chain(other.schema.fields.values().cloned()) | ||
.collect(); | ||
let new_schema = schema::Schema::new(fields); | ||
let new_pyschema = PySchema { | ||
schema: new_schema.into(), | ||
}; | ||
Ok(new_pyschema) | ||
} | ||
|
||
pub fn eq(&self, other: &PySchema) -> PyResult<bool> { | ||
Ok(self.schema.fields.eq(&other.schema.fields)) | ||
} | ||
} | ||
|
||
#[pymethods] | ||
impl PyField { | ||
pub fn name(&self) -> PyResult<String> { | ||
Ok(self.field.name.clone()) | ||
} | ||
|
||
pub fn dtype(&self) -> PyResult<datatype::PyDataType> { | ||
Ok(self.field.dtype.clone().into()) | ||
} | ||
|
||
pub fn eq(&self, other: &PyField) -> PyResult<bool> { | ||
Ok(self.field.eq(&other.field)) | ||
} | ||
} | ||
|
||
impl From<datatypes::Field> for PyField { | ||
fn from(item: datatypes::Field) -> Self { | ||
PyField { field: item } | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from __future__ import annotations | ||
|
||
import pytest | ||
|
||
from daft.datatype import DataType | ||
from daft.expressions2 import col | ||
from daft.table import Table | ||
|
||
DATA = { | ||
"int": ([1, 2, None], DataType.int64()), | ||
"float": ([1.0, 2.0, None], DataType.float64()), | ||
"string": (["a", "b", None], DataType.string()), | ||
"bool": ([True, True, None], DataType.bool()), | ||
} | ||
|
||
TABLE = Table.from_pydict({k: data for k, (data, _) in DATA.items()}) | ||
EXPECTED_TYPES = {k: t for k, (_, t) in DATA.items()} | ||
|
||
|
||
def test_schema_len(): | ||
schema = TABLE.schema() | ||
assert len(schema) == len(DATA) | ||
|
||
|
||
def test_schema_column_names(): | ||
schema = TABLE.schema() | ||
assert schema.column_names() == list(DATA.keys()) | ||
|
||
|
||
def test_schema_field_types(): | ||
schema = TABLE.schema() | ||
for key in EXPECTED_TYPES: | ||
assert schema[key].name == key | ||
assert schema[key].dtype == EXPECTED_TYPES[key] | ||
|
||
|
||
def test_schema_iter(): | ||
schema = TABLE.schema() | ||
for expected_name, field in zip(EXPECTED_TYPES, schema): | ||
assert field.name == expected_name | ||
assert field.dtype == EXPECTED_TYPES[expected_name] | ||
|
||
|
||
def test_schema_eq(): | ||
t1, t2 = Table.from_pydict({k: data for k, (data, _) in DATA.items()}), Table.from_pydict( | ||
{k: data for k, (data, _) in DATA.items()} | ||
) | ||
s1, s2 = t1.schema(), t2.schema() | ||
assert s1 == s2 | ||
|
||
t_empty = Table.empty() | ||
assert s1 != t_empty.schema() | ||
|
||
|
||
def test_schema_to_name_set(): | ||
schema = TABLE.schema() | ||
assert schema.to_name_set() == set(DATA.keys()) | ||
|
||
|
||
def test_repr(): | ||
schema = TABLE.schema() | ||
assert ( | ||
repr(schema) | ||
== "[('int', DataType(Int64)), ('float', DataType(Float64)), ('string', DataType(Utf8)), ('bool', DataType(Boolean))]" | ||
) | ||
|
||
|
||
def test_to_col_expr(): | ||
schema = TABLE.schema() | ||
schema_col_exprs = schema.to_column_expressions() | ||
expected_col_exprs = [col(n) for n in schema.column_names()] | ||
|
||
assert len(schema_col_exprs) == len(expected_col_exprs) | ||
for sce, ece in zip(schema_col_exprs, expected_col_exprs): | ||
assert sce.name() == ece.name() | ||
|
||
|
||
def test_union_err(): | ||
schema = TABLE.schema() | ||
with pytest.raises(ValueError): | ||
schema.union(schema) | ||
|
||
new_data = {f"{k}_": d for k, (d, _) in DATA.items()} | ||
new_table = Table.from_pydict(new_data) | ||
unioned_schema = schema.union(new_table.schema()) | ||
|
||
assert unioned_schema.column_names() == list(DATA.keys()) + list(new_data.keys()) | ||
assert list(unioned_schema) == list(schema) + list(new_table.schema()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
runtime check that key is a
str
to avoid panic in pyo3