Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 28 additions & 22 deletions src/binder/create_view.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,30 +24,36 @@ impl<T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'_, '_, T, A>
let view_name = Arc::new(lower_case_name(name)?);
let mut plan = self.bind_query(query)?;

if !columns.is_empty() {
let mapping_schema = plan.output_schema();
let exprs = columns
.iter()
.enumerate()
.map(|(i, ident)| {
let mapping_column = &mapping_schema[i];
let mut column = ColumnCatalog::new(
lower_ident(ident),
mapping_column.nullable(),
mapping_column.desc().clone(),
);
column.set_ref_table(view_name.clone(), Ulid::new(), true);
let mapping_schema = plan.output_schema();

ScalarExpression::Alias {
expr: Box::new(ScalarExpression::ColumnRef(mapping_column.clone())),
alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(
ColumnRef::from(column),
))),
}
})
.collect_vec();
plan = self.bind_project(plan, exprs)?;
let exprs = if columns.is_empty() {
Box::new(
mapping_schema
.iter()
.map(|column| column.name().to_string()),
) as Box<dyn Iterator<Item = String>>
} else {
Box::new(columns.iter().map(lower_ident)) as Box<dyn Iterator<Item = String>>
}
.enumerate()
.map(|(i, column_name)| {
let mapping_column = &mapping_schema[i];
let mut column = ColumnCatalog::new(
column_name,
mapping_column.nullable(),
mapping_column.desc().clone(),
);
column.set_ref_table(view_name.clone(), Ulid::new(), true);

ScalarExpression::Alias {
expr: Box::new(ScalarExpression::ColumnRef(mapping_column.clone())),
alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(ColumnRef::from(
column,
)))),
}
})
.collect_vec();
plan = self.bind_project(plan, exprs)?;

