Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 95 additions & 64 deletions parquet-variant-compute/src/shred_variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ pub fn shred_variant(array: &VariantArray, as_type: &DataType) -> Result<Variant
let (value, typed_value, nulls) = builder.finish()?;
Ok(VariantArray::from_parts(
array.metadata_field().clone(),
Some(value),
Some(Arc::new(value) as ArrayRef),
Some(typed_value),
nulls,
))
Expand Down Expand Up @@ -408,8 +408,11 @@ impl<'a> VariantToShreddedObjectVariantRowBuilder<'a> {
let mut builder = StructArrayBuilder::new();
for (field_name, typed_value_builder) in self.typed_value_builders {
let (value, typed_value, nulls) = typed_value_builder.finish()?;
let array =
ShreddedVariantFieldArray::from_parts(Some(value), Some(typed_value), nulls);
let array = ShreddedVariantFieldArray::from_parts(
Some(Arc::new(value) as ArrayRef),
Some(typed_value),
nulls,
);
builder = builder.with_field(field_name, ArrayRef::from(array), false);
}
if let Some(nulls) = self.typed_value_nulls.finish() {
Expand Down Expand Up @@ -654,6 +657,7 @@ impl VariantSchemaNode {
mod tests {
use super::*;
use crate::VariantArrayBuilder;
use crate::variant_array::binary_array_value;
use arrow::array::{
Array, BinaryViewArray, FixedSizeBinaryArray, Float64Array, GenericListArray,
GenericListViewArray, Int64Array, LargeBinaryArray, LargeStringArray, ListArray,
Expand Down Expand Up @@ -826,7 +830,8 @@ mod tests {
) {
assert_eq!(array.len(), expected_len);

let fallbacks = (array.value_field().unwrap(), Some(array.metadata_field()));
let fallback_value = array.value_field().unwrap();
let fallback_metadata = array.metadata_field();
let array = downcast_list_like_array::<O>(array);

assert_eq!(
Expand All @@ -846,7 +851,7 @@ mod tests {
);
assert_eq!(
array.len(),
fallbacks.0.len(),
fallback_value.len(),
"fallbacks value field should match array length"
);

Expand All @@ -861,23 +866,28 @@ mod tests {
// Successfully shredded: typed list value present, no fallback value
assert!(array.is_valid(idx));
assert_eq!(array.value_size(idx), *len);
assert!(fallbacks.0.is_null(idx));
assert!(fallback_value.is_null(idx));
}
None => {
// Unable to shred: typed list value absent, fallback should carry the variant
assert!(array.is_null(idx));
assert_eq!(array.value_size(idx), O::zero());
match expected_fallback {
Some(expected_variant) => {
assert!(fallbacks.0.is_valid(idx));
let metadata_bytes = fallbacks
.1
.filter(|m| m.is_valid(idx))
.map(|m| m.value(idx))
.filter(|bytes| !bytes.is_empty())
.unwrap_or(EMPTY_VARIANT_METADATA_BYTES);
assert!(fallback_value.is_valid(idx));
let metadata_bytes =
binary_array_value(fallback_metadata.as_ref(), idx);
let metadata_bytes =
if fallback_metadata.is_valid(idx) && !metadata_bytes.is_empty() {
metadata_bytes
} else {
EMPTY_VARIANT_METADATA_BYTES
};
assert_eq!(
Variant::new(metadata_bytes, fallbacks.0.value(idx)),
Variant::new(
metadata_bytes,
binary_array_value(fallback_value.as_ref(), idx)
),
expected_variant.clone()
);
}
Expand Down Expand Up @@ -940,7 +950,10 @@ mod tests {
Some(expected_variant) => {
assert!(element_fallbacks.is_valid(idx));
assert_eq!(
Variant::new(EMPTY_VARIANT_METADATA_BYTES, element_fallbacks.value(idx)),
Variant::new(
EMPTY_VARIANT_METADATA_BYTES,
binary_array_value(element_fallbacks.as_ref(), idx)
),
expected_variant.clone()
);
}
Expand Down Expand Up @@ -971,7 +984,7 @@ mod tests {
#[test]
fn test_all_null_input() {
// Create VariantArray with no value field (all null case)
let metadata = BinaryViewArray::from_iter_values([&[1u8, 0u8]]); // minimal valid metadata
let metadata = Arc::new(BinaryViewArray::from_iter_values([&[1u8, 0u8]])) as ArrayRef; // minimal valid metadata
let all_null_array = VariantArray::from_parts(metadata, None, None, None);
let result = shred_variant(&all_null_array, &DataType::Int64).unwrap();

Expand Down Expand Up @@ -1085,7 +1098,10 @@ mod tests {
assert!(!value_field.is_null(1)); // value should contain original
assert!(typed_value_field.is_null(1)); // typed_value should be null
assert_eq!(
Variant::new(metadata_field.value(1), value_field.value(1)),
Variant::new(
binary_array_value(metadata_field.as_ref(), 1),
binary_array_value(value_field.as_ref(), 1)
),
Variant::from("hello")
);

Expand All @@ -1101,7 +1117,10 @@ mod tests {
assert!(!result.is_null(4));
assert!(!value_field.is_null(4)); // should contain Variant::Null
assert_eq!(
Variant::new(metadata_field.value(4), value_field.value(4)),
Variant::new(
binary_array_value(metadata_field.as_ref(), 4),
binary_array_value(value_field.as_ref(), 4)
),
Variant::Null
);
assert!(typed_value_field.is_null(4));
Expand Down Expand Up @@ -1178,7 +1197,10 @@ mod tests {
assert!(value.is_valid(1));
assert!(typed_value.is_null(1));
assert_eq!(
Variant::new(metadata.value(1), value.value(1)),
Variant::new(
binary_array_value(metadata.as_ref(), 1),
binary_array_value(value.as_ref(), 1)
),
Variant::from(42i64)
);

Expand All @@ -1192,7 +1214,10 @@ mod tests {
assert!(value.is_valid(3));
assert!(typed_value.is_null(3));
assert_eq!(
Variant::new(metadata.value(3), value.value(3)),
Variant::new(
binary_array_value(metadata.as_ref(), 3),
binary_array_value(value.as_ref(), 3)
),
Variant::Null
);

Expand Down Expand Up @@ -1234,7 +1259,10 @@ mod tests {
assert!(value.is_valid(1));
assert!(typed_value.is_null(1));
assert_eq!(
Variant::new(metadata.value(1), value.value(1)),
Variant::new(
binary_array_value(metadata.as_ref(), 1),
binary_array_value(value.as_ref(), 1)
),
Variant::from("not_binary")
);

Expand All @@ -1248,7 +1276,10 @@ mod tests {
assert!(value.is_valid(3));
assert!(typed_value.is_null(3));
assert_eq!(
Variant::new(metadata.value(3), value.value(3)),
Variant::new(
binary_array_value(metadata.as_ref(), 3),
binary_array_value(value.as_ref(), 3)
),
Variant::Null
);

Expand Down Expand Up @@ -1541,14 +1572,14 @@ mod tests {
.unwrap();
let outer_fallbacks = outer_elements.value_field().unwrap();

let outer_metadata = BinaryViewArray::from_iter_values(std::iter::repeat_n(
let outer_metadata = Arc::new(BinaryViewArray::from_iter_values(std::iter::repeat_n(
EMPTY_VARIANT_METADATA_BYTES,
outer_elements.len(),
));
))) as ArrayRef;
let outer_variant = VariantArray::from_parts(
outer_metadata,
Some(outer_fallbacks.clone()),
Some(Arc::new(outer_values.clone())),
Some(Arc::new(outer_values.clone()) as ArrayRef),
None,
);

Expand Down Expand Up @@ -1651,7 +1682,10 @@ mod tests {
// null is stored as Variant::Null in values
assert!(id_values.is_valid(1));
assert_eq!(
Variant::new(EMPTY_VARIANT_METADATA_BYTES, id_values.value(1)),
Variant::new(
EMPTY_VARIANT_METADATA_BYTES,
binary_array_value(id_values.as_ref(), 1)
),
Variant::Null
);
assert!(id_typed_values.is_null(1));
Expand Down Expand Up @@ -1725,7 +1759,6 @@ mod tests {
assert_eq!(result.len(), 9);

let metadata = result.metadata_field();

let value = result.value_field().unwrap();
let typed_value = result
.typed_value_field()
Expand All @@ -1741,24 +1774,14 @@ mod tests {
let age_field =
ShreddedVariantFieldArray::try_new(typed_value.column_by_name("age").unwrap()).unwrap();

let score_value = score_field
.value_field()
.unwrap()
.as_any()
.downcast_ref::<BinaryViewArray>()
.unwrap();
let score_value = score_field.value_field().unwrap();
let score_typed_value = score_field
.typed_value_field()
.unwrap()
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
let age_value = age_field
.value_field()
.unwrap()
.as_any()
.downcast_ref::<BinaryViewArray>()
.unwrap();
let age_value = age_field.value_field().unwrap();
let age_typed_value = age_field
.typed_value_field()
.unwrap()
Expand All @@ -1777,10 +1800,13 @@ mod tests {
}
fn get_value<'m, 'v>(
i: usize,
metadata: &'m BinaryViewArray,
value: &'v BinaryViewArray,
metadata: &'m dyn Array,
value: &'v dyn Array,
) -> Variant<'m, 'v> {
Variant::new(metadata.value(i), value.value(i))
Variant::new(
binary_array_value(metadata, i),
binary_array_value(value, i),
)
}
let expect = |i, expected_result: Option<ShreddedValue<ShreddedStruct>>| {
match expected_result {
Expand All @@ -1792,7 +1818,10 @@ mod tests {
match expected_value {
Some(expected_value) => {
assert!(value.is_valid(i));
assert_eq!(expected_value, get_value(i, metadata, value));
assert_eq!(
expected_value,
get_value(i, metadata.as_ref(), value.as_ref())
);
}
None => {
assert!(value.is_null(i));
Expand All @@ -1811,7 +1840,7 @@ mod tests {
assert!(score_value.is_valid(i));
assert_eq!(
expected_score_value,
get_value(i, metadata, score_value)
get_value(i, metadata.as_ref(), score_value.as_ref())
);
}
None => {
Expand All @@ -1832,7 +1861,7 @@ mod tests {
assert!(age_value.is_valid(i));
assert_eq!(
expected_age_value,
get_value(i, metadata, age_value)
get_value(i, metadata.as_ref(), age_value.as_ref())
);
}
None => {
Expand Down Expand Up @@ -1973,7 +2002,7 @@ mod tests {
// Helper to correctly create a variant object using a row's existing metadata
let object_with_foo_field = |i| {
use parquet_variant::{ParentState, ValueBuilder, VariantMetadata};
let metadata = VariantMetadata::new(metadata.value(i));
let metadata = VariantMetadata::new(binary_array_value(metadata.as_ref(), i));
let mut metadata_builder = ReadOnlyMetadataBuilder::new(&metadata);
let mut value_builder = ValueBuilder::new();
let state = ParentState::variant(&mut value_builder, &mut metadata_builder);
Expand Down Expand Up @@ -2072,7 +2101,10 @@ mod tests {
assert!(value_field.is_null(2));
assert!(value_field.is_valid(3));
assert_eq!(
Variant::new(result.metadata_field().value(3), value_field.value(3)),
Variant::new(
binary_array_value(result.metadata_field().as_ref(), 3),
binary_array_value(value_field.as_ref(), 3)
),
Variant::from("not an object")
);
assert!(value_field.is_null(4));
Expand All @@ -2090,10 +2122,10 @@ mod tests {
.unwrap();
assert_list_structure_and_elements::<Int64Type, i32>(
&VariantArray::from_parts(
BinaryViewArray::from_iter_values(std::iter::repeat_n(
Arc::new(BinaryViewArray::from_iter_values(std::iter::repeat_n(
EMPTY_VARIANT_METADATA_BYTES,
scores_field.len(),
)),
))) as ArrayRef,
Some(scores_field.value_field().unwrap().clone()),
Some(scores_field.typed_value_field().unwrap().clone()),
None,
Expand Down Expand Up @@ -2215,24 +2247,14 @@ mod tests {
ShreddedVariantFieldArray::try_new(typed_value.column_by_name("session_id").unwrap())
.unwrap();

let id_value = id_field
.value_field()
.unwrap()
.as_any()
.downcast_ref::<BinaryViewArray>()
.unwrap();
let id_value = id_field.value_field().unwrap();
let id_typed_value = id_field
.typed_value_field()
.unwrap()
.as_any()
.downcast_ref::<FixedSizeBinaryArray>()
.unwrap();
let session_id_value = session_id_field
.value_field()
.unwrap()
.as_any()
.downcast_ref::<BinaryViewArray>()
.unwrap();
let session_id_value = session_id_field.value_field().unwrap();
let session_id_typed_value = session_id_field
.typed_value_field()
.unwrap()
Expand Down Expand Up @@ -2269,7 +2291,10 @@ mod tests {
assert_eq!(session_id_typed_value.value(1), mock_uuid_3.as_bytes());

// Verify the value field contains the name field
let row_1_variant = Variant::new(metadata.value(1), value.value(1));
let row_1_variant = Variant::new(
binary_array_value(metadata.as_ref(), 1),
binary_array_value(value.as_ref(), 1),
);
let Variant::Object(obj) = row_1_variant else {
panic!("Expected object");
};
Expand Down Expand Up @@ -2301,7 +2326,10 @@ mod tests {

assert!(session_id_value.is_valid(3)); // type mismatch, stored in value
assert!(session_id_typed_value.is_null(3));
let session_id_variant = Variant::new(metadata.value(3), session_id_value.value(3));
let session_id_variant = Variant::new(
binary_array_value(metadata.as_ref(), 3),
binary_array_value(session_id_value.as_ref(), 3),
);
assert_eq!(session_id_variant, Variant::from("not-a-uuid"));

// Row 4: Type mismatch - id is int64, not UUID
Expand All @@ -2312,7 +2340,10 @@ mod tests {

assert!(id_value.is_valid(4)); // type mismatch, stored in value
assert!(id_typed_value.is_null(4));
let id_variant = Variant::new(metadata.value(4), id_value.value(4));
let id_variant = Variant::new(
binary_array_value(metadata.as_ref(), 4),
binary_array_value(id_value.as_ref(), 4),
);
assert_eq!(id_variant, Variant::from(12345i64));

assert!(session_id_value.is_null(4));
Expand Down
Loading
Loading