Skip to content

Commit

Permalink
Fix sql injection issues (#807)
Browse files Browse the repository at this point in the history

Currently, we have some sql injection issues in our code which
DeepSource detects for us, but some of them we ignored.

## What is the new behavior?

This PR should fix them.
  • Loading branch information
feluelle committed Sep 13, 2022
1 parent bef9b9f commit 3caf210
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 19 deletions.
27 changes: 20 additions & 7 deletions python-sdk/src/astro/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,21 @@ def run_sql(
result = self.connection.execute(sql_statement, parameters)
return result

def columns_exist(self, table: Table, columns: list[str]) -> bool:
"""
Check that a list of columns exist in the given table.
:param table: The table to check in.
:param columns: The columns to check.
:returns: whether the columns exist in the table or not.
"""
sqla_table = self.get_sqla_table(table)
return all(
any(sqla_column.name == column for sqla_column in sqla_table.columns)
for column in columns
)

def table_exists(self, table: Table) -> bool:
"""
Check if a table exists in the database.
Expand Down Expand Up @@ -492,15 +507,13 @@ def export_table_to_pandas_dataframe(self, source_table: Table) -> pd.DataFrame:
:param source_table: An existing table in the database
"""
table_qualified_name = self.get_table_qualified_name(source_table)
if self.table_exists(source_table):
return pd.read_sql(
# We are avoiding SQL injection by confirming the table exists before this statement
f"SELECT * FROM {table_qualified_name}", # skipcq BAN-B608
con=self.sqlalchemy_engine,
)
sqla_table = self.get_sqla_table(source_table)
return pd.read_sql(sql=sqla_table.select(), con=self.sqlalchemy_engine)

table_qualified_name = self.get_table_qualified_name(source_table)
raise NonExistentTableException(
"The table %s does not exist" % table_qualified_name
f"The table {table_qualified_name} does not exist"
)

def export_table_to_file(
Expand Down
35 changes: 23 additions & 12 deletions python-sdk/src/astro/databases/google/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,20 +197,31 @@ def merge_table(
target_table_name = self.get_table_qualified_name(target_table)
source_table_name = self.get_table_qualified_name(source_table)

statement = f"MERGE {target_table_name} T USING {source_table_name} S\
ON {' AND '.join(['T.' + col + '= S.' + col for col in target_conflict_columns])}\
WHEN NOT MATCHED BY TARGET THEN INSERT ({','.join(target_columns)}) VALUES ({','.join(source_columns)})"

update_statement_map = ", ".join(
[
f"T.{target_columns[idx]}=S.{source_columns[idx]}"
for idx in range(len(target_columns))
]
insert_statement = (
f"INSERT ({', '.join(target_columns)}) VALUES ({', '.join(source_columns)})"
)
merge_statement = (
f"MERGE {target_table_name} T USING {source_table_name} S"
f" ON {' AND '.join(f'T.{col}=S.{col}' for col in target_conflict_columns)}"
f" WHEN NOT MATCHED BY TARGET THEN {insert_statement}"
)
if if_conflicts == "update":
update_statement = f"UPDATE SET {update_statement_map}" # skipcq: BAN-B608
statement += f" WHEN MATCHED THEN {update_statement}"
self.run_sql(sql_statement=statement)
update_statement_map = ", ".join(
f"T.{col}=S.{source_columns[idx]}"
for idx, col in enumerate(target_columns)
)
if not self.columns_exist(source_table, source_columns):
raise ValueError(
f"Not all the columns provided exist for {source_table_name}!"
)
if not self.columns_exist(target_table, target_columns):
raise ValueError(
f"Not all the columns provided exist for {target_table_name}!"
)
# Note: Ignoring below sql injection warning, as we validate that the table columns exist beforehand.
update_statement = f"UPDATE SET {update_statement_map}" # skipcq BAN-B608
merge_statement += f" WHEN MATCHED THEN {update_statement}"
self.run_sql(sql_statement=merge_statement)

def is_native_load_file_available(
self, source_file: File, target_table: Table
Expand Down
63 changes: 63 additions & 0 deletions python-sdk/tests/databases/test_all_databases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Tests all Database implementations."""
import pathlib

import pandas as pd
import pytest

from astro.constants import Database
from astro.files import File
from astro.settings import SCHEMA
from astro.sql.table import Metadata, Table

CWD = pathlib.Path(__file__).parent


@pytest.mark.integration
@pytest.mark.parametrize(
"database_table_fixture",
[
{
"database": Database.BIGQUERY,
"file": File(str(pathlib.Path(CWD.parent, "data/sample.csv"))),
"table": Table(metadata=Metadata(schema=SCHEMA)),
},
{
"database": Database.POSTGRES,
"file": File(str(pathlib.Path(CWD.parent, "data/sample.csv"))),
"table": Table(metadata=Metadata(schema=SCHEMA.lower())),
},
{
"database": Database.REDSHIFT,
"file": File(str(pathlib.Path(CWD.parent, "data/sample.csv"))),
"table": Table(metadata=Metadata(schema=SCHEMA.lower())),
},
{
"database": Database.SNOWFLAKE,
"file": File(str(pathlib.Path(CWD.parent, "data/sample.csv"))),
"table": Table(metadata=Metadata(schema=SCHEMA)),
},
{
"database": Database.SQLITE,
"file": File(str(pathlib.Path(CWD.parent, "data/sample.csv"))),
"table": Table(),
},
],
indirect=True,
ids=["bigquery", "postgres", "redshift", "snowflake", "sqlite"],
)
def test_export_table_to_pandas_dataframe(
database_table_fixture,
):
"""Test export_table_to_pandas_dataframe() where the table exists"""
database, table = database_table_fixture

df = database.export_table_to_pandas_dataframe(table)
assert len(df) == 3
expected = pd.DataFrame(
[
{"id": 1, "name": "First"},
{"id": 2, "name": "Second"},
{"id": 3, "name": "Third with unicode पांचाल"},
]
)
assert df.rename(columns=str.lower).equals(expected)

0 comments on commit 3caf210

Please sign in to comment.