Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sqlite hook - insert and replace functions #17695

Merged
merged 4 commits into from
Aug 18, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions airflow/hooks/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def insert_rows(self, table, rows, target_fields=None, commit_every=1000, replac
lst.append(self._serialize_cell(cell, conn))
values = tuple(lst)
sql = self._generate_insert_sql(table, values, target_fields, replace, **kwargs)
self.log.info(f"Generated sql: {sql}")
potiuk marked this conversation as resolved.
Show resolved Hide resolved
cur.execute(sql, values)
if commit_every and i % commit_every == 0:
conn.commit()
Expand Down
26 changes: 25 additions & 1 deletion airflow/providers/sqlite/example_dags/example_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
"""

from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.providers.sqlite.hooks.sqlite import SqliteHook
from airflow.providers.sqlite.operators.sqlite import SqliteOperator
from airflow.utils.dates import days_ago

Expand Down Expand Up @@ -52,6 +54,28 @@

# [END howto_operator_sqlite]


def insert_sqlite_hook():
sqlite_hook = SqliteHook("sqlite_default")
sqlite_hook.get_conn()

rows = [('James', '11'), ('James', '22'), ('James', '33')]
target_fields = ['first_name', 'last_name']
sqlite_hook.insert_rows(table='Customer', rows=rows, target_fields=target_fields)


def replace_sqlite_hook():
sqlite_hook = SqliteHook("sqlite_default")
sqlite_hook.get_conn()

rows = [('James', '11'), ('James', '22'), ('James', '33')]
target_fields = ['first_name', 'last_name']
sqlite_hook.insert_rows(table='Customer', rows=rows, target_fields=target_fields, replace=True)


insert_sqlite_task = PythonOperator(task_id="insert_sqlite_task", python_callable=insert_sqlite_hook)
replace_sqlite_task = PythonOperator(task_id="replace_sqlite_task", python_callable=replace_sqlite_hook)

# [START howto_operator_sqlite_external_file]

# Example of creating a task that calls an sql command from an external file.
Expand All @@ -64,4 +88,4 @@

# [END howto_operator_sqlite_external_file]

create_table_sqlite_task >> external_create_table_sqlite_task
create_table_sqlite_task >> external_create_table_sqlite_task >> insert_sqlite_task >> replace_sqlite_task
34 changes: 34 additions & 0 deletions airflow/providers/sqlite/hooks/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,37 @@ def get_conn(self) -> sqlite3.dbapi2.Connection:
airflow_conn = self.get_connection(conn_id)
conn = sqlite3.connect(airflow_conn.host)
return conn

@staticmethod
def _generate_insert_sql(table, values, target_fields, replace, **kwargs):
"""
Static helper method that generate the INSERT SQL statement.
The REPLACE variant is specific to MySQL syntax.

:param table: Name of the target table
:type table: str
:param values: The row to insert into the table
:type values: tuple of cell values
:param target_fields: The names of the columns to fill in the table
:type target_fields: iterable of strings
:param replace: Whether to replace instead of insert
:type replace: bool
:return: The generated INSERT or REPLACE SQL statement
:rtype: str
"""
placeholders = [
"?",
] * len(values)

if target_fields:
target_fields = ", ".join(target_fields)
target_fields = f"({target_fields})"
else:
target_fields = ''

if not replace:
sql = "INSERT INTO "
else:
sql = "REPLACE INTO "
sql += f"{table} {target_fields} VALUES ({','.join(placeholders)})"
return sql
2 changes: 1 addition & 1 deletion docs/apache-airflow-providers-sqlite/operators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ the connection metadata is structured as follows:
* - Parameter
- Input
* - Host: string
- MySql hostname
- Sqlite database file
* - Schema: string
- Set schema to execute Sql operations on by default
* - Login: string
Expand Down
20 changes: 20 additions & 0 deletions tests/providers/sqlite/hooks/test_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,23 @@ def test_run_log(self):
statement = 'SQL'
self.db_hook.run(statement)
assert self.db_hook.log.info.call_count == 2

def test_generate_insert_sql_replace_false(self):
expected_sql = "INSERT INTO Customer (first_name, last_name) VALUES (?,?)"
rows = ('James', '1')
target_fields = ['first_name', 'last_name']
sql = self.db_hook._generate_insert_sql(
table='Customer', values=rows, target_fields=target_fields, replace=False
)

assert sql == expected_sql

def test_generate_insert_sql_replace_true(self):
expected_sql = "REPLACE INTO Customer (first_name, last_name) VALUES (?,?)"
rows = ('James', '1')
target_fields = ['first_name', 'last_name']
sql = self.db_hook._generate_insert_sql(
table='Customer', values=rows, target_fields=target_fields, replace=True
)

assert sql == expected_sql