Skip to content

Commit

Permalink
SQL Alchemy 2.x.x (demisto#29436)
Browse files Browse the repository at this point in the history
* MySQL and Postgress works

* MSSQL, My SQL and postgres works with bind_variables from the second form

* resolve conflicts

* fix CR's comments

* pre commit

* parsing the results

* Add UT

* same name and right docker

* RN

* sourcery

* another docker image

* revert docker image

* Update Packs/GenericSQL/ReleaseNotes/1_0_25.md

Co-authored-by: dorschw <81086590+dorschw@users.noreply.github.com>

* Update Packs/GenericSQL/Integrations/GenericSQL/GenericSQL.py

Co-authored-by: dorschw <81086590+dorschw@users.noreply.github.com>

* Update Packs/GenericSQL/Integrations/GenericSQL/GenericSQL.py

Co-authored-by: dorschw <81086590+dorschw@users.noreply.github.com>

* Update Packs/GenericSQL/Integrations/GenericSQL/GenericSQL.py

Co-authored-by: dorschw <81086590+dorschw@users.noreply.github.com>

* fix variable name

* constants

* mapping instead of conditions

* unskip Oracle TPB

* resolve conflicts

* resolve conflicts

* Constants

* Update Packs/GenericSQL/Integrations/GenericSQL/GenericSQL.py

Co-authored-by: dorschw <81086590+dorschw@users.noreply.github.com>

* CR fixes

* Update Packs/GenericSQL/ReleaseNotes/1_1_0.md

Co-authored-by: dorschw <81086590+dorschw@users.noreply.github.com>

* add commit after executing a query

* fix UT

* remove autocommit true from MSSQL

* fix UT

* autocommit for
MSSQL, commit for the others

* commit for the others DBs, since in MSSQL is automatically

* docker image

---------

Co-authored-by: dorschw <81086590+dorschw@users.noreply.github.com>
  • Loading branch information
2 people authored and maimorag committed Sep 28, 2023
1 parent 07df767 commit 2bed962
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 57 deletions.
157 changes: 103 additions & 54 deletions Packs/GenericSQL/Integrations/GenericSQL/GenericSQL.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from CommonServerPython import *
from CommonServerUserPython import *

from typing import Any, Tuple, Dict, List, Callable, Optional
from typing import Any
from collections.abc import Callable
import sqlalchemy
import pymysql
import hashlib
Expand All @@ -11,6 +12,12 @@
from sqlalchemy.engine.url import URL
from urllib.parse import parse_qsl
import dateparser

ORACLE = "Oracle"
POSTGRES_SQL = "PostgreSQL"
MY_SQL = "MySQL"
MS_ODBC_DRIVER = "Microsoft SQL Server - MS ODBC Driver"
MICROSOFT_SQL_SERVER = "Microsoft SQL Server"
FETCH_DEFAULT_LIMIT = '50'

try:
Expand Down Expand Up @@ -64,13 +71,13 @@ def parse_connect_parameters(connect_parameters: str, dialect: str, verify_certi
A dict with the keys and values.
"""
connect_parameters_tuple_list = parse_qsl(connect_parameters, keep_blank_values=True)
connect_parameters_dict = dict()
connect_parameters_dict = {}
for key, value in connect_parameters_tuple_list:
connect_parameters_dict[key] = value
if dialect == "Microsoft SQL Server":
if dialect == MICROSOFT_SQL_SERVER:
connect_parameters_dict['driver'] = 'FreeTDS'
connect_parameters_dict.setdefault('autocommit', 'True')
elif dialect == 'Microsoft SQL Server - MS ODBC Driver':
elif dialect == MS_ODBC_DRIVER:
connect_parameters_dict['driver'] = 'ODBC Driver 18 for SQL Server'
connect_parameters_dict.setdefault('autocommit', 'True')
if not verify_certificate:
Expand All @@ -84,13 +91,13 @@ def _convert_dialect_to_module(dialect: str) -> str:
:param dialect: the SQL db
:return: a key string needed for the connection
"""
if dialect == "MySQL":
if dialect == MY_SQL:
module = "mysql"
elif dialect == "PostgreSQL":
elif dialect == POSTGRES_SQL:
module = "postgresql"
elif dialect == "Oracle":
elif dialect == ORACLE:
module = "oracle"
elif dialect in {"Microsoft SQL Server", 'Microsoft SQL Server - MS ODBC Driver'}:
elif dialect in {MICROSOFT_SQL_SERVER, MS_ODBC_DRIVER}:
module = "mssql+pyodbc"
else:
module = str(dialect)
Expand Down Expand Up @@ -119,11 +126,11 @@ def _create_engine_and_connect(self) -> sqlalchemy.engine.base.Connection:
username=self.username,
password=self.password,
host=self.host,
port=self.port,
port=arg_to_number(self.port),
database=self.dbname,
query=self.connect_parameters)
if self.ssl_connect:
if self.dialect == 'PostgreSQL':
if self.dialect == POSTGRES_SQL:
ssl_connection = {'sslmode': 'require'}
else:
ssl_connection = {'ssl': {'ssl-mode': 'preferred'}} # type: ignore[dict-item]
Expand All @@ -143,7 +150,7 @@ def _create_engine_and_connect(self) -> sqlalchemy.engine.base.Connection:
poolclass=sqlalchemy.pool.NullPool)
return engine.connect()

