Skip to content

Commit

Permalink
fix(Copy From): fix incorrect type casts (#3264)
Browse files Browse the repository at this point in the history
* refactor: refactor RecordBatchStreamTypeAdapter

* fix(Copy From): fix incorrect type casts

* fix: unit tests

* chore: add comment
  • Loading branch information
WenyXu committed Jan 30, 2024
1 parent a079955 commit ddbd0ab
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 135 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/common/datasource/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ async-trait.workspace = true
bytes.workspace = true
common-error.workspace = true
common-macro.workspace = true
common-recordbatch.workspace = true
common-runtime.workspace = true
datafusion.workspace = true
datatypes.workspace = true
Expand Down
88 changes: 13 additions & 75 deletions src/common/datasource/src/file_format/orc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use arrow::compute::cast;
use arrow_schema::{ArrowError, Schema, SchemaRef};
use async_trait::async_trait;
use datafusion::arrow::record_batch::RecordBatch as DfRecordBatch;
use common_recordbatch::adapter::RecordBatchStreamTypeAdapter;
use datafusion::datasource::physical_plan::{FileMeta, FileOpenFuture, FileOpener};
use datafusion::error::{DataFusionError, Result as DfResult};
use datafusion::physical_plan::RecordBatchStream;
use futures::{Stream, StreamExt, TryStreamExt};
use futures::{StreamExt, TryStreamExt};
use object_store::ObjectStore;
use orc_rust::arrow_reader::{create_arrow_schema, Cursor};
use orc_rust::async_arrow_reader::ArrowStreamReader;
Expand Down Expand Up @@ -61,73 +57,6 @@ pub async fn infer_orc_schema<R: AsyncRead + AsyncSeek + Unpin + Send + 'static>
Ok(create_arrow_schema(&cursor))
}

pub struct OrcArrowStreamReaderAdapter<T: AsyncRead + AsyncSeek + Unpin + Send + 'static> {
output_schema: SchemaRef,
projection: Vec<usize>,
stream: ArrowStreamReader<T>,
}

impl<T: AsyncRead + AsyncSeek + Unpin + Send + 'static> OrcArrowStreamReaderAdapter<T> {
pub fn new(
output_schema: SchemaRef,
stream: ArrowStreamReader<T>,
projection: Option<Vec<usize>>,
) -> Self {
let projection = if let Some(projection) = projection {
projection
} else {
(0..output_schema.fields().len()).collect()
};

Self {
output_schema,
projection,
stream,
}
}
}

impl<T: AsyncRead + AsyncSeek + Unpin + Send + 'static> RecordBatchStream
for OrcArrowStreamReaderAdapter<T>
{
fn schema(&self) -> SchemaRef {
self.output_schema.clone()
}
}

impl<T: AsyncRead + AsyncSeek + Unpin + Send + 'static> Stream for OrcArrowStreamReaderAdapter<T> {
type Item = DfResult<DfRecordBatch>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let batch = futures::ready!(Pin::new(&mut self.stream).poll_next(cx))
.map(|r| r.map_err(|e| DataFusionError::External(Box::new(e))));

let projected_schema = self.output_schema.project(&self.projection)?;
let batch = batch.map(|b| {
b.and_then(|b| {
let mut columns = Vec::with_capacity(self.projection.len());
for idx in self.projection.iter() {
let column = b.column(*idx);
let field = self.output_schema.field(*idx);

if column.data_type() != field.data_type() {
let output = cast(&column, field.data_type())?;
columns.push(output)
} else {
columns.push(column.clone())
}
}

let record_batch = DfRecordBatch::try_new(projected_schema.into(), columns)?;

Ok(record_batch)
})
});

Poll::Ready(batch)
}
}

