Skip to content

Commit

Permalink
Merge 794dbce into 9ba214a
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Apr 24, 2021
2 parents 9ba214a + 794dbce commit c73c1f3
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 181 deletions.
217 changes: 38 additions & 179 deletions ballista/rust/client/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,14 @@ use ballista_core::serde::protobuf::{
GetJobStatusResult,
};
use ballista_core::{
client::BallistaClient,
datasource::DfTableAdapter,
error::{BallistaError, Result},
memory_stream::MemoryStream,
client::BallistaClient, datasource::DfTableAdapter, memory_stream::MemoryStream,
utils::create_datafusion_context,
};

use arrow::datatypes::Schema;
use datafusion::catalog::TableReference;
use datafusion::logical_plan::{DFSchema, Expr, LogicalPlan, Partitioning};
use datafusion::error::{DataFusionError, Result};
use datafusion::logical_plan::LogicalPlan;
use datafusion::physical_plan::csv::CsvReadOptions;
use datafusion::{dataframe::DataFrame, physical_plan::RecordBatchStream};
use log::{error, info};
Expand Down Expand Up @@ -88,15 +86,15 @@ impl BallistaContext {

/// Create a DataFrame representing a Parquet table scan

pub fn read_parquet(&self, path: &str) -> Result<BallistaDataFrame> {
pub fn read_parquet(&self, path: &str) -> Result<Arc<dyn DataFrame>> {
// convert to absolute path because the executor likely has a different working directory
let path = PathBuf::from(path);
let path = fs::canonicalize(&path)?;

// use local DataFusion context for now but later this might call the scheduler
let mut ctx = create_datafusion_context();
let df = ctx.read_parquet(path.to_str().unwrap())?;
Ok(BallistaDataFrame::from(self.state.clone(), df))
Ok(df)
}

/// Create a DataFrame representing a CSV table scan
Expand All @@ -105,19 +103,19 @@ impl BallistaContext {
&self,
path: &str,
options: CsvReadOptions,
) -> Result<BallistaDataFrame> {
) -> Result<Arc<dyn DataFrame>> {
// convert to absolute path because the executor likely has a different working directory
let path = PathBuf::from(path);
let path = fs::canonicalize(&path)?;

// use local DataFusion context for now but later this might call the scheduler
let mut ctx = create_datafusion_context();
let df = ctx.read_csv(path.to_str().unwrap(), options)?;
Ok(BallistaDataFrame::from(self.state.clone(), df))
Ok(df)
}

/// Register a DataFrame as a table that can be referenced from a SQL query
pub fn register_table(&self, name: &str, table: &BallistaDataFrame) -> Result<()> {
pub fn register_table(&self, name: &str, table: &dyn DataFrame) -> Result<()> {
let mut state = self.state.lock().unwrap();
state
.tables
Expand All @@ -132,16 +130,16 @@ impl BallistaContext {
options: CsvReadOptions,
) -> Result<()> {
let df = self.read_csv(path, options)?;
self.register_table(name, &df)
self.register_table(name, df.as_ref())
}

pub fn register_parquet(&self, name: &str, path: &str) -> Result<()> {
let df = self.read_parquet(path)?;
self.register_table(name, &df)
self.register_table(name, df.as_ref())
}

/// Create a DataFrame from a SQL statement
pub fn sql(&self, sql: &str) -> Result<BallistaDataFrame> {
pub fn sql(&self, sql: &str) -> Result<Arc<dyn DataFrame>> {
// use local DataFusion context for now but later this might call the scheduler
let mut ctx = create_datafusion_context();
// register tables
Expand All @@ -154,27 +152,13 @@ impl BallistaContext {
Arc::new(DfTableAdapter::new(plan, execution_plan)),
)?;
}
let df = ctx.sql(sql)?;
Ok(BallistaDataFrame::from(self.state.clone(), df))
ctx.sql(sql)
}
}

/// The Ballista DataFrame is a wrapper around the DataFusion DataFrame and overrides the
/// `collect` method so that the query is executed against Ballista and not DataFusion.

pub struct BallistaDataFrame {
/// Ballista context state
state: Arc<Mutex<BallistaContextState>>,
/// DataFusion DataFrame representing logical query plan
df: Arc<dyn DataFrame>,
}

impl BallistaDataFrame {
fn from(state: Arc<Mutex<BallistaContextState>>, df: Arc<dyn DataFrame>) -> Self {
Self { state, df }
}

pub async fn collect(&self) -> Result<Pin<Box<dyn RecordBatchStream + Send + Sync>>> {
pub async fn collect(
&self,
plan: &LogicalPlan,
) -> Result<Pin<Box<dyn RecordBatchStream + Send + Sync>>> {
let scheduler_url = {
let state = self.state.lock().unwrap();

Expand All @@ -183,16 +167,22 @@ impl BallistaDataFrame {

info!("Connecting to Ballista scheduler at {}", scheduler_url);

let mut scheduler = SchedulerGrpcClient::connect(scheduler_url).await?;
let mut scheduler = SchedulerGrpcClient::connect(scheduler_url)
.await
.map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?;

let plan = self.df.to_logical_plan();
let schema: Schema = plan.schema().as_ref().clone().into();

let job_id = scheduler
.execute_query(ExecuteQueryParams {
query: Some(Query::LogicalPlan((&plan).try_into()?)),
query: Some(Query::LogicalPlan(
(plan)
.try_into()
.map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?,
)),
})
.await?
.await
.map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?
.into_inner()
.job_id;

Expand All @@ -201,10 +191,11 @@ impl BallistaDataFrame {
.get_job_status(GetJobStatusParams {
job_id: job_id.clone(),
})
.await?
.await
.map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?
.into_inner();
let status = status.and_then(|s| s.status).ok_or_else(|| {
BallistaError::Internal("Received empty status message".to_owned())
DataFusionError::Internal("Received empty status message".to_owned())
})?;
let wait_future = tokio::time::sleep(Duration::from_millis(100));
match status {
Expand All @@ -219,34 +210,38 @@ impl BallistaDataFrame {
job_status::Status::Failed(err) => {
let msg = format!("Job {} failed: {}", job_id, err.error);
error!("{}", msg);
break Err(BallistaError::General(msg));
break Err(DataFusionError::Execution(msg));
}
job_status::Status::Completed(completed) => {
// TODO: use streaming. Probably need to change the signature of fetch_partition to achieve that
let mut result = vec![];
for location in completed.partition_location {
let metadata = location.executor_meta.ok_or_else(|| {
BallistaError::Internal(
DataFusionError::Internal(
"Received empty executor metadata".to_owned(),
)
})?;
let partition_id = location.partition_id.ok_or_else(|| {
BallistaError::Internal(
DataFusionError::Internal(
"Received empty partition id".to_owned(),
)
})?;
let mut ballista_client = BallistaClient::try_new(
metadata.host.as_str(),
metadata.port as u16,
)
.await?;
.await
.map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?;
let stream = ballista_client
.fetch_partition(
&partition_id.job_id,
partition_id.stage_id as usize,
partition_id.partition_id as usize,
)
.await?;
.await
.map_err(|e| {
DataFusionError::Execution(format!("{:?}", e))
})?;
result.append(
&mut datafusion::physical_plan::common::collect(stream)
.await?,
Expand All @@ -261,140 +256,4 @@ impl BallistaDataFrame {
};
}
}

pub fn select_columns(&self, columns: &[&str]) -> Result<BallistaDataFrame> {
Ok(Self::from(
self.state.clone(),
self.df
.select_columns(columns)
.map_err(BallistaError::from)?,
))
}

pub fn select(&self, expr: Vec<Expr>) -> Result<BallistaDataFrame> {
Ok(Self::from(
self.state.clone(),
self.df.select(expr).map_err(BallistaError::from)?,
))
}

pub fn filter(&self, expr: Expr) -> Result<BallistaDataFrame> {
Ok(Self::from(
self.state.clone(),
self.df.filter(expr).map_err(BallistaError::from)?,
))
}

pub fn aggregate(
&self,
group_expr: Vec<Expr>,
aggr_expr: Vec<Expr>,
) -> Result<BallistaDataFrame> {
Ok(Self::from(
self.state.clone(),
self.df
.aggregate(group_expr, aggr_expr)
.map_err(BallistaError::from)?,
))
}

pub fn limit(&self, n: usize) -> Result<BallistaDataFrame> {
Ok(Self::from(
self.state.clone(),
self.df.limit(n).map_err(BallistaError::from)?,
))
}

pub fn sort(&self, expr: Vec<Expr>) -> Result<BallistaDataFrame> {
Ok(Self::from(
self.state.clone(),
self.df.sort(expr).map_err(BallistaError::from)?,
))
}

// TODO lifetime issue
// pub fn join(&self, right: Arc<dyn DataFrame>, join_type: JoinType, left_cols: &[&str], right_cols: &[&str]) ->
// Result<BallistaDataFrame> { Ok(Self::from(self.state.clone(), self.df.join(right, join_type, &left_cols,
// &right_cols).map_err(BallistaError::from)?)) }

pub fn repartition(
&self,
partitioning_scheme: Partitioning,
) -> Result<BallistaDataFrame> {
Ok(Self::from(
self.state.clone(),
self.df
.repartition(partitioning_scheme)
.map_err(BallistaError::from)?,
))
}

pub fn schema(&self) -> &DFSchema {
self.df.schema()
}

pub fn to_logical_plan(&self) -> LogicalPlan {
self.df.to_logical_plan()
}

pub fn explain(&self, verbose: bool) -> Result<BallistaDataFrame> {
Ok(Self::from(
self.state.clone(),
self.df.explain(verbose).map_err(BallistaError::from)?,
))
}
}

// #[async_trait]
// impl ExecutionContext for BallistaContext {
// async fn get_executor_ids(&self) -> Result<Vec<ExecutorMeta>> {
// match &self.config.discovery_mode {
// DiscoveryMode::Etcd => etcd_get_executors(&self.config.etcd_urls, "default").await,
// DiscoveryMode::Kubernetes => k8s_get_executors("default", "ballista").await,
// DiscoveryMode::Standalone => Err(ballista_error("Standalone mode not implemented yet")),
// }
// }
//
// async fn execute_task(
// &self,
// executor_meta: ExecutorMeta,
// task: ExecutionTask,
// ) -> Result<ShuffleId> {
// // TODO what is the point of returning this info since it is based on input arg?
// let shuffle_id = ShuffleId::new(task.job_uuid, task.stage_id, task.partition_id);
//
// let _ = execute_action(
// &executor_meta.host,
// executor_meta.port,
// &Action::Execute(task),
// )
// .await?;
//
// Ok(shuffle_id)
// }
//
// async fn read_shuffle(&self, shuffle_id: &ShuffleId) -> Result<Vec<ColumnarBatch>> {
// match self.shuffle_locations.get(shuffle_id) {
// Some(executor_meta) => {
// let batches = execute_action(
// &executor_meta.host,
// executor_meta.port,
// &Action::FetchShuffle(*shuffle_id),
// )
// .await?;
// Ok(batches
// .iter()
// .map(|b| ColumnarBatch::from_arrow(b))
// .collect())
// }
// _ => Err(ballista_error(&format!(
// "Failed to resolve executor UUID for shuffle ID {:?}",
// shuffle_id
// ))),
// }
// }
//
// fn config(&self) -> ExecutorConfig {
// self.config.clone()
// }
// }
4 changes: 2 additions & 2 deletions benchmarks/src/bin/tpch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,8 @@ async fn benchmark_ballista(opt: BenchmarkOpt) -> Result<()> {
.sql(&sql)
.map_err(|e| DataFusionError::Plan(format!("{:?}", e)))?;
let mut batches = vec![];
let mut stream = df
.collect()
let mut stream = ctx
.collect(&df.to_logical_plan())
.await
.map_err(|e| DataFusionError::Plan(format!("{:?}", e)))?;
while let Some(result) = stream.next().await {
Expand Down

0 comments on commit c73c1f3

Please sign in to comment.