Skip to content

Commit

Permalink
Set service-based Gremlin protocol defaults for module-generated configs
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelnchin committed Jun 18, 2024
1 parent 9c176c0 commit 7c64ae5
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 13 deletions.
14 changes: 11 additions & 3 deletions src/graph_notebook/configuration/generate_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
DEFAULT_NEO4J_USERNAME, DEFAULT_NEO4J_PASSWORD, DEFAULT_NEO4J_DATABASE,
NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host, false_str_variants,
GRAPHSONV3_VARIANTS, GRAPHSONV2_VARIANTS, GRAPHBINARYV1_VARIANTS,
NEPTUNE_DB_SERVICE_NAME, normalize_service_name)
NEPTUNE_DB_SERVICE_NAME, NEPTUNE_ANALYTICS_SERVICE_NAME,
normalize_service_name)

DEFAULT_CONFIG_LOCATION = os.path.expanduser('~/graph_notebook_config.json')

Expand Down Expand Up @@ -299,7 +300,7 @@ def generate_default_config():
default=DEFAULT_GREMLIN_SERIALIZER)
parser.add_argument("--gremlin_connection_protocol",
help="the connection protocol to use for Gremlin connections",
default=DEFAULT_GREMLIN_PROTOCOL)
default='')
parser.add_argument("--neo4j_username", help="the username to use for Neo4J connections",
default=DEFAULT_NEO4J_USERNAME)
parser.add_argument("--neo4j_password", help="the password to use for Neo4J connections",
Expand All @@ -311,6 +312,13 @@ def generate_default_config():
args = parser.parse_args()

auth_mode_arg = args.auth_mode if args.auth_mode != '' else AuthModeEnum.DEFAULT.value
protocol_arg = args.gremlin_connection_protocol
include_protocol = False
if is_allowed_neptune_host(args.host, args.neptune_hosts):
include_protocol = True
if not protocol_arg:
protocol_arg = DEFAULT_HTTP_PROTOCOL \
if args.neptune_service == NEPTUNE_ANALYTICS_SERVICE_NAME else DEFAULT_WS_PROTOCOL
config = generate_config(args.host, int(args.port),
AuthModeEnum(auth_mode_arg),
args.ssl, args.ssl_verify,
Expand All @@ -319,7 +327,7 @@ def generate_default_config():
SparqlSection(args.sparql_path, ''),
GremlinSection(args.gremlin_traversal_source, args.gremlin_username,
args.gremlin_password, args.gremlin_serializer,
args.gremlin_connection_protocol),
protocol_arg, include_protocol),
Neo4JSection(args.neo4j_username, args.neo4j_password,
args.neo4j_auth, args.neo4j_database),
args.neptune_hosts)
Expand Down
100 changes: 90 additions & 10 deletions test/unit/configuration/test_configuration_from_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from graph_notebook.configuration.generate_config import AuthModeEnum, Configuration, GremlinSection
from graph_notebook.configuration.get_config import get_config
from graph_notebook.neptune.client import (NEPTUNE_DB_SERVICE_NAME, NEPTUNE_ANALYTICS_SERVICE_NAME,
DEFAULT_HTTP_PROTOCOL, DEFAULT_WS_PROTOCOL)


class TestGenerateConfigurationMain(unittest.TestCase):
Expand Down Expand Up @@ -49,25 +51,61 @@ def test_generate_configuration_main_defaults_generic(self):
self.generate_config_from_main_and_test(expected_config)

def test_generate_configuration_main_override_defaults_neptune_reg(self):
expected_config = Configuration(self.neptune_host_reg, self.port, neptune_service='neptune-graph',
auth_mode=AuthModeEnum.IAM, load_from_s3_arn='loader_arn', ssl=False)
expected_config = Configuration(self.neptune_host_reg,
self.port,
neptune_service='neptune-graph',
auth_mode=AuthModeEnum.IAM,
load_from_s3_arn='loader_arn',
ssl=False,
gremlin_section=GremlinSection(
connection_protocol=DEFAULT_HTTP_PROTOCOL,
include_protocol=True
)
)
self.generate_config_from_main_and_test(expected_config, host_type='neptune')

