Skip to content

Commit

Permalink
[rust] Rust Data structures pickling (#716)
Browse files Browse the repository at this point in the history
* Adds pickling methods for DataType, Field, Expr, Series and Table
* adds tests to ensure correctness
  • Loading branch information
samster25 committed Mar 17, 2023
1 parent cf99ea1 commit a31099d
Show file tree
Hide file tree
Showing 13 changed files with 231 additions and 34 deletions.
62 changes: 32 additions & 30 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ version = "0.14.2"
version = "1.3.3"

[dependencies.indexmap]
features = ["serde"]
version = "1.9.2"

[dependencies.num-traits]
Expand Down
3 changes: 3 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ def str(self) -> SeriesStringNamespace:
def dt(self) -> SeriesDateNamespace:
return SeriesDateNamespace.from_series(self)

def __reduce__(self) -> tuple:
return (Series.from_arrow, (self.to_arrow(), self.name()))


SomeSeriesNamespace = TypeVar("SomeSeriesNamespace", bound="SeriesNamespace")

Expand Down
4 changes: 4 additions & 0 deletions daft/table/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,7 @@ def argsort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool
else:
raise TypeError(f"Expected a bool, list[bool] or None for `descending` but got {type(descending)}")
return Series._from_pyseries(self._table.argsort(pyexprs, descending))

def __reduce__(self) -> tuple:
names = self.column_names()
return Table.from_pydict, ({name: self.get_column(name) for name in names},)
32 changes: 31 additions & 1 deletion src/python/datatype.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use crate::datatypes::DataType;
use pyo3::prelude::*;
use pyo3::{
exceptions::PyValueError,
prelude::*,
types::{PyBytes, PyTuple},
};

#[pyclass]
#[derive(Clone)]
Expand All @@ -9,6 +13,18 @@ pub struct PyDataType {

#[pymethods]
impl PyDataType {
#[new]
#[args(args = "*")]
fn new(args: &PyTuple) -> PyResult<Self> {
match args.len() {
0 => Ok(DataType::new_null().into()),
_ => Err(PyValueError::new_err(format!(
"expected no arguments to make new PyDataType, got : {}",
args.len()
))),
}
}

pub fn __repr__(&self) -> PyResult<String> {
Ok(format!("{}", self.dtype))
}
Expand Down Expand Up @@ -96,6 +112,20 @@ impl PyDataType {
Ok(false)
}
}

pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
self.dtype = bincode::deserialize(s.as_bytes()).unwrap();
Ok(())
}
Err(e) => Err(e),
}
}

pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
Ok(PyBytes::new(py, &bincode::serialize(&self.dtype).unwrap()).to_object(py))
}
}

impl From<DataType> for PyDataType {
Expand Down
34 changes: 32 additions & 2 deletions src/python/field.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use pyo3::prelude::*;
use pyo3::{
exceptions::PyValueError,
prelude::*,
types::{PyBytes, PyTuple},
};

use super::datatype::PyDataType;
use crate::datatypes;
use crate::datatypes::{self, DataType, Field};

#[pyclass]
pub struct PyField {
Expand All @@ -10,6 +14,18 @@ pub struct PyField {

#[pymethods]
impl PyField {
#[new]
#[args(args = "*")]
fn new(args: &PyTuple) -> PyResult<Self> {
match args.len() {
0 => Ok(Field::new("null", DataType::new_null()).into()),
_ => Err(PyValueError::new_err(format!(
"expected no arguments to make new PyDataType, got : {}",
args.len()
))),
}
}

pub fn name(&self) -> PyResult<String> {
Ok(self.field.name.clone())
}
Expand All @@ -21,6 +37,20 @@ impl PyField {
pub fn eq(&self, other: &PyField) -> PyResult<bool> {
Ok(self.field.eq(&other.field))
}

pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
self.field = bincode::deserialize(s.as_bytes()).unwrap();
Ok(())
}
Err(e) => Err(e),
}
}

pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
Ok(PyBytes::new(py, &bincode::serialize(&self.field).unwrap()).to_object(py))
}
}

impl From<datatypes::Field> for PyField {
Expand Down
32 changes: 32 additions & 0 deletions src/python/schema.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
use std::sync::Arc;

use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyBytes;
use pyo3::types::PyTuple;

use super::datatype::PyDataType;
use super::field::PyField;
use crate::datatypes;
use crate::schema;
use crate::schema::Schema;

#[pyclass]
pub struct PySchema {
Expand All @@ -14,6 +18,20 @@ pub struct PySchema {

#[pymethods]
impl PySchema {
#[new]
#[args(args = "*")]
fn new(args: &PyTuple) -> PyResult<Self> {
match args.len() {
0 => Ok(Self {
schema: Schema::empty().into(),
}),
_ => Err(PyValueError::new_err(format!(
"expected no arguments to make new PyDataType, got : {}",
args.len()
))),
}
}

pub fn __getitem__(&self, name: &str) -> PyResult<PyField> {
Ok(self.schema.get_field(name)?.clone().into())
}
Expand Down Expand Up @@ -44,6 +62,20 @@ impl PySchema {
schema: schema.into(),
})
}

pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
self.schema = bincode::deserialize(s.as_bytes()).unwrap();
Ok(())
}
Err(e) => Err(e),
}
}

pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
Ok(PyBytes::new(py, &bincode::serialize(&self.schema).unwrap()).to_object(py))
}
}

impl From<schema::SchemaRef> for PySchema {
Expand Down
5 changes: 4 additions & 1 deletion src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::{
};

use indexmap::IndexMap;
use serde::{Deserialize, Serialize};

use crate::{
datatypes::Field,
Expand All @@ -13,8 +14,10 @@ use crate::{

pub type SchemaRef = Arc<Schema>;

#[derive(Debug, PartialEq)]
#[derive(Debug, PartialEq, Serialize, Deserialize)]
#[serde(transparent)]
pub struct Schema {
#[serde(with = "indexmap::serde_seq")]
pub fields: indexmap::IndexMap<String, Field>,
}

Expand Down
Loading

0 comments on commit a31099d

Please sign in to comment.