diff --git a/arrow/src/datatypes/schema.rs b/arrow/src/datatypes/schema.rs index 22b1ceb39a7..8c3014490d3 100644 --- a/arrow/src/datatypes/schema.rs +++ b/arrow/src/datatypes/schema.rs @@ -87,6 +87,24 @@ impl Schema { Self { fields, metadata } } + /// Returns a new schema with only the specified columns in the new schema + /// This carries metadata from the parent schema over as well + pub fn project(&self, indices: &[usize]) -> Result { + let new_fields = indices + .iter() + .map(|i| { + self.fields.get(*i).cloned().ok_or_else(|| { + ArrowError::SchemaError(format!( + "project index {} out of bounds, max field {}", + i, + self.fields().len() + )) + }) + }) + .collect::>>()?; + Ok(Self::new_with_metadata(new_fields, self.metadata.clone())) + } + /// Merge schema into self if it is compatible. Struct fields will be merged recursively. /// /// Example: @@ -369,4 +387,51 @@ mod tests { assert_eq!(schema, de_schema); } + + #[test] + fn test_projection() { + let mut metadata = HashMap::new(); + metadata.insert("meta".to_string(), "data".to_string()); + + let schema = Schema::new_with_metadata( + vec![ + Field::new("name", DataType::Utf8, false), + Field::new("address", DataType::Utf8, false), + Field::new("priority", DataType::UInt8, false), + ], + metadata, + ); + + let projected: Schema = schema.project(&[0, 2]).unwrap(); + + assert_eq!(projected.fields().len(), 2); + assert_eq!(projected.fields()[0].name(), "name"); + assert_eq!(projected.fields()[1].name(), "priority"); + assert_eq!(projected.metadata.get("meta").unwrap(), "data") + } + + #[test] + fn test_oob_projection() { + let mut metadata = HashMap::new(); + metadata.insert("meta".to_string(), "data".to_string()); + + let schema = Schema::new_with_metadata( + vec![ + Field::new("name", DataType::Utf8, false), + Field::new("address", DataType::Utf8, false), + Field::new("priority", DataType::UInt8, false), + ], + metadata, + ); + + let projected: Result = schema.project(&[0, 3]); + + assert!(projected.is_err()); + if let Err(e) = projected { + assert_eq!( + e.to_string(), + "Schema error: project index 3 out of bounds, max field 3".to_string() + ) + } + } } diff --git a/arrow/src/record_batch.rs b/arrow/src/record_batch.rs index b441f6cf295..9faba7ddce1 100644 --- a/arrow/src/record_batch.rs +++ b/arrow/src/record_batch.rs @@ -175,6 +175,25 @@ impl RecordBatch { self.schema.clone() } + /// Projects the schema onto the specified columns + pub fn project(&self, indices: &[usize]) -> Result { + let projected_schema = self.schema.project(indices)?; + let batch_fields = indices + .iter() + .map(|f| { + self.columns.get(*f).cloned().ok_or_else(|| { + ArrowError::SchemaError(format!( + "project index {} out of bounds, max field {}", + f, + self.columns.len() + )) + }) + }) + .collect::>>()?; + + RecordBatch::try_new(SchemaRef::new(projected_schema), batch_fields) + } + /// Returns the number of columns in the record batch. /// /// # Example @@ -900,4 +919,23 @@ mod tests { assert_ne!(batch1, batch2); } + + #[test] + fn project() { + let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])); + let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"])); + let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"])); + + let record_batch = RecordBatch::try_from_iter(vec![ + ("a", a.clone()), + ("b", b.clone()), + ("c", c.clone()), + ]) + .expect("valid conversion"); + + let expected = RecordBatch::try_from_iter(vec![("a", a), ("c", c)]) + .expect("valid conversion"); + + assert_eq!(expected, record_batch.project(&[0, 2]).unwrap()); + } }