Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Rust bindings for Schema #650

Merged
merged 6 commits into from Mar 3, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
87 changes: 87 additions & 0 deletions daft/logical/schema2.py
@@ -0,0 +1,87 @@
from __future__ import annotations

from typing import Iterator, TypeVar

from daft.daft import PyField as _PyField
from daft.daft import PySchema as _PySchema
from daft.datatype import DataType
from daft.expressions2 import Expression, col

ExpressionType = TypeVar("ExpressionType", bound=Expression)


class Field:
_field: _PyField

def __init__(self) -> None:
raise NotImplementedError("We do not support creating a Field via __init__ ")

@staticmethod
def _from_pyfield(field: _PyField) -> Field:
f = Field.__new__(Field)
f._field = field
return f

@property
def name(self):
return self._field.name()

@property
def dtype(self) -> DataType:
return DataType._from_pydatatype(self._field.dtype())

def __eq__(self, other: object) -> bool:
if not isinstance(other, Field):
return False
return self._field.eq(other._field)


class Schema:
_schema: _PySchema

def __init__(self) -> None:
raise NotImplementedError("We do not support creating a Schema via __init__ ")

@staticmethod
def _from_pyschema(schema: _PySchema) -> Schema:
s = Schema.__new__(Schema)
s._schema = schema
return s

def __getitem__(self, key: str) -> Field:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

runtime check that key is a str to avoid panic in pyo3

if key not in self._schema.names():
raise ValueError(f"{key} was not found in Schema of fields {self._schema.field_names()}")
pyfield = self._schema[key]
return Field._from_pyfield(pyfield)

def __len__(self) -> int:
return len(self._schema.names())

def column_names(self) -> list[str]:
return list(self._schema.names())

def __iter__(self) -> Iterator[Field]:
col_names = self.column_names()
return (self[name] for name in col_names)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return iter of the tuple

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually returns a generator rather than a tuple:

>>> x = (i for i in range(10))
>>> type(x)
<class 'generator'>

I can do yield from instead if its clearer?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nah its fine then


def __eq__(self, other: object) -> bool:
return isinstance(other, Schema) and self._schema.eq(other._schema)

def to_name_set(self) -> set[str]:
return set(self.column_names())

def __repr__(self) -> str:
return repr([(field.name, field.dtype) for field in self])

def to_column_expressions(self) -> list[Expression]:
return [col(f.name) for f in self]

def union(self, other: Schema) -> Schema:
if not isinstance(other, Schema):
raise ValueError(f"Expected Schema, got other: {type(other)}")

intersecting_names = self.to_name_set().intersection(other.to_name_set())
if intersecting_names:
raise ValueError(f"Cannot union schemas with overlapping names: {intersecting_names}")

return Schema._from_pyschema(self._schema.union(other._schema))
4 changes: 4 additions & 0 deletions daft/table.py
Expand Up @@ -4,6 +4,7 @@

from daft.daft import PyTable as _PyTable
from daft.expressions2 import Expression
from daft.logical.schema2 import Schema
from daft.series import Series


Expand Down Expand Up @@ -35,6 +36,9 @@ def from_pydict(data: dict) -> Table:
pya_table = pa.Table.from_pydict(data)
return Table.from_arrow(pya_table)

def schema(self) -> Schema:
return Schema._from_pyschema(self._table.schema())

def to_arrow(self) -> pa.Table:
return pa.Table.from_batches([self._table.to_arrow_record_batch()])

Expand Down
3 changes: 3 additions & 0 deletions src/python/mod.rs
Expand Up @@ -2,6 +2,7 @@ use pyo3::prelude::*;
mod datatype;
mod error;
mod expr;
mod schema;
mod series;
mod table;

Expand All @@ -10,6 +11,8 @@ pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> {
parent.add_class::<table::PyTable>()?;
parent.add_class::<series::PySeries>()?;
parent.add_class::<datatype::PyDataType>()?;
parent.add_class::<schema::PySchema>()?;
parent.add_class::<schema::PyField>()?;

parent.add_wrapped(wrap_pyfunction!(expr::col))?;
parent.add_wrapped(wrap_pyfunction!(expr::lit))?;
Expand Down
66 changes: 66 additions & 0 deletions src/python/schema.rs
@@ -0,0 +1,66 @@
use pyo3::prelude::*;

use crate::datatypes;
use crate::python::datatype;
use crate::schema;

#[pyclass]
pub struct PySchema {
pub schema: schema::SchemaRef,
}

#[pyclass]
pub struct PyField {
pub field: datatypes::Field,
}

#[pymethods]
impl PySchema {
pub fn __getitem__(&self, name: &str) -> PyResult<PyField> {
Ok(self.schema.get_field(name)?.clone().into())
}

pub fn names(&self) -> PyResult<Vec<String>> {
Ok(self.schema.names()?)
}

pub fn union(&self, other: &PySchema) -> PyResult<PySchema> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this union method should be implemented in Schema rather than having the logic in the pybindings

let fields: Vec<datatypes::Field> = self
.schema
.fields
.values()
.cloned()
.chain(other.schema.fields.values().cloned())
.collect();
let new_schema = schema::Schema::new(fields);
let new_pyschema = PySchema {
schema: new_schema.into(),
};
Ok(new_pyschema)
}

pub fn eq(&self, other: &PySchema) -> PyResult<bool> {
Ok(self.schema.fields.eq(&other.schema.fields))
}
}

