Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: helper functions for RLS #19055

Merged
merged 9 commits into from
Mar 11, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 177 additions & 0 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
IdentifierList,
Parenthesis,
remove_quotes,
Statement,
Token,
TokenList,
Where,
)
from sqlparse.tokens import (
CTE,
Expand Down Expand Up @@ -458,3 +460,178 @@ def validate_filter_clause(clause: str) -> None:
)
if open_parens > 0:
raise QueryClauseValidationException("Unclosed parenthesis in filter clause")


def has_table_query(statement: Statement) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@betodealmeida there's also this example which has logic for identifying tables.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember the details, but I've had issues with that example code before — I think it failed to identify table names when they were considered keywords (even though the example calls it out).

"""
Return if a stament has a query reading from a table.

>>> has_table_query(sqlparse.parse("COUNT(*)")[0])
False
>>> has_table_query(sqlparse.parse("SELECT * FROM table")[0])
True

Note that queries reading from constant values return false:

>>> has_table_query(sqlparse.parse("SELECT * FROM (SELECT 1)")[0])
False

"""
seen_source = False
tokens = statement.tokens[:]
while tokens:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You likely could just do for token in stmt.flatten(): and remove the logic from lines 483–485.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.flatten() is a bit different in that it returns the leaf nodes only, converting an identifier into 1+ Name tokens:

>>> list(sqlparse.parse('SELECT * FROM my_table')[0].flatten())
[<DML 'SELECT' at 0x10FF019A0>, <Whitespace ' ' at 0x10FF01D00>, <Wildcard '*' at 0x10FF01D60>, <Whitespace ' ' at 0x10FF01DC0>, <Keyword 'FROM' at 0x10FF01E20>, <Whitespace ' ' at 0x10FF01E80>, <Name 'my_tab...' at 0x10FF01EE0>]

Since I'm looking for identifiers after a FROM or JOIN I thought it was easier to implement a traversal logic that actually inspects the parents, not just the leaves.

token = tokens.pop(0)
if isinstance(token, TokenList):
tokens.extend(token.tokens)

if token.ttype == Keyword and token.value.lower() in ("from", "join"):
betodealmeida marked this conversation as resolved.
Show resolved Hide resolved
seen_source = True
elif seen_source and (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The challenge here is there's no strong connection to ensure that the consecutive (or near consecutive) tokens are those which are being identified here. I guess the question is how robust do we want this logic. The proposed solution may well we suffice.

The correct way of doing this is more of a tree traversal (as opposed to a flattened list) where one checks the next token (which could be a group) from the FROM or JOIN keyword and iterate from there.

My sense is that can likely be addressed later. We probably need to cleanup the sqlparse logic to junk it completely in favor of something else given that it seems like the package is somewhat on life support.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, in the insert_rls function I had to implement tree traversal to get it right. Let me give it a try rewriting this one.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@john-bodley I reimplemented it following the same logic as insert_rls (recursive tree traversal instead of flattening).

isinstance(token, sqlparse.sql.Identifier) or token.ttype == Keyword
):
return True
elif seen_source and token.ttype not in (Whitespace, Punctuation):
seen_source = False

return False


def add_table_name(rls: TokenList, table: str) -> None:
betodealmeida marked this conversation as resolved.
Show resolved Hide resolved
"""
Modify a RLS expression ensuring columns are fully qualified.
"""
tokens = rls.tokens[:]
while tokens:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You likely could use flatten here. It uses a generator so likely a copy should be made given you're mutating the tokens, i.e.,

for token in list(rls.flatten()):
    if imt(token, i=Identifier) and token.get_parent_name() is None:
        ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue, if we call .flatten() we would never get an Identifier.

token = tokens.pop(0)

if isinstance(token, Identifier) and token.get_parent_name() is None:
token.tokens = [
Token(Name, table),
Token(Punctuation, "."),
Token(Name, token.get_name()),
]
elif isinstance(token, TokenList):
tokens.extend(token.tokens)


class InsertRLSState(str, Enum):
"""
State machine that scans for WHERE and ON clauses referencing tables.
"""

SCANNING = "SCANNING"
SEEN_SOURCE = "SEEN_SOURCE"
FOUND_TABLE = "FOUND_TABLE"


def matches_table_name(token: Token, table: str) -> bool:
"""
Return the name of a table.
betodealmeida marked this conversation as resolved.
Show resolved Hide resolved

A table should be represented as an identifier, but due to sqlparse's aggressive list
of keywords (spanning multiple dialects) often it gets classified as a keyword.
"""
candidate = token.value

