Skip to content

Commit

Permalink
Merge pull request #10 from amoffat/feature/fix-unqualified-columns
Browse files Browse the repository at this point in the history
Feature/fix unqualified columns
  • Loading branch information
amoffat committed Jul 16, 2023
2 parents d3b2279 + 9f75462 commit 4d3e998
Show file tree
Hide file tree
Showing 46 changed files with 524 additions and 238 deletions.
5 changes: 4 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,8 @@
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"editor.rulers": [88],
"notebook.formatOnSave.enabled": true
"notebook.formatOnSave.enabled": true,
"python.linting.flake8Enabled": false,
"python.linting.mypyEnabled": true,
"python.linting.enabled": true
}
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
# Changelog

## 0.3.0 - 7/15/23

- Autofix non-qualified column names
- Bugfix where aliases composed of multiple columns would not correctly resolve
- Allow customization of SQL dialect placeholder format
- Bugfix to allow forwards and backwards required conditions
- Add required whitespace to SQL grammars
- SQL pretty-printer should return everything on one line.
- Bugfix with Mysql prompt envelope mentioning sqlite

## 0.2.1 - 7/10/23

- Automated Github releases
Expand Down
24 changes: 23 additions & 1 deletion docs/source/reconstruction/sql.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,26 @@ will pass validation:

Only selected columns are examined for reconstruction in this soft-validation pass.
Other clauses of the query where columns are referenced, like the ``WHERE`` clause,
are left alone until the hard-validation pass, where they could cause a failure.
are left alone until the hard-validation pass, where they could cause a failure.

Fully-qualifying columns
************************

Despite efforts to convince the LLM to fully-qualify columns, it may still produce
queries that use columns that are not prefixed by a table name or table alias:

.. code-block:: sql
SELECT u.email
FROM users u
WHERE id=:user_id
The LLM typically produces these queries when the column is unambiguous because a single
table is being selected. In these cases, the reconstructor will fully-qualify the column
based on the name of the selected table, so that the above query becomes:

.. code-block:: sql
SELECT u.email
FROM users u
WHERE users.id=:user_id
23 changes: 20 additions & 3 deletions heimdallm/bifrost.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def traverse(
validator=validator.__class__.__name__,
)
try:
trusted_llm_output = self._try_validator(
trusted_llm_output, tree = self._try_validator(
log,
validator,
autofix,
Expand All @@ -131,6 +131,8 @@ def traverse(
log = log.bind(trusted=untrusted_llm_output)
log.info("Validation succeeded")

trusted_llm_output = self.post_transform(trusted_llm_output, tree)

return trusted_llm_output

def _try_validator(
Expand All @@ -140,7 +142,7 @@ def _try_validator(
autofix: bool,
untrusted_llm_output: str,
tree: ParseTree,
) -> str:
) -> tuple[str, ParseTree]:
"""Attempt validation with an individual constraint validator."""

if autofix:
Expand All @@ -163,7 +165,22 @@ def _try_validator(
validator.validate(self, untrusted_llm_output, tree)
log.info("Validation succeeded")

return untrusted_llm_output
return untrusted_llm_output, tree

def post_transform(self, trusted_llm_output: str, tree: ParseTree) -> str:
"""
A hook for subclasses to perform post-transformations on the trusted output.
This is useful for making adjustments that cannot be made during
:doc:`/reconstruction` because they would produce an output that is incompatible
with the grammar.
For example, replacing the generic ``:placeholder`` with the SQL-specific
placeholder fields (e.g. ``%(placeholder)s``) cannot be done in reconstruction
because it would conflict with the grammar. It needs to be done in a separate
step, after the input has been reconstruction and the constraint validators have
been satisfied.
"""
return trusted_llm_output

def parse(self, untrusted_llm_output: str) -> ParseTree:
"""Converts the :term:`LLM` output into a parse tree. Override it in a subclass
Expand Down
38 changes: 36 additions & 2 deletions heimdallm/bifrosts/sql/bifrost.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Callable, Sequence, Union
from typing import TYPE_CHECKING, Callable, Sequence, Union, cast

import lark
from lark import Lark, ParseTree
from lark import Lark, ParseTree, Token
from lark.exceptions import VisitError

from heimdallm.bifrost import Bifrost as _BaseBifrost
Expand Down Expand Up @@ -117,6 +117,40 @@ def parse(grammar: Lark, untrusted_query: str) -> ParseTree:

return parse

@classmethod
def placeholder(cls, name: str) -> str:
"""
For a given database and library, produce a placeholder for a named parameter.
This is needed because different databases and libraries have different formats
for their placeholder params. For example, sqlite is ``:param`` while mysql is
``%(param)s``.
:param name: The name of the placeholder to be replaced.
:return: The db-specific placeholder.
:meta private:
"""
return ":" + name

def post_transform(self, trusted_llm_output: str, tree: ParseTree) -> str:
placeholders = list(tree.find_data("placeholder"))

# reverse=True so we work backwords so we don't mess up the indices
placeholders.sort(key=lambda x: x.meta.start_pos, reverse=True)

def replace_slice(input_str, start, end, replacement):
return input_str[:start] + replacement + input_str[end:]

for placeholder in placeholders:
m = placeholder.meta
name = cast(Token, placeholder.children[0]).value
trusted_llm_output = replace_slice(
trusted_llm_output,
m.start_pos,
m.end_pos,
self.placeholder(name),
)
return trusted_llm_output

@staticmethod
@abstractmethod
def build_grammar() -> Lark:
Expand Down
4 changes: 2 additions & 2 deletions heimdallm/bifrosts/sql/envelopes/sql/mysql/select.j2
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Consider the following Mysql 8.0 SQL schema:
{%- endblock %}

{% block requester_context -%}
A requester has written a free-form query that they want translated into a SQL query, compatible with sqlite that operates on the above database schema.
A requester has written a free-form query that they want translated into a SQL query, compatible with Mysql 8.0 that operates on the above database schema.
{%- endblock %}

{% block requester_identities -%}
Expand All @@ -18,7 +18,7 @@ Assume that the requester that wrote the query has unrestricted access to all da
{%- endblock %}

{% block current_time -%}
If the current time is needed, assume that the placeholder `:timestamp` holds the current unix timestamp.
If the current time is needed, assume that the placeholder `:timestamp` holds the current unix timestamp as an integer. Use the necessary functions to convert it to a datetime in order to compare it to datetime columns.
{%- endblock %}

{% block delimiters -%}
Expand Down
5 changes: 5 additions & 0 deletions heimdallm/bifrosts/sql/mysql/select/bifrost.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,15 @@ def build_grammar() -> Lark:
grammar = Lark(
ambiguity="explicit",
maybe_placeholders=False,
propagate_positions=True,
grammar=h,
)
return grammar

@classmethod
def reserved_keywords(self) -> set[str]:
return presets.reserved_keywords

@classmethod
def placeholder(cls, name: str) -> str:
return f"%({name})s"
62 changes: 32 additions & 30 deletions heimdallm/bifrosts/sql/mysql/select/grammar.lark
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ unions : union+
union : UNION (ALL | DISTINCT)? select_statement

select_statement : \
SELECT DISTINCT? selected_columns \
FROM selected_table \
joins? \
where_clause? \
group_by_clause? \
having_clause? \
order_by_clause? \
SELECT _WS DISTINCT? selected_columns \
_WS FROM selected_table \
(_WS joins)? \
(_WS where_clause)? \
(_WS group_by_clause)? \
(_WS having_clause)? \
(_WS order_by_clause)? \
limit_placeholder \
SEMICOLON?

Expand All @@ -37,7 +37,7 @@ selected_columns : selected_column ("," selected_column)*
// a selected column can be a value, not just a column name. a value encapsulates
// fully qualified column names, and also functions that use column names
selected_column : aliased_column | COUNT_STAR | value | ALL_COLUMNS
aliased_column : (value | COUNT_STAR) as column_alias
aliased_column : (value | COUNT_STAR) as generic_alias
// this is a "fully-qualified" column name, meaning it just has a table name prefixing
// it. keep in mind that the table may be an alias.
fq_column : table_name "." column_name
Expand All @@ -48,7 +48,7 @@ column_alias : generic_alias
generic_alias : quoted_identifier | unquoted_identifier

joins : join+
join : join_type joined_table "on"i join_condition ("and"i join_condition)*
join : join_type _WS joined_table _WS "on"i join_condition (_WS "and"i join_condition)*
// note we do not allow outer joins. this is because we cannot control
// the rows that are returned by the outer join, and we cannot depend on the
// join conditions to restrict the returned rows.
Expand All @@ -61,22 +61,22 @@ join_condition : connecting_join_condition | required_comparison
connecting_join_condition : fq_column EQUI_JOIN value
joined_table : aliased_table | table_name

group_by_clause : GROUP_BY group_by_column ("," group_by_column)*
group_by_clause : GROUP_BY _WS group_by_column ("," group_by_column)*
group_by_column : value

having_clause : HAVING having_condition
having_clause : HAVING _WS having_condition
having_condition : value comparison value

order_by_clause : ORDER_BY order_column ("," order_column)*
order_column : (COUNT_STAR | value) SORT_ORDER?
order_by_clause : ORDER_BY _WS order_column ("," order_column)*
order_column : (COUNT_STAR | value) (_WS SORT_ORDER)?
SORT_ORDER : ASC | DESC

limit_clause : LIMIT (((offset ",")? limit) | (limit OFFSET offset))
limit_clause : LIMIT (_WS ((offset ",")? limit) | (limit _WS OFFSET _WS offset))
limit : NUMBER
offset : NUMBER

where_clause : WHERE where_conditions
where_conditions : where_condition (WHERE_TYPE where_condition)*
where_clause : WHERE _WS where_conditions
where_conditions : where_condition (_WS WHERE_TYPE _WS where_condition)*
where_condition : \
| required_comparison
| relational_comparison
Expand All @@ -86,7 +86,9 @@ where_condition : \

// a required comparison is a parameterized, equality-based comparison that
// enforces a constraint on the returned rows.
required_comparison.1 : (fq_column | column_alias) "=" placeholder
required_comparison.1 : lhs_req_comparison | rhs_req_comparison
lhs_req_comparison : (fq_column | column_alias) "=" placeholder
rhs_req_comparison : placeholder "=" (fq_column | column_alias)
relational_comparison : value comparison value
comparison : \
EQ
Expand All @@ -99,9 +101,9 @@ comparison : \
| IS NOT?
| SOUNDS_LIKE
| NOT? (REGEXP | RLIKE)
in_comparison : value NOT? IN in_list
in_comparison : value _WS NOT? _WS IN _WS in_list
in_list : "(" value ("," value)* ")"
between_comparison : value NOT? BETWEEN value AND value
between_comparison : value _WS NOT? _WS BETWEEN _WS value _WS AND _WS value

// for everywhere but the SELECT clause, because that's where column aliases
// are declared, so we cannot use this there
Expand All @@ -120,7 +122,7 @@ wrapped_value : LPAREN value RPAREN

// a function with any number of arguments
function : function_name "(" \
AGG_FN_MODIFIER? (value ("," value)*)? \
(AGG_FN_MODIFIER _WS)? (value ("," value)*)? \
")"
function_name : /[a-zA-Z_]+/

Expand All @@ -130,7 +132,6 @@ function_name : /[a-zA-Z_]+/
// a placeholder for a value passed in as a parameter at query execution time
placeholder: ":" IDENTIFIER


SELECT : "select"i
FROM : "from"i
DISTINCT : "distinct"i
Expand All @@ -140,22 +141,22 @@ SEMICOLON : ";"
ALL_COLUMNS : "*"
WHERE : "where"i
WHERE_TYPE : AND | OR
ORDER_BY : "order"i WS "by"i
ORDER_BY : "order"i _WS "by"i
LIMIT : "limit"i
OFFSET : "offset"i
GROUP_BY : "group"i WS "by"i
GROUP_BY : "group"i _WS "by"i
HAVING : "having"i

// this is the only way that a query can have a "*" column spec, because it
// doesn't reveal any information that would be restricted.
COUNT_STAR : "count"i "(" AGG_FN_MODIFIER? ("*" | "1") ")"
COUNT_STAR : "count"i "(" (AGG_FN_MODIFIER _WS)? ("*" | "1") ")"
AGG_FN_MODIFIER : DISTINCT | ALL

INNER_JOIN : ("inner"i WS)? "join"i
CROSS_JOIN : "cross"i WS "join"i
NATURAL_JOIN : "natural"i WS (("left"i | "right"i | "full"i | "inner"i) WS)? "join"i
OUTER_JOIN : ("left"i | "right"i | "full"i) WS ("outer"i WS)? "join"i
NATURAL_OUTER_JOIN : "natural"i WS (("left"i | "right"i | "full"i) WS)? "outer"i WS "join"i
INNER_JOIN : ("inner"i _WS)? "join"i
CROSS_JOIN : "cross"i _WS "join"i
NATURAL_JOIN : "natural"i _WS (("left"i | "right"i | "full"i | "inner"i) _WS)? "join"i
OUTER_JOIN : ("left"i | "right"i | "full"i) _WS ("outer"i _WS)? "join"i
NATURAL_OUTER_JOIN : "natural"i _WS (("left"i | "right"i | "full"i) _WS)? "outer"i _WS "join"i

ASC : "asc"i
DESC : "desc"i
Expand All @@ -181,7 +182,7 @@ GTE : ">="
IS : "is"i
BETWEEN : "between"i
IN : "in"i
SOUNDS_LIKE : "sounds"i WS "like"i
SOUNDS_LIKE : "sounds"i _WS "like"i
LIKE : "like"i
NOT : "not"i
REGEXP : "regexp"i
Expand Down Expand Up @@ -211,6 +212,7 @@ ARITH_OP : \

// NOTE this is dialect specific
ESCAPED_STRING : "'" _STRING_ESC_INNER "'" | "\"" _STRING_ESC_INNER "\""
_WS : WS

%import common.NUMBER
%import common.DIGIT
Expand Down

0 comments on commit 4d3e998

Please sign in to comment.