def sql_query_execute_request(self, sql_query: str, bind_vars: Any, fetch_limit=0) -> Tuple[Dict, List]:
def sql_query_execute_request(self, sql_query: str, bind_vars: Any, fetch_limit=0) -> tuple[list[dict], list]:
"""Execute query in DB via engine
:param bind_vars: in case there are names and values - a bind_var dict, in case there are only values - list
:param sql_query: the SQL query
Expand All @@ -154,16 +161,26 @@ def sql_query_execute_request(self, sql_query: str, bind_vars: Any, fetch_limit=
sql_query = text(sql_query)

result = self.connection.execute(sql_query, bind_vars)
# For avoiding responses with lots of records
results = result.fetchmany(fetch_limit) if fetch_limit else result.fetchall()

# for MSSQL autocommit is True, so no need to commit again here
if self.dialect not in {MICROSOFT_SQL_SERVER, MS_ODBC_DRIVER}:
self.connection.commit()
# extracting the table from the response
if fetch_limit:
table = result.mappings().fetchmany(fetch_limit)
else:
table = result.mappings().fetchall()
results = [dict(row) for row in table]

headers = []
if results:
# if the table isn't empty
headers = list(results[0].keys() if results[0].keys() else '')
headers = list(results[0].keys() or '')

return results, headers


def generate_default_port_by_dialect(dialect: str) -> Optional[str]:
def generate_default_port_by_dialect(dialect: str) -> str | None:
"""
In case no port was chosen, a default port will be chosen according to the SQL db type. Only return a port for
Microsoft SQL Server and ODBC Driver 18 for SQL Server where it seems to be required.
Expand All @@ -176,27 +193,60 @@ def generate_default_port_by_dialect(dialect: str) -> Optional[str]:
return None


def generate_bind_vars(bind_variables_names: str, bind_variables_values: str) -> Any:
def generate_variable_names_and_mapping(bind_variables_values_list: list, query: str, dialect: str) ->\
tuple[dict[str, Any], str | Any]:
"""
In case of passing just bind_variables_values, since it's no longer supported in SQL Alchemy v2.,
this function generates names for those variables and return an edited query with a mapping.
Args:
bind_variables_values_list: Values to put in the bind variables
query: The given query which contains chars to replace
dialect: The DB dialect
Returns: A mapping (dict) and an edited query.
"""
# For counting and replacing, re.findall needs "\\?", whereas replace needs "?"
mapping_dialect_regex = {MICROSOFT_SQL_SERVER: ("\\?", "?"),
MS_ODBC_DRIVER: ("\\?", "?"),
POSTGRES_SQL: ("%s", "%s"),
MY_SQL: ("%s", "%s"),
ORACLE: ("%s", "%s")
}

# dialect is a configuration parameter with multiple choices, so it should be one of the keys in the mapping
char_to_count, char_to_replace = mapping_dialect_regex[dialect]

bind_variables_names_list = []
for i in range(len(re.findall(char_to_count, query))):
query = query.replace(char_to_replace, f":bind_variable_{i+1}", 1)
bind_variables_names_list.append(f"bind_variable_{i+1}")
return dict(zip(bind_variables_names_list, bind_variables_values_list)), query


def generate_bind_vars(bind_variables_names: str, bind_variables_values: str, query: str, dialect: str) -> tuple[dict, str]:
"""
The bind variables can be given in 2 legal ways: as 2 lists - names and values, or only values
any way defines a different executing way, therefore there are 2 legal return types
:param bind_variables_names: the names of the bind variables, must be in the length of the values list
:param bind_variables_values: the values of the bind variables, can be in the length of the names list
or in case there is no name lists - at any length
:return: a dict or lists of the bind variables
:param query: the given sql query
:param dialect: the given dialect
:return: a dict or lists of the bind variables and the edited query
"""
bind_variables_names_list = argToList(bind_variables_names)
bind_variables_values_list = argToList(bind_variables_values)