# match from right to left, splitting on the period, eg, schema.table == table
for left, right in zip(candidate.split(".")[::-1], table.split(".")[::-1]):
betodealmeida marked this conversation as resolved.
Show resolved Hide resolved
if left != right:
return False

return True


def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList:
"""
Update a statement inpalce applying an RLS associated with a given table.
betodealmeida marked this conversation as resolved.
Show resolved Hide resolved
"""
# make sure the identifier has the table name
add_table_name(rls, table)

state = InsertRLSState.SCANNING
for token in token_list.tokens:

# Recurse into child token list
if isinstance(token, TokenList):
i = token_list.tokens.index(token)
token_list.tokens[i] = insert_rls(token, table, rls)

# Found a source keyword (FROM/JOIN)
if token.ttype == Keyword and token.value.lower() in ("from", "join"):
state = InsertRLSState.SEEN_SOURCE

# Found identifier/keyword after FROM/JOIN, test for table
elif state == InsertRLSState.SEEN_SOURCE and (
isinstance(token, Identifier) or token.ttype == Keyword
):
if matches_table_name(token, table):
state = InsertRLSState.FOUND_TABLE

# found table at the end of the statement; append a WHERE clause
betodealmeida marked this conversation as resolved.
Show resolved Hide resolved
if token == token_list[-1]:
token_list.tokens.extend(
[
Token(Whitespace, " "),
Where(
[Token(Keyword, "WHERE"), Token(Whitespace, " "), rls]
),
]
)
return token_list

# Found WHERE clause, insert RLS if not present
elif state == InsertRLSState.FOUND_TABLE and isinstance(token, Where):
if str(rls) not in {str(t) for t in token.tokens}:
token.tokens.extend(
[
Token(Whitespace, " "),
Token(Keyword, "AND"),
Token(Whitespace, " "),
]
+ rls.tokens
)
state = InsertRLSState.SCANNING

# Found ON clause, insert RLS if not present
betodealmeida marked this conversation as resolved.
Show resolved Hide resolved
elif (
state == InsertRLSState.FOUND_TABLE
and token.ttype == Keyword
and token.value.upper() == "ON"
):
i = token_list.tokens.index(token)
token.parent.tokens[i + 1 : i + 1] = [
Token(Whitespace, " "),
rls,
Token(Whitespace, " "),
Token(Keyword, "AND"),
]
state = InsertRLSState.SCANNING

# Found table but no WHERE clause found, insert one
elif state == InsertRLSState.FOUND_TABLE and token.ttype != Whitespace:
i = token_list.tokens.index(token)

# Left pad with space, if needed
if i > 0 and token_list.tokens[i - 1].ttype != Whitespace:
token_list.tokens.insert(i, Token(Whitespace, " "))
i += 1

# Insert predicate
token_list.tokens.insert(
i, Where([Token(Keyword, "WHERE"), Token(Whitespace, " "), rls]),
)

# Right pad with space, if needed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does sqlparse even tokenize whitespace?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's because it makes it easier to convert the parse tree back to a string. Not sure.

if (
i < len(token_list.tokens) - 2
and token_list.tokens[i + 2] != Whitespace
):
token_list.tokens.insert(i + 1, Token(Whitespace, " "))

state = InsertRLSState.SCANNING

# Found nothing, leaving source
elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace:
state = InsertRLSState.SCANNING

return token_list
158 changes: 156 additions & 2 deletions tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

# pylint: disable=invalid-name
# pylint: disable=invalid-name, too-many-lines

import unittest
from typing import Set
Expand All @@ -25,6 +25,8 @@

