From 9b2e593fd4c79366681162a1da43595584bd1abd Mon Sep 17 00:00:00 2001 From: Kanthi Date: Wed, 18 Aug 2021 16:29:19 -0400 Subject: [PATCH] Fix sqlite hook - insert and replace functions (#17695) --- airflow/hooks/dbapi.py | 1 + .../sqlite/example_dags/example_sqlite.py | 26 +++++++++++++- airflow/providers/sqlite/hooks/sqlite.py | 34 +++++++++++++++++++ .../operators.rst | 2 +- tests/providers/sqlite/hooks/test_sqlite.py | 20 +++++++++++ 5 files changed, 81 insertions(+), 2 deletions(-) diff --git a/airflow/hooks/dbapi.py b/airflow/hooks/dbapi.py index 4156500085ada..3c51e6f8a3572 100644 --- a/airflow/hooks/dbapi.py +++ b/airflow/hooks/dbapi.py @@ -324,6 +324,7 @@ def insert_rows(self, table, rows, target_fields=None, commit_every=1000, replac lst.append(self._serialize_cell(cell, conn)) values = tuple(lst) sql = self._generate_insert_sql(table, values, target_fields, replace, **kwargs) + self.log.debug("Generated sql: %s", sql) cur.execute(sql, values) if commit_every and i % commit_every == 0: conn.commit() diff --git a/airflow/providers/sqlite/example_dags/example_sqlite.py b/airflow/providers/sqlite/example_dags/example_sqlite.py index 88b4ba2293469..2ff5114b80c19 100644 --- a/airflow/providers/sqlite/example_dags/example_sqlite.py +++ b/airflow/providers/sqlite/example_dags/example_sqlite.py @@ -24,6 +24,8 @@ """ from airflow import DAG +from airflow.operators.python import PythonOperator +from airflow.providers.sqlite.hooks.sqlite import SqliteHook from airflow.providers.sqlite.operators.sqlite import SqliteOperator from airflow.utils.dates import days_ago @@ -52,6 +54,28 @@ # [END howto_operator_sqlite] + +def insert_sqlite_hook(): + sqlite_hook = SqliteHook("sqlite_default") + sqlite_hook.get_conn() + + rows = [('James', '11'), ('James', '22'), ('James', '33')] + target_fields = ['first_name', 'last_name'] + sqlite_hook.insert_rows(table='Customer', rows=rows, target_fields=target_fields) + + +def replace_sqlite_hook(): + sqlite_hook = SqliteHook("sqlite_default") + sqlite_hook.get_conn() + + rows = [('James', '11'), ('James', '22'), ('James', '33')] + target_fields = ['first_name', 'last_name'] + sqlite_hook.insert_rows(table='Customer', rows=rows, target_fields=target_fields, replace=True) + + +insert_sqlite_task = PythonOperator(task_id="insert_sqlite_task", python_callable=insert_sqlite_hook) +replace_sqlite_task = PythonOperator(task_id="replace_sqlite_task", python_callable=replace_sqlite_hook) + # [START howto_operator_sqlite_external_file] # Example of creating a task that calls an sql command from an external file. @@ -64,4 +88,4 @@ # [END howto_operator_sqlite_external_file] -create_table_sqlite_task >> external_create_table_sqlite_task +create_table_sqlite_task >> external_create_table_sqlite_task >> insert_sqlite_task >> replace_sqlite_task diff --git a/airflow/providers/sqlite/hooks/sqlite.py b/airflow/providers/sqlite/hooks/sqlite.py index 82c9dd0468edc..e4a43317859b5 100644 --- a/airflow/providers/sqlite/hooks/sqlite.py +++ b/airflow/providers/sqlite/hooks/sqlite.py @@ -35,3 +35,37 @@ def get_conn(self) -> sqlite3.dbapi2.Connection: airflow_conn = self.get_connection(conn_id) conn = sqlite3.connect(airflow_conn.host) return conn + + @staticmethod + def _generate_insert_sql(table, values, target_fields, replace, **kwargs): + """ + Static helper method that generate the INSERT SQL statement. + The REPLACE variant is specific to MySQL syntax. + + :param table: Name of the target table + :type table: str + :param values: The row to insert into the table + :type values: tuple of cell values + :param target_fields: The names of the columns to fill in the table + :type target_fields: iterable of strings + :param replace: Whether to replace instead of insert + :type replace: bool + :return: The generated INSERT or REPLACE SQL statement + :rtype: str + """ + placeholders = [ + "?", + ] * len(values) + + if target_fields: + target_fields = ", ".join(target_fields) + target_fields = f"({target_fields})" + else: + target_fields = '' + + if not replace: + sql = "INSERT INTO " + else: + sql = "REPLACE INTO " + sql += f"{table} {target_fields} VALUES ({','.join(placeholders)})" + return sql diff --git a/docs/apache-airflow-providers-sqlite/operators.rst b/docs/apache-airflow-providers-sqlite/operators.rst index 100156e9fcf29..6cbe5f6d3062e 100644 --- a/docs/apache-airflow-providers-sqlite/operators.rst +++ b/docs/apache-airflow-providers-sqlite/operators.rst @@ -39,7 +39,7 @@ the connection metadata is structured as follows: * - Parameter - Input * - Host: string - - MySql hostname + - Sqlite database file * - Schema: string - Set schema to execute Sql operations on by default * - Login: string diff --git a/tests/providers/sqlite/hooks/test_sqlite.py b/tests/providers/sqlite/hooks/test_sqlite.py index d618fed97af74..3c5d977037499 100644 --- a/tests/providers/sqlite/hooks/test_sqlite.py +++ b/tests/providers/sqlite/hooks/test_sqlite.py @@ -106,3 +106,23 @@ def test_run_log(self): statement = 'SQL' self.db_hook.run(statement) assert self.db_hook.log.info.call_count == 2 + + def test_generate_insert_sql_replace_false(self): + expected_sql = "INSERT INTO Customer (first_name, last_name) VALUES (?,?)" + rows = ('James', '1') + target_fields = ['first_name', 'last_name'] + sql = self.db_hook._generate_insert_sql( + table='Customer', values=rows, target_fields=target_fields, replace=False + ) + + assert sql == expected_sql + + def test_generate_insert_sql_replace_true(self): + expected_sql = "REPLACE INTO Customer (first_name, last_name) VALUES (?,?)" + rows = ('James', '1') + target_fields = ['first_name', 'last_name'] + sql = self.db_hook._generate_insert_sql( + table='Customer', values=rows, target_fields=target_fields, replace=True + ) + + assert sql == expected_sql