-
Notifications
You must be signed in to change notification settings - Fork 17.1k
PostgresHook: Add upsert rows support using ON CONFLICT #67045
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,7 @@ | |
| from __future__ import annotations | ||
|
|
||
| import os | ||
| from collections.abc import Mapping | ||
| from collections.abc import Iterable, Mapping | ||
| from contextlib import closing | ||
| from copy import deepcopy | ||
| from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeAlias, cast, overload | ||
|
|
@@ -230,6 +230,29 @@ def _get_cursor(self, raw_cursor: str) -> CursorType: | |
| valid_cursors = ", ".join(cursor_types.keys()) | ||
| raise ValueError(f"Invalid cursor passed {_cursor}. Valid options are: {valid_cursors}") | ||
|
|
||
| def _get_cursor_config(self, raw_cursor: str) -> tuple[str, Any]: | ||
| cursor = self._get_cursor(raw_cursor) | ||
|
|
||
| if USE_PSYCOPG3: | ||
| return "row_factory", cursor | ||
|
|
||
| return "cursor_factory", cursor | ||
|
|
||
| def _create_connection(self, conn_args: dict[str, Any]) -> CompatConnection: | ||
| if USE_PSYCOPG3: | ||
| from psycopg.connection import Connection as pgConnection | ||
|
|
||
| connection = pgConnection.connect(**cast("Any", conn_args)) | ||
|
|
||
| register_default_adapters(connection) | ||
|
|
||
| if self.enable_log_db_messages and hasattr(connection, "add_notice_handler"): | ||
| connection.add_notice_handler(self._notice_handler) | ||
|
|
||
| return connection | ||
|
|
||
| return ppg2_connect(**conn_args) | ||
|
|
||
| def _generate_cursor_name(self): | ||
| """Generate a unique name for server-side cursor.""" | ||
| import uuid | ||
|
|
@@ -262,30 +285,13 @@ def get_conn(self) -> CompatConnection: | |
| if arg_name not in self.ignored_extra_options: | ||
| conn_args[arg_name] = arg_val | ||
|
|
||
| if USE_PSYCOPG3: | ||
| from psycopg.connection import Connection as pgConnection | ||
| raw_cursor = conn.extra_dejson.get("cursor") | ||
|
|
||
| raw_cursor = conn.extra_dejson.get("cursor") | ||
| if raw_cursor: | ||
| conn_args["row_factory"] = self._get_cursor(raw_cursor) | ||
| if raw_cursor: | ||
| key, value = self._get_cursor_config(raw_cursor) | ||
| conn_args[key] = value | ||
|
|
||
| # Use Any type for the connection args to avoid type conflicts | ||
| connection = pgConnection.connect(**cast("Any", conn_args)) | ||
| self.conn = cast("CompatConnection", connection) | ||
|
|
||
| # Register JSON handlers for both json and jsonb types | ||
| # This ensures JSON data is properly decoded from bytes to Python objects | ||
| register_default_adapters(connection) | ||
|
|
||
| # Add the notice handler AFTER the connection is established | ||
| if self.enable_log_db_messages and hasattr(self.conn, "add_notice_handler"): | ||
| self.conn.add_notice_handler(self._notice_handler) | ||
| else: # psycopg2 | ||
| raw_cursor = conn.extra_dejson.get("cursor", False) | ||
| if raw_cursor: | ||
| conn_args["cursor_factory"] = self._get_cursor(raw_cursor) | ||
|
|
||
| self.conn = cast("CompatConnection", ppg2_connect(**conn_args)) | ||
| self.conn = self._create_connection(conn_args) | ||
|
|
||
| return self.conn | ||
|
|
||
|
|
@@ -720,3 +726,137 @@ def insert_rows( | |
|
|
||
| self.log.info("Done loading. Loaded a total of %s rows into %s", nb_rows, table) | ||
| return None | ||
|
|
||
| def _generate_upsert_sql( | ||
| self, | ||
| table: str, | ||
| values: tuple[Any, ...] | list[Any], | ||
| target_fields: list[str], | ||
| conflict_fields: list[str], | ||
| update_fields: list[str] | None = None, | ||
| **kwargs, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need the kwargs here as there is nothing consuming them. |
||
| ) -> str: | ||
| """ | ||
| Generate PostgreSQL UPSERT SQL using ``ON CONFLICT``. | ||
|
|
||
| :param table: Name of target table. | ||
| :param values: Row values used for placeholder generation. | ||
| :param target_fields: Non-empty column names used in the ``INSERT`` statement. | ||
| :param conflict_fields: Non-empty column names used in the ``ON CONFLICT`` clause. | ||
| :param update_fields: Columns to update on conflict. If omitted or empty, | ||
| ``DO NOTHING`` is used. | ||
| """ | ||
| placeholders = ", ".join(["%s"] * len(values)) | ||
|
|
||
| columns = ", ".join(target_fields) | ||
|
|
||
| conflict_clause = ", ".join(conflict_fields) | ||
|
|
||
| sql = f"INSERT INTO {table} ({columns}) VALUES ({placeholders}) " | ||
|
|
||
| if update_fields: | ||
| update_clause = ", ".join(f"{field} = EXCLUDED.{field}" for field in update_fields) | ||
|
|
||
| sql += f"ON CONFLICT ({conflict_clause}) DO UPDATE SET {update_clause}" | ||
| else: | ||
| sql += f"ON CONFLICT ({conflict_clause}) DO NOTHING" | ||
|
|
||
| return sql.strip() | ||
|
|
||
| def upsert_rows( | ||
| self, | ||
| table: str, | ||
| rows: Iterable[tuple[Any, ...]], | ||
| target_fields: list[str], | ||
| conflict_fields: list[str], | ||
| update_fields: list[str] | None = None, | ||
| commit_every: int = 1000, | ||
| *, | ||
| fast_executemany: bool = False, | ||
| autocommit: bool = False, | ||
| ) -> None: | ||
| """ | ||
| Upsert rows into a PostgreSQL table using ``ON CONFLICT``. | ||
|
|
||
| :param table: Name of the target table. | ||
| :param rows: Rows to upsert. | ||
| :param target_fields: Non-empty column names used in the ``INSERT`` statement. | ||
| :param conflict_fields: Non-empty column names used in the ``ON CONFLICT`` clause. | ||
| :param update_fields: Columns updated on conflict. If omitted or empty, | ||
| conflicting rows are ignored via ``DO NOTHING``. | ||
| :param commit_every: Maximum number of rows per transaction. Default value is 1000. | ||
| :param fast_executemany: Use ``psycopg2.extras.execute_batch`` for improved | ||
| batch performance. | ||
| :param autocommit: Connection autocommit setting. | ||
| """ | ||
| if not target_fields or any(not field for field in target_fields): | ||
| raise ValueError("target_fields must be provided and must not be empty.") | ||
|
|
||
| if not conflict_fields or any(not field for field in conflict_fields): | ||
| raise ValueError("conflict_fields must be provided and must not be empty.") | ||
|
|
||
| rows = iter(rows) | ||
|
|
||
| nb_rows = 0 | ||
| sql = None | ||
|
|
||
| with self._create_autocommit_connection(autocommit) as conn: | ||
| conn.commit() | ||
|
|
||
| with closing(conn.cursor()) as cur: | ||
| for chunked_rows in chunked(rows, commit_every): | ||
| values = [self._serialize_cells(row, conn) for row in chunked_rows] | ||
|
|
||
| if not values: | ||
| continue | ||
|
|
||
| sql = self._generate_upsert_sql( | ||
| table=table, | ||
| values=values[0], | ||
| target_fields=target_fields, | ||
| conflict_fields=conflict_fields, | ||
| update_fields=update_fields, | ||
| ) | ||
|
Comment on lines
+813
to
+819
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Related to my comment above - if we don't need to pass values since it's just used for the number of placeholders, I don't think we need to regenerate the same SQL string on every chunk. This can just be done once outside of the loop. |
||
|
|
||
| self.log.debug("Generated sql: %s", sql) | ||
|
|
||
| try: | ||
| if fast_executemany: | ||
| # execute_batch reduces round trips by batching parameter sets. | ||
| execute_batch( | ||
|
Comment on lines
+824
to
+826
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need a guard here since psycopg3 does not support |
||
| cur, | ||
| sql, | ||
| values, | ||
| page_size=commit_every, | ||
| ) | ||
| else: | ||
| cur.executemany(sql, values) | ||
|
|
||
| except Exception: | ||
| self.log.error("Generated sql: %s", sql) | ||
| self.log.error("Parameters: %s", values) | ||
| raise | ||
|
|
||
| conn.commit() | ||
|
|
||
| nb_rows += len(values) | ||
|
|
||
| self.log.info( | ||
| "Upserted %s rows into %s so far", | ||
| nb_rows, | ||
| table, | ||
| ) | ||
|
|
||
| if sql: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also related to my comment above, if we construct the query once outside the loop then this would need to be updated. This could be |
||
| send_sql_hook_lineage( | ||
| context=self, | ||
| sql=sql, | ||
| row_count=nb_rows, | ||
| ) | ||
|
|
||
| self.log.info( | ||
| "Done upserting. Upserted a total of %s rows into %s", | ||
| nb_rows, | ||
| table, | ||
| ) | ||
| return None | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to pass in the values here? The only thing it's used for is to produce the right number of placeholders but I think that can just be done with
len(target_fields).