diff --git a/sql_db_utils/__version__.py b/sql_db_utils/__version__.py index c68196d..67bc602 100644 --- a/sql_db_utils/__version__.py +++ b/sql_db_utils/__version__.py @@ -1 +1 @@ -__version__ = "1.2.0" +__version__ = "1.3.0" diff --git a/sql_db_utils/asyncio/sql_utils.py b/sql_db_utils/asyncio/sql_utils.py index 0ea6ce0..a7250d8 100644 --- a/sql_db_utils/asyncio/sql_utils.py +++ b/sql_db_utils/asyncio/sql_utils.py @@ -62,7 +62,11 @@ async def insert(self, data: Union[dict, list[dict]], return_keys: List[str] = N table = table if table is not None else self.table return_keys = return_keys or [] try: - insert_stmt = insert(table).values(data).returning(*(getattr(table.c, key) for key in return_keys)) + insert_stmt = ( + insert(table) + .values(data) + .returning(*(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys)) + ) return_values = await self.session.execute(insert_stmt) await self.session.commit() if return_keys: @@ -93,7 +97,7 @@ async def update_with_where( update(table) .values(data) .where(*where_conditions) - .returning(*(getattr(table.c, key) for key in return_keys)) + .returning(*(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys)) ) return_values = await self.session.execute(update_stmt) await self.session.commit() @@ -115,7 +119,10 @@ async def update(self, data: Union[dict, list[dict]], return_keys: List[str] = N return_keys = return_keys or [] try: return_values = await self.session.execute( - update(table).returning(*(getattr(table.c, key) for key in return_keys)), data + update(table).returning( + *(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys) + ), + data, ) await self.session.commit() if return_keys: @@ -124,24 +131,34 @@ async def update(self, data: Union[dict, list[dict]], return_keys: List[str] = N logger.error(f"Error occurred while updating: {e}", exc_info=True) raise e - async def upsert(self, insert_json: dict, primary_keys: List[str] = None, table: TableType = None): + async def upsert( + self, insert_json: dict, primary_keys: List[str] = None, return_keys: List[str] = None, table: TableType = None + ): """ Inserts or updates a row in the database. Args: insert_json (dict): A dictionary containing the data to be inserted or updated. primary_keys (List[str], optional): A list of primary key column names. Defaults to None. + return_keys (List[str], optional): A list of column names to return after the upsert. Defaults to None. table (TableType, optional): The SQLAlchemy declarative base object. Defaults to None. + + Returns: + A list of dictionaries containing the upserted data if return_keys is provided. """ table = table if table is not None else self.table + return_keys = return_keys or [] try: insert_statement = ( postgres_insert(table) .values(**insert_json) .on_conflict_do_update(index_elements=primary_keys, set_=insert_json) + .returning(*(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys)) ) - await self.session.execute(insert_statement) + return_values = await self.session.execute(insert_statement) await self.session.commit() + if return_keys: + return jsonable_encoder(return_values.mappings().all()) except Exception as e: logger.error(f"Error while upserting the record {e}", exc_info=True) raise e diff --git a/sql_db_utils/sql_utils.py b/sql_db_utils/sql_utils.py index 59d2899..82799de 100644 --- a/sql_db_utils/sql_utils.py +++ b/sql_db_utils/sql_utils.py @@ -52,7 +52,11 @@ def insert(self, data: Union[dict, list[dict]], return_keys: List[str] = None, t table = table if table is not None else self.table return_keys = return_keys or [] try: - insert_stmt = insert(table).values(data).returning(*(getattr(table.c, key) for key in return_keys)) + insert_stmt = ( + insert(table) + .values(data) + .returning(*(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys)) + ) return_values = self.session.execute(insert_stmt) self.session.commit() if return_keys: @@ -83,7 +87,7 @@ def update_with_where( update(table) .values(data) .where(*where_conditions) - .returning(*(getattr(table.c, key) for key in return_keys)) + .returning(*(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys)) ) return_values = self.session.execute(update_stmt) self.session.commit() @@ -105,7 +109,10 @@ def update(self, data: Union[dict, list[dict]], return_keys: List[str] = None, t return_keys = return_keys or [] try: return_values = self.session.execute( - update(table).returning(*(getattr(table.c, key) for key in return_keys)), data + update(table).returning( + *(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys) + ), + data, ) self.session.commit() if return_keys: @@ -114,24 +121,34 @@ def update(self, data: Union[dict, list[dict]], return_keys: List[str] = None, t logger.error(f"Error occurred while updating: {e}", exc_info=True) raise e - def upsert(self, insert_json: dict, primary_keys: List[str] = None, table: TableType = None): + def upsert( + self, insert_json: dict, primary_keys: List[str] = None, return_keys: List[str] = None, table: TableType = None + ): """ Inserts or updates a row in the database. Args: insert_json (dict): A dictionary containing the data to be inserted or updated. primary_keys (List[str], optional): A list of primary key column names. Defaults to None. + return_keys (List[str], optional): A list of column names to return after the upsert. Defaults to None. table (TableType, optional): The SQLAlchemy declarative base object. Defaults to None. + + Returns: + A list of dictionaries containing the upserted data if return_keys is provided. """ table = table if table is not None else self.table + return_keys = return_keys or [] try: insert_statement = ( postgres_insert(table) .values(**insert_json) .on_conflict_do_update(index_elements=primary_keys, set_=insert_json) + .returning(*(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys)) ) - self.session.execute(insert_statement) + return_values = self.session.execute(insert_statement) self.session.commit() + if return_keys: + return jsonable_encoder(return_values.mappings().all()) except Exception as e: logger.error(f"Error while upserting the record {e}", exc_info=True) raise e