From 8234395466d0edaead3c36b57fd0b81621cc1c5c Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Fri, 11 Mar 2022 14:47:11 -0800 Subject: [PATCH] feat: helper functions for RLS (#19055) * feat: helper functions for RLS * Add function to inject RLS * Add UNION tests * Add tests for schema * Add more tests; cleanup * has_table_query via tree traversal * Wrap existing predicate in parenthesis * Clean up logic * Improve table matching --- superset/sql_parse.py | 202 +++++++++++++++++++++++ tests/unit_tests/sql_parse_tests.py | 247 +++++++++++++++++++++++++++- 2 files changed, 447 insertions(+), 2 deletions(-) diff --git a/superset/sql_parse.py b/superset/sql_parse.py index b5b614cf25ac..f5523bab71e8 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -29,6 +29,7 @@ remove_quotes, Token, TokenList, + Where, ) from sqlparse.tokens import ( CTE, @@ -458,3 +459,204 @@ def validate_filter_clause(clause: str) -> None: ) if open_parens > 0: raise QueryClauseValidationException("Unclosed parenthesis in filter clause") + + +class InsertRLSState(str, Enum): + """ + State machine that scans for WHERE and ON clauses referencing tables. + """ + + SCANNING = "SCANNING" + SEEN_SOURCE = "SEEN_SOURCE" + FOUND_TABLE = "FOUND_TABLE" + + +def has_table_query(token_list: TokenList) -> bool: + """ + Return if a stament has a query reading from a table. + + >>> has_table_query(sqlparse.parse("COUNT(*)")[0]) + False + >>> has_table_query(sqlparse.parse("SELECT * FROM table")[0]) + True + + Note that queries reading from constant values return false: + + >>> has_table_query(sqlparse.parse("SELECT * FROM (SELECT 1)")[0]) + False + + """ + state = InsertRLSState.SCANNING + for token in token_list.tokens: + + # # Recurse into child token list + if isinstance(token, TokenList) and has_table_query(token): + return True + + # Found a source keyword (FROM/JOIN) + if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]): + state = InsertRLSState.SEEN_SOURCE + + # Found identifier/keyword after FROM/JOIN + elif state == InsertRLSState.SEEN_SOURCE and ( + isinstance(token, sqlparse.sql.Identifier) or token.ttype == Keyword + ): + return True + + # Found nothing, leaving source + elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace: + state = InsertRLSState.SCANNING + + return False + + +def add_table_name(rls: TokenList, table: str) -> None: + """ + Modify a RLS expression ensuring columns are fully qualified. + """ + tokens = rls.tokens[:] + while tokens: + token = tokens.pop(0) + + if isinstance(token, Identifier) and token.get_parent_name() is None: + token.tokens = [ + Token(Name, table), + Token(Punctuation, "."), + Token(Name, token.get_name()), + ] + elif isinstance(token, TokenList): + tokens.extend(token.tokens) + + +def matches_table_name(candidate: Token, table: str) -> bool: + """ + Returns if the token represents a reference to the table. + + Tables can be fully qualified with periods. + + Note that in theory a table should be represented as an identifier, but due to + sqlparse's aggressive list of keywords (spanning multiple dialects) often it gets + classified as a keyword. + """ + if not isinstance(candidate, Identifier): + candidate = Identifier([Token(Name, candidate.value)]) + + target = sqlparse.parse(table)[0].tokens[0] + if not isinstance(target, Identifier): + target = Identifier([Token(Name, target.value)]) + + # match from right to left, splitting on the period, eg, schema.table == table + for left, right in zip(candidate.tokens[::-1], target.tokens[::-1]): + if left.value != right.value: + return False + + return True + + +def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList: + """ + Update a statement inplace applying an RLS associated with a given table. + """ + # make sure the identifier has the table name + add_table_name(rls, table) + + state = InsertRLSState.SCANNING + for token in token_list.tokens: + + # Recurse into child token list + if isinstance(token, TokenList): + i = token_list.tokens.index(token) + token_list.tokens[i] = insert_rls(token, table, rls) + + # Found a source keyword (FROM/JOIN) + if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]): + state = InsertRLSState.SEEN_SOURCE + + # Found identifier/keyword after FROM/JOIN, test for table + elif state == InsertRLSState.SEEN_SOURCE and ( + isinstance(token, Identifier) or token.ttype == Keyword + ): + if matches_table_name(token, table): + state = InsertRLSState.FOUND_TABLE + + # Found WHERE clause, insert RLS. Note that we insert it even it already exists, + # to be on the safe side: it could be present in a clause like `1=1 OR RLS`. + elif state == InsertRLSState.FOUND_TABLE and isinstance(token, Where): + token.tokens[1:1] = [Token(Whitespace, " "), Token(Punctuation, "(")] + token.tokens.extend( + [ + Token(Punctuation, ")"), + Token(Whitespace, " "), + Token(Keyword, "AND"), + Token(Whitespace, " "), + ] + + rls.tokens + ) + state = InsertRLSState.SCANNING + + # Found ON clause, insert RLS. The logic for ON is more complicated than the logic + # for WHERE because in the former the comparisons are siblings, while on the + # latter they are children. + elif ( + state == InsertRLSState.FOUND_TABLE + and token.ttype == Keyword + and token.value.upper() == "ON" + ): + tokens = [ + Token(Whitespace, " "), + rls, + Token(Whitespace, " "), + Token(Keyword, "AND"), + Token(Whitespace, " "), + Token(Punctuation, "("), + ] + i = token_list.tokens.index(token) + token.parent.tokens[i + 1 : i + 1] = tokens + i += len(tokens) + 2 + + # close parenthesis after last existing comparison + j = 0 + for j, sibling in enumerate(token_list.tokens[i:]): + # scan until we hit a non-comparison keyword (like ORDER BY) or a WHERE + if ( + sibling.ttype == Keyword + and not imt( + sibling, m=[(Keyword, "AND"), (Keyword, "OR"), (Keyword, "NOT")] + ) + or isinstance(sibling, Where) + ): + j -= 1 + break + token.parent.tokens[i + j + 1 : i + j + 1] = [ + Token(Whitespace, " "), + Token(Punctuation, ")"), + Token(Whitespace, " "), + ] + + state = InsertRLSState.SCANNING + + # Found table but no WHERE clause found, insert one + elif state == InsertRLSState.FOUND_TABLE and token.ttype != Whitespace: + i = token_list.tokens.index(token) + token_list.tokens[i:i] = [ + Token(Whitespace, " "), + Where([Token(Keyword, "WHERE"), Token(Whitespace, " "), rls]), + Token(Whitespace, " "), + ] + + state = InsertRLSState.SCANNING + + # Found nothing, leaving source + elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace: + state = InsertRLSState.SCANNING + + # found table at the end of the statement; append a WHERE clause + if state == InsertRLSState.FOUND_TABLE: + token_list.tokens.extend( + [ + Token(Whitespace, " "), + Where([Token(Keyword, "WHERE"), Token(Whitespace, " "), rls]), + ] + ) + + return token_list diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 9026eab212ac..aa811bdef757 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name +# pylint: disable=invalid-name, too-many-lines import unittest from typing import Set @@ -25,6 +25,10 @@ from superset.exceptions import QueryClauseValidationException from superset.sql_parse import ( + add_table_name, + has_table_query, + insert_rls, + matches_table_name, ParsedQuery, strip_comments_from_sql, Table, @@ -1111,7 +1115,8 @@ def test_sqlparse_formatting(): """ assert sqlparse.format( - "SELECT extract(HOUR from from_unixtime(hour_ts) AT TIME ZONE 'America/Los_Angeles') from table", + "SELECT extract(HOUR from from_unixtime(hour_ts) " + "AT TIME ZONE 'America/Los_Angeles') from table", reindent=True, ) == ( "SELECT extract(HOUR\n from from_unixtime(hour_ts) " @@ -1189,3 +1194,241 @@ def test_sqlparse_issue_652(): stmt = sqlparse.parse(r"foo = '\' AND bar = 'baz'")[0] assert len(stmt.tokens) == 5 assert str(stmt.tokens[0]) == "foo = '\\'" + + +@pytest.mark.parametrize( + "sql,expected", + [ + ("SELECT * FROM table", True), + ("SELECT a FROM (SELECT 1 AS a) JOIN (SELECT * FROM table)", True), + ("(SELECT COUNT(DISTINCT name) AS foo FROM birth_names)", True), + ("COUNT(*)", False), + ("SELECT a FROM (SELECT 1 AS a)", False), + ("SELECT a FROM (SELECT 1 AS a) JOIN table", True), + ("SELECT * FROM (SELECT 1 AS foo, 2 AS bar) ORDER BY foo ASC, bar", False), + ("SELECT * FROM other_table", True), + ("extract(HOUR from from_unixtime(hour_ts)", False), + ], +) +def test_has_table_query(sql: str, expected: bool) -> None: + """ + Test if a given statement queries a table. + + This is used to prevent ad-hoc metrics from querying unauthorized tables, bypassing + row-level security. + """ + statement = sqlparse.parse(sql)[0] + assert has_table_query(statement) == expected + + +@pytest.mark.parametrize( + "sql,table,rls,expected", + [ + # Basic test: append RLS (some_table.id=42) to an existing WHERE clause. + ( + "SELECT * FROM some_table WHERE 1=1", + "some_table", + "id=42", + "SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42", + ), + # Any existing predicates MUST to be wrapped in parenthesis because AND has higher + # precedence than OR. If the RLS it `1=0` and we didn't add parenthesis a user + # could bypass it by crafting a query with `WHERE TRUE OR FALSE`, since + # `WHERE TRUE OR FALSE AND 1=0` evaluates to `WHERE TRUE OR (FALSE AND 1=0)`. + ( + "SELECT * FROM some_table WHERE TRUE OR FALSE", + "some_table", + "1=0", + "SELECT * FROM some_table WHERE ( TRUE OR FALSE) AND 1=0", + ), + # Here "table" is a reserved word; since sqlparse is too aggressive when + # characterizing reserved words we need to support them even when not quoted. + ( + "SELECT * FROM table WHERE 1=1", + "table", + "id=42", + "SELECT * FROM table WHERE ( 1=1) AND table.id=42", + ), + # RLS is only applied to queries reading from the associated table. + ( + "SELECT * FROM table WHERE 1=1", + "other_table", + "id=42", + "SELECT * FROM table WHERE 1=1", + ), + ( + "SELECT * FROM other_table WHERE 1=1", + "table", + "id=42", + "SELECT * FROM other_table WHERE 1=1", + ), + # If there's no pre-existing WHERE clause we create one. + ( + "SELECT * FROM table", + "table", + "id=42", + "SELECT * FROM table WHERE table.id=42", + ), + ( + "SELECT * FROM some_table", + "some_table", + "id=42", + "SELECT * FROM some_table WHERE some_table.id=42", + ), + ( + "SELECT * FROM table ORDER BY id", + "table", + "id=42", + "SELECT * FROM table WHERE table.id=42 ORDER BY id", + ), + ( + "SELECT * FROM some_table;", + "some_table", + "id=42", + "SELECT * FROM some_table WHERE some_table.id=42 ;", + ), + ( + "SELECT * FROM some_table ;", + "some_table", + "id=42", + "SELECT * FROM some_table WHERE some_table.id=42 ;", + ), + ( + "SELECT * FROM some_table ", + "some_table", + "id=42", + "SELECT * FROM some_table WHERE some_table.id=42", + ), + # We add the RLS even if it's already present, to be conservative. It should have + # no impact on the query, and it's easier than testing if the RLS is already + # present (it could be present in an OR clause, eg). + ( + "SELECT * FROM table WHERE 1=1 AND table.id=42", + "table", + "id=42", + "SELECT * FROM table WHERE ( 1=1 AND table.id=42) AND table.id=42", + ), + ( + ( + "SELECT * FROM table JOIN other_table ON " + "table.id = other_table.id AND other_table.id=42" + ), + "other_table", + "id=42", + ( + "SELECT * FROM table JOIN other_table ON other_table.id=42 " + "AND ( table.id = other_table.id AND other_table.id=42 )" + ), + ), + ( + "SELECT * FROM table WHERE 1=1 AND id=42", + "table", + "id=42", + "SELECT * FROM table WHERE ( 1=1 AND id=42) AND table.id=42", + ), + # For joins we apply the RLS to the ON clause, since it's easier and prevents + # leaking information about number of rows on OUTER JOINs. + ( + "SELECT * FROM table JOIN other_table ON table.id = other_table.id", + "other_table", + "id=42", + ( + "SELECT * FROM table JOIN other_table ON other_table.id=42 " + "AND ( table.id = other_table.id )" + ), + ), + ( + ( + "SELECT * FROM table JOIN other_table ON table.id = other_table.id " + "WHERE 1=1" + ), + "other_table", + "id=42", + ( + "SELECT * FROM table JOIN other_table ON other_table.id=42 " + "AND ( table.id = other_table.id ) WHERE 1=1" + ), + ), + # Subqueries also work, as expected. + ( + "SELECT * FROM (SELECT * FROM other_table)", + "other_table", + "id=42", + "SELECT * FROM (SELECT * FROM other_table WHERE other_table.id=42 )", + ), + # As well as UNION. + ( + "SELECT * FROM table UNION ALL SELECT * FROM other_table", + "table", + "id=42", + "SELECT * FROM table WHERE table.id=42 UNION ALL SELECT * FROM other_table", + ), + ( + "SELECT * FROM table UNION ALL SELECT * FROM other_table", + "other_table", + "id=42", + ( + "SELECT * FROM table UNION ALL " + "SELECT * FROM other_table WHERE other_table.id=42" + ), + ), + # When comparing fully qualified table names (eg, schema.table) to simple names + # (eg, table) we are also conservative, assuming the schema is the same, since + # we don't have information on the default schema. + ( + "SELECT * FROM schema.table_name", + "table_name", + "id=42", + "SELECT * FROM schema.table_name WHERE table_name.id=42", + ), + ( + "SELECT * FROM schema.table_name", + "schema.table_name", + "id=42", + "SELECT * FROM schema.table_name WHERE schema.table_name.id=42", + ), + ( + "SELECT * FROM table_name", + "schema.table_name", + "id=42", + "SELECT * FROM table_name WHERE schema.table_name.id=42", + ), + ], +) +def test_insert_rls(sql: str, table: str, rls: str, expected: str) -> None: + """ + Insert into a statement a given RLS condition associated with a table. + """ + statement = sqlparse.parse(sql)[0] + condition = sqlparse.parse(rls)[0] + assert str(insert_rls(statement, table, condition)).strip() == expected.strip() + + +@pytest.mark.parametrize( + "rls,table,expected", + [ + ("id=42", "users", "users.id=42"), + ("users.id=42", "users", "users.id=42"), + ("schema.users.id=42", "users", "schema.users.id=42"), + ("false", "users", "false"), + ], +) +def test_add_table_name(rls: str, table: str, expected: str) -> None: + condition = sqlparse.parse(rls)[0] + add_table_name(condition, table) + assert str(condition) == expected + + +@pytest.mark.parametrize( + "candidate,table,expected", + [ + ("table", "table", True), + ("schema.table", "table", True), + ("table", "schema.table", True), + ('schema."my table"', '"my table"', True), + ('schema."my.table"', '"my.table"', True), + ], +) +def test_matches_table_name(candidate: str, table: str, expected: bool) -> None: + token = sqlparse.parse(candidate)[0].tokens[0] + assert matches_table_name(token, table) == expected