In [85]:
from typing import Any

TABLE_POLICIES: dict[str, dict[str, Any]] = {
    "products": {
        "scope": "GLOBAL",
        "allowed_columns": {"id", "name", "price"}
    },
    "salary": {
        "scope": "USER_SCOPED",
        "user_key": "employee_id",
        "allowed_columns": {"employee_id", "amount", "from_date", "to_date"}
    }
}

BLOCKED_FUNCTIONS: set[str] = {
    "nextval",
    "setval",
    "pg_advisory_lock"
}

MAX_LIMIT = 1000

In [86]:
import sqlglot
from sqlglot import exp

def parse_sql(sql: str) -> sqlglot.Expression:
    try:
        return sqlglot.parse_one(sql, dialect="postgres")
    except Exception as e:
        raise ValueError(f"Invalid SQL: {e}")


In [87]:
def enforce_read_only(ast: sqlglot.Expression):
    for node in ast.walk():
        if isinstance(node, (
            exp.Insert,
            exp.Update,
            exp.Delete,
            exp.Merge,
            exp.Create,
            exp.Drop,
            exp.Alter,
        )):
            raise PermissionError("Write or DDL operation detected")


In [88]:
def enforce_safe_functions(ast: sqlglot.Expression):
    for func in ast.find_all(exp.Func):
        if func.name.lower() in BLOCKED_FUNCTIONS:
            raise PermissionError(f"Blocked function: {func.name}")


In [89]:
def extract_tables(ast: sqlglot.Expression) -> set[str]:
    tables: set[str] = set()
    for table in ast.find_all(exp.Table):
        tables.add(table.name)
    return tables


In [90]:
def enforce_table_access(ast: sqlglot.Expression):
    tables = extract_tables(ast)
    for table in tables:
        if table not in TABLE_POLICIES:
            raise PermissionError(f"Table not allowed: {table}")

In [91]:
def enforce_column_access(ast: sqlglot.Expression):
    for col in ast.find_all(exp.Column):
        table = col.table
        column = col.name

        if not table:
            continue  # derived column / expression

        policy = TABLE_POLICIES.get(table)
        if not policy:
            continue

        if column not in policy["allowed_columns"]:
            raise PermissionError(
                f"Column '{column}' not allowed on table '{table}'"
            )


In [92]:
def block_select_star(ast: sqlglot.Expression):
    for _ in ast.find_all(exp.Star):
        raise PermissionError("SELECT * is not allowed")


In [93]:
def inject_user_filters(ast: sqlglot.Expression, current_user_id: int, skip_tables: set[str] | None = None):
    """
    Inject user filters for user-scoped tables.
    
    Args:
        ast: The SQL AST to modify
        current_user_id: The current user's ID to filter by
        skip_tables: Optional set of table names to skip user filter injection
    """
    skip_tables = skip_tables or set()
    
    for table_name, policy in TABLE_POLICIES.items():
        if policy["scope"] != "USER_SCOPED":
            continue

        if table_name in skip_tables:
            continue

        if table_name not in extract_tables(ast):
            continue

        predicate = exp.EQ(
            this=exp.Column(this=policy["user_key"], table=table_name),
            expression=exp.Literal.number(current_user_id)
        )

        # Check if this exact filter already exists in the WHERE clause
        where = ast.args.get("where")
        if where:
            # Convert predicate to string for comparison
            predicate_sql = predicate.sql(dialect="postgres")
            where_sql = where.sql(dialect="postgres")
            
            # If the filter already exists, skip adding it
            if predicate_sql in where_sql:
                continue
                
            ast.set("where", exp.Where(this=exp.And(this=where.this, expression=predicate)))
        else:
            ast.set("where", exp.Where(this=predicate))


In [94]:
def enforce_limit(ast: sqlglot.Expression, max_limit: int | None = None):
    """
    Enforce a LIMIT clause on the query.
    
    Args:
        ast: The SQL AST to modify
        max_limit: Optional custom max limit (defaults to MAX_LIMIT constant)
    """
    limit_value = max_limit if max_limit is not None else MAX_LIMIT
    
    if not ast.args.get("limit"):
        ast.set("limit", exp.Limit(
            expression=exp.Literal.number(limit_value)
        ))


In [95]:
def validate_and_rewrite(
    sql: str, 
    current_user_id: int,
    skip_user_filter: bool = False,
    skip_user_filter_tables: set[str] | None = None,
    override_user_id: int | None = None,
    custom_limit: int | None = None
) -> str:
    """
    Validate and rewrite SQL query with security policies.
    
    Args:
        sql: The SQL query to validate
        current_user_id: The current user's ID
        skip_user_filter: If True, skips user filter injection entirely
        skip_user_filter_tables: Set of specific tables to skip user filter injection
        override_user_id: If provided, uses this user_id instead of current_user_id for filters
        custom_limit: If provided, uses this limit instead of the default MAX_LIMIT
    
    Returns:
        The validated and rewritten SQL query
    """
    ast = parse_sql(sql)

    enforce_read_only(ast)
    enforce_safe_functions(ast)
    enforce_table_access(ast)
    block_select_star(ast)
    enforce_column_access(ast)
    
    if not skip_user_filter:
        user_id = override_user_id if override_user_id is not None else current_user_id
        inject_user_filters(ast, user_id, skip_tables=skip_user_filter_tables)
    
    enforce_limit(ast, max_limit=custom_limit)

    return ast.sql(dialect="postgres")