Ok(LogicalPlan::new(
Operator::CreateView(CreateViewOperator {
Expand Down
15 changes: 12 additions & 3 deletions src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ pub struct BinderContext<'a, T: Transaction> {
group_by_exprs: Vec<ScalarExpression>,
pub(crate) agg_calls: Vec<ScalarExpression>,
// join
using: HashSet<String>,
using: HashSet<ColumnRef>,

bind_step: QueryBindStep,
sub_queries: HashMap<QueryBindStep, Vec<SubQueryType>>,
Expand Down Expand Up @@ -295,8 +295,17 @@ impl<'a, T: Transaction> BinderContext<'a, T> {
}
}

pub fn add_using(&mut self, name: String) {
self.using.insert(name);
pub fn add_using(
&mut self,
join_type: JoinType,
left_expr: &ColumnRef,
right_expr: &ColumnRef,
) {
self.using.insert(if join_type.is_right() {
left_expr.clone()
} else {
right_expr.clone()
});
}

pub fn add_alias(
Expand Down
124 changes: 58 additions & 66 deletions src/binder/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ use crate::types::tuple::{Schema, SchemaRef};
use crate::types::value::Utf8Type;
use crate::types::{ColumnId, LogicalType};
use itertools::Itertools;
use sqlparser::ast::CharLengthUnits::Characters;
use sqlparser::ast::{
CharLengthUnits, Distinct, Expr, Ident, Join, JoinConstraint, JoinOperator, Offset,
OrderByExpr, Query, Select, SelectInto, SelectItem, SetExpr, SetOperator, SetQuantifier,
Expand Down Expand Up @@ -189,14 +188,14 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'
}
}

if left_cast.len() > 0 {
if !left_cast.is_empty() {
left_plan = LogicalPlan::new(
Operator::Project(ProjectOperator { exprs: left_cast }),
Childrens::Only(left_plan),
);
}

if right_cast.len() > 0 {
if !right_cast.is_empty() {
right_plan = LogicalPlan::new(
Operator::Project(ProjectOperator { exprs: right_cast }),
Childrens::Only(right_plan),
Expand Down Expand Up @@ -393,7 +392,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'
unreachable!()
}
}
_ => unimplemented!(),
table => return Err(DatabaseError::UnsupportedStmt(format!("{:#?}", table))),
};

Ok(plan)
Expand Down Expand Up @@ -517,8 +516,6 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'
}
continue;
}
let mut join_used = HashSet::with_capacity(self.context.using.len());

for (table_name, alias, _) in self.context.bind_table.keys() {
let schema_buf =
self.table_schema_buf.entry(table_name.clone()).or_default();
Expand All @@ -527,7 +524,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'
schema_buf,
&mut select_items,
alias.as_ref().unwrap_or(table_name).clone(),
Some(&mut join_used),
false,
)?;
}
}
Expand All @@ -540,7 +537,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'
schema_buf,
&mut select_items,
table_name,
None,
true,
)?;
}
};
Expand All @@ -555,58 +552,48 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'
schema_buf: &mut Option<SchemaOutput>,
exprs: &mut Vec<ScalarExpression>,
table_name: TableName,
mut join_used: Option<&mut HashSet<String>>,
is_qualified_wildcard: bool,
) -> Result<(), DatabaseError> {
let mut is_bound_alias = false;

let fn_used =
|column_name: &str, context: &BinderContext<T>, join_used: Option<&HashSet<_>>| {
context.using.contains(column_name)
&& matches!(join_used.map(|used| used.contains(column_name)), Some(true))
};
for (_, alias_expr) in context.expr_aliases.iter().filter(|(_, expr)| {
if let ScalarExpression::ColumnRef(col) = expr.unpack_alias_ref() {
let column_name = col.name();
let fn_not_on_using = |column: &ColumnRef| {
if context.using.is_empty() {
return Some(&table_name) == column.table_name();
}
is_qualified_wildcard
|| Some(&table_name) == column.table_name() && !context.using.contains(column)
};

if Some(&table_name) == col.table_name()
&& !fn_used(column_name, context, join_used.as_deref())
{
if let Some(used) = join_used.as_mut() {
used.insert(column_name.to_string());
let bound_alias = context
.expr_aliases
.iter()
.filter(|(_, expr)| {
if let ScalarExpression::ColumnRef(col) = expr.unpack_alias_ref() {
if fn_not_on_using(col) {
exprs.push(ScalarExpression::clone(expr));
return true;
}
return true;
}
}
false
}) {
is_bound_alias = true;
exprs.push(alias_expr.clone());
}
if is_bound_alias {
false
})
.count()
> 0;

if bound_alias {
return Ok(());
}

let mut source = None;

source = context.table(table_name.clone())?.map(Source::Table);
if source.is_none() {
source = context.view(table_name)?.map(Source::View);
source = context.view(table_name.clone())?.map(Source::View);
}
for column in source
.ok_or(DatabaseError::SourceNotFound)?
.columns(schema_buf)
{
let column_name = column.name();

if fn_used(column_name, context, join_used.as_deref()) {
if !fn_not_on_using(column) {
continue;
}
let expr = ScalarExpression::ColumnRef(column.clone());

if let Some(used) = join_used.as_mut() {
used.insert(column_name.to_string());
}
exprs.push(expr);
exprs.push(ScalarExpression::ColumnRef(column.clone()));
}
Ok(())
}
Expand Down Expand Up @@ -654,9 +641,12 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'
self.extend(binder.context);

let on = match joint_condition {
Some(constraint) => {
self.bind_join_constraint(left.output_schema(), right.output_schema(), constraint)?
}
Some(constraint) => self.bind_join_constraint(
join_type,
left.output_schema(),
right.output_schema(),
constraint,
)?,
None => JoinCondition::None,
};

Expand Down Expand Up @@ -902,6 +892,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'

fn bind_join_constraint<'c>(
&mut self,
join_type: JoinType,
left_schema: &'c SchemaRef,
right_schema: &'c SchemaRef,
constraint: &JoinConstraint,
Expand Down Expand Up @@ -938,24 +929,25 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'
})
}
JoinConstraint::Using(idents) => {
fn find_column<'a>(schema: &'a Schema, name: &'a str) -> Option<&'a ColumnRef> {
schema.iter().find(|column| column.name() == name)
}

let mut on_keys: Vec<(ScalarExpression, ScalarExpression)> = Vec::new();
let fn_column = |schema: &Schema, name: &str| {
schema
.iter()
.find(|column| column.name() == name)
.map(|column| ScalarExpression::ColumnRef(column.clone()))
};

for ident in idents {
let name = lower_ident(ident);
if let (Some(left_column), Some(right_column)) = (
fn_column(left_schema, &name),
fn_column(right_schema, &name),
) {
on_keys.push((left_column, right_column));
} else {
return Err(DatabaseError::InvalidColumn("not found column".to_string()))?;
}
self.context.add_using(name);
let (Some(left_column), Some(right_column)) = (
find_column(left_schema, &name),
find_column(right_schema, &name),
) else {
return Err(DatabaseError::InvalidColumn("not found column".to_string()));
};
self.context.add_using(join_type, left_column, right_column);
on_keys.push((
ScalarExpression::ColumnRef(left_column.clone()),
ScalarExpression::ColumnRef(right_column.clone()),
));
}
Ok(JoinCondition::On {
on: on_keys,
Expand All @@ -970,15 +962,15 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'
let mut on_keys: Vec<(ScalarExpression, ScalarExpression)> = Vec::new();

for name in fn_names(left_schema).intersection(&fn_names(right_schema)) {
self.context.add_using(name.to_string());
if let (Some(left_column), Some(right_column)) = (
left_schema.iter().find(|column| column.name() == *name),
right_schema.iter().find(|column| column.name() == *name),
) {
on_keys.push((
ScalarExpression::ColumnRef(left_column.clone()),
ScalarExpression::ColumnRef(right_column.clone()),
));
let left_expr = ScalarExpression::ColumnRef(left_column.clone());
let right_expr = ScalarExpression::ColumnRef(right_column.clone());

self.context.add_using(join_type, left_column, right_column);
on_keys.push((left_expr, right_expr));
}
}
Ok(JoinCondition::On {
Expand Down
7 changes: 7 additions & 0 deletions src/planner/operator/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ pub enum JoinType {
Full,
Cross,
}

impl JoinType {
pub fn is_right(&self) -> bool {
matches!(self, JoinType::RightOuter)
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash, ReferenceSerialization)]
pub enum JoinCondition {
On {
Expand Down
Loading
Loading