Skip to content

Commit

Permalink
Fix sqlite hook - insert and replace functions (#17695)
Browse files Browse the repository at this point in the history
  • Loading branch information
subkanthi committed Aug 18, 2021
1 parent 9863811 commit 9b2e593
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 2 deletions.
1 change: 1 addition & 0 deletions airflow/hooks/dbapi.py
Expand Up @@ -324,6 +324,7 @@ def insert_rows(self, table, rows, target_fields=None, commit_every=1000, replac
lst.append(self._serialize_cell(cell, conn))
values = tuple(lst)
sql = self._generate_insert_sql(table, values, target_fields, replace, **kwargs)
self.log.debug("Generated sql: %s", sql)
cur.execute(sql, values)
if commit_every and i % commit_every == 0:
conn.commit()
Expand Down
26 changes: 25 additions & 1 deletion airflow/providers/sqlite/example_dags/example_sqlite.py
Expand Up @@ -24,6 +24,8 @@
"""

from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.providers.sqlite.hooks.sqlite import SqliteHook
from airflow.providers.sqlite.operators.sqlite import SqliteOperator
from airflow.utils.dates import days_ago

Expand Down Expand Up @@ -52,6 +54,28 @@

# [END howto_operator_sqlite]


def insert_sqlite_hook():
sqlite_hook = SqliteHook("sqlite_default")
sqlite_hook.get_conn()

rows = [('James', '11'), ('James', '22'), ('James', '33')]
target_fields = ['first_name', 'last_name']
sqlite_hook.insert_rows(table='Customer', rows=rows, target_fields=target_fields)


def replace_sqlite_hook():
sqlite_hook = SqliteHook("sqlite_default")
sqlite_hook.get_conn()

rows = [('James', '11'), ('James', '22'), ('James', '33')]
target_fields = ['first_name', 'last_name']
sqlite_hook.insert_rows(table='Customer', rows=rows, target_fields=target_fields, replace=True)


insert_sqlite_task = PythonOperator(task_id="insert_sqlite_task", python_callable=insert_sqlite_hook)
replace_sqlite_task = PythonOperator(task_id="replace_sqlite_task", python_callable=replace_sqlite_hook)

# [START howto_operator_sqlite_external_file]

# Example of creating a task that calls an sql command from an external file.
Expand All @@ -64,4 +88,4 @@

# [END howto_operator_sqlite_external_file]

create_table_sqlite_task >> external_create_table_sqlite_task
create_table_sqlite_task >> external_create_table_sqlite_task >> insert_sqlite_task >> replace_sqlite_task
34 changes: 34 additions & 0 deletions airflow/providers/sqlite/hooks/sqlite.py
Expand Up @@ -35,3 +35,37 @@ def get_conn(self) -> sqlite3.dbapi2.Connection:
airflow_conn = self.get_connection(conn_id)
conn = sqlite3.connect(airflow_conn.host)
return conn

@staticmethod
def _generate_insert_sql(table, values, target_fields, replace, **kwargs):
"""
Static helper method that generate the INSERT SQL statement.
The REPLACE variant is specific to MySQL syntax.
:param table: Name of the target table
:type table: str
:param values: The row to insert into the table
:type values: tuple of cell values
:param target_fields: The names of the columns to fill in the table
:type target_fields: iterable of strings
:param replace: Whether to replace instead of insert
:type replace: bool
:return: The generated INSERT or REPLACE SQL statement
:rtype: str
"""
placeholders = [
"?",
] * len(values)

if target_fields:
target_fields = ", ".join(target_fields)
target_fields = f"({target_fields})"
else:
target_fields = ''

if not replace:
sql = "INSERT INTO "
else:
sql = "REPLACE INTO "
sql += f"{table} {target_fields} VALUES ({','.join(placeholders)})"
return sql
2 changes: 1 addition & 1 deletion docs/apache-airflow-providers-sqlite/operators.rst
Expand Up @@ -39,7 +39,7 @@ the connection metadata is structured as follows:
* - Parameter
- Input
* - Host: string
- MySql hostname
- Sqlite database file
* - Schema: string
- Set schema to execute Sql operations on by default
* - Login: string
Expand Down
20 changes: 20 additions & 0 deletions tests/providers/sqlite/hooks/test_sqlite.py
Expand Up @@ -106,3 +106,23 @@ def test_run_log(self):
statement = 'SQL'
self.db_hook.run(statement)
assert self.db_hook.log.info.call_count == 2

def test_generate_insert_sql_replace_false(self):
expected_sql = "INSERT INTO Customer (first_name, last_name) VALUES (?,?)"
rows = ('James', '1')
target_fields = ['first_name', 'last_name']
sql = self.db_hook._generate_insert_sql(
table='Customer', values=rows, target_fields=target_fields, replace=False
)

assert sql == expected_sql

def test_generate_insert_sql_replace_true(self):
expected_sql = "REPLACE INTO Customer (first_name, last_name) VALUES (?,?)"
rows = ('James', '1')
target_fields = ['first_name', 'last_name']
sql = self.db_hook._generate_insert_sql(
table='Customer', values=rows, target_fields=target_fields, replace=True
)

assert sql == expected_sql

0 comments on commit 9b2e593

Please sign in to comment.