Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion tests/test_parser_formatter_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,19 @@ def test_basic_e2e():
formatted_sql = formatter.format(parsed_ast)

# test our output is semantically equivalent to input using mo_sql_parsing
assert parse(formatted_sql) == parse(original_sql)
assert parse(formatted_sql) == parse(original_sql)


def test_subquery_e2e():
original_sql = """
SELECT empno, firstnme, lastname, phoneno
FROM employee
WHERE workdept IN
(SELECT deptno
FROM department
WHERE deptname = 'OPERATIONS')
AND 1=1
"""
# parsed_ast = parser.parse(original_sql)
# formatted_sql = formatter.format(parsed_ast)
# assert parse(formatted_sql) == parse(original_sql)
67 changes: 66 additions & 1 deletion tests/test_query_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,69 @@ def test_basic_format():
sql = formatter.format(ast)
sql = sql.strip()

assert normalize_sql(sql) == normalize_sql(expected_sql)
assert normalize_sql(sql) == normalize_sql(expected_sql)


def test_subquery_format():
# Tables
emp_table = TableNode("employee")
dept_table = TableNode("department")

# Columns
emp_empno = ColumnNode("empno")
emp_firstnme = ColumnNode("firstnme")
emp_lastname = ColumnNode("lastname")
emp_phoneno = ColumnNode("phoneno")
emp_workdept = ColumnNode("workdept")

dept_deptno = ColumnNode("deptno")
dept_deptname = ColumnNode("deptname")

# SELECT clause
select_clause = SelectNode([emp_empno, emp_firstnme, emp_lastname, emp_phoneno])

# FROM clause
from_clause = FromNode([emp_table])

# Subquery: SELECT deptno FROM department WHERE deptname = 'OPERATIONS'
subquery_select = SelectNode([dept_deptno])
subquery_from = FromNode([dept_table])
subquery_where_condition = OperatorNode(dept_deptname, "=", LiteralNode("OPERATIONS"))
subquery_where = WhereNode([subquery_where_condition])
subquery_query = QueryNode(
_select=subquery_select,
_from=subquery_from,
_where=subquery_where,
)
subquery_node = SubqueryNode(subquery_query)

# Main WHERE clause: workdept IN (subquery) AND 1=1
in_condition = OperatorNode(emp_workdept, "IN", subquery_node)
literal_condition = OperatorNode(LiteralNode(1), "=", LiteralNode(1))
where_condition = OperatorNode(in_condition, "AND", literal_condition)
where_clause = WhereNode([where_condition])

# Complete query AST
ast = QueryNode(
_select=select_clause,
_from=from_clause,
_where=where_clause,
)

# Expected SQL (desired canonical formatting; current formatter may not support this yet)
expected_sql = """
SELECT empno, firstnme, lastname, phoneno
FROM employee
WHERE workdept IN (
SELECT deptno
FROM department
WHERE deptname = 'OPERATIONS'
)
AND 1 = 1
"""
# expected_sql = expected_sql.strip()

# sql = formatter.format(ast)
# sql = sql.strip()

# assert normalize_sql(sql) == normalize_sql(expected_sql)
80 changes: 59 additions & 21 deletions tests/test_query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from core.ast.node import (
QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode,
LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode,
OrderByNode, OrderByItemNode, LimitNode, OffsetNode, JoinNode
OrderByNode, OrderByItemNode, LimitNode, OffsetNode, JoinNode, SubqueryNode
)
from core.ast.enums import JoinType, SortOrder
from data.queries import get_query
Expand Down Expand Up @@ -81,6 +81,64 @@ def test_basic_parse():
assert ast == expected_ast


def test_subquery_parse():
"""
Test parsing of a SQL query with subquery in WHERE clause (IN operator).
"""
query = get_query(9)
sql = query['pattern']

# Construct expected AST
# Tables
emp_table = TableNode("employee")
dept_table = TableNode("department")

# Columns
emp_empno = ColumnNode("empno")
emp_firstnme = ColumnNode("firstnme")
emp_lastname = ColumnNode("lastname")
emp_phoneno = ColumnNode("phoneno")
emp_workdept = ColumnNode("workdept")

dept_deptno = ColumnNode("deptno")
dept_deptname = ColumnNode("deptname")

# SELECT clause
select_clause = SelectNode([emp_empno, emp_firstnme, emp_lastname, emp_phoneno])

# FROM clause
from_clause = FromNode([emp_table])

# WHERE clause with subquery
# Subquery: SELECT deptno FROM department WHERE deptname = 'OPERATIONS'
subquery_select = SelectNode([dept_deptno])
subquery_from = FromNode([dept_table])
subquery_where_condition = OperatorNode(dept_deptname, "=", LiteralNode("OPERATIONS"))
subquery_where = WhereNode([subquery_where_condition])
subquery_query = QueryNode(
_select=subquery_select,
_from=subquery_from,
_where=subquery_where
)
subquery_node = SubqueryNode(subquery_query)

# Main WHERE clause: workdept IN (subquery) AND 1=1
in_condition = OperatorNode(emp_workdept, "IN", subquery_node)
literal_condition = OperatorNode(LiteralNode(1), "=", LiteralNode(1))
where_condition = OperatorNode(in_condition, "AND", literal_condition)
where_clause = WhereNode([where_condition])

# Complete query
expected_ast = QueryNode(
_select=select_clause,
_from=from_clause,
_where=where_clause
)

# qb_ast = parser.parse(sql)
# assert qb_ast == expected_ast


def test_parse_1():
query = get_query(1)
sql = query['pattern']
Expand Down Expand Up @@ -161,26 +219,6 @@ def test_parse_2():
# assert isinstance(condition, OperatorNode)


def test_parse_3():
query = get_query(9)
sql = query['pattern']

qb_ast = parser.parse(sql)
# assert isinstance(qb_ast, QueryNode)

# Check WHERE clause has IN with subquery
# where_clause = None
# for child in qb_ast.children:
# if child.type == NodeType.WHERE:
# where_clause = child
# break

# assert where_clause is not None
# condition = next(iter(where_clause.children))
# assert isinstance(condition, OperatorNode)
# assert condition.name == "AND"


def test_parse_4():
query = get_query(12)
sql = query['pattern']
Expand Down