In [96]:
# Example 1: Default behavior (injects user_id filter)
query = """SELECT amount
FROM salary
LIMIT 1000
"""

validated_query = validate_and_rewrite(query, current_user_id=10001)
print("Example 1 - Default behavior:")
print(validated_query)
print()

# Example 2: Override user_id to use a different value
query2 = """SELECT amount
FROM salary
"""

validated_query2 = validate_and_rewrite(query2, current_user_id=10001, override_user_id=45)
print("Example 2 - Override user_id to 45:")
print(validated_query2)
print()

# Example 3: Skip user filter entirely (for admin queries)
query3 = """SELECT amount
FROM salary
"""

validated_query3 = validate_and_rewrite(query3, current_user_id=10001, skip_user_filter=True)
print("Example 3 - Skip user filter entirely:")
print(validated_query3)
print()

# Example 4: Skip user filter for specific tables
query4 = """SELECT amount
FROM salary
"""

validated_query4 = validate_and_rewrite(
    query4, 
    current_user_id=10001, 
    skip_user_filter_tables={"salary"}
)
print("Example 4 - Skip user filter for 'salary' table:")
print(validated_query4)
print()

# Example 5: Custom limit
query5 = """SELECT amount
FROM salary
"""

validated_query5 = validate_and_rewrite(query5, current_user_id=10001, custom_limit=5000)
print("Example 5 - Custom limit of 5000:")
print(validated_query5)


Example 1 - Default behavior:
SELECT amount FROM salary WHERE salary.employee_id = 10001 LIMIT 1000

Example 2 - Override user_id to 45:
SELECT amount FROM salary WHERE salary.employee_id = 45 LIMIT 1000

Example 3 - Skip user filter entirely:
SELECT amount FROM salary LIMIT 1000

Example 4 - Skip user filter for 'salary' table:
SELECT amount FROM salary LIMIT 1000

Example 5 - Custom limit of 5000:
SELECT amount FROM salary WHERE salary.employee_id = 10001 LIMIT 5000


In [97]:
from pydantic import BaseModel

class WhereExpression(BaseModel):
    table: str
    column: str
    expression: str
    value: str | int | float | None

def extract_where_expressions(ast: sqlglot.Expression) -> list[WhereExpression]:
    expressions: list[WhereExpression] = []

    where = ast.args.get("where")
    if not where:
        return expressions

    # Map of sqlglot expression types to operator strings
    OPERATORS: dict[type[sqlglot.Expression], str] = {
        exp.EQ: "=",
        exp.NEQ: "!=",
        exp.GT: ">",
        exp.GTE: ">=",
        exp.LT: "<",
        exp.LTE: "<=",
        exp.Like: "LIKE",
        exp.ILike: "ILIKE",
        exp.In: "IN",
        exp.Is: "IS",
    }

    def _extract(expr: sqlglot.Expression):
        if isinstance(expr, (exp.And, exp.Or)):
            _extract(expr.this)
            _extract(expr.expression)
        else:
            # Check if this is a comparison operator
            for op_type, op_str in OPERATORS.items():
                if isinstance(expr, op_type):
                    left = expr.this
                    right = expr.expression

                    # Extract column information from the left side
                    if isinstance(left, exp.Column):
                        # Try to get a simple representation of the right side
                        if isinstance(right, (exp.Literal, exp.Identifier)):
                            value = right.name if isinstance(right, exp.Identifier) else right.this
                        else:
                            # For complex expressions, use the SQL representation
                            value = right.sql(dialect="postgres")

                        expressions.append(
                            WhereExpression(
                                table=left.table or "",
                                column=left.name,
                                expression=op_str,
                                value=value
                            )
                        )
                    break

    _extract(where.this)
    return expressions

ast = parse_sql(validated_query)
where_expressions = extract_where_expressions(ast)
print("Extracted WHERE expressions:")
for expr in where_expressions:
    print(expr)

Extracted WHERE expressions:
table='salary' column='employee_id' expression='=' value='10001'


In [None]:
import psycopg2
from psycopg2 import sql
import os

CONN_STRING = os.getenv("CONN_STRING", "postgresql://postgres:postgres@localhost:5432/employees") 

def execute_query(query: str) -> list[tuple[Any, ...]]:
    conn = None
    cursor = None
    try:
        # Connect to the PostgreSQL database
        conn = psycopg2.connect(CONN_STRING)
        cursor = conn.cursor()
        
        # Execute the query
        cursor.execute(sql.SQL(query))
        
        # Fetch results if it's a SELECT query
        if query.strip().lower().startswith("select"):
            results = cursor.fetchall()
            return results
        else:
            raise PermissionError("Only read methods are allowed.")
        
    except Exception as e:
        print(f"An error occurred: {e}")
        return []
    finally:
        # Close the database connection
        if cursor is not None:
            cursor.close()
        if conn is not None:
            conn.close()
            
# Example usage
print(validated_query)
results = execute_query(validated_query)
for row in results:
    print(row)

SELECT amount FROM salary WHERE salary.employee_id = 10001 LIMIT 1000
(60117,)
(62102,)
(66074,)
(66596,)
(66961,)
(71046,)
(74333,)
(75286,)
(75994,)
(76884,)
(80013,)
(81025,)
(81097,)
(84917,)
(85112,)
(85097,)
(88958,)
