Skip to content

Commit

Permalink
Fail on queries that use floats as keys
Browse files Browse the repository at this point in the history
  • Loading branch information
mwylde committed Dec 8, 2023
1 parent 587514f commit a7b9099
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 4 deletions.
30 changes: 26 additions & 4 deletions arroyo-sql/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -666,8 +666,14 @@ impl<'a> SqlPipelineBuilder<'a> {
(window_type_def(), AggregateResultExtraction::WindowTake)
} else {
let data_type = expression.expression_type(&ValuePointerContext::new());

(data_type, AggregateResultExtraction::KeyColumn)
};

data_type
.try_as_key()
.map_err(|e| anyhow!("cannot group by {}: {}", field.name(), e))?;

if let TypeDef::DataType(DataType::Struct(_), _) = &data_type {
bail!("structs should be struct-defs {:?}", expr);
}
Expand Down Expand Up @@ -825,11 +831,13 @@ impl<'a> SqlPipelineBuilder<'a> {
}
_ => {}
}

let mut join_pairs = join.on.clone();
if let Some(Expr::BinaryExpr(BinaryExpr { left, op, right })) = &join.filter {
if *op != datafusion_expr::Operator::Eq {
bail!("only equality joins are supported");
}

// check which side each column comes from. Assumes there's at least one field
let left_relation = join
.schema
Expand All @@ -849,15 +857,18 @@ impl<'a> SqlPipelineBuilder<'a> {
.as_ref()
.unwrap()
.to_string();

let left_table = Column::convert_expr(left)?.relation.unwrap();
let right_table = Column::convert_expr(right)?.relation.unwrap();

let pair = if right_table == right_relation && left_table == left_relation {
(left.as_ref().clone(), right.as_ref().clone())
} else if left_table == right_relation && right_table == left_relation {
((right.as_ref()).clone(), left.as_ref().clone())
} else {
bail!("join filter must contain at least one column from each side of the join");
};

join_pairs.push(pair);
} else if join.filter.is_some() {
bail!(
Expand All @@ -870,14 +881,25 @@ impl<'a> SqlPipelineBuilder<'a> {
.iter()
.map(|(left, _right)| Column::convert_expr(left))
.collect::<Result<Vec<_>>>()?;

let (left_computations, right_computations): (Vec<_>, Vec<_>) = join
.on
.iter()
.map(|(left, right)| {
Ok((
self.ctx(&left_input.return_type()).compile_expr(left)?,
self.ctx(&right_input.return_type()).compile_expr(right)?,
))
let left_expr = self.ctx(&left_input.return_type()).compile_expr(left)?;
let right_expr = self.ctx(&right_input.return_type()).compile_expr(right)?;

left_expr
.expression_type(&ValuePointerContext::new())
.try_as_key()
.map_err(|e| anyhow!("cannot join on {}: {}", left, e))?;

right_expr
.expression_type(&ValuePointerContext::new())
.try_as_key()
.map_err(|e| anyhow!("cannot join on {}: {}", right, e))?;

Ok((left_expr, right_expr))
})
.collect::<Result<Vec<_>>>()?
.into_iter()
Expand Down
47 changes: 47 additions & 0 deletions arroyo-sql/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,50 @@ async fn test_udf() {
.await
.unwrap();
}

#[tokio::test]
async fn test_no_float_group_by() {
let schema_provider = get_test_schema_provider();

let sql = "create table nexmark with (
connector = 'nexmark',
event_rate = '5'
);
select cast(bid.price as float)
from nexmark
group by 1;";

let _ = parse_and_get_program(sql, schema_provider, SqlConfig::default())
.await
.unwrap_err();
}

#[tokio::test]
async fn test_no_float_join_key() {
let schema_provider = get_test_schema_provider();

let sql = "create table test1 (
a FLOAT
) with (
connector = 'websocket',
endpoint = 'ws://blah',
format = 'json'
);
create table test2 (
b FLOAT
) with (
connector = 'websocket',
endpoint = 'ws://blah',
format = 'json'
);
select a from test1
LEFT JOIN test2 ON test1.a = test2.b";

let _ = parse_and_get_program(sql, schema_provider, SqlConfig::default())
.await
.unwrap_err();
}
16 changes: 16 additions & 0 deletions arroyo-sql/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,22 @@ impl TypeDef {
}
}

pub fn try_as_key(&self) -> Result<()> {
match self {
TypeDef::StructDef(sd, _) => {
sd.fields
.iter()
.map(|f| f.data_type.try_as_key())
.collect::<Result<Vec<_>>>()?;
Ok(())
}
TypeDef::DataType(DataType::Float16 | DataType::Float32 | DataType::Float64, _) => {
bail!("FLOAT field cannot be used as key")
}
_ => Ok(()),
}
}

pub fn get_literal(scalar: &ScalarValue) -> syn::Expr {
if scalar.is_null() {
return parse_quote!(None);
Expand Down

0 comments on commit a7b9099

Please sign in to comment.