#[async_trait]
impl FileFormat for OrcFormat {
async fn infer_schema(&self, store: &ObjectStore, path: &str) -> Result<Schema> {
Expand Down Expand Up @@ -166,7 +95,15 @@ impl OrcOpener {
impl FileOpener for OrcOpener {
fn open(&self, meta: FileMeta) -> DfResult<FileOpenFuture> {
let object_store = self.object_store.clone();
let output_schema = self.output_schema.clone();
let projected_schema = if let Some(projection) = &self.projection {
let projected_schema = self
.output_schema
.project(projection)
.map_err(|e| DataFusionError::External(Box::new(e)))?;
Arc::new(projected_schema)
} else {
self.output_schema.clone()
};
let projection = self.projection.clone();
Ok(Box::pin(async move {
let reader = object_store
Expand All @@ -178,7 +115,8 @@ impl FileOpener for OrcOpener {
.await
.map_err(|e| DataFusionError::External(Box::new(e)))?;

let stream = OrcArrowStreamReaderAdapter::new(output_schema, stream_reader, projection);
let stream =
RecordBatchStreamTypeAdapter::new(projected_schema, stream_reader, projection);

let adopted = stream.map_err(|e| ArrowError::ExternalError(Box::new(e)));
Ok(adopted.boxed())
Expand Down
1 change: 1 addition & 0 deletions src/common/recordbatch/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ datafusion.workspace = true
datatypes.workspace = true
futures.workspace = true
paste = "1.0"
pin-project.workspace = true
serde.workspace = true
serde_json.workspace = true
snafu.workspace = true
Expand Down
74 changes: 47 additions & 27 deletions src/common/recordbatch/src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,21 @@
// limitations under the License.

use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use datafusion::arrow::compute::cast;
use datafusion::arrow::datatypes::SchemaRef as DfSchemaRef;
use datafusion::error::Result as DfResult;
use datafusion::parquet::arrow::async_reader::{AsyncFileReader, ParquetRecordBatchStream};
use datafusion::physical_plan::metrics::{BaselineMetrics, MetricValue};
use datafusion::physical_plan::{ExecutionPlan, RecordBatchStream as DfRecordBatchStream};
use datafusion_common::arrow::error::ArrowError;
use datafusion_common::DataFusionError;
use datatypes::schema::{Schema, SchemaRef};
use futures::ready;
use pin_project::pin_project;
use snafu::ResultExt;

use crate::error::{self, Result};
Expand All @@ -37,66 +39,84 @@ use crate::{
type FutureStream =
Pin<Box<dyn std::future::Future<Output = Result<SendableRecordBatchStream>> + Send>>;

/// ParquetRecordBatchStream -> DataFusion RecordBatchStream
pub struct ParquetRecordBatchStreamAdapter<T> {
stream: ParquetRecordBatchStream<T>,
output_schema: DfSchemaRef,
/// Casts the `RecordBatch`es of `stream` against the `output_schema`.
#[pin_project]
pub struct RecordBatchStreamTypeAdapter<T, E> {
#[pin]
stream: T,
projected_schema: DfSchemaRef,
projection: Vec<usize>,
phantom: PhantomData<E>,
}

impl<T: Unpin + AsyncFileReader + Send + 'static> ParquetRecordBatchStreamAdapter<T> {
pub fn new(
output_schema: DfSchemaRef,
stream: ParquetRecordBatchStream<T>,
projection: Option<Vec<usize>>,
) -> Self {
impl<T, E> RecordBatchStreamTypeAdapter<T, E>
where
T: Stream<Item = std::result::Result<DfRecordBatch, E>>,
E: std::error::Error + Send + Sync + 'static,
{
pub fn new(projected_schema: DfSchemaRef, stream: T, projection: Option<Vec<usize>>) -> Self {
let projection = if let Some(projection) = projection {
projection
} else {
(0..output_schema.fields().len()).collect()
(0..projected_schema.fields().len()).collect()
};

Self {
stream,
output_schema,
projected_schema,
projection,
phantom: Default::default(),
}
}
}

impl<T: Unpin + AsyncFileReader + Send + 'static> DfRecordBatchStream
for ParquetRecordBatchStreamAdapter<T>
impl<T, E> DfRecordBatchStream for RecordBatchStreamTypeAdapter<T, E>
where
T: Stream<Item = std::result::Result<DfRecordBatch, E>>,
E: std::error::Error + Send + Sync + 'static,
{
fn schema(&self) -> DfSchemaRef {
self.stream.schema().clone()
self.projected_schema.clone()
}
}

impl<T: Unpin + AsyncFileReader + Send + 'static> Stream for ParquetRecordBatchStreamAdapter<T> {
impl<T, E> Stream for RecordBatchStreamTypeAdapter<T, E>
where
T: Stream<Item = std::result::Result<DfRecordBatch, E>>,
E: std::error::Error + Send + Sync + 'static,
{
type Item = DfResult<DfRecordBatch>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let batch = futures::ready!(Pin::new(&mut self.stream).poll_next(cx))
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();

let batch = futures::ready!(this.stream.poll_next(cx))
.map(|r| r.map_err(|e| DataFusionError::External(Box::new(e))));

let projected_schema = self.output_schema.project(&self.projection)?;
let projected_schema = this.projected_schema.clone();
let projection = this.projection.clone();
let batch = batch.map(|b| {
b.and_then(|b| {
let mut columns = Vec::with_capacity(self.projection.len());
for idx in self.projection.iter() {
let column = b.column(*idx);
let field = self.output_schema.field(*idx);
let projected_column = b.project(&projection)?;
if projected_column.schema().fields.len() != projected_schema.fields.len() {
return Err(DataFusionError::ArrowError(ArrowError::SchemaError(format!(
"Trying to cast a RecordBatch into an incompatible schema. RecordBatch: {}, Target: {}",
projected_column.schema(),
projected_schema,
))));
}

let mut columns = Vec::with_capacity(projected_schema.fields.len());
for (idx,field) in projected_schema.fields.iter().enumerate() {
let column = projected_column.column(idx);
if column.data_type() != field.data_type() {
let output = cast(&column, field.data_type())?;
columns.push(output)
} else {
columns.push(column.clone())
}
}

let record_batch = DfRecordBatch::try_new(projected_schema.into(), columns)?;

let record_batch = DfRecordBatch::try_new(projected_schema, columns)?;
Ok(record_batch)
})
});
Expand Down

0 comments on commit ddbd0ab

Please sign in to comment.