diff --git a/python-sdk/src/astro/databases/base.py b/python-sdk/src/astro/databases/base.py index 091bd36476..8ff99cbd11 100644 --- a/python-sdk/src/astro/databases/base.py +++ b/python-sdk/src/astro/databases/base.py @@ -101,6 +101,21 @@ def run_sql( result = self.connection.execute(sql_statement, parameters) return result + def columns_exist(self, table: Table, columns: list[str]) -> bool: + """ + Check that a list of columns exist in the given table. + + :param table: The table to check in. + :param columns: The columns to check. + + :returns: whether the columns exist in the table or not. + """ + sqla_table = self.get_sqla_table(table) + return all( + any(sqla_column.name == column for sqla_column in sqla_table.columns) + for column in columns + ) + def table_exists(self, table: Table) -> bool: """ Check if a table exists in the database. @@ -490,15 +505,13 @@ def export_table_to_pandas_dataframe(self, source_table: Table) -> pd.DataFrame: :param source_table: An existing table in the database """ - table_qualified_name = self.get_table_qualified_name(source_table) if self.table_exists(source_table): - return pd.read_sql( - # We are avoiding SQL injection by confirming the table exists before this statement - f"SELECT * FROM {table_qualified_name}", # skipcq BAN-B608 - con=self.sqlalchemy_engine, - ) + sqla_table = self.get_sqla_table(source_table) + return pd.read_sql(sql=sqla_table.select(), con=self.sqlalchemy_engine) + + table_qualified_name = self.get_table_qualified_name(source_table) raise NonExistentTableException( - "The table %s does not exist" % table_qualified_name + f"The table {table_qualified_name} does not exist" ) def export_table_to_file( diff --git a/python-sdk/src/astro/databases/google/bigquery.py b/python-sdk/src/astro/databases/google/bigquery.py index 0efe9ed5b4..fe6fa58b03 100644 --- a/python-sdk/src/astro/databases/google/bigquery.py +++ b/python-sdk/src/astro/databases/google/bigquery.py @@ -197,20 +197,31 @@ def merge_table( target_table_name = self.get_table_qualified_name(target_table) source_table_name = self.get_table_qualified_name(source_table) - statement = f"MERGE {target_table_name} T USING {source_table_name} S\ - ON {' AND '.join(['T.' + col + '= S.' + col for col in target_conflict_columns])}\ - WHEN NOT MATCHED BY TARGET THEN INSERT ({','.join(target_columns)}) VALUES ({','.join(source_columns)})" - - update_statement_map = ", ".join( - [ - f"T.{target_columns[idx]}=S.{source_columns[idx]}" - for idx in range(len(target_columns)) - ] + insert_statement = ( + f"INSERT ({', '.join(target_columns)}) VALUES ({', '.join(source_columns)})" + ) + merge_statement = ( + f"MERGE {target_table_name} T USING {source_table_name} S" + f" ON {' AND '.join(f'T.{col}=S.{col}' for col in target_conflict_columns)}" + f" WHEN NOT MATCHED BY TARGET THEN {insert_statement}" ) if if_conflicts == "update": - update_statement = f"UPDATE SET {update_statement_map}" # skipcq: BAN-B608 - statement += f" WHEN MATCHED THEN {update_statement}" - self.run_sql(sql_statement=statement) + update_statement_map = ", ".join( + f"T.{col}=S.{source_columns[idx]}" + for idx, col in enumerate(target_columns) + ) + if not self.columns_exist(source_table, source_columns): + raise ValueError( + f"Not all the columns provided exist for {source_table_name}!" + ) + if not self.columns_exist(target_table, target_columns): + raise ValueError( + f"Not all the columns provided exist for {target_table_name}!" + ) + # Note: Ignoring below sql injection warning, as we validate that the table columns exist beforehand. + update_statement = f"UPDATE SET {update_statement_map}" # skipcq BAN-B608 + merge_statement += f" WHEN MATCHED THEN {update_statement}" + self.run_sql(sql_statement=merge_statement) def is_native_load_file_available( self, source_file: File, target_table: Table diff --git a/python-sdk/tests/databases/test_all_databases.py b/python-sdk/tests/databases/test_all_databases.py new file mode 100644 index 0000000000..9456a6db8e --- /dev/null +++ b/python-sdk/tests/databases/test_all_databases.py @@ -0,0 +1,63 @@ +"""Tests all Database implementations.""" +import pathlib + +import pandas as pd +import pytest + +from astro.constants import Database +from astro.files import File +from astro.settings import SCHEMA +from astro.sql.table import Metadata, Table + +CWD = pathlib.Path(__file__).parent + + +@pytest.mark.integration +@pytest.mark.parametrize( + "database_table_fixture", + [ + { + "database": Database.BIGQUERY, + "file": File(str(pathlib.Path(CWD.parent, "data/sample.csv"))), + "table": Table(metadata=Metadata(schema=SCHEMA)), + }, + { + "database": Database.POSTGRES, + "file": File(str(pathlib.Path(CWD.parent, "data/sample.csv"))), + "table": Table(metadata=Metadata(schema=SCHEMA.lower())), + }, + { + "database": Database.REDSHIFT, + "file": File(str(pathlib.Path(CWD.parent, "data/sample.csv"))), + "table": Table(metadata=Metadata(schema=SCHEMA.lower())), + }, + { + "database": Database.SNOWFLAKE, + "file": File(str(pathlib.Path(CWD.parent, "data/sample.csv"))), + "table": Table(metadata=Metadata(schema=SCHEMA)), + }, + { + "database": Database.SQLITE, + "file": File(str(pathlib.Path(CWD.parent, "data/sample.csv"))), + "table": Table(), + }, + ], + indirect=True, + ids=["bigquery", "postgres", "redshift", "snowflake", "sqlite"], +) +def test_export_table_to_pandas_dataframe( + database_table_fixture, +): + """Test export_table_to_pandas_dataframe() where the table exists""" + database, table = database_table_fixture + + df = database.export_table_to_pandas_dataframe(table) + assert len(df) == 3 + expected = pd.DataFrame( + [ + {"id": 1, "name": "First"}, + {"id": 2, "name": "Second"}, + {"id": 3, "name": "Third with unicode पांचाल"}, + ] + ) + assert df.rename(columns=str.lower).equals(expected)