diff --git a/src/DatabaseLibrary/query.py b/src/DatabaseLibrary/query.py index 9c979b7..80fa73d 100644 --- a/src/DatabaseLibrary/query.py +++ b/src/DatabaseLibrary/query.py @@ -22,7 +22,7 @@ class Query(object): Query handles all the querying done by the Database Library. """ - def query(self, selectStatement, sansTran=False, returnAsDict=False): + def query(self, selectStatement, sansTran=False, returnAsDict=False, parameters=None): """ Uses the input `selectStatement` to query for the values that will be returned as a list of tuples. Set optional input `sansTran` to True to run command without an explicit transaction commit or rollback. @@ -60,7 +60,7 @@ def query(self, selectStatement, sansTran=False, returnAsDict=False): try: cur = self._dbconnection.cursor() logger.info('Executing : Query | %s ' % selectStatement) - self.__execute_sql(cur, selectStatement) + self.__execute_sql(cur, selectStatement, parameters=parameters) allRows = cur.fetchall() if returnAsDict: @@ -80,7 +80,7 @@ def query(self, selectStatement, sansTran=False, returnAsDict=False): if not sansTran: self._dbconnection.rollback() - def row_count(self, selectStatement, sansTran=False): + def row_count(self, selectStatement, sansTran=False, parameters=None): """ Uses the input `selectStatement` to query the database and returns the number of rows from the query. Set optional input `sansTran` to True to run command without an explicit transaction commit or rollback. @@ -111,7 +111,7 @@ def row_count(self, selectStatement, sansTran=False): try: cur = self._dbconnection.cursor() logger.info('Executing : Row Count | %s ' % selectStatement) - self.__execute_sql(cur, selectStatement) + self.__execute_sql(cur, selectStatement, parameters=parameters) data = cur.fetchall() if self.db_api_module_name in ["sqlite3", "ibm_db", "ibm_db_dbi", "pyodbc"]: rowCount = len(data) @@ -123,7 +123,7 @@ def row_count(self, selectStatement, sansTran=False): if not sansTran: self._dbconnection.rollback() - def description(self, selectStatement, sansTran=False): + def description(self, selectStatement, sansTran=False, parameters=None): """ Uses the input `selectStatement` to query a table in the db which will be used to determine the description. Set optional input `sansTran` to True to run command without an explicit transaction commit or rollback. @@ -148,7 +148,7 @@ def description(self, selectStatement, sansTran=False): try: cur = self._dbconnection.cursor() logger.info('Executing : Description | %s ' % selectStatement) - self.__execute_sql(cur, selectStatement) + self.__execute_sql(cur, selectStatement, parameters=parameters) description = list(cur.description) if sys.version_info[0] < 3: for row in range(0, len(description)): @@ -319,7 +319,7 @@ def execute_sql_script(self, sqlScriptFileName, sansTran=False): if not sansTran: self._dbconnection.rollback() - def execute_sql_string(self, sqlString, sansTran=False): + def execute_sql_string(self, sqlString, sansTran=False, parameters=None): """ Executes the sqlString as SQL commands. Useful to pass arguments to your sql. Set optional input `sansTran` to True to run command without an explicit transaction commit or rollback. @@ -339,7 +339,7 @@ def execute_sql_string(self, sqlString, sansTran=False): try: cur = self._dbconnection.cursor() logger.info('Executing : Execute SQL String | %s ' % sqlString) - self.__execute_sql(cur, sqlString) + self.__execute_sql(cur, sqlString, parameters=parameters) if not sansTran: self._dbconnection.commit() finally: @@ -484,7 +484,7 @@ def call_stored_procedure(self, spName, spParams=None, sansTran=False): if not sansTran: self._dbconnection.rollback() - def __execute_sql(self, cur, sql_statement, omit_trailing_semicolon=None): + def __execute_sql(self, cur, sql_statement, omit_trailing_semicolon=None, parameters=None): """ Runs the `sql_statement` using `cur` as Cursor object. Use `omit_trailing_semicolon` parameter (bool) for explicite instruction, @@ -496,5 +496,7 @@ def __execute_sql(self, cur, sql_statement, omit_trailing_semicolon=None): omit_trailing_semicolon = self.omit_trailing_semicolon if omit_trailing_semicolon: sql_statement = sql_statement.rstrip(";") + if parameters is None: + parameters = [] logger.debug(f"Executing sql: {sql_statement}") - return cur.execute(sql_statement) + return cur.execute(sql_statement, parameters) diff --git a/test/tests/common_tests/basic_tests.robot b/test/tests/common_tests/basic_tests.robot index 2fe17bc..4db398c 100644 --- a/test/tests/common_tests/basic_tests.robot +++ b/test/tests/common_tests/basic_tests.robot @@ -16,6 +16,19 @@ SQL Statement Ending With Semicolon Works SQL Statement Ending Without Semicolon Works Query SELECT * FROM person; +SQL Statement With Parameters Works + @{params}= Create List 2 + + IF "${DB_MODULE}" in ["oracledb"] + ${output}= Query SELECT * FROM person WHERE id < :id parameters=${params} + ELSE IF "${DB_MODULE}" in ["sqlite3", "pyodbc"] + ${output}= Query SELECT * FROM person WHERE id < ? parameters=${params} + ELSE + ${output}= Query SELECT * FROM person WHERE id < %s parameters=${params} + END + + Length Should Be ${output} 1 + Create Person Table [Setup] Log No setup for this test ${output}= Create Person Table