diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index a75354cf9b3..0775392b7d6 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -160,6 +160,12 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Decimal128(_, _) | Decimal256(_, _), Utf8 | LargeUtf8) => true, // Utf8 to decimal (Utf8 | LargeUtf8, Decimal128(_, _) | Decimal256(_, _)) => true, + (Struct(from_fields), Struct(to_fields)) => { + from_fields.len() == to_fields.len() && + from_fields.iter().zip(to_fields.iter()).all(|(f1, f2)| { + can_cast_types(f1.data_type(), f2.data_type()) + }) + } (Struct(_), _) => false, (_, Struct(_)) => false, (_, Boolean) => { @@ -1138,11 +1144,22 @@ pub fn cast_with_options( ))), } } + (Struct(_), Struct(to_fields)) => { + let array = array.as_struct(); + let fields = array + .columns() + .iter() + .zip(to_fields.iter()) + .map(|(l, field)| cast_with_options(l, field.data_type(), cast_options)) + .collect::, ArrowError>>()?; + let array = StructArray::new(to_fields.clone(), fields, array.nulls().cloned()); + Ok(Arc::new(array) as ArrayRef) + } (Struct(_), _) => Err(ArrowError::CastError( - "Cannot cast from struct to other types".to_string(), + "Cannot cast from struct to other types except struct".to_string(), )), (_, Struct(_)) => Err(ArrowError::CastError( - "Cannot cast to struct from other types".to_string(), + "Cannot cast to struct from other types except struct".to_string(), )), (_, Boolean) => match from_type { UInt8 => cast_numeric_to_bool::(array), @@ -9447,4 +9464,65 @@ mod tests { ); } } + #[test] + fn test_cast_struct_to_struct() { + let struct_type = DataType::Struct( + vec![ + Field::new("a", DataType::Boolean, false), + Field::new("b", DataType::Int32, false), + ] + .into(), + ); + let to_type = DataType::Struct( + vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + ] + .into(), + ); + let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true])); + let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31])); + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("b", DataType::Boolean, false)), + boolean.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("c", DataType::Int32, false)), + int.clone() as ArrayRef, + ), + ]); + let casted_array = cast(&struct_array, &to_type).unwrap(); + let casted_array = casted_array.as_struct(); + assert_eq!(casted_array.data_type(), &to_type); + let casted_boolean_array = casted_array + .column(0) + .as_string::() + .into_iter() + .flatten() + .collect::>(); + let casted_int_array = casted_array + .column(1) + .as_string::() + .into_iter() + .flatten() + .collect::>(); + assert_eq!(casted_boolean_array, vec!["false", "false", "true", "true"]); + assert_eq!(casted_int_array, vec!["42", "28", "19", "31"]); + + // test for can't cast + let to_type = DataType::Struct( + vec![ + Field::new("a", DataType::Date32, false), + Field::new("b", DataType::Utf8, false), + ] + .into(), + ); + assert!(!can_cast_types(&struct_type, &to_type)); + let result = cast(&struct_array, &to_type); + assert_eq!( + "Cast error: Casting from Boolean to Date32 not supported", + result.unwrap_err().to_string() + ); + } }