From 6c9be7ea6a7aa3f7f0b0d2ed949340f1c80761f6 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 28 Sep 2025 18:10:54 +0000 Subject: [PATCH 1/3] Initial plan From 4e07257c489e8a8ab1eacbe98eb4b5c0f7a8d37e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 28 Sep 2025 18:18:29 +0000 Subject: [PATCH 2/3] Implement upsert return_keys support and safe execution pattern for all methods Co-authored-by: faizanazim11 <20454506+faizanazim11@users.noreply.github.com> --- sql_db_utils/asyncio/sql_utils.py | 27 ++++++++++++++++++++++----- sql_db_utils/sql_utils.py | 27 ++++++++++++++++++++++----- 2 files changed, 44 insertions(+), 10 deletions(-) diff --git a/sql_db_utils/asyncio/sql_utils.py b/sql_db_utils/asyncio/sql_utils.py index 0ea6ce0..a7250d8 100644 --- a/sql_db_utils/asyncio/sql_utils.py +++ b/sql_db_utils/asyncio/sql_utils.py @@ -62,7 +62,11 @@ async def insert(self, data: Union[dict, list[dict]], return_keys: List[str] = N table = table if table is not None else self.table return_keys = return_keys or [] try: - insert_stmt = insert(table).values(data).returning(*(getattr(table.c, key) for key in return_keys)) + insert_stmt = ( + insert(table) + .values(data) + .returning(*(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys)) + ) return_values = await self.session.execute(insert_stmt) await self.session.commit() if return_keys: @@ -93,7 +97,7 @@ async def update_with_where( update(table) .values(data) .where(*where_conditions) - .returning(*(getattr(table.c, key) for key in return_keys)) + .returning(*(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys)) ) return_values = await self.session.execute(update_stmt) await self.session.commit() @@ -115,7 +119,10 @@ async def update(self, data: Union[dict, list[dict]], return_keys: List[str] = N return_keys = return_keys or [] try: return_values = await self.session.execute( - update(table).returning(*(getattr(table.c, key) for key in return_keys)), data + update(table).returning( + *(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys) + ), + data, ) await self.session.commit() if return_keys: @@ -124,24 +131,34 @@ async def update(self, data: Union[dict, list[dict]], return_keys: List[str] = N logger.error(f"Error occurred while updating: {e}", exc_info=True) raise e - async def upsert(self, insert_json: dict, primary_keys: List[str] = None, table: TableType = None): + async def upsert( + self, insert_json: dict, primary_keys: List[str] = None, return_keys: List[str] = None, table: TableType = None + ): """ Inserts or updates a row in the database. Args: insert_json (dict): A dictionary containing the data to be inserted or updated. primary_keys (List[str], optional): A list of primary key column names. Defaults to None. + return_keys (List[str], optional): A list of column names to return after the upsert. Defaults to None. table (TableType, optional): The SQLAlchemy declarative base object. Defaults to None. + + Returns: + A list of dictionaries containing the upserted data if return_keys is provided. """ table = table if table is not None else self.table + return_keys = return_keys or [] try: insert_statement = ( postgres_insert(table) .values(**insert_json) .on_conflict_do_update(index_elements=primary_keys, set_=insert_json) + .returning(*(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys)) ) - await self.session.execute(insert_statement) + return_values = await self.session.execute(insert_statement) await self.session.commit() + if return_keys: + return jsonable_encoder(return_values.mappings().all()) except Exception as e: logger.error(f"Error while upserting the record {e}", exc_info=True) raise e diff --git a/sql_db_utils/sql_utils.py b/sql_db_utils/sql_utils.py index 59d2899..82799de 100644 --- a/sql_db_utils/sql_utils.py +++ b/sql_db_utils/sql_utils.py @@ -52,7 +52,11 @@ def insert(self, data: Union[dict, list[dict]], return_keys: List[str] = None, t table = table if table is not None else self.table return_keys = return_keys or [] try: - insert_stmt = insert(table).values(data).returning(*(getattr(table.c, key) for key in return_keys)) + insert_stmt = ( + insert(table) + .values(data) + .returning(*(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys)) + ) return_values = self.session.execute(insert_stmt) self.session.commit() if return_keys: @@ -83,7 +87,7 @@ def update_with_where( update(table) .values(data) .where(*where_conditions) - .returning(*(getattr(table.c, key) for key in return_keys)) + .returning(*(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys)) ) return_values = self.session.execute(update_stmt) self.session.commit() @@ -105,7 +109,10 @@ def update(self, data: Union[dict, list[dict]], return_keys: List[str] = None, t return_keys = return_keys or [] try: return_values = self.session.execute( - update(table).returning(*(getattr(table.c, key) for key in return_keys)), data + update(table).returning( + *(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys) + ), + data, ) self.session.commit() if return_keys: @@ -114,24 +121,34 @@ def update(self, data: Union[dict, list[dict]], return_keys: List[str] = None, t logger.error(f"Error occurred while updating: {e}", exc_info=True) raise e - def upsert(self, insert_json: dict, primary_keys: List[str] = None, table: TableType = None): + def upsert( + self, insert_json: dict, primary_keys: List[str] = None, return_keys: List[str] = None, table: TableType = None + ): """ Inserts or updates a row in the database. Args: insert_json (dict): A dictionary containing the data to be inserted or updated. primary_keys (List[str], optional): A list of primary key column names. Defaults to None. + return_keys (List[str], optional): A list of column names to return after the upsert. Defaults to None. table (TableType, optional): The SQLAlchemy declarative base object. Defaults to None. + + Returns: + A list of dictionaries containing the upserted data if return_keys is provided. """ table = table if table is not None else self.table + return_keys = return_keys or [] try: insert_statement = ( postgres_insert(table) .values(**insert_json) .on_conflict_do_update(index_elements=primary_keys, set_=insert_json) + .returning(*(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys)) ) - self.session.execute(insert_statement) + return_values = self.session.execute(insert_statement) self.session.commit() + if return_keys: + return jsonable_encoder(return_values.mappings().all()) except Exception as e: logger.error(f"Error while upserting the record {e}", exc_info=True) raise e From 3488da7c1003cb6cfa5062719ea7ef2bc38bd57c Mon Sep 17 00:00:00 2001 From: Faizan Azim Date: Sun, 28 Sep 2025 18:23:05 +0000 Subject: [PATCH 3/3] chore(version): :bookmark: update version to 1.3.0 --- sql_db_utils/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql_db_utils/__version__.py b/sql_db_utils/__version__.py index c68196d..67bc602 100644 --- a/sql_db_utils/__version__.py +++ b/sql_db_utils/__version__.py @@ -1 +1 @@ -__version__ = "1.2.0" +__version__ = "1.3.0"