-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
storage infrastructure,pyarrow to rust
- Loading branch information
Showing
11 changed files
with
1,744 additions
and
40 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
# PyArrow to rust | ||
|
||
related work:https://medium.com/@niklas.molin/0-copy-you-pyarrow-array-to-rust-23b138cb5bf2 | ||
source_code:https://github.com/NiklasMolin/python-rust-arrow | ||
|
||
Now I have implemented data exchange based on Polars and PyO3, based on the bind implementation | ||
|
||
|
||
bind_code: | ||
```rust | ||
use log::debug; | ||
use polars_core::export::rayon::prelude::*; | ||
use polars_core::prelude::*; | ||
use polars_core::utils::accumulate_dataframes_vertical_unchecked; | ||
use polars_core::utils::arrow::ffi; | ||
use polars_core::POOL; | ||
use pyo3::exceptions::PyRuntimeError; | ||
use pyo3::ffi::Py_uintptr_t; | ||
use pyo3::prelude::*; | ||
|
||
pub fn array_to_rust(obj: &PyAny) -> PyResult<ArrayRef> { | ||
// prepare a pointer to receive the Array struct | ||
let array = Box::new(ffi::ArrowArray::empty()); | ||
let schema = Box::new(ffi::ArrowSchema::empty()); | ||
|
||
let array_ptr = &*array as *const ffi::ArrowArray; | ||
let schema_ptr = &*schema as *const ffi::ArrowSchema; | ||
|
||
// make the conversion through PyArrow's private API | ||
// this changes the pointer's memory and is thus unsafe. In particular, `_export_to_c` can go out of bounds | ||
obj.call_method1( | ||
"_export_to_c", | ||
(array_ptr as Py_uintptr_t, schema_ptr as Py_uintptr_t), | ||
)?; | ||
|
||
unsafe { | ||
let field = ffi::import_field_from_c(schema.as_ref()) | ||
.map_err(|err| PyRuntimeError::new_err(format!("{:?}", &err)))?; | ||
let array = ffi::import_array_from_c(*array, field.data_type) | ||
.map_err(|err| PyRuntimeError::new_err(format!("{:?}", &err)))?; | ||
Ok(array) | ||
} | ||
} | ||
|
||
/// | ||
/// copy from `py-polars/src/arrow_interop/to_rust.rs` | ||
/// UseAge: | ||
/// ```rust | ||
/// #[cfg(not(doctest))] | ||
/// #[pyfunction] | ||
/// fn process_pyarrow_table(arrow_table: &PyAny) -> PyResult<()> { | ||
/// let pl_df: DataFrame = pyarrow_to_polars_df(&arrow_table)?; | ||
/// println!("Arrow2 DataFrame: {}", pl_df); | ||
/// Ok(()) | ||
/// } | ||
/// ``` | ||
pub fn pyarrow_to_polars_df(arrow_table: &PyAny) -> PyResult<DataFrame> { | ||
let schema = arrow_table | ||
.getattr("schema") | ||
.expect("pyarrow.Table has no schema"); | ||
debug!("arrow_table schema:{:?}", schema); | ||
let columns = arrow_table | ||
.getattr("columns") | ||
.expect("pyarrow.Table has no columns"); | ||
debug!("arrow_table columns:{:?}", columns); | ||
|
||
let mut rb: Vec<&PyAny> = vec![]; | ||
for item in arrow_table | ||
.call_method0("to_batches") | ||
.expect("pyarrow.Table has no method to_batches") | ||
.iter()? | ||
{ | ||
rb.push(item?); | ||
} | ||
|
||
let names = schema.getattr("names")?.extract::<Vec<String>>()?; | ||
|
||
let dfs = rb | ||
.iter() | ||
.map(|rb| { | ||
let mut run_parallel = false; | ||
|
||
let columns = (0..names.len()) | ||
.map(|i| { | ||
let array = rb.call_method1("column", (i,))?; | ||
let arr = array_to_rust(array)?; | ||
run_parallel |= matches!( | ||
arr.data_type(), | ||
ArrowDataType::Utf8 | ArrowDataType::Dictionary(_, _, _) | ||
); | ||
Ok(arr) | ||
}) | ||
.collect::<PyResult<Vec<_>>>()?; | ||
|
||
// we parallelize this part because we can have dtypes that are not zero copy | ||
// for instance utf8 -> large-utf8 | ||
// dict encoded to categorical | ||
let columns = if run_parallel { | ||
POOL.install(|| { | ||
columns | ||
.into_par_iter() | ||
.enumerate() | ||
.map(|(i, arr)| { | ||
let s = Series::try_from((names[i].as_str(), arr)) | ||
.map_err(|err| PyRuntimeError::new_err(format!("{:?}", &err)))?; | ||
Ok(s) | ||
}) | ||
.collect::<PyResult<Vec<_>>>() | ||
}) | ||
} else { | ||
columns | ||
.into_iter() | ||
.enumerate() | ||
.map(|(i, arr)| { | ||
let s = Series::try_from((names[i].as_str(), arr)) | ||
.map_err(|err| PyRuntimeError::new_err(format!("{:?}", &err)))?; | ||
Ok(s) | ||
}) | ||
.collect::<PyResult<Vec<_>>>() | ||
}?; | ||
|
||
// no need to check as a record batch has the same guarantees | ||
Ok(DataFrame::new_no_checks(columns)) | ||
}) | ||
.collect::<PyResult<Vec<_>>>()?; | ||
|
||
Ok(accumulate_dataframes_vertical_unchecked(dfs)) | ||
} | ||
|
||
``` | ||
|
||
|
||
python_code: | ||
```python | ||
def test_df(self): | ||
import polars as pl | ||
import pyarrow as pa | ||
import pandas as pd | ||
data = { | ||
'Name': ['Alice', 'Bob', 'Charlie'], | ||
'Age': [25, 30, 22], | ||
'City': ['New York', 'San Francisco', 'Seattle'] | ||
} | ||
print(data) | ||
df = pl.DataFrame(data) | ||
print(df) | ||
arrow_table = df.to_arrow() | ||
record_batches = arrow_table.to_batches() | ||
print("") | ||
print("arrow_table") | ||
print("") | ||
print(arrow_table) | ||
print("") | ||
print("schema") | ||
print("") | ||
print(arrow_table.schema) | ||
print("") | ||
print("columns") | ||
print("") | ||
print(arrow_table.columns) | ||
print("") | ||
print("record_batches") | ||
print("") | ||
print(record_batches) | ||
|
||
|
||
print("") | ||
print("#"*100) | ||
print("call in rust") | ||
print("") | ||
|
||
result = process_pyarrow_table(arrow_table) | ||
|
||
|
||
print("") | ||
print("#"*100) | ||
print("") | ||
print(result) | ||
|
||
``` | ||
|
||
![result](https://user-images.githubusercontent.com/34028978/258007062-d14bc0ad-6de7-4439-951c-b60007e79421.png) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,13 @@ | ||
from typing import Any, Union, List, Iterator, Tuple, Dict | ||
import pyarrow as pa | ||
|
||
|
||
def sum_as_string(a: int, b: int) -> str: ... | ||
|
||
|
||
__all__ = ["sum_as_string"] | ||
def process_pyarrow_table(df: pa.Table) -> int: ... | ||
|
||
|
||
__all__ = ["sum_as_string", | ||
"process_pyarrow_table" | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
use clap::Parser; | ||
use env_logger::Env; | ||
use log::Level; | ||
use std::path::PathBuf; | ||
use std::str::FromStr; | ||
|
||
/// Information Multi-Objective and Multi-Directional RRT* System for Path Planning.\n | ||
/// This work reimplemented an anytime iterative system to concurrently solve the multi-objective path planning problem \n | ||
/// and determine the visiting order of destinations using rust-lang | ||
#[derive(Parser, Debug)] | ||
#[command(author, bin_name = "IMOMD_RRTStar", version, about)] | ||
pub struct Command { | ||
/// Path of the config file. eg:config.yaml | ||
#[arg(short, long, value_name = "FILE")] | ||
pub config: Option<PathBuf>, | ||
/// log level | ||
#[arg(short, long,action = clap::ArgAction::Count)] | ||
pub verbose: u8, | ||
} | ||
|
||
impl Command { | ||
pub fn init_log(verbose: u8) { | ||
// The logging level is set through the environment variable RUST_LOG, | ||
// which defaults to the info level | ||
env_logger::Builder::from_env(Env::default().default_filter_or({ | ||
let level_env = match std::env::var("RUST_LOG") { | ||
Ok(val) => val, | ||
Err(_) => "info".to_string(), | ||
}; | ||
println!("RUST_ENV_LOG: {}", level_env); | ||
let level_verbose = match verbose { | ||
0 => "ERROR", | ||
1 => "INFO", | ||
2 => "Debug", | ||
_ => "Trace", | ||
}; | ||
println!("RUST_VERBOSE_LOG: {}", level_verbose); | ||
// Converts a string to the corresponding log-level enumeration | ||
let level = if let (Ok(level1), Ok(level2)) = | ||
(Level::from_str(&level_env), Level::from_str(&level_verbose)) | ||
{ | ||
// Use the cmp method to compare the sizes of the two log levels | ||
match level1.cmp(&level2) { | ||
std::cmp::Ordering::Less => level2.to_string(), | ||
std::cmp::Ordering::Equal => level2.to_string(), | ||
std::cmp::Ordering::Greater => level1.to_string(), | ||
} | ||
} else { | ||
level_env | ||
}; | ||
println!("RUST_LOG: {}", level); | ||
level | ||
})) | ||
.init(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
use serde::{Deserialize, Serialize}; | ||
|
||
#[derive(Debug, Deserialize, Serialize)] | ||
pub struct Config { | ||
pub name: String, | ||
} | ||
|
||
impl Config { | ||
|
||
} | ||
|
||
|
||
|
||
// // parse config from yaml | ||
// fn parse_config() -> Config { | ||
// None | ||
// } | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,31 @@ | ||
use pyo3::prelude::*; | ||
mod command; | ||
pub mod config; | ||
pub mod prelude; | ||
mod storage; | ||
|
||
/// native rust function | ||
pub fn sum(a: usize, b: usize) -> String { | ||
println!("call in rust {:?} {:?}", a, b); | ||
(a + b).to_string() | ||
} | ||
use log::debug; | ||
use polars_core::prelude::DataFrame; | ||
use prelude::*; | ||
use pyo3::prelude::*; | ||
|
||
/// Formats the sum of two numbers as string. | ||
#[pyfunction] | ||
fn sum_as_string(a: usize, b: usize) -> PyResult<String> { | ||
Ok(sum(a, b)) | ||
} | ||
|
||
// 定义一个Python绑定函数,该函数将接受PyArrow表作为参数,并将其传递给Rust的process_arrow_table函数 | ||
#[pyfunction] | ||
fn process_pyarrow_table(arrow_table: &PyAny) -> PyResult<()> { | ||
let pl_df: DataFrame = storage::pyarrow_to_polars_df(&arrow_table)?; | ||
println!("Arrow2 DataFrame: {}", pl_df); | ||
Ok(()) | ||
} | ||
|
||
/// A Python module implemented in Rust. | ||
#[pymodule] | ||
fn IMOMD_RRTStar(_py: Python, m: &PyModule) -> PyResult<()> { | ||
m.add_function(wrap_pyfunction!(sum_as_string, m)?)?; | ||
m.add_function(wrap_pyfunction!(process_pyarrow_table, m)?)?; | ||
Ok(()) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
use clap::Parser; | ||
use log::info; | ||
use std::env; | ||
use IMOMD_RRTStar::prelude::Command; | ||
|
||
fn main() { | ||
let args = Command::parse(); | ||
Command::init_log(args.verbose); | ||
|
||
if let Some(config_path) = args.config.as_deref() { | ||
println!("Value for config: {}", config_path.display()); | ||
} | ||
|
||
let cargo_path = env!("CARGO_MANIFEST_DIR"); | ||
let exe_path = env::current_exe().unwrap(); | ||
let exe_dir = exe_path.parent().unwrap(); | ||
let work_dir = env::current_dir().unwrap(); | ||
info!("cargo_path: {:?}", cargo_path); | ||
info!("exe_path: {:?}", exe_path); | ||
info!("exe_dir: {:?}", exe_dir); | ||
info!("work_dir: {:?}", work_dir); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
pub use crate::command::Command; | ||
use crate::config::Config; | ||
/// native rust function | ||
pub fn sum(a: usize, b: usize) -> String { | ||
println!("call in rust {:?} {:?}", a, b); | ||
(a + b).to_string() | ||
} |
Oops, something went wrong.