Skip to content

Commit

Permalink
test: unit tests for ParquetSink, with and without partitioned file w…
Browse files Browse the repository at this point in the history
…rites
  • Loading branch information
wiedld committed Mar 11, 2024
1 parent bc2add7 commit 1cdea3e
Showing 1 changed file with 186 additions and 1 deletion.
187 changes: 186 additions & 1 deletion datafusion/core/src/datasource/file_format/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,7 @@ pub(crate) mod test_util {
#[cfg(test)]
mod tests {
use super::super::test_util::scan_format;
use crate::datasource::listing::{ListingTableUrl, PartitionedFile};
use crate::physical_plan::collect;
use std::fmt::{Display, Formatter};
use std::sync::atomic::{AtomicUsize, Ordering};
Expand All @@ -1113,13 +1114,19 @@ mod tests {
use crate::prelude::{SessionConfig, SessionContext};
use arrow::array::{Array, ArrayRef, StringArray};
use arrow::record_batch::RecordBatch;
use arrow_schema::Field;
use async_trait::async_trait;
use bytes::Bytes;
use datafusion_common::cast::{
as_binary_array, as_boolean_array, as_float32_array, as_float64_array,
as_int32_array, as_timestamp_nanosecond_array,
};
use datafusion_common::ScalarValue;
use datafusion_common::config::ParquetOptions;
use datafusion_common::file_options::parquet_writer::ParquetWriterOptions;
use datafusion_common::{FileTypeWriterOptions, ScalarValue};
use datafusion_execution::object_store::ObjectStoreUrl;
use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use futures::stream::BoxStream;
use futures::StreamExt;
use log::error;
Expand Down Expand Up @@ -1814,4 +1821,182 @@ mod tests {
let format = ParquetFormat::default();
scan_format(state, &format, &testdata, file_name, projection, limit).await
}

fn build_ctx(store_url: &url::Url) -> Arc<TaskContext> {
let tmp_dir = tempfile::TempDir::new().unwrap();
let local = Arc::new(
LocalFileSystem::new_with_prefix(&tmp_dir)
.expect("should create object store"),
);

let mut session = SessionConfig::default();
let mut parquet_opts = ParquetOptions::default();
parquet_opts.allow_single_file_parallelism = true;
session.options_mut().execution.parquet = parquet_opts;

let runtime = RuntimeEnv::default();
runtime
.object_store_registry
.register_store(store_url, local);

Arc::new(
TaskContext::default()
.with_session_config(session)
.with_runtime(Arc::new(runtime)),
)
}

#[tokio::test]
async fn parquet_sink_write() -> Result<()> {
let field_a = Field::new("a", DataType::Utf8, false);
let field_b = Field::new("b", DataType::Utf8, false);
let schema = Arc::new(Schema::new(vec![field_a, field_b]));
let object_store_url = ObjectStoreUrl::local_filesystem();

let file_sink_config = FileSinkConfig {
object_store_url: object_store_url.clone(),
file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)],
table_paths: vec![ListingTableUrl::parse("file:///")?],
output_schema: schema.clone(),
table_partition_cols: vec![],
overwrite: true,
file_type_writer_options: FileTypeWriterOptions::Parquet(
ParquetWriterOptions::new(WriterProperties::default()),
),
};
let parquet_sink = Arc::new(ParquetSink::new(file_sink_config));

// create data
let col_a: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar"]));
let col_b: ArrayRef = Arc::new(StringArray::from(vec!["baz", "baz"]));
let batch = RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)]).unwrap();

// write stream
parquet_sink
.write_all(
Box::pin(RecordBatchStreamAdapter::new(
schema,
futures::stream::iter(vec![Ok(batch)]),
)),
&build_ctx(object_store_url.as_ref()),
)
.await
.unwrap();

// assert written
let mut written = parquet_sink.written();
let written = written.drain();
assert_eq!(
written.len(),
1,
"expected a single parquet files to be written, instead found {}",
written.len()
);

// check the file metadata
for (
path,
FileMetaData {
num_rows, schema, ..
},
) in written.take(1)
{
let path_parts = path.parts().collect::<Vec<_>>();
assert_eq!(path_parts.len(), 1, "should not have path prefix");

assert_eq!(num_rows, 2, "file metdata to have 2 rows");
assert!(
schema.iter().any(|col_schema| col_schema.name == "a"),
"output file metadata should contain col a"
);
assert!(
schema.iter().any(|col_schema| col_schema.name == "b"),
"output file metadata should contain col b"
);
}

Ok(())
}

#[tokio::test]
async fn parquet_sink_write_partitions() -> Result<()> {
let field_a = Field::new("a", DataType::Utf8, false);
let field_b = Field::new("b", DataType::Utf8, false);
let schema = Arc::new(Schema::new(vec![field_a, field_b]));
let object_store_url = ObjectStoreUrl::local_filesystem();

// set file config to include partitioning on field_a
let file_sink_config = FileSinkConfig {
object_store_url: object_store_url.clone(),
file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)],
table_paths: vec![ListingTableUrl::parse("file:///")?],
output_schema: schema.clone(),
table_partition_cols: vec![("a".to_string(), DataType::Utf8)], // add partitioning
overwrite: true,
file_type_writer_options: FileTypeWriterOptions::Parquet(
ParquetWriterOptions::new(WriterProperties::default()),
),
};
let parquet_sink = Arc::new(ParquetSink::new(file_sink_config));

// create data with 2 partitions
let col_a: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar"]));
let col_b: ArrayRef = Arc::new(StringArray::from(vec!["baz", "baz"]));
let batch = RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)]).unwrap();

// write stream
parquet_sink
.write_all(
Box::pin(RecordBatchStreamAdapter::new(
schema,
futures::stream::iter(vec![Ok(batch)]),
)),
&build_ctx(object_store_url.as_ref()),
)
.await
.unwrap();

// assert written
let mut written = parquet_sink.written();
let written = written.drain();
assert_eq!(
written.len(),
2,
"expected two parquet files to be written, instead found {}",
written.len()
);

// check the file metadata includes partitions
let mut expected_partitions = std::collections::HashSet::from(["a=foo", "a=bar"]);
for (
path,
FileMetaData {
num_rows, schema, ..
},
) in written.take(2)
{
let path_parts = path.parts().collect::<Vec<_>>();
assert_eq!(path_parts.len(), 2, "should have path prefix");

let prefix = path_parts[0].as_ref();
assert!(
expected_partitions.contains(prefix),
"expected path prefix to match partition, instead found {:?}",
prefix
);
expected_partitions.remove(prefix);

assert_eq!(num_rows, 1, "file metdata to have 1 row");
assert!(
!schema.iter().any(|col_schema| col_schema.name == "a"),
"output file metadata will not contain partitioned col a"
);
assert!(
schema.iter().any(|col_schema| col_schema.name == "b"),
"output file metadata should contain col b"
);
}

Ok(())
}
}

0 comments on commit 1cdea3e

Please sign in to comment.