diff --git a/README.md b/README.md index 26e4569..f1840bd 100644 --- a/README.md +++ b/README.md @@ -2,37 +2,81 @@ Provides an API on top of [`arrow2`](https://github.com/jorgecarleitao/arrow2) to convert between rust types and Arrow. -The Arrow ecosystem provides many ways to convert between Arrow and other popular formats across several languages. This project aims to serve the need for rust-centric data pipelines to easily convert to/from Arrow via an auto-generated compile-time schema. +The Arrow ecosystem provides many ways to convert between Arrow and other popular formats across several languages. This project aims to serve the need for rust-centric data pipelines to easily convert to/from Arrow with strong typing and arbitrary nesting. + +## Example + +The example below performs a round trip conversion of a struct with a single field. + +Please see the [complex_example.rs](./arrow2_convert/tests/complex_example.rs) for usage of the full functionality. + +```rust +/// Simple example + +use arrow2::array::Array; +use arrow2_convert::{deserialize::TryIntoCollection, serialize::TryIntoArrow, ArrowField}; + +#[derive(Debug, Clone, PartialEq, ArrowField)] +pub struct Foo { + name: String, +} + +#[test] +fn test_simple_roundtrip() { + // an item + let original_array = [ + Foo { name: "hello".to_string() }, + Foo { name: "one more".to_string() }, + Foo { name: "good bye".to_string() }, + ]; + + // serialize to an arrow array. try_into_arrow() is enabled by the TryIntoArrow trait + let arrow_array: Box = original_array.try_into_arrow().unwrap(); + + // which can be cast to an Arrow StructArray and be used for all kinds of IPC, FFI, etc. + // supported by `arrow2` + let struct_array= arrow_array.as_any().downcast_ref::().unwrap(); + assert_eq!(struct_array.len(), 3); + + // deserialize back to our original vector via TryIntoCollection trait. + let round_trip_array: Vec = arrow_array.try_into_collection().unwrap(); + assert_eq!(round_trip_array, original_array); +} +``` ## API -Types that implement the `ArrowField`, `ArrowSerialize` and `ArrowDeserialize` traits can be converted to/from Arrow. The `ArrowField` implementation for a type defines the Arrow schema. The `ArrowSerialize` and `ArrowDeserialize` implementations provide the conversion logic via arrow2's data structures. +Types that implement the `ArrowField`, `ArrowSerialize` and `ArrowDeserialize` traits can be converted to/from Arrow via the `try_into_arrow` and the `try_into_collection` methods. + +The `ArrowField` derive macro can be used to generate implementations of these traits for structs and enums. Custom implementations can also be defined for any type that needs to convert to/from Arrow by manually implementing the traits. For serializing to arrow, `TryIntoArrow::try_into_arrow` can be used to serialize any iterable into an `arrow2::Array` or a `arrow2::Chunk`. `arrow2::Array` represents the in-memory Arrow layout. `arrow2::Chunk` represents a column group and can be used with `arrow2` API for other functionality such converting to parquet and arrow flight RPC. For deserializing from arrow, the `TryIntoCollection::try_into_collection` can be used to deserialize from an `arrow2::Array` representation into any container that implements `FromIterator`. -## Features - -- A derive macro, `ArrowField`, can generate implementations of the above traits for structures. Support for enums is in progress. -- Implementations are provided for Arrow primitives - - Numeric types - - [`u8`], [`u16`], [`u32`], [`u64`], [`i8`], [`i16`], [`i32`], [`i64`], [`f32`], [`f64`] - - [`i128`] is supported via the `override` attribute. Please see the [i128 section](#i128) for more details. - - Other types: - - [`bool`], [`String`], [`Binary`] - - Temporal types: - - [`chrono::NaiveDate`], [`chrono::NaiveDateTime`] -- Blanket implementations are provided for types that implement the above traits: - - Option - - Vec -- Large Arrow types [`LargeBinary`], [`LargeString`], [`LargeList`] are supported via the `override` attribute. Please see the [complex_example.rs](./arrow2_convert/tests/complex_example.rs) for usage. +### Default implementations + +Default implementations of the above traits are provided for the following: + +- Numeric types + - [`u8`], [`u16`], [`u32`], [`u64`], [`i8`], [`i16`], [`i32`], [`i64`], [`f32`], [`f64`] + - [`i128`] is supported via the `type` attribute. Please see the [i128 section](#i128) for more details. +- Other types: + - [`bool`], [`String`], [`Binary`] +- Temporal types: + - [`chrono::NaiveDate`], [`chrono::NaiveDateTime`] +- Option if T implements `ArrowField` +- Vec if T implements `ArrowField` +- Large Arrow types [`LargeBinary`], [`LargeString`], [`LargeList`] are supported via the `type` attribute. Please see the [complex_example.rs](./arrow2_convert/tests/complex_example.rs) for usage. - Fixed size types [`FixedSizeBinary`], [`FixedSizeList`] are supported via the `FixedSizeVec` type override. - Note: nesting of [`FixedSizeList`] is not supported. -- Scalars and enums are in progress -- Support for generics, slices and reference is currently missing. -This is not an exhaustive list. Please open an issue if you need a feature. +### Enums + +Enums are still an experimental feature and need to be integrated tested. Rust enum arrays are converted to a `Arrow::UnionArray`. Some additional notes on enums: + +- Rust unit variants are represented using as the `bool` data type. +- Enum slices currently [don't deserialize correctly](https://github.com/DataEngineeringLabs/arrow2-convert/issues/53). ### i128 @@ -46,7 +90,7 @@ use arrow2_convert::ArrowField; #[derive(Debug, ArrowField)] struct S { - #[arrow_field(override = "I128<32, 32>")] + #[arrow_field(type = "I128<32, 32>")] field: i128, } ``` @@ -72,51 +116,16 @@ fn convert_i128() { ``` ### Nested Option Types -Since the Arrow format only supports one level of validity, nested option types such as `Option>` after serialization to Arrow will lose intermediate nesting of None values. For example, `Some(None)` will be serialized to `None`, - -## Memory - -Pass-thru conversions perform a single memory copy. Deserialization performs a copy from arrow2 to the destination. Serialization performs a copy from the source to arrow2. In-place deserialization is theoretically possible but currently not supported. - -## Example - -Below is a bare-bones example that does a round trip conversion of a struct with a single field. - -Please see the [complex_example.rs](./arrow2_convert/tests/complex_example.rs) for usage of the full functionality. +Since the Arrow format only supports one level of validity, nested option types such as `Option>`, after serialization to Arrow, will lose any intermediate nesting of None values. For example, `Some(None)` will be serialized to `None`, -```rust -/// Simple example +### Missing Features -use arrow2::array::Array; -use arrow2_convert::{deserialize::TryIntoCollection, serialize::TryIntoArrow, ArrowField}; - -#[derive(Debug, Clone, PartialEq, ArrowField)] -pub struct Foo { - name: String, -} - -#[test] -fn test_simple_roundtrip() { - // an item - let original_array = [ - Foo { name: "hello".to_string() }, - Foo { name: "one more".to_string() }, - Foo { name: "good bye".to_string() }, - ]; - - // serialize to an arrow array. try_into_arrow() is enabled by the TryIntoArrow trait - let arrow_array: Box = original_array.try_into_arrow().unwrap(); +- Support for generics, slices and reference is currently missing. - // which can be cast to an Arrow StructArray and be used for all kinds of IPC, FFI, etc. - // supported by `arrow2` - let struct_array= arrow_array.as_any().downcast_ref::().unwrap(); - assert_eq!(struct_array.len(), 3); +This is not an exhaustive list. Please open an issue if you need a feature. +## Memory - // deserialize back to our original vector via TryIntoCollection trait. - let round_trip_array: Vec = arrow_array.try_into_collection().unwrap(); - assert_eq!(round_trip_array, original_array); -} -``` +Pass-thru conversions perform a single memory copy. Deserialization performs a copy from arrow2 to the destination. Serialization performs a copy from the source to arrow2. In-place deserialization is theoretically possible but currently not supported. ## Internals @@ -128,7 +137,6 @@ However unlike serde's traits provide an exhaustive and flexible mapping to the Specifically, the `ArrowSerialize` trait provides the logic to serialize a type to the corresponding `arrow2::array::MutableArray`. The `ArrowDeserialize` trait deserializes a type from the corresponding `arrow2::array::ArrowArray`. - ### Workarounds Features such as partial implementation specialization and generic associated types (currently only available in nightly builds) can greatly simplify the underlying implementation. diff --git a/arrow2_convert/tests/complex_example.rs b/arrow2_convert/tests/complex_example.rs index 844b9ad..77fae51 100644 --- a/arrow2_convert/tests/complex_example.rs +++ b/arrow2_convert/tests/complex_example.rs @@ -1,6 +1,5 @@ use arrow2::array::*; use arrow2_convert::deserialize::{arrow_array_deserialize_iterator, TryIntoCollection}; -use arrow2_convert::field::{FixedSizeBinary, FixedSizeVec, LargeBinary, LargeString, LargeVec}; use arrow2_convert::serialize::TryIntoArrow; /// Complex example that uses the following features: /// @@ -40,19 +39,19 @@ pub struct Root { // int 32 array int32_array: Vec, // large binary - #[arrow_field(override = "LargeBinary")] + #[arrow_field(type = "arrow2_convert::field::LargeBinary")] large_binary: Vec, // fixed size binary - #[arrow_field(override = "FixedSizeBinary<3>")] + #[arrow_field(type = "arrow2_convert::field::FixedSizeBinary<3>")] fixed_size_binary: Vec, // large string - #[arrow_field(override = "LargeString")] + #[arrow_field(type = "arrow2_convert::field::LargeString")] large_string: String, // large vec - #[arrow_field(override = "LargeVec")] + #[arrow_field(type = "arrow2_convert::field::LargeVec")] large_vec: Vec, // fixed size vec - #[arrow_field(override = "FixedSizeVec")] + #[arrow_field(type = "arrow2_convert::field::FixedSizeVec")] fixed_size_vec: Vec, } diff --git a/arrow2_convert/tests/test_deserialize.rs b/arrow2_convert/tests/test_deserialize.rs index 1b7c704..7e32b53 100644 --- a/arrow2_convert/tests/test_deserialize.rs +++ b/arrow2_convert/tests/test_deserialize.rs @@ -1,7 +1,6 @@ use arrow2::array::*; use arrow2::error::Result; use arrow2_convert::deserialize::*; -use arrow2_convert::field::LargeString; use arrow2_convert::serialize::*; use arrow2_convert::ArrowField; @@ -63,7 +62,7 @@ fn test_deserialize_large_types_schema_mismatch_error() { } #[derive(Debug, Clone, PartialEq, ArrowField)] struct S2 { - #[arrow_field(override = "LargeString")] + #[arrow_field(type = "arrow2_convert::field::LargeString")] a: String, } diff --git a/arrow2_convert/tests/test_enum.rs b/arrow2_convert/tests/test_enum.rs new file mode 100644 index 0000000..5749f93 --- /dev/null +++ b/arrow2_convert/tests/test_enum.rs @@ -0,0 +1,130 @@ +use arrow2::array::*; +use arrow2_convert::{deserialize::TryIntoCollection, serialize::TryIntoArrow, ArrowField}; + +#[test] +fn test_dense_enum_unit_variant() { + #[derive(Debug, PartialEq, ArrowField)] + #[arrow_field(type = "dense")] + enum TestEnum { + VAL1, + VAL2, + VAL3, + VAL4, + } + + let enums = vec![ + TestEnum::VAL1, + TestEnum::VAL2, + TestEnum::VAL3, + TestEnum::VAL4, + ]; + let b: Box = enums.try_into_arrow().unwrap(); + let round_trip: Vec = b.try_into_collection().unwrap(); + assert_eq!(round_trip, enums); +} + +#[test] +fn test_sparse_enum_unit_variant() { + #[derive(Debug, PartialEq, ArrowField)] + #[arrow_field(type = "sparse")] + enum TestEnum { + VAL1, + VAL2, + VAL3, + VAL4, + } + + let enums = vec![ + TestEnum::VAL1, + TestEnum::VAL2, + TestEnum::VAL3, + TestEnum::VAL4, + ]; + let b: Box = enums.try_into_arrow().unwrap(); + let round_trip: Vec = b.try_into_collection().unwrap(); + assert_eq!(round_trip, enums); +} + +#[test] +fn test_nested_unit_variant() { + #[derive(Debug, PartialEq, ArrowField)] + struct TestStruct { + a1: i64, + } + + #[derive(Debug, PartialEq, ArrowField)] + #[arrow_field(type = "dense")] + enum TestEnum { + VAL1, + VAL2(i32), + VAL3(f64), + VAL4(TestStruct), + VAL5(ChildEnum), + } + + #[derive(Debug, PartialEq, ArrowField)] + #[arrow_field(type = "sparse")] + enum ChildEnum { + VAL1, + VAL2(i32), + VAL3(f64), + VAL4(TestStruct), + } + + let enums = vec![ + TestEnum::VAL1, + TestEnum::VAL2(2), + TestEnum::VAL3(1.2), + TestEnum::VAL4(TestStruct { a1: 10 }), + ]; + + let b: Box = enums.try_into_arrow().unwrap(); + let round_trip: Vec = b.try_into_collection().unwrap(); + assert_eq!(round_trip, enums); +} + +// TODO: reenable this test once slices for enums is fixed. +//#[test] +#[allow(unused)] +fn test_slice() { + #[derive(Debug, PartialEq, ArrowField)] + struct TestStruct { + a1: i64, + } + + #[derive(Debug, PartialEq, ArrowField)] + #[arrow_field(type = "dense")] + enum TestEnum { + VAL1, + VAL2(i32), + VAL3(f64), + VAL4(TestStruct), + VAL5(ChildEnum), + } + + #[derive(Debug, PartialEq, ArrowField)] + #[arrow_field(type = "sparse")] + enum ChildEnum { + VAL1, + VAL2(i32), + VAL3(f64), + VAL4(TestStruct), + } + + let enums = vec![ + TestEnum::VAL4(TestStruct { a1: 11 }), + TestEnum::VAL1, + TestEnum::VAL2(2), + TestEnum::VAL3(1.2), + TestEnum::VAL4(TestStruct { a1: 10 }), + ]; + + let b: Box = enums.try_into_arrow().unwrap(); + + for i in 0..enums.len() { + let arrow_slice = b.slice(i, enums.len() - i); + let original_slice = &enums[i..enums.len()]; + let round_trip: Vec = arrow_slice.try_into_collection().unwrap(); + assert_eq!(round_trip, original_slice); + } +} diff --git a/arrow2_convert/tests/test_schema.rs b/arrow2_convert/tests/test_schema.rs index 3cef4b4..e4c1349 100644 --- a/arrow2_convert/tests/test_schema.rs +++ b/arrow2_convert/tests/test_schema.rs @@ -1,7 +1,4 @@ use arrow2::datatypes::*; -use arrow2_convert::field::{ - FixedSizeBinary, FixedSizeVec, LargeBinary, LargeString, LargeVec, I128, -}; use arrow2_convert::ArrowField; #[test] @@ -21,7 +18,7 @@ fn test_schema_types() { // timestamp(ns, None) a6: Option, // i128(precision, scale) - #[arrow_field(override = "I128<32, 32>")] + #[arrow_field(type = "arrow2_convert::field::I128<32, 32>")] a7: i128, // array of date times date_time_list: Vec, @@ -40,19 +37,19 @@ fn test_schema_types() { // int 32 array int32_array: Vec, // large binary - #[arrow_field(override = "LargeBinary")] + #[arrow_field(type = "arrow2_convert::field::LargeBinary")] large_binary: Vec, // fixed size binary - #[arrow_field(override = "FixedSizeBinary<3>")] + #[arrow_field(type = "arrow2_convert::field::FixedSizeBinary<3>")] fixed_size_binary: Vec, // large string - #[arrow_field(override = "LargeString")] + #[arrow_field(type = "arrow2_convert::field::LargeString")] large_string: String, // large vec - #[arrow_field(override = "LargeVec")] + #[arrow_field(type = "arrow2_convert::field::LargeVec")] large_vec: Vec, // fixed size vec - #[arrow_field(override = "FixedSizeVec")] + #[arrow_field(type = "arrow2_convert::field::FixedSizeVec")] fixed_size_vec: Vec, } diff --git a/arrow2_convert/tests/test_struct.rs b/arrow2_convert/tests/test_struct.rs new file mode 100644 index 0000000..c7aa06d --- /dev/null +++ b/arrow2_convert/tests/test_struct.rs @@ -0,0 +1,94 @@ +use arrow2::array::*; +use arrow2_convert::deserialize::*; +use arrow2_convert::serialize::*; +use arrow2_convert::ArrowField; + +#[test] +fn test_nested_optional_struct_array() { + #[derive(Debug, Clone, ArrowField, PartialEq)] + struct Top { + child_array: Vec>, + } + #[derive(Debug, Clone, ArrowField, PartialEq)] + struct Child { + a1: i64, + } + + let original_array = vec![ + Top { + child_array: vec![ + Some(Child { a1: 10 }), + None, + Some(Child { a1: 12 }), + Some(Child { a1: 14 }), + ], + }, + Top { + child_array: vec![None, None, None, None], + }, + Top { + child_array: vec![None, None, Some(Child { a1: 12 }), None], + }, + ]; + + let b: Box = original_array.try_into_arrow().unwrap(); + let round_trip: Vec = b.try_into_collection().unwrap(); + assert_eq!(original_array, round_trip); +} + +#[test] +fn test_slice() { + #[derive(Debug, Clone, ArrowField, PartialEq)] + struct T { + a1: i64, + } + + let original = vec![T { a1: 1 }, T { a1: 2 }, T { a1: 3 }, T { a1: 4 }]; + + let b: Box = original.try_into_arrow().unwrap(); + + for i in 0..original.len() { + let arrow_slice = b.slice(i, original.len() - i); + let original_slice = &original[i..original.len()]; + let round_trip: Vec = arrow_slice.try_into_collection().unwrap(); + assert_eq!(round_trip, original_slice); + } +} + +#[test] +fn test_nested_slice() { + #[derive(Debug, Clone, ArrowField, PartialEq)] + struct Top { + child_array: Vec>, + } + #[derive(Debug, Clone, ArrowField, PartialEq)] + struct Child { + a1: i64, + } + + let original = vec![ + Top { + child_array: vec![ + Some(Child { a1: 10 }), + None, + Some(Child { a1: 12 }), + Some(Child { a1: 14 }), + ], + }, + Top { + child_array: vec![None, None, None, None], + }, + Top { + child_array: vec![None, None, Some(Child { a1: 12 }), None], + }, + ]; + + let b: Box = original.try_into_arrow().unwrap(); + + for i in 0..original.len() { + let arrow_slice = b.slice(i, original.len() - i); + let original_slice = &original[i..original.len()]; + let round_trip: Vec = arrow_slice.try_into_collection().unwrap(); + assert_eq!(round_trip, original_slice); + } +} diff --git a/arrow2_convert/tests/ui/struct_incorrect_override.rs b/arrow2_convert/tests/ui/struct_incorrect_type.rs similarity index 77% rename from arrow2_convert/tests/ui/struct_incorrect_override.rs rename to arrow2_convert/tests/ui/struct_incorrect_type.rs index 7fbebe8..b1b0b55 100644 --- a/arrow2_convert/tests/ui/struct_incorrect_override.rs +++ b/arrow2_convert/tests/ui/struct_incorrect_type.rs @@ -3,7 +3,7 @@ use arrow2_convert::field::LargeBinary; #[derive(Debug, ArrowField)] struct Test { - #[arrow_field(override="LargeBinary")] + #[arrow_field(type="LargeBinary")] s: String } diff --git a/arrow2_convert/tests/ui/struct_incorrect_override.stderr b/arrow2_convert/tests/ui/struct_incorrect_type.stderr similarity index 88% rename from arrow2_convert/tests/ui/struct_incorrect_override.stderr rename to arrow2_convert/tests/ui/struct_incorrect_type.stderr index ab04400..802aa6e 100644 --- a/arrow2_convert/tests/ui/struct_incorrect_override.stderr +++ b/arrow2_convert/tests/ui/struct_incorrect_type.stderr @@ -1,5 +1,5 @@ error[E0277]: the trait bound `String: Borrow>` is not satisfied - --> tests/ui/struct_incorrect_override.rs:4:17 + --> tests/ui/struct_incorrect_type.rs:4:17 | 4 | #[derive(Debug, ArrowField)] | ^^^^^^^^^^ the trait `Borrow>` is not implemented for `String` @@ -9,7 +9,7 @@ error[E0277]: the trait bound `String: Borrow>` is not satisfied = note: this error originates in the derive macro `ArrowField` (in Nightly builds, run with -Z macro-backtrace for more info) error[E0308]: mismatched types - --> tests/ui/struct_incorrect_override.rs:4:17 + --> tests/ui/struct_incorrect_type.rs:4:17 | 4 | #[derive(Debug, ArrowField)] | ^^^^^^^^^^ expected struct `String`, found struct `Vec` diff --git a/arrow2_convert_derive/src/attr.rs b/arrow2_convert_derive/src/attr.rs new file mode 100644 index 0000000..9786020 --- /dev/null +++ b/arrow2_convert_derive/src/attr.rs @@ -0,0 +1,73 @@ +use proc_macro_error::{abort, ResultExt}; +use syn::{Lit, Meta, MetaNameValue}; +use syn::spanned::Spanned; + +pub const ARROW_FIELD: &'static str = "arrow_field"; +pub const FIELD_TYPE: &'static str = "type"; +pub const UNION_MODE: &'static str = "mode"; + +pub fn field_type(field: &syn::Field) -> syn::Type { + for attr in &field.attrs { + if let Ok(meta) = attr.parse_meta() { + if meta.path().is_ident(ARROW_FIELD) { + if let Meta::List(list) = meta { + for nested in list.nested { + if let syn::NestedMeta::Meta(meta) = nested { + match meta { + Meta::NameValue(MetaNameValue { + lit: Lit::Str(string), + path, + .. + }) => { + if path.is_ident(FIELD_TYPE) { + return syn::parse_str(&string.value()).unwrap_or_abort() + } + }, + _ => { + abort!(meta.span(), "Unexpected attribute"); + } + } + } + } + } + } + } + } + + field.ty.clone() +} + +pub fn union_type(input: &syn::DeriveInput) -> bool { + for attr in &input.attrs { + if let Ok(meta) = attr.parse_meta() { + if meta.path().is_ident(ARROW_FIELD) { + if let Meta::List(list) = meta { + for nested in list.nested { + if let syn::NestedMeta::Meta(meta) = nested { + match meta { + Meta::NameValue(MetaNameValue { + lit: Lit::Str(string), + path, + .. + }) => { + if path.is_ident(UNION_MODE) { + match string.value().as_ref() { + "sparse" => { return false; }, + "dense" => { return true; }, + _ => { abort!(path.span(), "Unexpected value for mode") } + } + } + }, + _ => { + abort!(meta.span(), "Unexpected attribute"); + } + } + } + } + } + } + } + } + + abort!(input.span(), "Missing mode attribute for enum"); +} \ No newline at end of file diff --git a/arrow2_convert_derive/src/derive_enum.rs b/arrow2_convert_derive/src/derive_enum.rs new file mode 100644 index 0000000..8a2b1fd --- /dev/null +++ b/arrow2_convert_derive/src/derive_enum.rs @@ -0,0 +1,539 @@ +use proc_macro2::TokenStream; +use proc_macro_error::abort; +use quote::{quote, quote_spanned}; +use syn::spanned::Spanned; + +use super::input::*; + +pub fn expand(input: DeriveEnum) -> TokenStream { + let original_name = &input.common.name; + let original_name_str = format!("{}", original_name); + let visibility = &input.common.visibility; + let is_dense = input.is_dense; + let variants = &input.variants; + + let union_type = if is_dense { + quote!(arrow2::datatypes::UnionMode::Dense) + } else { + quote!(arrow2::datatypes::UnionMode::Sparse) + }; + + let (gen_serialize, gen_deserialize) = input.common.traits_to_derive.to_flags(); + + let variant_names = variants + .iter() + .map(|v| v.syn.ident.clone()) + .collect::>(); + + if variant_names.is_empty() { + abort!( + original_name.span(), + "Expected enum to have more than one field" + ); + } + + let first_variant = &variant_names[0]; + + let variant_names_str = variant_names + .iter() + .map(|v| syn::LitStr::new(&format!("{}", v), proc_macro2::Span::call_site())) + .collect::>(); + + let variant_indices = variant_names + .iter() + .enumerate() + .map(|(idx, _ident)| syn::LitInt::new(&format!("{}", idx), proc_macro2::Span::call_site())) + .collect::>(); + + let variant_types: Vec<&syn::TypePath> = variants + .iter() + .map(|v| match &v.field_type { + syn::Type::Path(path) => path, + _ => panic!("Only types are supported atm"), + }) + .collect::>(); + + let mut generated = quote! { + impl arrow2_convert::field::ArrowField for #original_name { + type Type = Self; + + fn data_type() -> arrow2::datatypes::DataType { + arrow2::datatypes::DataType::Union( + vec![ + #( + <#variant_types as arrow2_convert::field::ArrowField>::field(#variant_names_str), + )* + ], + None, + #union_type, + ) + } + } + + arrow2_convert::arrow_enable_vec_for_type!(#original_name); + }; + + if gen_serialize { + let mutable_array_name = &input.common.mutable_array_name(); + let mutable_variant_array_types = variant_types + .iter() + .map(|field_type| quote_spanned!( field_type.span() => <#field_type as arrow2_convert::serialize::ArrowSerialize>::MutableArrayType)) + .collect::>(); + + let (offsets_decl, offsets_init, offsets_reserve, offsets_take, offsets_shrink_to_fit) = + if is_dense { + ( + quote! { offsets: Vec, }, + quote! { offsets: vec![], }, + quote! { self.offsets.reserve(additional); }, + quote! { Some(std::mem::take(&mut self.offsets).into()), }, + quote! { self.offsets.shrink_to_fit(); }, + ) + } else { + (quote! {}, quote! {}, quote! {}, quote! {None}, quote! {}) + }; + + let try_push_match_blocks = variants + .iter() + .enumerate() + .zip(&variant_indices) + .zip(&variant_types) + .map(|(((idx, v), lit_idx), variant_type)| { + let name = &v.syn.ident; + // - For dense unions, update the mutable array of the matched variant and also the offset. + // - For sparse unions, update the mutable array of the matched variant, and push null for all + // the other variants. This unfortunately results in some large code blocks per match arm. + // There might be a better way of doing this. + if is_dense { + let update_offset = quote! { + self.types.push(#lit_idx); + self.offsets.push((self.#name.len() - 1) as i32); + }; + if v.is_unit { + quote! { + #original_name::#name => { + <#variant_type as arrow2_convert::serialize::ArrowSerialize>::arrow_serialize(&true, &mut self.#name)?; + #update_offset + } + } + } + else { + quote! { + #original_name::#name(v) => { + <#variant_type as arrow2_convert::serialize::ArrowSerialize>::arrow_serialize(v, &mut self.#name)?; + #update_offset + } + } + } + } + else { + let push_none = variants + .iter() + .enumerate() + .zip(&variant_types) + .map(|((nested_idx,y), variant_type)| { + let name = &y.syn.ident; + if nested_idx != idx { + quote! { + <<#variant_type as arrow2_convert::serialize::ArrowSerialize>::MutableArrayType as MutableArray>::push_null(&mut self.#name); + } + } + else { + quote!{} + } + }) + .collect::>(); + + let update_offset = quote! { + self.types.push(#lit_idx); + }; + + if v.is_unit { + quote! { + #original_name::#name => { + <#variant_type as arrow2_convert::serialize::ArrowSerialize>::arrow_serialize(&true, &mut self.#name)?; + #( + #push_none + )* + #update_offset + } + } + } + else { + quote! { + #original_name::#name(v) => { + <#variant_type as arrow2_convert::serialize::ArrowSerialize>::arrow_serialize(v, &mut self.#name)?; + #( + #push_none + )* + #update_offset + } + } + } + } + }) + .collect::>(); + + let try_push_none = if is_dense { + let first_array_type = &mutable_variant_array_types[0]; + let first_name = &variant_names[0]; + quote! { + self.types.push(0); + <#first_array_type as MutableArray>::push_null(&mut self.#first_name); + } + } else { + quote! { + self.types.push(0); + #( + <#mutable_variant_array_types as MutableArray>::push_null(&mut self.#variant_names); + )* + } + }; + + let array_decl = quote! { + #[allow(non_snake_case)] + #[derive(Debug)] + #visibility struct #mutable_array_name { + #( + #variant_names: #mutable_variant_array_types, + )* + data_type: arrow2::datatypes::DataType, + types: Vec, + #offsets_decl + } + }; + + let array_impl = quote! { + impl #mutable_array_name { + pub fn new() -> Self { + Self { + #(#variant_names: <#variant_types as arrow2_convert::serialize::ArrowSerialize>::new_array(),)* + data_type: <#original_name as arrow2_convert::field::ArrowField>::data_type(), + types: vec![], + #offsets_init + } + } + } + }; + + let array_arrow_mutable_array_impl = quote! { + impl arrow2_convert::serialize::ArrowMutableArray for #mutable_array_name { + fn reserve(&mut self, additional: usize, _additional_values: usize) { + #(<<#variant_types as arrow2_convert::serialize::ArrowSerialize>::MutableArrayType as arrow2_convert::serialize::ArrowMutableArray>::reserve(&mut self.#variant_names, additional, _additional_values);)* + self.types.reserve(additional); + #offsets_reserve + } + } + }; + + let array_try_push_impl = quote! { + impl<__T: std::borrow::Borrow<#original_name>> arrow2::array::TryPush> for #mutable_array_name { + fn try_push(&mut self, item: Option<__T>) -> arrow2::error::Result<()> { + use arrow2::array::MutableArray; + + match item { + Some(i) => { + match i.borrow() { + #( + #try_push_match_blocks + )* + } + }, + None => { + #try_push_none + } + } + Ok(()) + } + } + }; + + let array_default_impl = quote! { + impl Default for #mutable_array_name { + fn default() -> Self { + Self::new() + } + } + }; + + let array_try_extend_impl = quote! { + impl<__T: std::borrow::Borrow<#original_name>> arrow2::array::TryExtend> for #mutable_array_name { + fn try_extend>>(&mut self, iter: I) -> arrow2::error::Result<()> { + use arrow2::array::TryPush; + for i in iter { + self.try_push(i)?; + } + Ok(()) + } + } + }; + + let array_mutable_array_impl = quote! { + impl arrow2::array::MutableArray for #mutable_array_name { + fn data_type(&self) -> &arrow2::datatypes::DataType { + &self.data_type + } + + fn len(&self) -> usize { + self.#first_variant.len() + } + + fn validity(&self) -> Option<&arrow2::bitmap::MutableBitmap> { + None + } + + fn as_box(&mut self) -> Box { + let values = vec![#( + <#mutable_variant_array_types as arrow2::array::MutableArray>::as_arc(&mut self.#variant_names), + )*]; + + Box::new(arrow2::array::UnionArray::from_data( + <#original_name as arrow2_convert::field::ArrowField>::data_type().clone(), + std::mem::take(&mut self.types).into(), + values, + #offsets_take + )) + } + + fn as_arc(&mut self) -> std::sync::Arc { + let values = vec![#( + <#mutable_variant_array_types as arrow2::array::MutableArray>::as_arc(&mut self.#variant_names), + )*]; + + std::sync::Arc::new(arrow2::array::UnionArray::from_data( + <#original_name as arrow2_convert::field::ArrowField>::data_type().clone(), + std::mem::take(&mut self.types).into(), + values, + #offsets_take + )) + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + fn push_null(&mut self) { + use arrow2::array::TryPush; + self.try_push(None::<#original_name>).unwrap(); + } + + fn shrink_to_fit(&mut self) { + #( + <#mutable_variant_array_types as arrow2::array::MutableArray>::shrink_to_fit(&mut self.#variant_names); + )* + self.types.shrink_to_fit(); + #offsets_shrink_to_fit + } + } + }; + + let field_arrow_serialize_impl = quote! { + impl arrow2_convert::serialize::ArrowSerialize for #original_name { + type MutableArrayType = #mutable_array_name; + + #[inline] + fn new_array() -> Self::MutableArrayType { + Self::MutableArrayType::default() + } + + #[inline] + fn arrow_serialize(v: &Self, array: &mut Self::MutableArrayType) -> arrow2::error::Result<()> { + use arrow2::array::TryPush; + array.try_push(Some(v)) + } + } + }; + + generated.extend([ + array_decl, + array_impl, + array_arrow_mutable_array_impl, + array_try_push_impl, + array_default_impl, + array_try_extend_impl, + array_mutable_array_impl, + field_arrow_serialize_impl, + ]) + } + + if gen_deserialize { + let array_name = &input.common.array_name(); + let iterator_name = &input.common.iterator_name(); + + // - For dense unions, return the value of the variant that corresponds to the matched arm. Since + // deserialization is sequential rather than via random access, the offset is not used even + // for dense unions. + // - For sparse unions, return the value of the variant that corresponds to the matched arm, and + // consume the iterators of the rest of the variants. + let iter_next_match_block = if is_dense { + let candidates = variants.iter() + .zip(&variant_indices) + .zip(&variant_types) + .map(|((v, lit_idx), variant_type)| { + let name = &v.syn.ident; + if v.is_unit { + quote! { + #lit_idx => { + let v = self.#name.next() + .unwrap_or_else(|| panic!("Invalid offset for {}", #original_name_str)); + assert!(v.unwrap()); + Some(Some(#original_name::#name)) + } + } + } + else { + quote! { + #lit_idx => { + let v = self.#name.next() + .unwrap_or_else(|| panic!("Invalid offset for {}", #original_name_str)); + Some(<#variant_type as arrow2_convert::deserialize::ArrowDeserialize>::arrow_deserialize(v).map(|v| #original_name::#name(v))) + } + } + } + }) + .collect::>(); + quote! { #(#candidates)* } + } else { + let candidates = variants.iter() + .enumerate() + .zip(variant_indices.iter()) + .zip(&variant_types) + .map(|(((i, v), lit_idx), variant_type)| { + let consume = variants.iter() + .enumerate() + .map(|(n, v)| { + let name = &v.syn.ident; + if i != n { + quote! { + let _ = self.#name.next(); + } + } + else { + quote! {} + } + }) + .collect::>(); + let consume = quote! { #(#consume)* }; + + let name = &v.syn.ident; + if v.is_unit { + quote! { + #lit_idx => { + #consume + let v = self.#name.next() + .unwrap_or_else(|| panic!("Invalid offset for {}", #original_name_str)); + assert!(v.unwrap()); + Some(Some(#original_name::#name)) + } + } + } + else { + quote! { + #lit_idx => { + #consume + let v = self.#name.next() + .unwrap_or_else(|| panic!("Invalid offset for {}", #original_name_str)); + Some(<#variant_type as arrow2_convert::deserialize::ArrowDeserialize>::arrow_deserialize(v).map(|v| #original_name::#name(v))) + } + } + } + }) + .collect::>(); + quote! { #(#candidates)* } + }; + + let array_decl = quote! { + #visibility struct #array_name + {} + }; + + let array_impl = quote! { + impl arrow2_convert::deserialize::ArrowArray for #array_name + { + type BaseArrayType = arrow2::array::UnionArray; + + #[inline] + fn iter_from_array_ref<'a>(b: &'a dyn arrow2::array::Array) -> <&'a Self as IntoIterator>::IntoIter + { + use core::ops::Deref; + let arr = b.as_any().downcast_ref::().unwrap(); + let fields = arr.fields(); + + #iterator_name { + #( + #variant_names: <<#variant_types as arrow2_convert::deserialize::ArrowDeserialize>::ArrayType as arrow2_convert::deserialize::ArrowArray>::iter_from_array_ref(fields[#variant_indices].deref()), + )* + types_iter: arr.types().iter(), + } + } + } + }; + + let array_into_iterator_impl = quote! { + impl<'a> IntoIterator for &'a #array_name + { + type Item = Option<#original_name>; + type IntoIter = #iterator_name<'a>; + + fn into_iter(self) -> Self::IntoIter { + unimplemented!("Use iter_from_array_ref"); + } + } + }; + + let array_iterator_decl = quote! { + #[allow(non_snake_case)] + #visibility struct #iterator_name<'a> { + #( + #variant_names: <&'a <#variant_types as arrow2_convert::deserialize::ArrowDeserialize>::ArrayType as IntoIterator>::IntoIter, + )* + types_iter: std::slice::Iter<'a, i8>, + } + }; + + let array_iterator_iterator_impl = quote! { + impl<'a> Iterator for #iterator_name<'a> { + type Item = Option<#original_name>; + + #[inline] + fn next(&mut self) -> Option { + match self.types_iter.next() { + Some(type_idx) => { + match type_idx { + #iter_next_match_block + _ => panic!("Invalid type for {}", #original_name_str) + } + } + None => None + } + } + } + }; + + let field_arrow_deserialize_impl = quote! { + impl arrow2_convert::deserialize::ArrowDeserialize for #original_name { + type ArrayType = #array_name; + + #[inline] + fn arrow_deserialize<'a>(v: Option) -> Option { + v + } + } + }; + + generated.extend([ + array_decl, + array_impl, + array_into_iterator_impl, + array_iterator_decl, + array_iterator_iterator_impl, + field_arrow_deserialize_impl, + ]); + } + + generated +} diff --git a/arrow2_convert_derive/src/_struct.rs b/arrow2_convert_derive/src/derive_struct.rs similarity index 81% rename from arrow2_convert_derive/src/_struct.rs rename to arrow2_convert_derive/src/derive_struct.rs index aa8ba7b..3ceaa4b 100644 --- a/arrow2_convert_derive/src/_struct.rs +++ b/arrow2_convert_derive/src/derive_struct.rs @@ -5,43 +5,14 @@ use syn::spanned::Spanned; use super::input::*; -// Helper method for identifying the traits to derive -fn traits_to_derive(t: &TraitsToDerive) -> (bool, bool) { - let mut gen_serialize = true; - let mut gen_deserialize = true; - - // setup the flags - match t { - TraitsToDerive::All => { /* do nothing */ } - TraitsToDerive::DeserializeOnly => { - gen_serialize = false; - } - TraitsToDerive::SerializeOnly => { - gen_deserialize = false; - } - TraitsToDerive::FieldOnly => { - gen_deserialize = false; - gen_serialize = false; - } - } - - (gen_serialize, gen_deserialize) -} - -pub fn expand_derive(input: &Input) -> TokenStream { - let original_name = &input.name; - let original_name_str = format!("{}", original_name); - let vec_name_str = format!("Vec<{}>", original_name); - let visibility = &input.visibility; - - let mutable_array_name = &input.mutable_array_name(); - let array_name = &input.array_name(); - let iterator_name = &input.iterator_name(); +pub fn expand(input: DeriveStruct) -> TokenStream { + let original_name = &input.common.name; + let visibility = &input.common.visibility; + let fields = &input.fields; - let (gen_serialize, gen_deserialize) = traits_to_derive(&input.traits_to_derive); + let (gen_serialize, gen_deserialize) = input.common.traits_to_derive.to_flags(); - let field_names = input - .fields + let field_names = fields .iter() .map(|field| field.syn.ident.as_ref().unwrap()) .collect::>(); @@ -66,18 +37,7 @@ pub fn expand_derive(input: &Input) -> TokenStream { .map(|(idx, _ident)| syn::LitInt::new(&format!("{}", idx), proc_macro2::Span::call_site())) .collect::>(); - let field_docs = field_names - .iter() - .map(|field| { - format!( - "A vector of `{0}` from a [`{1}`](struct.{1}.html)", - field, mutable_array_name - ) - }) - .collect::>(); - - let field_types: Vec<&syn::TypePath> = input - .fields + let field_types: Vec<&syn::TypePath> = fields .iter() .map(|field| match &field.field_type { syn::Type::Path(path) => path, @@ -85,11 +45,6 @@ pub fn expand_derive(input: &Input) -> TokenStream { }) .collect::>(); - let mutable_field_array_types = field_types - .iter() - .map(|field_type| quote_spanned!( field_type.span() => <#field_type as arrow2_convert::serialize::ArrowSerialize>::MutableArrayType)) - .collect::>(); - let mut generated = quote!( impl arrow2_convert::field::ArrowField for #original_name { type Type = Self; @@ -109,21 +64,24 @@ pub fn expand_derive(input: &Input) -> TokenStream { ); if gen_serialize { - generated.extend(quote! { - /// A mutable [`arrow2::StructArray`] for elements of - #[doc = #original_name_str] - /// which is logically equivalent to a - #[doc = #vec_name_str] + let mutable_array_name = &input.common.mutable_array_name(); + let mutable_field_array_types = field_types + .iter() + .map(|field_type| quote_spanned!( field_type.span() => <#field_type as arrow2_convert::serialize::ArrowSerialize>::MutableArrayType)) + .collect::>(); + + let array_decl = quote! { #[derive(Debug)] #visibility struct #mutable_array_name { #( - #[doc = #field_docs] #field_names: #mutable_field_array_types, )* data_type: arrow2::datatypes::DataType, validity: Option, } + }; + let array_impl = quote! { impl #mutable_array_name { pub fn new() -> Self { Self { @@ -140,13 +98,17 @@ pub fn expand_derive(input: &Input) -> TokenStream { self.validity = Some(validity) } } + }; + let array_default_impl = quote! { impl Default for #mutable_array_name { fn default() -> Self { Self::new() } } + }; + let array_arrow_mutable_array_impl = quote! { impl arrow2_convert::serialize::ArrowMutableArray for #mutable_array_name { fn reserve(&mut self, additional: usize, _additional_values: usize) { if let Some(x) = self.validity.as_mut() { @@ -155,9 +117,11 @@ pub fn expand_derive(input: &Input) -> TokenStream { #(<<#field_types as arrow2_convert::serialize::ArrowSerialize>::MutableArrayType as arrow2_convert::serialize::ArrowMutableArray>::reserve(&mut self.#field_names, additional, _additional_values);)* } } + }; - impl> arrow2::array::TryPush> for #mutable_array_name { - fn try_push(&mut self, item: Option) -> arrow2::error::Result<()> { + let array_try_push_impl = quote! { + impl<__T: std::borrow::Borrow<#original_name>> arrow2::array::TryPush> for #mutable_array_name { + fn try_push(&mut self, item: Option<__T>) -> arrow2::error::Result<()> { use arrow2::array::MutableArray; use std::borrow::Borrow; @@ -187,9 +151,11 @@ pub fn expand_derive(input: &Input) -> TokenStream { Ok(()) } } + }; - impl> arrow2::array::TryExtend> for #mutable_array_name { - fn try_extend>>(&mut self, iter: I) -> arrow2::error::Result<()> { + let array_try_extend_impl = quote! { + impl<__T: std::borrow::Borrow<#original_name>> arrow2::array::TryExtend> for #mutable_array_name { + fn try_extend>>(&mut self, iter: I) -> arrow2::error::Result<()> { use arrow2::array::TryPush; for i in iter { self.try_push(i)?; @@ -197,7 +163,9 @@ pub fn expand_derive(input: &Input) -> TokenStream { Ok(()) } } + }; + let array_mutable_array_impl = quote! { impl arrow2::array::MutableArray for #mutable_array_name { fn data_type(&self) -> &arrow2::datatypes::DataType { &self.data_type @@ -257,7 +225,9 @@ pub fn expand_derive(input: &Input) -> TokenStream { } } } + }; + let field_arrow_serialize_impl = quote! { impl arrow2_convert::serialize::ArrowSerialize for #original_name { type MutableArrayType = #mutable_array_name; @@ -266,25 +236,41 @@ pub fn expand_derive(input: &Input) -> TokenStream { Self::MutableArrayType::default() } + #[inline] fn arrow_serialize(v: &Self, array: &mut Self::MutableArrayType) -> arrow2::error::Result<()> { use arrow2::array::TryPush; array.try_push(Some(v)) } } - }); + }; + + generated.extend([ + array_decl, + array_impl, + array_default_impl, + array_arrow_mutable_array_impl, + array_try_push_impl, + array_try_extend_impl, + array_mutable_array_impl, + field_arrow_serialize_impl, + ]) } if gen_deserialize { - generated.extend(quote! { + let array_name = &input.common.array_name(); + let iterator_name = &input.common.iterator_name(); + + let array_decl = quote! { #visibility struct #array_name - { - array: Box - } + {} + }; + let array_impl = quote! { impl arrow2_convert::deserialize::ArrowArray for #array_name { type BaseArrayType = arrow2::array::StructArray; + #[inline] fn iter_from_array_ref<'a>(b: &'a dyn arrow2::array::Array) -> <&'a Self as IntoIterator>::IntoIter { use core::ops::Deref; @@ -301,7 +287,9 @@ pub fn expand_derive(input: &Input) -> TokenStream { } } } + }; + let array_into_iterator_impl = quote! { impl<'a> IntoIterator for &'a #array_name { type Item = Option<#original_name>; @@ -311,7 +299,9 @@ pub fn expand_derive(input: &Input) -> TokenStream { unimplemented!("Use iter_from_array_ref"); } } + }; + let iterator_decl = quote! { #visibility struct #iterator_name<'a> { #( #field_names: <&'a <#field_types as arrow2_convert::deserialize::ArrowDeserialize>::ArrayType as IntoIterator>::IntoIter, @@ -319,8 +309,11 @@ pub fn expand_derive(input: &Input) -> TokenStream { validity_iter: arrow2::bitmap::utils::BitmapIter<'a>, has_validity: bool } + }; + let iterator_impl = quote! { impl<'a> #iterator_name<'a> { + #[inline] fn return_next(&mut self) -> Option<#original_name> { if let (#( Some(#field_names), @@ -337,14 +330,18 @@ pub fn expand_derive(input: &Input) -> TokenStream { } } + #[inline] fn consume_next(&mut self) { #(let _ = self.#field_names.next();)* } } + }; + let iterator_iterator_impl = quote! { impl<'a> Iterator for #iterator_name<'a> { type Item = Option<#original_name>; + #[inline] fn next(&mut self) -> Option { if !self.has_validity { self.return_next().map(|y| Some(y)) @@ -355,15 +352,28 @@ pub fn expand_derive(input: &Input) -> TokenStream { } } } + }; + let field_arrow_deserialize_impl = quote! { impl arrow2_convert::deserialize::ArrowDeserialize for #original_name { type ArrayType = #array_name; + #[inline] fn arrow_deserialize<'a>(v: Option) -> Option { v } } - }); + }; + + generated.extend([ + array_decl, + array_impl, + array_into_iterator_impl, + iterator_decl, + iterator_impl, + iterator_iterator_impl, + field_arrow_deserialize_impl, + ]) } generated diff --git a/arrow2_convert_derive/src/input.rs b/arrow2_convert_derive/src/input.rs index 52ab012..18d9191 100644 --- a/arrow2_convert_derive/src/input.rs +++ b/arrow2_convert_derive/src/input.rs @@ -1,8 +1,18 @@ use proc_macro2::Span; use proc_macro_error::{abort, ResultExt}; -use syn::{Data, DeriveInput, Ident, Lit, Meta, MetaNameValue, Visibility}; +use syn::spanned::Spanned; +use syn::{DeriveInput, Ident, Lit, Meta, MetaNameValue, Visibility}; -#[derive(PartialEq)] +pub const ARROW_FIELD: &str = "arrow_field"; +pub const FIELD_TYPE: &str = "type"; +pub const UNION_TYPE: &str = "type"; +pub const UNION_TYPE_SPARSE: &str = "sparse"; +pub const UNION_TYPE_DENSE: &str = "dense"; +pub const FIELD_ONLY: &str = "field_only"; +pub const SERIALIZE_ONLY: &str = "serialize_only"; +pub const DESERIALIZE_ONLY: &str = "deserialize_only"; + +#[derive(PartialEq, Clone)] pub enum TraitsToDerive { FieldOnly, SerializeOnly, @@ -10,43 +20,145 @@ pub enum TraitsToDerive { All, } -/// Representing the struct we are deriving -pub struct Input { - /// The input struct name +pub struct DeriveCommon { + /// The input name pub name: Ident, /// The traits to derive pub traits_to_derive: TraitsToDerive, - /// The list of fields in the struct - pub fields: Vec, - /// The struct overall visibility + /// The overall visibility pub visibility: Visibility, } -pub struct Field { +pub struct DeriveStruct { + pub common: DeriveCommon, + /// The list of fields in the struct + pub fields: Vec, +} + +pub struct DeriveEnum { + pub common: DeriveCommon, + /// The list of variants in the enum + pub variants: Vec, + pub is_dense: bool, +} +/// All container attributes +pub struct ContainerAttrs { + pub traits_to_derive: Option, + pub is_dense: Option, +} + +/// All field attributes +pub struct FieldAttrs { + pub field_type: Option, +} + +pub struct DeriveField { pub syn: syn::Field, pub field_type: syn::Type, } -fn arrow_field(field: &syn::Field) -> syn::Type { - for attr in &field.attrs { - if let Ok(meta) = attr.parse_meta() { - if meta.path().is_ident("arrow_field") { - if let Meta::List(list) = meta { - for nested in list.nested { - if let syn::NestedMeta::Meta(meta) = nested { - match meta { - Meta::NameValue(MetaNameValue { - lit: Lit::Str(string), - path, - .. - }) => { - if path.is_ident("override") { - return syn::parse_str(&string.value()).unwrap_or_abort(); +pub struct DeriveVariant { + pub syn: syn::Variant, + pub field_type: syn::Type, + pub is_unit: bool, +} + +impl DeriveCommon { + pub fn from_ast(input: &DeriveInput, container_attrs: &ContainerAttrs) -> DeriveCommon { + DeriveCommon { + name: input.ident.clone(), + traits_to_derive: container_attrs + .traits_to_derive + .clone() + .unwrap_or(TraitsToDerive::All), + visibility: input.vis.clone(), + } + } + + pub fn mutable_array_name(&self) -> Ident { + Ident::new(&format!("Mutable{}Array", self.name), Span::call_site()) + } + + pub fn array_name(&self) -> Ident { + Ident::new(&format!("{}Array", self.name), Span::call_site()) + } + + pub fn iterator_name(&self) -> Ident { + Ident::new(&format!("{}ArrayIterator", self.name), Span::call_site()) + } +} + +impl ContainerAttrs { + pub fn from_ast(attrs: &[syn::Attribute]) -> ContainerAttrs { + let mut traits_to_derive: Option = None; + let mut is_dense: Option = None; + + for attr in attrs { + if let Ok(meta) = attr.parse_meta() { + if meta.path().is_ident(ARROW_FIELD) { + if let Meta::List(list) = meta { + for nested in list.nested { + if let syn::NestedMeta::Meta(meta) = nested { + match meta { + syn::Meta::NameValue(MetaNameValue { + lit: Lit::Str(string), + path, + .. + }) => { + if path.is_ident(UNION_TYPE) { + match string.value().as_ref() { + UNION_TYPE_DENSE => { + is_dense = Some(true); + } + UNION_TYPE_SPARSE => { + is_dense = Some(false); + } + _ => { + abort!( + path.span(), + "Unexpected value for mode" + ); + } + } + } else { + for value in string.value().split(',') { + match value { + FIELD_ONLY | SERIALIZE_ONLY + | DESERIALIZE_ONLY => { + if traits_to_derive.is_some() { + abort!(string.span(), "Only one of field_only, serialize-only or deserialize_only can be specified"); + } + + match value { + FIELD_ONLY => { + traits_to_derive = + Some(TraitsToDerive::FieldOnly); + } + SERIALIZE_ONLY => { + traits_to_derive = Some( + TraitsToDerive::SerializeOnly, + ); + } + DESERIALIZE_ONLY => { + traits_to_derive = Some( + TraitsToDerive::DeserializeOnly, + ); + } + _ => panic!("Unexpected {}", value), // intentionally leave as panic since we should never get here + } + } + _ => abort!( + string.span(), + "Unexpected {}", + value + ), + } + } + } + } + _ => { + abort!(meta.span(), "Unexpected attribute"); } - } - _ => { - use syn::spanned::Spanned; - abort!(meta.span(), "Unexpected attribute"); } } } @@ -54,90 +166,143 @@ fn arrow_field(field: &syn::Field) -> syn::Type { } } } - } - field.ty.clone() + ContainerAttrs { + traits_to_derive, + is_dense, + } + } } -impl Input { - pub fn new(input: DeriveInput) -> Input { - let mut traits_to_derive = TraitsToDerive::All; - - let fields = match input.data { - Data::Struct(s) => s - .fields - .iter() - .map(|f| Field { - syn: f.clone(), - field_type: arrow_field(f), - }) - .collect::>(), - _ => abort!( - input.ident.span(), - "#[derive(ArrowField)] only supports structs." - ), - }; +impl FieldAttrs { + pub fn from_ast(input: &[syn::Attribute]) -> FieldAttrs { + let mut field_type: Option = None; - let mut derives: Vec = vec![]; - for attr in input.attrs { + for attr in input { if let Ok(meta) = attr.parse_meta() { - if meta.path().is_ident("arrow2_convert") { - match meta { - Meta::NameValue(MetaNameValue { - lit: Lit::Str(string), - .. - }) => { - for value in string.value().split(',') { - match value { - "field_only" | "serialize_only" | "deserialize_only" => { - if traits_to_derive != TraitsToDerive::All { - abort!(string.span(), "Only one of field_only, serialize-only or deserialize_only can be specified"); - } - - match value { - "field_only" => { - traits_to_derive = TraitsToDerive::FieldOnly; - } - "serialize_only" => { - traits_to_derive = TraitsToDerive::SerializeOnly; - } - "deserialize_only" => { - traits_to_derive = TraitsToDerive::DeserializeOnly; - } - _ => panic!("Unexpected {}", value), // intentionally leave as panic since we should never get here + if meta.path().is_ident(ARROW_FIELD) { + if let Meta::List(list) = meta { + for nested in list.nested { + if let syn::NestedMeta::Meta(meta) = nested { + match meta { + Meta::NameValue(MetaNameValue { + lit: Lit::Str(string), + path, + .. + }) => { + if path.is_ident(FIELD_TYPE) { + field_type = Some( + syn::parse_str(&string.value()).unwrap_or_abort(), + ); } } - _ => abort!(string.span(), "Unexpected {}", value), + _ => { + abort!(meta.span(), "Unexpected attribute"); + } } - derives.push(Ident::new(value.trim(), Span::call_site())); } } - _ => { - use syn::spanned::Spanned; - abort!(meta.span(), "Unexpected attribute"); - } } } } } - Input { - name: input.ident, - fields, - visibility: input.vis, - traits_to_derive, + FieldAttrs { field_type } + } +} + +impl DeriveStruct { + pub fn from_ast(input: &DeriveInput, ast: &syn::DataStruct) -> DeriveStruct { + let container_attrs = ContainerAttrs::from_ast(&input.attrs); + let common = DeriveCommon::from_ast(input, &container_attrs); + + DeriveStruct { + common, + fields: ast + .fields + .iter() + .map(DeriveField::from_ast) + .collect::>(), } } +} - pub fn mutable_array_name(&self) -> Ident { - Ident::new(&format!("Mutable{}Array", self.name), Span::call_site()) +impl DeriveEnum { + pub fn from_ast(input: &DeriveInput, ast: &syn::DataEnum) -> DeriveEnum { + let container_attrs = ContainerAttrs::from_ast(&input.attrs); + let common = DeriveCommon::from_ast(input, &container_attrs); + + DeriveEnum { + common, + variants: ast + .variants + .iter() + .map(DeriveVariant::from_ast) + .collect::>(), + is_dense: container_attrs + .is_dense + .unwrap_or_else(|| abort!(input.span(), "Missing mode attribute for enum")), + } } +} - pub fn array_name(&self) -> Ident { - Ident::new(&format!("{}Array", self.name), Span::call_site()) +impl DeriveField { + pub fn from_ast(input: &syn::Field) -> DeriveField { + let attrs = FieldAttrs::from_ast(&input.attrs); + + DeriveField { + syn: input.clone(), + field_type: attrs.field_type.unwrap_or_else(|| input.ty.clone()), + } } +} - pub fn iterator_name(&self) -> Ident { - Ident::new(&format!("{}ArrayIterator", self.name), Span::call_site()) +impl DeriveVariant { + pub fn from_ast(input: &syn::Variant) -> DeriveVariant { + let attrs = FieldAttrs::from_ast(&input.attrs); + + let (is_unit, field_type) = match &input.fields { + syn::Fields::Named(_f) => { + unimplemented!() + } + syn::Fields::Unnamed(f) => { + if f.unnamed.len() > 1 { + unimplemented!() + } else { + (false, f.unnamed[0].ty.clone()) + } + } + syn::Fields::Unit => (true, syn::parse_str("bool").unwrap_or_abort()), + }; + DeriveVariant { + syn: input.clone(), + field_type: attrs.field_type.unwrap_or_else(|| field_type.clone()), + is_unit, + } + } +} + +impl TraitsToDerive { + // Helper method for identifying the traits to derive + pub fn to_flags(&self) -> (bool, bool) { + let mut gen_serialize = true; + let mut gen_deserialize = true; + + // setup the flags + match self { + TraitsToDerive::All => { /* do nothing */ } + TraitsToDerive::DeserializeOnly => { + gen_serialize = false; + } + TraitsToDerive::SerializeOnly => { + gen_deserialize = false; + } + TraitsToDerive::FieldOnly => { + gen_deserialize = false; + gen_serialize = false; + } + } + + (gen_serialize, gen_deserialize) } } diff --git a/arrow2_convert_derive/src/lib.rs b/arrow2_convert_derive/src/lib.rs index 22b0df1..0ab8de4 100644 --- a/arrow2_convert_derive/src/lib.rs +++ b/arrow2_convert_derive/src/lib.rs @@ -1,19 +1,22 @@ -use proc_macro2::TokenStream; -use proc_macro_error::proc_macro_error; -use quote::TokenStreamExt; +use proc_macro_error::{abort, proc_macro_error}; -mod _struct; +mod derive_enum; +mod derive_struct; mod input; -/// Derive macro for the Array trait. +use input::*; + +/// Derive macro for arrow fields #[proc_macro_error] #[proc_macro_derive(ArrowField, attributes(arrow_field))] pub fn arrow2_convert_derive_field(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - let ast = syn::parse(input).unwrap(); - let input = input::Input::new(ast); + let ast: syn::DeriveInput = syn::parse(input).unwrap(); - // Build the output, possibly using quasi-quotation - let mut generated = TokenStream::new(); - generated.append_all(_struct::expand_derive(&input)); - generated.into() + match &ast.data { + syn::Data::Enum(e) => derive_enum::expand(DeriveEnum::from_ast(&ast, e)).into(), + syn::Data::Struct(s) => derive_struct::expand(DeriveStruct::from_ast(&ast, s)).into(), + _ => { + abort!(ast.ident.span(), "Only structs and enums supported"); + } + } }