def test_generate_configuration_main_override_defaults_neptune_no_verify(self):
expected_config = Configuration(self.neptune_host_reg, self.port, neptune_service='neptune-graph',
auth_mode=AuthModeEnum.IAM, load_from_s3_arn='loader_arn',
ssl=True, ssl_verify=False)
expected_config = Configuration(self.neptune_host_reg,
self.port,
neptune_service='neptune-graph',
auth_mode=AuthModeEnum.IAM,
load_from_s3_arn='loader_arn',
ssl=True,
ssl_verify=False,
gremlin_section=GremlinSection(
connection_protocol=DEFAULT_HTTP_PROTOCOL,
include_protocol=True
)
)
self.generate_config_from_main_and_test(expected_config, host_type='neptune')

def test_generate_configuration_main_override_defaults_neptune_with_serializer(self):
expected_config = Configuration(self.neptune_host_reg, self.port, neptune_service='neptune-graph',
auth_mode=AuthModeEnum.IAM, load_from_s3_arn='loader_arn', ssl=False,
gremlin_section=GremlinSection(message_serializer='graphbinary'))
expected_config = Configuration(self.neptune_host_reg,
self.port,
neptune_service='neptune-graph',
auth_mode=AuthModeEnum.IAM,
load_from_s3_arn='loader_arn',
ssl=False,
gremlin_section=GremlinSection(
message_serializer='graphbinary',
connection_protocol=DEFAULT_HTTP_PROTOCOL,
include_protocol=True
)
)
self.generate_config_from_main_and_test(expected_config, host_type='neptune')

def test_generate_configuration_main_override_defaults_neptune_cn(self):
expected_config = Configuration(self.neptune_host_cn, self.port, neptune_service='neptune-graph',
auth_mode=AuthModeEnum.IAM, load_from_s3_arn='loader_arn', ssl=False)
expected_config = Configuration(self.neptune_host_cn,
self.port,
neptune_service='neptune-graph',
auth_mode=AuthModeEnum.IAM,
load_from_s3_arn='loader_arn',
ssl=False,
gremlin_section=GremlinSection(
connection_protocol=DEFAULT_HTTP_PROTOCOL,
include_protocol=True
)
)
self.generate_config_from_main_and_test(expected_config, host_type='neptune')

def test_generate_configuration_main_override_defaults_generic(self):
Expand All @@ -85,6 +123,48 @@ def test_generate_configuration_main_empty_args_neptune(self):
config = get_config(self.test_file_path)
self.assertEqual(expected_config.to_dict(), config.to_dict())

def test_generate_configuration_main_gremlin_protocol_no_service(self):
result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config '
f'--host "{self.neptune_host_reg}" '
f'--port "{self.port}" '
f'--neptune_service "" '
f'--auth_mode "" '
f'--ssl "" '
f'--load_from_s3_arn "" '
f'--config_destination="{self.test_file_path}" ')
self.assertEqual(0, result)
config = get_config(self.test_file_path)
config_dict = config.to_dict()
self.assertEqual(DEFAULT_WS_PROTOCOL, config_dict['gremlin']['connection_protocol'])

def test_generate_configuration_main_gremlin_protocol_db(self):
result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config '
f'--host "{self.neptune_host_reg}" '
f'--port "{self.port}" '
f'--neptune_service "{NEPTUNE_DB_SERVICE_NAME}" '
f'--auth_mode "" '
f'--ssl "" '
f'--load_from_s3_arn "" '
f'--config_destination="{self.test_file_path}" ')
self.assertEqual(0, result)
config = get_config(self.test_file_path)
config_dict = config.to_dict()
self.assertEqual(DEFAULT_WS_PROTOCOL, config_dict['gremlin']['connection_protocol'])

def test_generate_configuration_main_gremlin_protocol_analytics(self):
result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config '
f'--host "{self.neptune_host_reg}" '
f'--port "{self.port}" '
f'--neptune_service "{NEPTUNE_ANALYTICS_SERVICE_NAME}" '
f'--auth_mode "" '
f'--ssl "" '
f'--load_from_s3_arn "" '
f'--config_destination="{self.test_file_path}" ')
self.assertEqual(0, result)
config = get_config(self.test_file_path)
config_dict = config.to_dict()
self.assertEqual(DEFAULT_HTTP_PROTOCOL, config_dict['gremlin']['connection_protocol'])

def test_generate_configuration_main_empty_args_custom(self):
expected_config = Configuration(self.neptune_host_custom, self.port, neptune_hosts=self.custom_hosts_list)
result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config '
Expand Down

0 comments on commit 7c64ae5

Please sign in to comment.