Skip to content

Commit

Permalink
Clean the code in field.rs and add more tests (#2345)
Browse files Browse the repository at this point in the history
* clean up the field

Signed-off-by: remzi <13716567376yh@gmail.com>

* test to check same field

Signed-off-by: remzi <13716567376yh@gmail.com>

* fix nit

Signed-off-by: remzi <13716567376yh@gmail.com>

* fix fmt

Signed-off-by: remzi <13716567376yh@gmail.com>
  • Loading branch information
HaoYang670 committed Aug 10, 2022
1 parent 195d9c5 commit d4ad4b7
Showing 1 changed file with 108 additions and 70 deletions.
178 changes: 108 additions & 70 deletions arrow/src/datatypes/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,23 +209,17 @@ impl Field {
}

fn _fields<'a>(&'a self, dt: &'a DataType) -> Vec<&Field> {
let mut collected_fields = vec![];

match dt {
DataType::Struct(fields) | DataType::Union(fields, _, _) => {
collected_fields.extend(fields.iter().flat_map(|f| f.fields()))
fields.iter().flat_map(|f| f.fields()).collect()
}
DataType::List(field)
| DataType::LargeList(field)
| DataType::FixedSizeList(field, _)
| DataType::Map(field, _) => collected_fields.extend(field.fields()),
DataType::Dictionary(_, value_field) => {
collected_fields.append(&mut self._fields(value_field.as_ref()))
}
_ => (),
| DataType::Map(field, _) => field.fields(),
DataType::Dictionary(_, value_field) => self._fields(value_field.as_ref()),
_ => vec![],
}

collected_fields
}

