Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sql injection issues #807

Merged
merged 11 commits into from
Sep 13, 2022
27 changes: 20 additions & 7 deletions python-sdk/src/astro/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,21 @@ def run_sql(
result = self.connection.execute(sql_statement, parameters)
return result

def columns_exist(self, table: Table, columns: list[str]) -> bool:
"""
Check that a list of columns exist in the given table.

:param table: The table to check in.
:param columns: The columns to check.

:returns: whether the columns exist in the table or not.
"""
sqla_table = self.get_sqla_table(table)
return all(
any(sqla_column.name == column for sqla_column in sqla_table.columns)
for column in columns
)

def table_exists(self, table: Table) -> bool:
"""
Check if a table exists in the database.
Expand Down Expand Up @@ -490,15 +505,13 @@ def export_table_to_pandas_dataframe(self, source_table: Table) -> pd.DataFrame:

:param source_table: An existing table in the database
"""
table_qualified_name = self.get_table_qualified_name(source_table)
if self.table_exists(source_table):
return pd.read_sql(
# We are avoiding SQL injection by confirming the table exists before this statement
f"SELECT * FROM {table_qualified_name}", # skipcq BAN-B608
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for catching this.

con=self.sqlalchemy_engine,
)
sqla_table = self.get_sqla_table(source_table)
return pd.read_sql(sql=sqla_table.select(), con=self.sqlalchemy_engine)

table_qualified_name = self.get_table_qualified_name(source_table)
raise NonExistentTableException(
"The table %s does not exist" % table_qualified_name
f"The table {table_qualified_name} does not exist"
)

def export_table_to_file(
Expand Down
35 changes: 23 additions & 12 deletions python-sdk/src/astro/databases/google/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,20 +197,31 @@ def merge_table(
target_table_name = self.get_table_qualified_name(target_table)
source_table_name = self.get_table_qualified_name(source_table)

statement = f"MERGE {target_table_name} T USING {source_table_name} S\
ON {' AND '.join(['T.' + col + '= S.' + col for col in target_conflict_columns])}\
WHEN NOT MATCHED BY TARGET THEN INSERT ({','.join(target_columns)}) VALUES ({','.join(source_columns)})"

update_statement_map = ", ".join(
[
f"T.{target_columns[idx]}=S.{source_columns[idx]}"
for idx in range(len(target_columns))
]
insert_statement = (
f"INSERT ({', '.join(target_columns)}) VALUES ({', '.join(source_columns)})"
)
merge_statement = (
f"MERGE {target_table_name} T USING {source_table_name} S"
f" ON {' AND '.join(f'T.{col}=S.{col}' for col in target_conflict_columns)}"
f" WHEN NOT MATCHED BY TARGET THEN {insert_statement}"
)
if if_conflicts == "update":
update_statement = f"UPDATE SET {update_statement_map}" # skipcq: BAN-B608
statement += f" WHEN MATCHED THEN {update_statement}"
self.run_sql(sql_statement=statement)
update_statement_map = ", ".join(
f"T.{col}=S.{source_columns[idx]}"
for idx, col in enumerate(target_columns)
)
if not self.columns_exist(source_table, source_columns):
raise ValueError(
f"Not all the columns provided exist for {source_table_name}!"
)
if not self.columns_exist(target_table, target_columns):
raise ValueError(
f"Not all the columns provided exist for {target_table_name}!"
)
# Note: Ignoring below sql injection warning, as we validate that the table columns exist beforehand.
update_statement = f"UPDATE SET {update_statement_map}" # skipcq BAN-B608
feluelle marked this conversation as resolved.
Show resolved Hide resolved
merge_statement += f" WHEN MATCHED THEN {update_statement}"
self.run_sql(sql_statement=merge_statement)

def is_native_load_file_available(
self, source_file: File, target_table: Table
Expand Down
63 changes: 63 additions & 0 deletions python-sdk/tests/databases/test_all_databases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Tests all Database implementations."""
import pathlib

import pandas as pd
import pytest

from astro.constants import Database
from astro.files import File
from astro.settings import SCHEMA
from astro.sql.table import Metadata, Table

CWD = pathlib.Path(__file__).parent


@pytest.mark.integration
@pytest.mark.parametrize(
"database_table_fixture",
[
{
"database": Database.BIGQUERY,
"file": File(str(pathlib.Path(CWD.parent, "data/sample.csv"))),
"table": Table(metadata=Metadata(schema=SCHEMA)),
},
{
"database": Database.POSTGRES,
"file": File(str(pathlib.Path(CWD.parent, "data/sample.csv"))),
"table": Table(metadata=Metadata(schema=SCHEMA.lower())),
},
{
"database": Database.REDSHIFT,
"file": File(str(pathlib.Path(CWD.parent, "data/sample.csv"))),
"table": Table(metadata=Metadata(schema=SCHEMA.lower())),
},
{
"database": Database.SNOWFLAKE,
"file": File(str(pathlib.Path(CWD.parent, "data/sample.csv"))),
"table": Table(metadata=Metadata(schema=SCHEMA)),
},
{
"database": Database.SQLITE,
"file": File(str(pathlib.Path(CWD.parent, "data/sample.csv"))),
"table": Table(),
},
],
indirect=True,
ids=["bigquery", "postgres", "redshift", "snowflake", "sqlite"],
)
def test_export_table_to_pandas_dataframe(
database_table_fixture,
):
"""Test export_table_to_pandas_dataframe() where the table exists"""
database, table = database_table_fixture

df = database.export_table_to_pandas_dataframe(table)
assert len(df) == 3
expected = pd.DataFrame(
[
{"id": 1, "name": "First"},
{"id": 2, "name": "Second"},
{"id": 3, "name": "Third with unicode पांचाल"},
]
)
assert df.rename(columns=str.lower).equals(expected)