if bind_variables_values and not bind_variables_names:
return [var for var in argToList(bind_variables_values)]
elif len(bind_variables_names_list) is len(bind_variables_values_list):
return dict(zip(bind_variables_names_list, bind_variables_values_list))
return generate_variable_names_and_mapping(bind_variables_values_list, query, dialect)
elif len(bind_variables_names_list) == len(bind_variables_values_list):
return dict(zip(bind_variables_names_list, bind_variables_values_list)), query
else:
raise Exception("The bind variables lists are not is the same length")


def test_module(client: Client, *_) -> Tuple[str, Dict[Any, Any], List[Any]]:
def test_module(client: Client, *_) -> tuple[str, dict[Any, Any], list[Any]]:
"""
If the connection in the client was successful the test will return OK
if it wasn't an exception will be raised
Expand All @@ -215,10 +265,9 @@ def test_module(client: Client, *_) -> Tuple[str, Dict[Any, Any], List[Any]]:
if limit < 1 or limit > 50:
msg += 'Fetch Limit value should be between 1 and 50. '

if params.get('fetch_parameters') == 'ID and timestamp':
if not (params.get('column_name') and params.get('id_column')):
msg += 'Missing Fetch Column or ID Column name (when ID and timestamp are chosen,' \
' fill in both). '
if params.get('fetch_parameters') == 'ID and timestamp' and not (params.get('column_name') and params.get('id_column')):
msg += 'Missing Fetch Column or ID Column name (when ID and timestamp are chosen,' \
' fill in both). '

if params.get('fetch_parameters') in ['Unique ascending', 'Unique timestamp']:
if not params.get('column_name'):
Expand Down Expand Up @@ -268,7 +317,7 @@ def test_module(client: Client, *_) -> Tuple[str, Dict[Any, Any], List[Any]]:
return msg if msg else 'ok', {}, []


def sql_query_execute(client: Client, args: dict, *_) -> Tuple[str, Dict[str, Any], List[Dict[str, Any]]]:
def sql_query_execute(client: Client, args: dict, *_) -> tuple[str, dict[str, Any], list[dict[str, Any]]]:
"""
Executes the sql query with the connection that was configured in the client
:param client: the client object with the db connection
Expand All @@ -281,24 +330,22 @@ def sql_query_execute(client: Client, args: dict, *_) -> Tuple[str, Dict[str, An
skip = int(args.get('skip', 0))
bind_variables_names = args.get('bind_variables_names', "")
bind_variables_values = args.get('bind_variables_values', "")
bind_variables = generate_bind_vars(bind_variables_names, bind_variables_values)
bind_variables, sql_query = generate_bind_vars(bind_variables_names, bind_variables_values, sql_query, client.dialect)

result, headers = client.sql_query_execute_request(sql_query, bind_variables)
# converting an sqlalchemy object to a table
converted_table = [dict(row) for row in result]
# converting b'' and datetime objects to readable ones
table = [{str(key): str(value) for key, value in dictionary.items()} for dictionary in converted_table]
table = table[skip:skip + limit]
human_readable = tableToMarkdown(name="Query result:", t=table, headers=headers,
removeNull=True)
result = convert_sqlalchemy_to_readable_table(result)
result = result[skip:skip + limit]

human_readable = tableToMarkdown(name="Query result", t=result, headers=headers, removeNull=True)

context = {
"Result": table,
"Result": result,
"Headers": headers,
"Query": sql_query,
"InstanceName": f"{client.dialect}_{client.dbname}",
}
entry_context: Dict = {'GenericSQL(val.Query && val.Query === obj.Query)': {'GenericSQL': context}}
return human_readable, entry_context, table
entry_context: dict = {'GenericSQL(val.Query && val.Query === obj.Query)': {'GenericSQL': context}}
return human_readable, entry_context, result

except Exception as err:
# In case there is no query executed and only an action e.g - insert, delete, update
Expand Down Expand Up @@ -330,7 +377,7 @@ def initialize_last_run(fetch_parameters: str, first_fetch: str):
last_run = {'last_timestamp': False, 'last_id': first_fetch}

# for the case when we get timestamp and id - need to maintain an id's list
last_run['ids'] = list()
last_run['ids'] = []

return last_run

Expand Down Expand Up @@ -371,7 +418,7 @@ def create_sql_query(last_run: dict, query: str, column_name: str, max_fetch: st
return sql_query


def convert_sqlalchemy_to_readable_table(result: dict):
def convert_sqlalchemy_to_readable_table(result: list[dict]):
"""
Args:
Expand All @@ -380,22 +427,19 @@ def convert_sqlalchemy_to_readable_table(result: dict):
Returns:
"""
# converting a sqlalchemy object to a table
converted_table = [dict(row) for row in result]
# converting b'' and datetime objects to readable ones
incidents = [{str(key): str(value) for key, value in dictionary.items()} for dictionary in converted_table]
return incidents
return [{str(key): str(value) for key, value in dictionary.items()} for dictionary in result]


def update_last_run_after_fetch(table: List[dict], last_run: dict, fetch_parameters: str, column_name: str,
def update_last_run_after_fetch(table: list[dict], last_run: dict, fetch_parameters: str, column_name: str,
id_column: str):
is_timestamp_and_id = True if fetch_parameters == 'ID and timestamp' else False
is_timestamp_and_id = fetch_parameters == 'ID and timestamp'
if last_run.get('last_timestamp'):
last_record_timestamp = table[-1].get(column_name, '')

# keep the id's for the next fetch cycle for avoiding duplicates
if is_timestamp_and_id:
new_ids_list = list()
new_ids_list = []
for record in table:
if record.get(column_name) == last_record_timestamp:
new_ids_list.append(record.get(id_column))
Expand All @@ -417,19 +461,24 @@ def update_last_run_after_fetch(table: List[dict], last_run: dict, fetch_paramet
return last_run


def table_to_incidents(table: List[dict], last_run: dict, fetch_parameters: str, column_name: str, id_column: str,
incident_name: str) -> List[Dict[str, Any]]:
def table_to_incidents(table: list[dict], last_run: dict, fetch_parameters: str, column_name: str, id_column: str,
incident_name: str) -> list[dict[str, Any]]:
incidents = []
is_timestamp_and_id = True if fetch_parameters == 'ID and timestamp' else False
is_timestamp_and_id = fetch_parameters == 'ID and timestamp'
for record in table:

timestamp = record.get(column_name) if last_run.get('last_timestamp') else None
date_time = dateparser.parse(timestamp) if timestamp else datetime.now()

# for avoiding duplicate incidents
if is_timestamp_and_id and record.get(column_name, '').startswith(last_run.get('last_timestamp')):
if record.get(id_column, '') in last_run.get('ids', []):
continue
if (
is_timestamp_and_id
and record.get(column_name, '').startswith(
last_run.get('last_timestamp')
)
and record.get(id_column, '') in last_run.get('ids', [])
):
continue

record['type'] = 'GenericSQL Record'
incident_context = {
Expand Down Expand Up @@ -476,7 +525,7 @@ def fetch_incidents(client: Client, params: dict):
table = convert_sqlalchemy_to_readable_table(result)
table = table[:limit_fetch]

incidents: List[Dict[str, Any]] = table_to_incidents(table, last_run, params.get('fetch_parameters', ''),
incidents: list[dict[str, Any]] = table_to_incidents(table, last_run, params.get('fetch_parameters', ''),
params.get('column_name', ''), params.get('id_column', ''),
params.get('incident_name', ''))

Expand Down Expand Up @@ -537,7 +586,7 @@ def main():
port=port, database=database, connect_parameters=connect_parameters,
ssl_connect=ssl_connect, use_pool=use_pool, verify_certificate=verify_certificate,
pool_ttl=pool_ttl)
commands: Dict[str, Callable[[Client, Dict[str, str], str], Tuple[str, Dict[Any, Any], List[Any]]]] = {
commands: dict[str, Callable[[Client, dict[str, str], str], tuple[str, dict[Any, Any], list[Any]]]] = {
'test-module': test_module,
'query': sql_query_execute,
'pgsql-query': sql_query_execute,
Expand Down
2 changes: 1 addition & 1 deletion Packs/GenericSQL/Integrations/GenericSQL/GenericSQL.yml
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ script:
name: bind_variables_values
description: Running a sql query
name: sql-command
dockerimage: demisto/genericsql:1.1.0.62758
dockerimage: demisto/genericsql:1.1.0.75523
isfetch: true
runonce: false
script: '-'
Expand Down

0 comments on commit 2bed962

Please sign in to comment.