diff --git a/awswrangler/mysql.py b/awswrangler/mysql.py index de20f6a3f..bf026cc94 100644 --- a/awswrangler/mysql.py +++ b/awswrangler/mysql.py @@ -266,6 +266,7 @@ def to_sql( index: bool = False, dtype: Optional[Dict[str, str]] = None, varchar_lengths: Optional[Dict[str, int]] = None, + use_column_names: bool = False, ) -> None: """Write records stored in a DataFrame into MySQL. @@ -290,6 +291,10 @@ def to_sql( (e.g. {'col name': 'TEXT', 'col2 name': 'FLOAT'}) varchar_lengths : Dict[str, int], optional Dict of VARCHAR length by columns. (e.g. {"col1": 10, "col5": 200}). + use_column_names: bool + If set to True, will use the column names of the DataFrame for generating the INSERT SQL Query. + E.g. If the DataFrame has two columns `col1` and `col3` and `use_column_names` is True, data will only be + inserted into the database columns `col1` and `col3`. Returns ------- @@ -329,7 +334,10 @@ def to_sql( if index: df.reset_index(level=df.index.names, inplace=True) placeholders: str = ", ".join(["%s"] * len(df.columns)) - sql: str = f"INSERT INTO `{schema}`.`{table}` VALUES ({placeholders})" + insertion_columns = "" + if use_column_names: + insertion_columns = f"({', '.join(df.columns)})" + sql: str = f"INSERT INTO `{schema}`.`{table}` {insertion_columns} VALUES ({placeholders})" _logger.debug("sql: %s", sql) parameters: List[List[Any]] = _db_utils.extract_parameters(df=df) cursor.executemany(sql, parameters) diff --git a/awswrangler/postgresql.py b/awswrangler/postgresql.py index 6a1461079..34ec780d8 100644 --- a/awswrangler/postgresql.py +++ b/awswrangler/postgresql.py @@ -272,6 +272,7 @@ def to_sql( index: bool = False, dtype: Optional[Dict[str, str]] = None, varchar_lengths: Optional[Dict[str, int]] = None, + use_column_names: bool = False, ) -> None: """Write records stored in a DataFrame into PostgreSQL. @@ -296,6 +297,10 @@ def to_sql( (e.g. {'col name': 'TEXT', 'col2 name': 'FLOAT'}) varchar_lengths : Dict[str, int], optional Dict of VARCHAR length by columns. (e.g. {"col1": 10, "col5": 200}). + use_column_names: bool + If set to True, will use the column names of the DataFrame for generating the INSERT SQL Query. + E.g. If the DataFrame has two columns `col1` and `col3` and `use_column_names` is True, data will only be + inserted into the database columns `col1` and `col3`. Returns ------- @@ -335,7 +340,10 @@ def to_sql( if index: df.reset_index(level=df.index.names, inplace=True) placeholders: str = ", ".join(["%s"] * len(df.columns)) - sql: str = f'INSERT INTO "{schema}"."{table}" VALUES ({placeholders})' + insertion_columns = "" + if use_column_names: + insertion_columns = f"({', '.join(df.columns)})" + sql: str = f'INSERT INTO "{schema}"."{table}" {insertion_columns} VALUES ({placeholders})' _logger.debug("sql: %s", sql) parameters: List[List[Any]] = _db_utils.extract_parameters(df=df) cursor.executemany(sql, parameters) diff --git a/awswrangler/redshift.py b/awswrangler/redshift.py index 6c7c7d912..d1c889e85 100644 --- a/awswrangler/redshift.py +++ b/awswrangler/redshift.py @@ -644,6 +644,7 @@ def to_sql( primary_keys: Optional[List[str]] = None, varchar_lengths_default: int = 256, varchar_lengths: Optional[Dict[str, int]] = None, + use_column_names: bool = False, ) -> None: """Write records stored in a DataFrame into Redshift. @@ -688,6 +689,10 @@ def to_sql( The size that will be set for all VARCHAR columns not specified with varchar_lengths. varchar_lengths : Dict[str, int], optional Dict of VARCHAR length by columns. (e.g. {"col1": 10, "col5": 200}). + use_column_names: bool + If set to True, will use the column names of the DataFrame for generating the INSERT SQL Query. + E.g. If the DataFrame has two columns `col1` and `col3` and `use_column_names` is True, data will only be + inserted into the database columns `col1` and `col3`. Returns ------- @@ -737,7 +742,10 @@ def to_sql( df.reset_index(level=df.index.names, inplace=True) placeholders: str = ", ".join(["%s"] * len(df.columns)) schema_str = f'"{created_schema}".' if created_schema else "" - sql: str = f'INSERT INTO {schema_str}"{created_table}" VALUES ({placeholders})' + insertion_columns = "" + if use_column_names: + insertion_columns = f"({', '.join(df.columns)})" + sql: str = f'INSERT INTO {schema_str}"{created_table}" {insertion_columns} VALUES ({placeholders})' _logger.debug("sql: %s", sql) parameters: List[List[Any]] = _db_utils.extract_parameters(df=df) cursor.executemany(sql, parameters) diff --git a/awswrangler/sqlserver.py b/awswrangler/sqlserver.py index 10c3ab8ef..f7957a338 100644 --- a/awswrangler/sqlserver.py +++ b/awswrangler/sqlserver.py @@ -290,6 +290,7 @@ def to_sql( index: bool = False, dtype: Optional[Dict[str, str]] = None, varchar_lengths: Optional[Dict[str, int]] = None, + use_column_names: bool = False, ) -> None: """Write records stored in a DataFrame into Microsoft SQL Server. @@ -314,6 +315,10 @@ def to_sql( (e.g. {'col name': 'TEXT', 'col2 name': 'FLOAT'}) varchar_lengths : Dict[str, int], optional Dict of VARCHAR length by columns. (e.g. {"col1": 10, "col5": 200}). + use_column_names: bool + If set to True, will use the column names of the DataFrame for generating the INSERT SQL Query. + E.g. If the DataFrame has two columns `col1` and `col3` and `use_column_names` is True, data will only be + inserted into the database columns `col1` and `col3`. Returns ------- @@ -354,7 +359,10 @@ def to_sql( df.reset_index(level=df.index.names, inplace=True) placeholders: str = ", ".join(["?"] * len(df.columns)) table_identifier = _get_table_identifier(schema, table) - sql: str = f"INSERT INTO {table_identifier} VALUES ({placeholders})" + insertion_columns = "" + if use_column_names: + insertion_columns = f"({', '.join(df.columns)})" + sql: str = f"INSERT INTO {table_identifier} {insertion_columns} VALUES ({placeholders})" _logger.debug("sql: %s", sql) parameters: List[List[Any]] = _db_utils.extract_parameters(df=df) cursor.executemany(sql, parameters) diff --git a/tests/test_mysql.py b/tests/test_mysql.py index 6cccbcfb6..0aea31e7d 100644 --- a/tests/test_mysql.py +++ b/tests/test_mysql.py @@ -180,3 +180,29 @@ def test_connect_secret_manager(dbname): df = wr.mysql.read_sql_query("SELECT 1", con=con) con.close() assert df.shape == (1, 1) + + +def test_insert_with_column_names(mysql_table): + con = wr.mysql.connect(connection="aws-data-wrangler-mysql") + create_table_sql = ( + f"CREATE TABLE test.{mysql_table} " "(c0 varchar(100) NULL, " "c1 INT DEFAULT 42 NULL, " "c2 INT NOT NULL);" + ) + with con.cursor() as cursor: + cursor.execute(create_table_sql) + con.commit() + + df = pd.DataFrame({"c0": ["foo", "bar"], "c2": [1, 2]}) + + with pytest.raises(pymysql.err.OperationalError): + wr.mysql.to_sql(df=df, con=con, schema="test", table=mysql_table, mode="append", use_column_names=False) + + wr.mysql.to_sql(df=df, con=con, schema="test", table=mysql_table, mode="append", use_column_names=True) + + df2 = wr.mysql.read_sql_table(con=con, schema="test", table=mysql_table) + + df["c1"] = 42 + df["c0"] = df["c0"].astype("string") + df["c1"] = df["c1"].astype("Int64") + df["c2"] = df["c2"].astype("Int64") + df = df.reindex(sorted(df.columns), axis=1) + assert df.equals(df2) diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index c697a3685..6f92cae54 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -180,3 +180,31 @@ def test_connect_secret_manager(dbname): df = wr.postgresql.read_sql_query("SELECT 1", con=con) con.close() assert df.shape == (1, 1) + + +def test_insert_with_column_names(postgresql_table): + con = wr.postgresql.connect(connection="aws-data-wrangler-postgresql") + create_table_sql = ( + f"CREATE TABLE public.{postgresql_table} " "(c0 varchar NULL," "c1 int NULL DEFAULT 42," "c2 int NOT NULL);" + ) + with con.cursor() as cursor: + cursor.execute(create_table_sql) + con.commit() + + df = pd.DataFrame({"c0": ["foo", "bar"], "c2": [1, 2]}) + + with pytest.raises(pg8000.exceptions.ProgrammingError): + wr.postgresql.to_sql( + df=df, con=con, schema="public", table=postgresql_table, mode="append", use_column_names=False + ) + + wr.postgresql.to_sql(df=df, con=con, schema="public", table=postgresql_table, mode="append", use_column_names=True) + + df2 = wr.postgresql.read_sql_table(con=con, schema="public", table=postgresql_table) + + df["c1"] = 42 + df["c0"] = df["c0"].astype("string") + df["c1"] = df["c1"].astype("Int64") + df["c2"] = df["c2"].astype("Int64") + df = df.reindex(sorted(df.columns), axis=1) + assert df.equals(df2) diff --git a/tests/test_redshift.py b/tests/test_redshift.py index defa245c1..fbed0754a 100644 --- a/tests/test_redshift.py +++ b/tests/test_redshift.py @@ -911,3 +911,29 @@ def test_failed_keep_files(path, redshift_table, databases_parameters): varchar_lengths={"c1": 2}, ) assert len(wr.s3.list_objects(path)) == 0 + + +def test_insert_with_column_names(redshift_table): + con = wr.redshift.connect(connection="aws-data-wrangler-redshift") + create_table_sql = ( + f"CREATE TABLE public.{redshift_table} " "(c0 varchar(100), " "c1 integer default 42, " "c2 integer not null);" + ) + with con.cursor() as cursor: + cursor.execute(create_table_sql) + con.commit() + + df = pd.DataFrame({"c0": ["foo", "bar"], "c2": [1, 2]}) + + with pytest.raises(redshift_connector.error.ProgrammingError): + wr.redshift.to_sql(df=df, con=con, schema="public", table=redshift_table, mode="append", use_column_names=False) + + wr.redshift.to_sql(df=df, con=con, schema="public", table=redshift_table, mode="append", use_column_names=True) + + df2 = wr.redshift.read_sql_table(con=con, schema="public", table=redshift_table) + + df["c1"] = 42 + df["c0"] = df["c0"].astype("string") + df["c1"] = df["c1"].astype("Int64") + df["c2"] = df["c2"].astype("Int64") + df = df.reindex(sorted(df.columns), axis=1) + assert df.equals(df2) diff --git a/tests/test_sqlserver.py b/tests/test_sqlserver.py index b661df1a3..923e70784 100644 --- a/tests/test_sqlserver.py +++ b/tests/test_sqlserver.py @@ -194,3 +194,29 @@ def test_connect_secret_manager(dbname): assert df.shape == (1, 1) except boto3.client("secretsmanager").exceptions.ResourceNotFoundException: pass # Workaround for secretmanager inconsistance + + +def test_insert_with_column_names(sqlserver_table): + con = wr.sqlserver.connect(connection="aws-data-wrangler-sqlserver") + create_table_sql = ( + f"CREATE TABLE dbo.{sqlserver_table} " "(c0 varchar(100) NULL," "c1 INT DEFAULT 42 NULL," "c2 INT NOT NULL);" + ) + with con.cursor() as cursor: + cursor.execute(create_table_sql) + con.commit() + + df = pd.DataFrame({"c0": ["foo", "bar"], "c2": [1, 2]}) + + with pytest.raises(pyodbc.ProgrammingError): + wr.sqlserver.to_sql(df=df, con=con, schema="dbo", table=sqlserver_table, mode="append", use_column_names=False) + + wr.sqlserver.to_sql(df=df, con=con, schema="dbo", table=sqlserver_table, mode="append", use_column_names=True) + + df2 = wr.sqlserver.read_sql_table(con=con, schema="dbo", table=sqlserver_table) + + df["c1"] = 42 + df["c0"] = df["c0"].astype("string") + df["c1"] = df["c1"].astype("Int64") + df["c2"] = df["c2"].astype("Int64") + df = df.reindex(sorted(df.columns), axis=1) + assert df.equals(df2)