Skip to content

Commit

Permalink
support loading custom dictionary (#301)
Browse files Browse the repository at this point in the history
TT-1198: Adds support for loading multiple dictionaries concurrently
  • Loading branch information
kulgan committed May 6, 2020
1 parent fbd13f3 commit 5375e46
Show file tree
Hide file tree
Showing 12 changed files with 453 additions and 120 deletions.
2 changes: 1 addition & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ jupyter-client==4.2.2
jupyter-console==4.1.1
jupyter-core==4.1.0
pytest==4.6.5
pytest-cov==2.7.1
pytest-cov==2.8.1
59 changes: 34 additions & 25 deletions gdcdatamodel/gdc_postgres_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import time

from collections import namedtuple

from psqlgraph.base import ORMBase
from sqlalchemy import create_engine
from sqlalchemy.exc import OperationalError

Expand All @@ -21,8 +23,7 @@

from psqlgraph import (
create_all,
Node,
Edge,
ext,
)

logging.basicConfig()
Expand Down Expand Up @@ -119,34 +120,37 @@ def get_engine(host, user, password, database):
return create_engine(con_str, connect_args=connect_args)


def execute_for_all_graph_tables(engine, sql, *args, **kwargs):
def execute_for_all_graph_tables(engine, sql, namespace=None, *args, **kwargs):
"""Execute a SQL statment that has a python format variable {table}
to be replaced with the tablename for all Node and Edge tables
"""
for cls in Node.__subclasses__() + Edge.__subclasses__():
node_cls = ext.get_abstract_node(namespace)
edge_cls = ext.get_abstract_edge(namespace)

for cls in node_cls.get_subclasses() + edge_cls.get_subclasses():
_kwargs = dict(kwargs, **{'table': cls.__tablename__})
statement = sql.format(**_kwargs)
execute(engine, statement)


def grant_read_permissions_to_graph(engine, user):
execute_for_all_graph_tables(engine, GRANT_READ_PRIVS_SQL, user=user)
def grant_read_permissions_to_graph(engine, user, namespace=None):
execute_for_all_graph_tables(engine, GRANT_READ_PRIVS_SQL, namespace, user=user)


def grant_write_permissions_to_graph(engine, user):
execute_for_all_graph_tables(engine, GRANT_WRITE_PRIVS_SQL, user=user)
def grant_write_permissions_to_graph(engine, user, namespace=None):
execute_for_all_graph_tables(engine, GRANT_WRITE_PRIVS_SQL, namespace, user=user)


def revoke_read_permissions_to_graph(engine, user):
execute_for_all_graph_tables(engine, REVOKE_READ_PRIVS_SQL, user=user)
def revoke_read_permissions_to_graph(engine, user, namespace=None):
execute_for_all_graph_tables(engine, REVOKE_READ_PRIVS_SQL, namespace, user=user)


def revoke_write_permissions_to_graph(engine, user):
execute_for_all_graph_tables(engine, REVOKE_WRITE_PRIVS_SQL, user=user)
def revoke_write_permissions_to_graph(engine, user, namespace=None):
execute_for_all_graph_tables(engine, REVOKE_WRITE_PRIVS_SQL, namespace, user=user)


def create_graph_tables(engine, timeout):
def create_graph_tables(engine, timeout, namespace=None):
"""
create a table
"""
Expand All @@ -159,7 +163,8 @@ def create_graph_tables(engine, timeout):
timeout_str = '{}s'.format(int(timeout+1))
connection.execute("SET LOCAL lock_timeout = %s;", timeout_str)

create_all(connection)
orm_base = ext.get_orm_base(namespace) if namespace else ORMBase
create_all(connection, base=orm_base)
trans.commit()


Expand Down Expand Up @@ -208,7 +213,7 @@ def kill_blocking_psql_backend_processes(engine):
blockers = lookup_blocking_psql_backend_processes(engine)

if is_blocked_by_no_kill(blockers):
logger.warn("Process blocked by a 'no-kill' process. "
logger.warning("Process blocked by a 'no-kill' process. "
"Refusing to kill it")
return

Expand All @@ -232,7 +237,7 @@ def kill_blocking_psql_backend_processes(engine):
execute(engine, sql_cmd)


def create_tables_force(engine, delay, retries):
def create_tables_force(engine, delay, retries, namespace=None):
"""Create the tables and **KILL ANY BLOCKING PROCESSES**.
This command will spawn a process to create the new tables in
Expand All @@ -247,7 +252,7 @@ def create_tables_force(engine, delay, retries):
logger.warning('Running with force=True option %s', app_name)

from multiprocessing import Process
p = Process(target=create_graph_tables, args=(engine, delay))
p = Process(target=create_graph_tables, args=(engine, delay, namespace))
p.start()
time.sleep(delay)

Expand All @@ -263,10 +268,10 @@ def create_tables_force(engine, delay, retries):
raise RuntimeError('Max retries exceeded.')

logger.warning('Table creation failed, retrying.')
return create_tables_force(engine, delay, retries-1)
return create_tables_force(engine, delay, retries-1, namespace=namespace)


def create_tables(engine, delay, retries):
def create_tables(engine, delay, retries, namespace=None):
"""Create the tables but do not kill any blocking processes.
This command will catch OperationalErrors signalling timeouts from
Expand All @@ -277,7 +282,7 @@ def create_tables(engine, delay, retries):

logger.info('Running table creator named %s', app_name)
try:
return create_graph_tables(engine, delay)
return create_graph_tables(engine, delay, namespace=namespace)

except OperationalError as e:
if 'timeout' in str(e):
Expand All @@ -293,7 +298,7 @@ def create_tables(engine, delay, retries):
.format(delay, retries))
time.sleep(delay)

create_tables(engine, delay, retries-1)
create_tables(engine, delay, retries-1, namespace=namespace)


def subcommand_create(args):
Expand All @@ -308,6 +313,7 @@ def subcommand_create(args):
engine=engine,
delay=args.delay,
retries=args.retries,
namespace=args.namespace
)