from superset.exceptions import QueryClauseValidationException
from superset.sql_parse import (
has_table_query,
insert_rls,
ParsedQuery,
strip_comments_from_sql,
Table,
Expand Down Expand Up @@ -1111,7 +1113,8 @@ def test_sqlparse_formatting():

"""
assert sqlparse.format(
"SELECT extract(HOUR from from_unixtime(hour_ts) AT TIME ZONE 'America/Los_Angeles') from table",
"SELECT extract(HOUR from from_unixtime(hour_ts) "
"AT TIME ZONE 'America/Los_Angeles') from table",
reindent=True,
) == (
"SELECT extract(HOUR\n from from_unixtime(hour_ts) "
Expand Down Expand Up @@ -1189,3 +1192,154 @@ def test_sqlparse_issue_652():
stmt = sqlparse.parse(r"foo = '\' AND bar = 'baz'")[0]
assert len(stmt.tokens) == 5
assert str(stmt.tokens[0]) == "foo = '\\'"


@pytest.mark.parametrize(
"sql,expected",
[
("SELECT * FROM table", True),
("SELECT a FROM (SELECT 1 AS a) JOIN (SELECT * FROM table)", True),
("(SELECT COUNT(DISTINCT name) AS foo FROM birth_names)", True),
("COUNT(*)", False),
("SELECT a FROM (SELECT 1 AS a)", False),
("SELECT a FROM (SELECT 1 AS a) JOIN table", True),
("SELECT * FROM (SELECT 1 AS foo, 2 AS bar) ORDER BY foo ASC, bar", False),
("SELECT * FROM other_table", True),
],
)
betodealmeida marked this conversation as resolved.
Show resolved Hide resolved
def test_has_table_query(sql: str, expected: bool) -> None:
"""
Test if a given statement queries a table.

This is used to prevent ad-hoc metrics from querying unauthorized tables, bypassing
row-level security.
"""
statement = sqlparse.parse(sql)[0]
assert has_table_query(statement) == expected


@pytest.mark.parametrize(
"sql,table,rls,expected",
[
# append RLS to an existing WHERE clause
(
"SELECT * FROM other_table WHERE 1=1",
"other_table",
"id=42",
"SELECT * FROM other_table WHERE 1=1 AND other_table.id=42",
),
# "table" is a reserved word; since sqlparse is too aggressive when characterizing
# reserved words we need to support them even when not quoted
betodealmeida marked this conversation as resolved.
Show resolved Hide resolved
(
"SELECT * FROM table WHERE 1=1",
"table",
"id=42",
"SELECT * FROM table WHERE 1=1 AND table.id=42",
),
# RLS applies to a different table
(
"SELECT * FROM table WHERE 1=1",
"other_table",
"id=42",
"SELECT * FROM table WHERE 1=1",
),
(
"SELECT * FROM other_table WHERE 1=1",
"table",
"id=42",
"SELECT * FROM other_table WHERE 1=1",
),
# insert the WHERE clause if there isn't one
(
"SELECT * FROM table",
"table",
"id=42",
"SELECT * FROM table WHERE table.id=42",
),
(
"SELECT * FROM other_table",
"other_table",
"id=42",
"SELECT * FROM other_table WHERE other_table.id=42",
),
(
"SELECT * FROM table ORDER BY id",
"table",
"id=42",
"SELECT * FROM table WHERE table.id=42 ORDER BY id",
),
# do not add RLS if already present...
(
"SELECT * FROM table WHERE 1=1 AND table.id=42",
"table",
"id=42",
"SELECT * FROM table WHERE 1=1 AND table.id=42",
),
# ...but when in doubt add it
(
"SELECT * FROM table WHERE 1=1 AND id=42",
"table",
"id=42",
"SELECT * FROM table WHERE 1=1 AND id=42 AND table.id=42",
),
# test with joins
(
"SELECT * FROM table JOIN other_table ON table.id = other_table.id",
"other_table",
"id=42",
(
"SELECT * FROM table JOIN other_table ON other_table.id=42 "
"AND table.id = other_table.id"
),
),
# test with inner selects
(
"SELECT * FROM (SELECT * FROM other_table)",
"other_table",
"id=42",
"SELECT * FROM (SELECT * FROM other_table WHERE other_table.id=42)",
),
# union
(
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
"table",
"id=42",
"SELECT * FROM table WHERE table.id=42 UNION ALL SELECT * FROM other_table",
),
(
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
"other_table",
"id=42",
(
"SELECT * FROM table UNION ALL "
"SELECT * FROM other_table WHERE other_table.id=42"
),
),
# fully qualified table names
(
"SELECT * FROM schema.table_name",
"table_name",
"id=42",
"SELECT * FROM schema.table_name WHERE table_name.id=42",
),
(
"SELECT * FROM schema.table_name",
"schema.table_name",
"id=42",
"SELECT * FROM schema.table_name WHERE schema.table_name.id=42",
),
(
"SELECT * FROM table_name",
"schema.table_name",
"id=42",
"SELECT * FROM table_name WHERE schema.table_name.id=42",
),
],
)
def test_insert_rls(sql, table, rls, expected) -> None:
"""
Insert into a statement a given RLS condition associated with a table.
"""
statement = sqlparse.parse(sql)[0]
condition = sqlparse.parse(rls)[0]
assert str(insert_rls(statement, table, condition)).strip() == expected.strip()