/// Returns a vector containing all (potentially nested) `Field` instances selected by the
Expand Down Expand Up @@ -506,12 +500,10 @@ impl Field {
pub fn to_json(&self) -> Value {
let children: Vec<Value> = match self.data_type() {
DataType::Struct(fields) => fields.iter().map(|f| f.to_json()).collect(),
DataType::List(field) => vec![field.to_json()],
DataType::LargeList(field) => vec![field.to_json()],
DataType::FixedSizeList(field, _) => vec![field.to_json()],
DataType::Map(field, _) => {
vec![field.to_json()]
}
DataType::List(field)
| DataType::LargeList(field)
| DataType::FixedSizeList(field, _)
| DataType::Map(field, _) => vec![field.to_json()],
_ => vec![],
};
match self.data_type() {
Expand Down Expand Up @@ -550,6 +542,17 @@ impl Field {
/// assert!(field.is_nullable());
/// ```
pub fn try_merge(&mut self, from: &Field) -> Result<()> {
if from.dict_id != self.dict_id {
return Err(ArrowError::SchemaError(
"Fail to merge schema Field due to conflicting dict_id".to_string(),
));
}
if from.dict_is_ordered != self.dict_is_ordered {
return Err(ArrowError::SchemaError(
"Fail to merge schema Field due to conflicting dict_is_ordered"
.to_string(),
));
}
// merge metadata
match (self.metadata(), from.metadata()) {
(Some(self_metadata), Some(from_metadata)) => {
Expand All @@ -572,31 +575,16 @@ impl Field {
}
_ => {}
}
if from.dict_id != self.dict_id {
return Err(ArrowError::SchemaError(
"Fail to merge schema Field due to conflicting dict_id".to_string(),
));
}
if from.dict_is_ordered != self.dict_is_ordered {
return Err(ArrowError::SchemaError(
"Fail to merge schema Field due to conflicting dict_is_ordered"
.to_string(),
));
}
match &mut self.data_type {
DataType::Struct(nested_fields) => match &from.data_type {
DataType::Struct(from_nested_fields) => {
for from_field in from_nested_fields {
let mut is_new_field = true;
for self_field in nested_fields.iter_mut() {
if self_field.name != from_field.name {
continue;
}
is_new_field = false;
self_field.try_merge(from_field)?;
}
if is_new_field {
nested_fields.push(from_field.clone());
match nested_fields
.iter_mut()
.find(|self_field| self_field.name == from_field.name)
{
Some(self_field) => self_field.try_merge(from_field)?,
None => nested_fields.push(from_field.clone()),
}
}
}
Expand Down Expand Up @@ -685,9 +673,7 @@ impl Field {
}
}
}
if from.nullable {
self.nullable = from.nullable;
}
self.nullable |= from.nullable;

Ok(())
}
Expand All @@ -698,41 +684,25 @@ impl Field {
/// * self.metadata is a superset of other.metadata
/// * all other fields are equal
pub fn contains(&self, other: &Field) -> bool {
if self.name != other.name
|| self.data_type != other.data_type
|| self.dict_id != other.dict_id
|| self.dict_is_ordered != other.dict_is_ordered
{
return false;
}

if self.nullable != other.nullable && !self.nullable {
return false;
}

self.name == other.name
&& self.data_type == other.data_type
&& self.dict_id == other.dict_id
&& self.dict_is_ordered == other.dict_is_ordered
// self need to be nullable or both of them are not nullable
&& (self.nullable || !other.nullable)
// make sure self.metadata is a superset of other.metadata
match (&self.metadata, &other.metadata) {
(None, Some(_)) => {
return false;
}
&& match (&self.metadata, &other.metadata) {
(_, None) => true,
(None, Some(_)) => false,
(Some(self_meta), Some(other_meta)) => {
for (k, v) in other_meta.iter() {
other_meta.iter().all(|(k, v)| {
match self_meta.get(k) {
Some(s) => {
if s != v {
return false;
}
}
None => {
return false;
}
Some(s) => s == v,
None => false
}
}
})
}
_ => {}
}

true
}
}

Expand All @@ -745,7 +715,7 @@ impl std::fmt::Display for Field {

#[cfg(test)]
mod test {
use super::{DataType, Field};
use super::*;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};

Expand Down Expand Up @@ -840,4 +810,72 @@ mod test {
assert_ne!(dict1, dict2);
assert_ne!(get_field_hash(&dict1), get_field_hash(&dict2));
}

#[test]
fn test_contains_reflexivity() {
let mut field = Field::new("field1", DataType::Float16, false);
field.set_metadata(Some(BTreeMap::from([
(String::from("k0"), String::from("v0")),
(String::from("k1"), String::from("v1")),
])));
assert!(field.contains(&field))
}

#[test]
fn test_contains_transitivity() {
let child_field = Field::new("child1", DataType::Float16, false);

let mut field1 = Field::new("field1", DataType::Struct(vec![child_field]), false);
field1.set_metadata(Some(BTreeMap::from([(
String::from("k1"),
String::from("v1"),
)])));

let mut field2 = Field::new("field1", DataType::Struct(vec![]), true);
field2.set_metadata(Some(BTreeMap::from([(
String::from("k2"),
String::from("v2"),
)])));
field2.try_merge(&field1).unwrap();

let mut field3 = Field::new("field1", DataType::Struct(vec![]), false);
field3.set_metadata(Some(BTreeMap::from([(
String::from("k3"),
String::from("v3"),
)])));
field3.try_merge(&field2).unwrap();

assert!(field2.contains(&field1));
assert!(field3.contains(&field2));
assert!(field3.contains(&field1));

assert!(!field1.contains(&field2));
assert!(!field1.contains(&field3));
assert!(!field2.contains(&field3));
}

#[test]
fn test_contains_nullable() {
let field1 = Field::new("field1", DataType::Boolean, true);
let field2 = Field::new("field1", DataType::Boolean, false);
assert!(field1.contains(&field2));
assert!(!field2.contains(&field1));
}

#[test]
fn test_contains_must_have_same_fields() {
let child_field1 = Field::new("child1", DataType::Float16, false);
let child_field2 = Field::new("child2", DataType::Float16, false);

let field1 =
Field::new("field1", DataType::Struct(vec![child_field1.clone()]), true);
let field2 = Field::new(
"field1",
DataType::Struct(vec![child_field1, child_field2]),
true,
);

assert!(!field1.contains(&field2));
assert!(!field2.contains(&field1));
}
}

0 comments on commit d4ad4b7

Please sign in to comment.