#[pymethods]
impl PyField {
pub fn name(&self) -> PyResult<String> {
Ok(self.field.name.clone())
}

pub fn dtype(&self) -> PyResult<datatype::PyDataType> {
Ok(self.field.dtype.clone().into())
}

pub fn eq(&self, other: &PyField) -> PyResult<bool> {
Ok(self.field.eq(&other.field))
}
}

impl From<datatypes::Field> for PyField {
fn from(item: datatypes::Field) -> Self {
PyField { field: item }
}
}
7 changes: 7 additions & 0 deletions src/python/table.rs
Expand Up @@ -7,6 +7,7 @@ use crate::table;

use crate::python::expr::PyExpr;

use super::schema::PySchema;
use super::series::PySeries;

#[pyclass]
Expand All @@ -16,6 +17,12 @@ pub struct PyTable {

#[pymethods]
impl PyTable {
pub fn schema(&self) -> PyResult<PySchema> {
Ok(PySchema {
schema: self.table.schema.clone(),
})
}

pub fn eval_expression_list(&self, exprs: Vec<PyExpr>) -> PyResult<Self> {
let converted_exprs: Vec<dsl::Expr> = exprs.into_iter().map(|e| e.into()).collect();
Ok(self
Expand Down
2 changes: 1 addition & 1 deletion src/schema.rs
Expand Up @@ -10,7 +10,7 @@ use crate::{
error::{DaftError, DaftResult},
};

type SchemaRef = Arc<Schema>;
pub type SchemaRef = Arc<Schema>;

pub struct Schema {
pub fields: indexmap::IndexMap<String, Field>,
Expand Down
5 changes: 2 additions & 3 deletions src/table/mod.rs
@@ -1,15 +1,14 @@
use std::fmt::{Display, Formatter, Result};
use std::sync::Arc;

use crate::datatypes::{BooleanType, DataType, Field};
use crate::dsl::Expr;
use crate::error::{DaftError, DaftResult};
use crate::schema::Schema;
use crate::schema::{Schema, SchemaRef};
use crate::series::Series;

#[derive(Clone)]
pub struct Table {
schema: Arc<Schema>,
pub schema: SchemaRef,
columns: Vec<Series>,
}

Expand Down
88 changes: 88 additions & 0 deletions tests/logical/test_schema2.py
@@ -0,0 +1,88 @@
from __future__ import annotations

import pytest

from daft.datatype import DataType
from daft.expressions2 import col
from daft.table import Table

DATA = {
"int": ([1, 2, None], DataType.int64()),
"float": ([1.0, 2.0, None], DataType.float64()),
"string": (["a", "b", None], DataType.string()),
"bool": ([True, True, None], DataType.bool()),
}

TABLE = Table.from_pydict({k: data for k, (data, _) in DATA.items()})
EXPECTED_TYPES = {k: t for k, (_, t) in DATA.items()}


def test_schema_len():
schema = TABLE.schema()
assert len(schema) == len(DATA)


def test_schema_column_names():
schema = TABLE.schema()
assert schema.column_names() == list(DATA.keys())


def test_schema_field_types():
schema = TABLE.schema()
for key in EXPECTED_TYPES:
assert schema[key].name == key
assert schema[key].dtype == EXPECTED_TYPES[key]


def test_schema_iter():
schema = TABLE.schema()
for expected_name, field in zip(EXPECTED_TYPES, schema):
assert field.name == expected_name
assert field.dtype == EXPECTED_TYPES[expected_name]


def test_schema_eq():
t1, t2 = Table.from_pydict({k: data for k, (data, _) in DATA.items()}), Table.from_pydict(
{k: data for k, (data, _) in DATA.items()}
)
s1, s2 = t1.schema(), t2.schema()
assert s1 == s2

t_empty = Table.empty()
assert s1 != t_empty.schema()


def test_schema_to_name_set():
schema = TABLE.schema()
assert schema.to_name_set() == set(DATA.keys())


def test_repr():
schema = TABLE.schema()
assert (
repr(schema)
== "[('int', DataType(Int64)), ('float', DataType(Float64)), ('string', DataType(Utf8)), ('bool', DataType(Boolean))]"
)


def test_to_col_expr():
schema = TABLE.schema()
schema_col_exprs = schema.to_column_expressions()
expected_col_exprs = [col(n) for n in schema.column_names()]

assert len(schema_col_exprs) == len(expected_col_exprs)
for sce, ece in zip(schema_col_exprs, expected_col_exprs):
assert sce.name() == ece.name()


def test_union_err():
schema = TABLE.schema()
with pytest.raises(ValueError):
schema.union(schema)

new_data = {f"{k}_": d for k, (d, _) in DATA.items()}
new_table = Table.from_pydict(new_data)
unioned_schema = schema.union(new_table.schema())

assert unioned_schema.column_names() == list(DATA.keys()) + list(new_data.keys())
assert list(unioned_schema) == list(schema) + list(new_table.schema())