Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion crates/control_plane/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ datafusion-functions-json = { workspace = true }
datafusion-physical-plan = { workspace = true }
datafusion_iceberg = { workspace = true }

flatbuffers = { version = "24.3.25" }
futures = { workspace = true }
iceberg-rest-catalog = { workspace = true }
iceberg-rust = { workspace = true }
Expand Down
48 changes: 34 additions & 14 deletions crates/control_plane/src/models/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use arrow::array::RecordBatch;
use arrow::datatypes::{DataType, Field};
use arrow::datatypes::{DataType, Field, TimeUnit};
use chrono::{NaiveDateTime, Utc};
use iceberg_rust::object_store::ObjectStoreBuilder;
use object_store::aws::AmazonS3Builder;
Expand Down Expand Up @@ -481,17 +481,25 @@ impl ColumnInfo {
DataType::Date32 | DataType::Date64 => {
column_info.r#type = "date".to_string();
}
DataType::Timestamp(_, _) => {
DataType::Timestamp(unit, _) => {
column_info.r#type = "timestamp_ntz".to_string();
column_info.precision = Some(0);
column_info.scale = Some(9);
let scale = match unit {
TimeUnit::Second => 0,
TimeUnit::Millisecond => 3,
TimeUnit::Microsecond => 6,
TimeUnit::Nanosecond => 9,
};
column_info.scale = Some(scale);
}
DataType::Binary => {
column_info.r#type = "binary".to_string();
column_info.byte_length = Some(8_388_608);
column_info.length = Some(8_388_608);
}
_ => {}
_ => {
column_info.r#type = "text".to_string();
}
}
column_info
}
Expand Down Expand Up @@ -704,23 +712,35 @@ mod tests {
assert_eq!(column_info.name, "test_field");
assert_eq!(column_info.r#type, "date");

let field = Field::new(
"test_field",
DataType::Timestamp(TimeUnit::Second, None),
false,
);
let column_info = ColumnInfo::from_field(&field);
assert_eq!(column_info.name, "test_field");
assert_eq!(column_info.r#type, "timestamp_ntz");
assert_eq!(column_info.precision.unwrap(), 0);
assert_eq!(column_info.scale.unwrap(), 9);
let units = [
(TimeUnit::Second, 0),
(TimeUnit::Millisecond, 3),
(TimeUnit::Microsecond, 6),
(TimeUnit::Nanosecond, 9),
];
for (unit, scale) in units {
let field = Field::new("test_field", DataType::Timestamp(unit, None), false);
let column_info = ColumnInfo::from_field(&field);
assert_eq!(column_info.name, "test_field");
assert_eq!(column_info.r#type, "timestamp_ntz");
assert_eq!(column_info.precision.unwrap(), 0);
assert_eq!(column_info.scale.unwrap(), scale);
}

let field = Field::new("test_field", DataType::Binary, false);
let column_info = ColumnInfo::from_field(&field);
assert_eq!(column_info.name, "test_field");
assert_eq!(column_info.r#type, "binary");
assert_eq!(column_info.byte_length.unwrap(), 8_388_608);
assert_eq!(column_info.length.unwrap(), 8_388_608);

// Any other type
let field = Field::new("test_field", DataType::Utf8View, false);
let column_info = ColumnInfo::from_field(&field);
assert_eq!(column_info.name, "test_field");
assert_eq!(column_info.r#type, "text");
assert_eq!(column_info.byte_length, None);
assert_eq!(column_info.length, None);
}

#[tokio::test]
Expand Down
15 changes: 12 additions & 3 deletions crates/control_plane/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::error::{self, ControlPlaneError, ControlPlaneResult};
use crate::models::{ColumnInfo, Credentials, StorageProfile, StorageProfileCreateRequest};
use crate::models::{Warehouse, WarehouseCreateRequest};
use crate::repository::{StorageProfileRepository, WarehouseRepository};
use crate::utils::convert_record_batches;
use crate::utils::{convert_record_batches, Config};
use arrow::record_batch::RecordBatch;
use arrow_json::writer::JsonArray;
use arrow_json::WriterBuilder;
Expand Down Expand Up @@ -80,12 +80,14 @@ pub trait ControlService: Send + Sync {
async fn create_session(&self, session_id: String) -> ControlPlaneResult<()>;

async fn delete_session(&self, session_id: String) -> ControlPlaneResult<()>;
fn config(&self) -> &Config;
}

pub struct ControlServiceImpl {
storage_profile_repo: Arc<dyn StorageProfileRepository>,
warehouse_repo: Arc<dyn WarehouseRepository>,
df_sessions: Arc<RwLock<HashMap<String, SqlExecutor>>>,
config: Config,
}

impl ControlServiceImpl {
Expand All @@ -98,6 +100,7 @@ impl ControlServiceImpl {
storage_profile_repo,
warehouse_repo,
df_sessions,
config: Config::default(),
}
}
}
Expand Down Expand Up @@ -327,8 +330,11 @@ impl ControlService for ControlServiceImpl {
.context(crate::error::ExecutionSnafu)?
.into_iter()
.collect::<Vec<_>>();

let serialization_format = self.config().dbt_serialization_format;
// Add columns dbt metadata to each field
convert_record_batches(records).context(crate::error::DataFusionQuerySnafu { query })
convert_record_batches(records, serialization_format)
.context(error::DataFusionQuerySnafu { query })
}

#[tracing::instrument(level = "debug", skip(self))]
Expand Down Expand Up @@ -538,12 +544,15 @@ impl ControlService for ControlServiceImpl {

Ok(())
}

fn config(&self) -> &Config {
&self.config
}
}

#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {

use super::*;
use crate::error::ControlPlaneError;
use crate::models::{
Expand Down
183 changes: 127 additions & 56 deletions crates/control_plane/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::models::ColumnInfo;
use arrow::array::{
Array, StringArray, TimestampMicrosecondArray, TimestampMillisecondArray,
Array, Int64Array, StringArray, TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampNanosecondArray, TimestampSecondArray, UnionArray,
};
use arrow::datatypes::{Field, Schema, TimeUnit};
Expand All @@ -9,7 +9,45 @@ use chrono::DateTime;
use datafusion::arrow::array::ArrayRef;
use datafusion::arrow::datatypes::DataType;
use datafusion::common::Result as DataFusionResult;
use std::fmt::Display;
use std::sync::Arc;
use std::{env, fmt};

pub struct Config {
pub dbt_serialization_format: SerializationFormat,
}

impl Default for Config {
fn default() -> Self {
Self {
dbt_serialization_format: SerializationFormat::new(),
}
}
}
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum SerializationFormat {
Arrow,
Json,
}

impl Display for SerializationFormat {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Arrow => write!(f, "arrow"),
Self::Json => write!(f, "json"),
}
}
}

impl SerializationFormat {
fn new() -> Self {
let var = env::var("DBT_SERIALIZATION_FORMAT").unwrap_or_else(|_| "json".to_string());
match var.to_lowercase().as_str() {
"arrow" => Self::Arrow,
_ => Self::Json,
}
}
}

#[must_use]
pub fn first_non_empty_type(union_array: &UnionArray) -> Option<(DataType, ArrayRef)> {
Expand All @@ -25,6 +63,7 @@ pub fn first_non_empty_type(union_array: &UnionArray) -> Option<(DataType, Array

pub fn convert_record_batches(
records: Vec<RecordBatch>,
serialization_format: SerializationFormat,
) -> DataFusionResult<(Vec<RecordBatch>, Vec<ColumnInfo>)> {
let mut converted_batches = Vec::new();
let column_infos = ColumnInfo::from_batch(&records);
Expand Down Expand Up @@ -54,7 +93,8 @@ pub fn convert_record_batches(
}
}
DataType::Timestamp(unit, _) => {
let converted_column = convert_timestamp_to_struct(column, *unit);
let converted_column =
convert_timestamp_to_struct(column, *unit, serialization_format);
fields.push(
Field::new(
field.name(),
Expand Down Expand Up @@ -82,63 +122,80 @@ pub fn convert_record_batches(
Ok((converted_batches.clone(), column_infos))
}

macro_rules! downcast_and_iter {
($column:expr, $array_type:ty) => {
$column
.as_any()
.downcast_ref::<$array_type>()
.unwrap()
.into_iter()
};
}

#[allow(
clippy::unwrap_used,
clippy::as_conversions,
clippy::cast_possible_truncation
)]
fn convert_timestamp_to_struct(column: &ArrayRef, unit: TimeUnit) -> ArrayRef {
let timestamps: Vec<_> = match unit {
TimeUnit::Second => column
.as_any()
.downcast_ref::<TimestampSecondArray>()
.unwrap()
.iter()
.map(|x| {
x.map(|ts| {
let ts = DateTime::from_timestamp(ts, 0).unwrap();
format!("{}", ts.timestamp())
})
})
.collect(),
TimeUnit::Millisecond => column
.as_any()
.downcast_ref::<TimestampMillisecondArray>()
.unwrap()
.iter()
.map(|x| {
x.map(|ts| {
let ts = DateTime::from_timestamp_millis(ts).unwrap();
format!("{}.{}", ts.timestamp(), ts.timestamp_subsec_millis())
})
})
.collect(),
TimeUnit::Microsecond => column
.as_any()
.downcast_ref::<TimestampMicrosecondArray>()
.unwrap()
.iter()
.map(|x| {
x.map(|ts| {
let ts = DateTime::from_timestamp_micros(ts).unwrap();
format!("{}.{}", ts.timestamp(), ts.timestamp_subsec_micros())
})
})
.collect(),
TimeUnit::Nanosecond => column
.as_any()
.downcast_ref::<TimestampNanosecondArray>()
.unwrap()
.iter()
.map(|x| {
x.map(|ts| {
let ts = DateTime::from_timestamp_nanos(ts);
format!("{}.{}", ts.timestamp(), ts.timestamp_subsec_nanos())
})
})
.collect(),
};
Arc::new(StringArray::from(timestamps)) as ArrayRef
fn convert_timestamp_to_struct(
column: &ArrayRef,
unit: TimeUnit,
ser: SerializationFormat,
) -> ArrayRef {
match ser {
SerializationFormat::Arrow => {
let timestamps: Vec<_> = match unit {
TimeUnit::Second => downcast_and_iter!(column, TimestampSecondArray).collect(),
TimeUnit::Millisecond => {
downcast_and_iter!(column, TimestampMillisecondArray).collect()
}
TimeUnit::Microsecond => {
downcast_and_iter!(column, TimestampMicrosecondArray).collect()
}
TimeUnit::Nanosecond => {
downcast_and_iter!(column, TimestampNanosecondArray).collect()
}
};
Arc::new(Int64Array::from(timestamps)) as ArrayRef
}
SerializationFormat::Json => {
let timestamps: Vec<_> = match unit {
TimeUnit::Second => downcast_and_iter!(column, TimestampSecondArray)
.map(|x| {
x.map(|ts| {
let ts = DateTime::from_timestamp(ts, 0).unwrap();
format!("{}", ts.timestamp())
})
})
.collect(),
TimeUnit::Millisecond => downcast_and_iter!(column, TimestampMillisecondArray)
.map(|x| {
x.map(|ts| {
let ts = DateTime::from_timestamp_millis(ts).unwrap();
format!("{}.{}", ts.timestamp(), ts.timestamp_subsec_millis())
})
})
.collect(),
TimeUnit::Microsecond => downcast_and_iter!(column, TimestampMicrosecondArray)
.map(|x| {
x.map(|ts| {
let ts = DateTime::from_timestamp_micros(ts).unwrap();
format!("{}.{}", ts.timestamp(), ts.timestamp_subsec_micros())
})
})
.collect(),
TimeUnit::Nanosecond => downcast_and_iter!(column, TimestampNanosecondArray)
.map(|x| {
x.map(|ts| {
let ts = DateTime::from_timestamp_nanos(ts);
format!("{}.{}", ts.timestamp(), ts.timestamp_subsec_nanos())
})
})
.collect(),
};
Arc::new(StringArray::from(timestamps)) as ArrayRef
}
}
}

#[cfg(test)]
Expand Down Expand Up @@ -209,7 +266,8 @@ mod tests {
Arc::new(TimestampNanosecondArray::from(values)) as ArrayRef
}
};
let result = convert_timestamp_to_struct(&timestamp_array, *unit);
let result =
convert_timestamp_to_struct(&timestamp_array, *unit, SerializationFormat::Json);
let string_array = result.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(string_array.len(), 2);
assert_eq!(string_array.value(0), *expected);
Expand All @@ -235,7 +293,8 @@ mod tests {
])) as ArrayRef;
let batch = RecordBatch::try_new(schema, vec![int_array, timestamp_array]).unwrap();
let records = vec![batch];
let (converted_batches, column_infos) = convert_record_batches(records).unwrap();
let (converted_batches, column_infos) =
convert_record_batches(records.clone(), SerializationFormat::Json).unwrap();

let converted_batch = &converted_batches[0];
assert_eq!(converted_batches.len(), 1);
Expand All @@ -255,5 +314,17 @@ mod tests {
assert_eq!(column_infos[0].r#type, "fixed");
assert_eq!(column_infos[1].name, "timestamp_col");
assert_eq!(column_infos[1].r#type, "timestamp_ntz");

let (converted_batches, _) =
convert_record_batches(records, SerializationFormat::Arrow).unwrap();
let converted_batch = &converted_batches[0];
let converted_timestamp_array = converted_batch
.column(1)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(converted_timestamp_array.value(0), 1_627_846_261);
assert!(converted_timestamp_array.is_null(1));
assert_eq!(converted_timestamp_array.value(2), 1_627_846_262);
}
}
Loading
Loading