Skip to content

Commit

Permalink
chore: simplify the code (#147)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Jackson Newhouse <jackson@arroyo.systems>
  • Loading branch information
chenquan and jacksonrnewhouse committed May 31, 2023
1 parent f015a0b commit eb81077
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 95 deletions.
83 changes: 43 additions & 40 deletions arroyo-api/src/optimizations.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
use arroyo_datastream::{
AggregateBehavior, EdgeType, ExpressionReturnType, ExpressionReturnType::*, Operator,
StreamEdge, StreamNode, WindowAgg,
};
use petgraph::data::DataMap;
use petgraph::graph::DiGraph;
use petgraph::prelude::EdgeRef;
Expand All @@ -12,6 +8,11 @@ use proc_macro2::TokenStream;
use quote::quote;
use syn::{parse_str, Type};

use arroyo_datastream::{
AggregateBehavior, EdgeType, ExpressionReturnType, ExpressionReturnType::*, Operator,
StreamEdge, StreamNode, WindowAgg,
};

pub fn optimize(graph: &mut DiGraph<StreamNode, StreamEdge>) {
WasmFusionOptimizer {}.optimize(graph);
fuse_window_aggregation(graph);
Expand All @@ -20,17 +21,14 @@ pub fn optimize(graph: &mut DiGraph<StreamNode, StreamEdge>) {
}

fn remove_in_place(graph: &mut DiGraph<StreamNode, StreamEdge>, node: NodeIndex) {
let incoming = graph
.edges_directed(node, Direction::Incoming)
.next()
.unwrap();
let incoming = graph.edges_directed(node, Incoming).next().unwrap();

let parent = incoming.source().id();
let incoming = incoming.id();
graph.remove_edge(incoming);

let outgoing: Vec<_> = graph
.edges_directed(node, Direction::Outgoing)
.edges_directed(node, Outgoing)
.map(|e| (e.id(), e.target().id()))
.collect();

Expand All @@ -44,39 +42,43 @@ fn remove_in_place(graph: &mut DiGraph<StreamNode, StreamEdge>, node: NodeIndex)

fn fuse_window_aggregation(graph: &mut DiGraph<StreamNode, StreamEdge>) {
'outer: loop {
let sources: Vec<_> = graph.externals(Direction::Incoming).collect();
let sources: Vec<_> = graph.externals(Incoming).collect();

for source in sources {
let mut dfs = Dfs::new(&(*graph), source);

while let Some(idx) = dfs.next(&(*graph)) {
let operator = graph.node_weight(idx).unwrap().operator.clone();
let mut ins = graph.edges_directed(idx, Direction::Incoming);
let mut ins = graph.edges_directed(idx, Incoming);

let in_degree = ins.clone().count();
let no_shuffles = ins.all(|e| e.weight().typ == EdgeType::Forward);

let mut ins = graph.edges_directed(idx, Direction::Incoming);
if no_shuffles && in_degree == 1 {
let source_idx = ins.next().unwrap().source();
let in_node = graph.node_weight_mut(source_idx).unwrap();
if let Operator::Window { agg, .. } = &mut in_node.operator {
if agg.is_none() {
let new_agg = match &operator {
Operator::Count => Some(WindowAgg::Count),
Operator::Aggregate(AggregateBehavior::Min) => Some(WindowAgg::Min),
Operator::Aggregate(AggregateBehavior::Max) => Some(WindowAgg::Max),
Operator::Aggregate(AggregateBehavior::Sum) => Some(WindowAgg::Sum),
_ => None,
};

if let Some(new_agg) = new_agg {
*agg = Some(new_agg);
remove_in_place(graph, idx);
// restart the loop if we change something
continue 'outer;
}
}
let mut ins = graph.edges_directed(idx, Incoming);
if !(no_shuffles && in_degree == 1) {
continue;
}

let source_idx = ins.next().unwrap().source();
let in_node = graph.node_weight_mut(source_idx).unwrap();
if let Operator::Window { agg, .. } = &mut in_node.operator {
if agg.is_some() {
continue;
}

let new_agg = match &operator {
Operator::Count => Some(WindowAgg::Count),
Operator::Aggregate(AggregateBehavior::Min) => Some(WindowAgg::Min),
Operator::Aggregate(AggregateBehavior::Max) => Some(WindowAgg::Max),
Operator::Aggregate(AggregateBehavior::Sum) => Some(WindowAgg::Sum),
_ => None,
};

if let Some(new_agg) = new_agg {
*agg = Some(new_agg);
remove_in_place(graph, idx);
// restart the loop if we change something
continue 'outer;
}
}
}
Expand All @@ -95,13 +97,13 @@ pub trait Optimizer {

fn optimize_once(&self, graph: &mut DiGraph<StreamNode, StreamEdge>) -> bool {
let mut to_fuse = vec![vec![]];
for source in graph.externals(Direction::Incoming) {
for source in graph.externals(Incoming) {
let mut dfs = Dfs::new(&(*graph), source);
let mut current_chain = vec![];

while let Some(idx) = dfs.next(&(*graph)) {
let node = graph.node_weight(idx).unwrap();
let mut ins = graph.edges_directed(idx, Direction::Incoming);
let mut ins = graph.edges_directed(idx, Incoming);

let in_degree = ins.clone().count();
let no_shuffles = ins.all(|e| e.weight().typ == EdgeType::Forward);
Expand Down Expand Up @@ -243,6 +245,7 @@ impl FusedExpressionOperatorBuilder {
#return_value
})
.to_string();

let expr: syn::Expr = parse_str(&body).expect(&body);
let expression = quote!(#expr).to_string();
Operator::ExpressionOperator {
Expand All @@ -261,11 +264,9 @@ impl FusedExpressionOperatorBuilder {
) -> bool {
let out_type = format!("arroyo_types::Record<{},{}>", edge.key, edge.value);
match return_type {
ExpressionReturnType::Predicate => self.fuse_predicate(name, expression),
ExpressionReturnType::Record => self.fuse_map(name, expression, out_type),
ExpressionReturnType::OptionalRecord => {
self.fuse_option_map(name, expression, out_type)
}
Predicate => self.fuse_predicate(name, expression),
Record => self.fuse_map(name, expression, out_type),
OptionalRecord => self.fuse_option_map(name, expression, out_type),
}
}

Expand Down Expand Up @@ -395,6 +396,7 @@ impl Optimizer for FlatMapFusionOptimizer {
}

struct WasmFusionOptimizer {}

impl Optimizer for WasmFusionOptimizer {
fn can_optimize(&self, node: &StreamNode, _current_chain: &[StreamNode]) -> bool {
matches!(&node.operator, Operator::FusedWasmUDFs { .. })
Expand Down Expand Up @@ -429,9 +431,10 @@ impl Optimizer for WasmFusionOptimizer {
mod tests {
use std::{path::PathBuf, time::Duration};

use arroyo_datastream::{StreamEdge, StreamNode, WindowAgg};
use petgraph::prelude::DiGraph;

use arroyo_datastream::{StreamEdge, StreamNode, WindowAgg};

use super::fuse_window_aggregation;

#[test]
Expand Down
22 changes: 11 additions & 11 deletions arroyo-api/src/pipelines.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
use std::str::FromStr;

use anyhow::Context;
use cornucopia_async::GenericClient;
use deadpool_postgres::Transaction;
use petgraph::Direction;
use prost::Message;
use serde_json::Value;
use tonic::Status;
use tracing::log::info;
use tracing::warn;

use arroyo_datastream::{auth_config_to_hashmap, Operator, Program, SinkConfig};
use arroyo_rpc::grpc::api::create_sql_job::Sink;
use arroyo_rpc::grpc::api::sink::SinkType;
Expand All @@ -9,17 +20,6 @@ use arroyo_rpc::grpc::api::{
};
use arroyo_sql::{ArroyoSchemaProvider, SqlConfig};

use cornucopia_async::GenericClient;
use deadpool_postgres::Transaction;
use petgraph::Direction;
use prost::Message;
use serde_json::Value;
use tracing::log::info;

use std::str::FromStr;
use tonic::Status;
use tracing::warn;

use crate::queries::api_queries;
use crate::queries::api_queries::DbPipeline;
use crate::types::public::PipelineType;
Expand Down
82 changes: 38 additions & 44 deletions arroyo-api/src/sources.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
use std::fmt::{Display, Formatter};

use arrow::datatypes::TimeUnit;
use cornucopia_async::GenericClient;
use deadpool_postgres::Pool;
use http::StatusCode;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tonic::Status;
use tracing::warn;

use arroyo_datastream::{SerializationMode, SourceConfig};
use arroyo_rpc::grpc::api::{
self,
Expand All @@ -14,13 +23,6 @@ use arroyo_sql::{
types::{StructDef, StructField, TypeDef},
ArroyoSchemaProvider,
};
use cornucopia_async::GenericClient;
use deadpool_postgres::Pool;
use http::StatusCode;
use std::fmt::{Display, Formatter};
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tonic::Status;
use tracing::warn;

use crate::types::public::SchemaType;
use crate::{
Expand Down Expand Up @@ -350,16 +352,16 @@ fn builtin_for_name(name: &str) -> Result<SourceSchema, String> {
impl SourceSchema {
pub fn try_from(name: &str, s: api::SourceSchema) -> Result<Self, String> {
match s.schema.unwrap() {
api::source_schema::Schema::Builtin(name) => builtin_for_name(&name),
api::source_schema::Schema::JsonSchema(def) => {
Schema::Builtin(name) => builtin_for_name(&name),
Schema::JsonSchema(def) => {
let fields = json_schema::convert_json_schema(name, &def.json_schema)?;
Ok(SourceSchema {
format: SourceFormat::JsonSchema(def.json_schema),
fields,
kafka_schema: s.kafka_schema_registry,
})
}
api::source_schema::Schema::JsonFields(def) => {
Schema::JsonFields(def) => {
let fields: Result<Vec<_>, String> =
def.fields.into_iter().map(|f| f.try_into()).collect();
Ok(SourceSchema {
Expand All @@ -369,9 +371,7 @@ impl SourceSchema {
})
}
Schema::RawJson(_) => Ok(raw_schema()),
api::source_schema::Schema::Protobuf(_) => {
Err("protobuf not supported yet".to_string())
}
Schema::Protobuf(_) => Err("protobuf not supported yet".to_string()),
}
}

Expand All @@ -396,23 +396,19 @@ impl TryFrom<&SourceSchema> for api::SourceSchema {
fn try_from(s: &SourceSchema) -> Result<Self, Self::Error> {
Ok(api::SourceSchema {
schema: Some(match &s.format {
SourceFormat::Native(s) => api::source_schema::Schema::Builtin(s.clone()),
SourceFormat::JsonFields => {
api::source_schema::Schema::JsonFields(api::JsonFieldDef {
fields: s
.fields
.clone()
.into_iter()
.filter_map(|f| f.try_into().ok())
.collect(),
})
}
SourceFormat::JsonSchema(s) => {
api::source_schema::Schema::JsonSchema(JsonSchemaDef {
json_schema: s.clone(),
})
}
SourceFormat::RawJson => api::source_schema::Schema::RawJson(api::RawJsonDef {}),
SourceFormat::Native(s) => Schema::Builtin(s.clone()),
SourceFormat::JsonFields => Schema::JsonFields(api::JsonFieldDef {
fields: s
.fields
.clone()
.into_iter()
.filter_map(|f| f.try_into().ok())
.collect(),
}),
SourceFormat::JsonSchema(s) => Schema::JsonSchema(JsonSchemaDef {
json_schema: s.clone(),
}),
SourceFormat::RawJson => Schema::RawJson(api::RawJsonDef {}),
}),
kafka_schema_registry: s.kafka_schema,
})
Expand All @@ -430,7 +426,7 @@ impl TryFrom<SourceDef> for Source {
type Error = String;

fn try_from(value: SourceDef) -> Result<Self, Self::Error> {
let schema = if let api::source_schema::Schema::Builtin(name) =
let schema = if let Schema::Builtin(name) =
value.schema.as_ref().unwrap().schema.as_ref().unwrap()
{
builtin_for_name(name)?
Expand Down Expand Up @@ -496,7 +492,7 @@ pub(crate) async fn create_source(
req: CreateSourceReq,
auth: AuthData,
pool: &Pool,
) -> core::result::Result<(), Status> {
) -> Result<(), Status> {
let schema_name = format!("{}_schema", req.name);

let mut c = pool.get().await.map_err(log_and_map)?;
Expand All @@ -518,13 +514,13 @@ pub(crate) async fn create_source(
.or_else(|| {
match req.type_oneof {
Some(create_source_req::TypeOneof::Impulse { .. }) => {
Some(source_schema::Schema::Builtin("impulse".to_string()))
Some(Schema::Builtin("impulse".to_string()))
}
Some(create_source_req::TypeOneof::Nexmark { .. }) => {
Some(source_schema::Schema::Builtin("nexmark".to_string()))
Some(Schema::Builtin("nexmark".to_string()))
}
Some(create_source_req::TypeOneof::Kafka { .. }) => {
Some(source_schema::Schema::RawJson(RawJsonDef {}))
Some(Schema::RawJson(RawJsonDef {}))
}
_ => None,
}
Expand All @@ -539,25 +535,23 @@ pub(crate) async fn create_source(
.schema
.ok_or_else(|| required_field("schema.schema"))?
{
source_schema::Schema::Builtin(name) => {
Schema::Builtin(name) => {
builtin_for_name(&name).map_err(Status::invalid_argument)?;
(SchemaType::builtin, serde_json::to_value(&name).unwrap())
}
source_schema::Schema::JsonSchema(js) => {
Schema::JsonSchema(js) => {
// try to convert the schema to ensure it's valid
convert_json_schema(&req.name, &js.json_schema).map_err(Status::invalid_argument)?;

// parse the schema into a value
(SchemaType::json_schema, serde_json::to_value(&js).unwrap())
}
source_schema::Schema::JsonFields(fields) => (
Schema::JsonFields(fields) => (
SchemaType::json_fields,
serde_json::to_value(fields).unwrap(),
),
source_schema::Schema::RawJson(_) => {
(SchemaType::raw_json, serde_json::to_value(()).unwrap())
}
source_schema::Schema::Protobuf(_) => todo!(),
Schema::RawJson(_) => (SchemaType::raw_json, serde_json::to_value(()).unwrap()),
Schema::Protobuf(_) => todo!(),
};

let schema_id = api_queries::create_schema()
Expand Down Expand Up @@ -771,7 +765,7 @@ pub(crate) async fn test_schema(req: CreateSourceReq) -> Result<Vec<String>, Sta

match schema {
Schema::JsonSchema(schema) => {
if let Err(e) = json_schema::convert_json_schema(&req.name, &schema.json_schema) {
if let Err(e) = convert_json_schema(&req.name, &schema.json_schema) {
Ok(vec![e])
} else {
Ok(vec![])
Expand Down Expand Up @@ -962,7 +956,7 @@ pub(crate) async fn get_confluent_schema(
)
})?;

if let Err(e) = json_schema::convert_json_schema(&req.topic, schema) {
if let Err(e) = convert_json_schema(&req.topic, schema) {
warn!(
"Schema from schema registry is not valid: '{}': {}",
schema, e
Expand Down

0 comments on commit eb81077

Please sign in to comment.