if args.force:
Expand All @@ -331,12 +337,12 @@ def subcommand_grant(args):
if args.read:
users_read = [u for u in args.read.split(',') if u]
for user in users_read:
grant_read_permissions_to_graph(engine, user)
grant_read_permissions_to_graph(engine, user, args.namespace)

if args.write:
users_write = [u for u in args.write.split(',') if u]
for user in users_write:
grant_write_permissions_to_graph(engine, user)
grant_write_permissions_to_graph(engine, user, args.namespace)


def subcommand_revoke(args):
Expand All @@ -352,12 +358,12 @@ def subcommand_revoke(args):
if args.read:
users_read = [u for u in args.read.split(',') if u]
for user in users_read:
revoke_read_permissions_to_graph(engine, user)
revoke_read_permissions_to_graph(engine, user, args.namespace)

if args.write:
users_write = [u for u in args.write.split(',') if u]
for user in users_write:
revoke_write_permissions_to_graph(engine, user)
revoke_write_permissions_to_graph(engine, user, args.namespace)


def add_base_args(subparser):
Expand All @@ -369,6 +375,8 @@ def add_base_args(subparser):
required=True, help="psql test database")
subparser.add_argument("-P", "--password", type=str, action="store",
default='', help="psql test password")
subparser.add_argument("-N", "--namespace", type=lambda x: x if x else None,
help="psqlgraph model namespace")
return subparser


Expand Down Expand Up @@ -437,6 +445,7 @@ def main(args=None):
logger.info("[ HOST : %-10s ]", args.host)
logger.info("[ DATABASE : %-10s ]", args.database)
logger.info("[ USER : %-10s ]", args.user)
logger.info("[ NAMESPACE : %-10s ]", args.namespace or "default")

return_value = {
'graph-create': subcommand_create,
Expand Down
Loading

0 comments on commit 5375e46

Please sign in to comment.