Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sql injection issues #807

Merged
merged 11 commits into from
Sep 13, 2022
12 changes: 5 additions & 7 deletions python-sdk/src/astro/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,15 +490,13 @@ def export_table_to_pandas_dataframe(self, source_table: Table) -> pd.DataFrame:

:param source_table: An existing table in the database
"""
table_qualified_name = self.get_table_qualified_name(source_table)
if self.table_exists(source_table):
return pd.read_sql(
# We are avoiding SQL injection by confirming the table exists before this statement
f"SELECT * FROM {table_qualified_name}", # skipcq BAN-B608
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for catching this.

con=self.sqlalchemy_engine,
)
sqla_table = self.get_sqla_table(source_table)
return pd.read_sql(sql=select(sqla_table), con=self.sqlalchemy_engine)

table_qualified_name = self.get_table_qualified_name(source_table)
raise NonExistentTableException(
"The table %s does not exist" % table_qualified_name
f"The table {table_qualified_name} does not exist"
)

def export_table_to_file(
Expand Down
63 changes: 63 additions & 0 deletions python-sdk/tests/databases/test_all_databases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Tests all Database implementations."""
import pathlib

import pandas as pd
import pytest

from astro.constants import Database
from astro.files import File
from astro.settings import SCHEMA
from astro.sql.table import Metadata, Table

CWD = pathlib.Path(__file__).parent


@pytest.mark.integration
@pytest.mark.parametrize(
"database_table_fixture",
[
{
"database": Database.BIGQUERY,
"file": File(str(pathlib.Path(CWD.parent, "data/sample.csv"))),
"table": Table(metadata=Metadata(schema=SCHEMA)),
},
{
"database": Database.POSTGRES,
"file": File(str(pathlib.Path(CWD.parent, "data/sample.csv"))),
"table": Table(metadata=Metadata(schema=SCHEMA.lower())),
},
{
"database": Database.REDSHIFT,
"file": File(str(pathlib.Path(CWD.parent, "data/sample.csv"))),
"table": Table(metadata=Metadata(schema=SCHEMA.lower())),
},
{
"database": Database.SNOWFLAKE,
"file": File(str(pathlib.Path(CWD.parent, "data/sample.csv"))),
"table": Table(metadata=Metadata(schema=SCHEMA)),
},
{
"database": Database.SQLITE,
"file": File(str(pathlib.Path(CWD.parent, "data/sample.csv"))),
"table": Table(),
},
],
indirect=True,
ids=["bigquery", "postgres", "redshift", "snowflake", "sqlite"],
)
def test_export_table_to_pandas_dataframe(
database_table_fixture,
):
"""Test export_table_to_pandas_dataframe() where the table exists"""
database, table = database_table_fixture

df = database.export_table_to_pandas_dataframe(table)
assert len(df) == 3
expected = pd.DataFrame(
[
{"id": 1, "name": "First"},
{"id": 2, "name": "Second"},
{"id": 3, "name": "Third with unicode पांचाल"},
]
)
assert df.rename(columns=str.lower).equals(expected)