diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index de3fb15e..b0b0f40d 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -81,37 +81,27 @@ jobs: - name: Give user /etc/hosts permission run: | sudo chmod 777 /etc/hosts - - name: Run Basic Tests + - name: Generate test configuration run: | - python test/integration/NeptuneIntegrationWorkflowSteps.py \ - run-tests \ - --pattern "*without_iam.py" \ + python test/integration/NeptuneIntegrationWorkflowSteps.py generate-config \ --cfn-stack-name ${{ needs.generate-stack-name.outputs.stack-name }} \ --aws-region ${{ secrets.AWS_REGION }} - - name: Run Networkx Tests + - name: Run Basic Tests + env: + GRAPH_NOTEBOK_CONFIG: /tmp/graph_notebook_config_integration_test.json run: | - python test/integration/NeptuneIntegrationWorkflowSteps.py \ - run-tests \ - --pattern "*network*.py" \ - --cfn-stack-name ${{ needs.generate-stack-name.outputs.stack-name }} \ - --aws-region ${{ secrets.AWS_REGION }} - - name: Run Notebook Tests + pytest test/integration/without_iam + - name: Generate iam test configuration run: | - python test/integration/NeptuneIntegrationWorkflowSteps.py \ - run-tests \ - --pattern "*graph_notebook.py" \ + python test/integration/NeptuneIntegrationWorkflowSteps.py generate-config \ --cfn-stack-name ${{ needs.generate-stack-name.outputs.stack-name }} \ - --aws-region ${{ secrets.AWS_REGION }} + --aws-region ${{ secrets.AWS_REGION }} \ + --iam - name: Run IAM Tests env: GRAPH_NOTEBOK_CONFIG: /tmp/graph_notebook_config_integration_test.json run: | - python test/integration/NeptuneIntegrationWorkflowSteps.py \ - run-tests \ - --pattern "*with_iam.py" \ - --iam \ - --cfn-stack-name ${{ needs.generate-stack-name.outputs.stack-name }} \ - --aws-region ${{ secrets.AWS_REGION }} + pytest test/integration/iam - name: Cleanup run: | python test/integration/NeptuneIntegrationWorkflowSteps.py \ diff --git a/.github/workflows/unit.yml b/.github/workflows/unit.yml index 8f7b5e7b..5f044413 100644 --- a/.github/workflows/unit.yml +++ b/.github/workflows/unit.yml @@ -35,4 +35,4 @@ jobs: python -m graph_notebook.notebooks.install - name: Test with pytest run: | - pytest \ No newline at end of file + pytest test/unit \ No newline at end of file diff --git a/ChangeLog.md b/ChangeLog.md index 524a81ad..68d89427 100644 --- a/ChangeLog.md +++ b/ChangeLog.md @@ -4,14 +4,15 @@ Starting with v1.31.6, this file will contain a record of major features and upd ## Upcoming -- Add support for Mode, queueRequest, and Dependencies parameters when running %load command -- Add support for list and dict as map keys in Python Gremlin +- Add support for Mode, queueRequest, and Dependencies parameters when running %load command ([Link to PR](https://github.com/aws/graph-notebook/pull/91)) +- Add support for list and dict as map keys in Python Gremlin ([Link to PR](https://github.com/aws/graph-notebook/pull/100)) +- Refactor modules that call to Neptune or other SPARQL/Gremlin endpoints to use a unified client object ([Link to PR](https://github.com/aws/graph-notebook/pull/104)) ## Release 2.0.12 (Mar 25, 2021) - - Add default parameters for `get_load_status` - - Add ipython as a dependency in `setup.py` ([Link to PT](https://github.com/aws/graph-notebook/pull/95)) - - Add parameters in `load_status` for `details`, `errors`, `page`, and `errorsPerPage` + - Add default parameters for `get_load_status` ([Link to PR](https://github.com/aws/graph-notebook/pull/96)) + - Add ipython as a dependency in `setup.py` ([Link to PR](https://github.com/aws/graph-notebook/pull/95)) + - Add parameters in `load_status` for `details`, `errors`, `page`, and `errorsPerPage` ([Link to PR](https://github.com/aws/graph-notebook/pull/88)) ## Release 2.0.10 (Mar 18, 2021) diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..4fee5c5f --- /dev/null +++ b/pytest.ini @@ -0,0 +1,11 @@ +[pytest] +markers = + neptune: tests which have to run against neptune + iam: tests which require iam authentication + gremlin: tests which run against a gremlin endpoint + sparql: tests which run against SPARQL1.1 endpoint + neptuneml: tests which run Neptune ML workloads + jupyter: tests which run against ipython/jupyter frameworks + reset: test which performs a fast reset against Neptune, running this will wipe your database! + + diff --git a/requirements.txt b/requirements.txt index 04e0b7e6..1852a511 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,8 +9,9 @@ jupyter-contrib-nbextensions widgetsnbextension gremlinpython requests==2.24.0 +ipython==7.16.1 # requirements for testing boto3==1.15.15 botocore==1.18.18 -ipython==7.16.1 \ No newline at end of file +pytest==6.2.2 \ No newline at end of file diff --git a/setup.py b/setup.py index f7de0895..bb08f9c5 100644 --- a/setup.py +++ b/setup.py @@ -93,4 +93,7 @@ def get_version(): 'License :: OSI Approved :: Apache Software License' ], keywords='jupyter neptune gremlin sparql', + tests_require=[ + 'pytest' + ] ) diff --git a/src/graph_notebook/authentication/iam_credentials_provider/__init__.py b/src/graph_notebook/authentication/iam_credentials_provider/__init__.py deleted file mode 100644 index fa84f3bc..00000000 --- a/src/graph_notebook/authentication/iam_credentials_provider/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" \ No newline at end of file diff --git a/src/graph_notebook/authentication/iam_credentials_provider/credentials_factory.py b/src/graph_notebook/authentication/iam_credentials_provider/credentials_factory.py deleted file mode 100644 index c849b595..00000000 --- a/src/graph_notebook/authentication/iam_credentials_provider/credentials_factory.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from enum import Enum - -from graph_notebook.authentication.iam_credentials_provider.credentials_provider import CredentialsProviderBase -from graph_notebook.authentication.iam_credentials_provider.env_credentials_provider import EnvCredentialsProvider -from graph_notebook.authentication.iam_credentials_provider.ec2_metadata_credentials_provider import MetadataCredentialsProvider - - -class IAMAuthCredentialsProvider(Enum): - ROLE = "ROLE" - ENV = "ENV" - - -def credentials_provider_factory(mode: IAMAuthCredentialsProvider) -> CredentialsProviderBase: - if mode == IAMAuthCredentialsProvider.ENV: - return EnvCredentialsProvider() - elif mode == IAMAuthCredentialsProvider.ROLE: - return MetadataCredentialsProvider() - else: - raise NotImplementedError(f'the provided mode of {mode} has not been implemented by credentials_provider_factory') diff --git a/src/graph_notebook/authentication/iam_credentials_provider/credentials_provider.py b/src/graph_notebook/authentication/iam_credentials_provider/credentials_provider.py deleted file mode 100644 index 8e6f34cf..00000000 --- a/src/graph_notebook/authentication/iam_credentials_provider/credentials_provider.py +++ /dev/null @@ -1,20 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from abc import ABC, abstractmethod - - -class Credentials(object): - def __init__(self, key, secret, region, token=''): - self.key = key - self.secret = secret - self.token = token - self.region = region - - -class CredentialsProviderBase(ABC): - @abstractmethod - def get_iam_credentials(self) -> Credentials: - pass diff --git a/src/graph_notebook/authentication/iam_credentials_provider/ec2_metadata_credentials_provider.py b/src/graph_notebook/authentication/iam_credentials_provider/ec2_metadata_credentials_provider.py deleted file mode 100644 index 8dd9bf58..00000000 --- a/src/graph_notebook/authentication/iam_credentials_provider/ec2_metadata_credentials_provider.py +++ /dev/null @@ -1,26 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import botocore.session -import requests - - -from graph_notebook.authentication.iam_credentials_provider.credentials_provider import CredentialsProviderBase, \ - Credentials - -region_url = 'http://169.254.169.254/latest/meta-data/placement/availability-zone' - - -class MetadataCredentialsProvider(CredentialsProviderBase): - def __init__(self): - res = requests.get(region_url) - zone = res.content.decode('utf-8') - region = zone[0:len(zone) - 1] - self.region = region - - def get_iam_credentials(self) -> Credentials: - session = botocore.session.get_session() - creds = session.get_credentials() - return Credentials(key=creds.access_key, secret=creds.secret_key, token=creds.token, region=self.region) diff --git a/src/graph_notebook/authentication/iam_credentials_provider/env_credentials_provider.py b/src/graph_notebook/authentication/iam_credentials_provider/env_credentials_provider.py deleted file mode 100644 index 0ba39e9d..00000000 --- a/src/graph_notebook/authentication/iam_credentials_provider/env_credentials_provider.py +++ /dev/null @@ -1,35 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import os - -from graph_notebook.authentication.iam_credentials_provider.credentials_provider import CredentialsProviderBase, \ - Credentials - -ACCESS_ENV_KEY = 'AWS_ACCESS_KEY_ID' -SECRET_ENV_KEY = 'AWS_SECRET_ACCESS_KEY' -REGION_ENV_KEY = 'AWS_REGION' -AWS_TOKEN_ENV_KEY = 'AWS_SESSION_TOKEN' - - -class EnvCredentialsProvider(CredentialsProviderBase): - def __init__(self): - self.creds = Credentials(key='', secret='', region='', token='') - self.loaded = False - - def load_iam_credentials(self): - access_key = os.environ.get(ACCESS_ENV_KEY, '') - secret_key = os.environ.get(SECRET_ENV_KEY, '') - region = os.environ.get(REGION_ENV_KEY, '') - token = os.environ.get(AWS_TOKEN_ENV_KEY, '') - self.creds = Credentials(access_key, secret_key, region, token) - self.loaded = True - return - - def get_iam_credentials(self, service=None) -> Credentials: - if not self.loaded: - self.load_iam_credentials() - - return self.creds diff --git a/src/graph_notebook/authentication/iam_headers.py b/src/graph_notebook/authentication/iam_headers.py deleted file mode 100644 index 6c39038a..00000000 --- a/src/graph_notebook/authentication/iam_headers.py +++ /dev/null @@ -1,217 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import datetime -import hashlib -import hmac -import json -import logging -import urllib - -logging.basicConfig() -logger = logging.getLogger("graph_magic") - - -# Key derivation functions. See: -# https://docs.aws.amazon.com/general/latest/gr/signature-v4-examples.html#signature-v4-examples-python -def sign(key, msg): - return hmac.new(key, msg.encode('utf-8'), hashlib.sha256).digest() - - -def get_signature_key(key, dateStamp, regionName, serviceName): - k_date = sign(('AWS4' + key).encode('utf-8'), dateStamp) - k_region = sign(k_date, regionName) - k_service = sign(k_region, serviceName) - k_signing = sign(k_service, 'aws4_request') - return k_signing - - -def get_canonical_uri_and_payload(query_type, query): - # Set the stack and payload depending on query_type. - if query_type == 'sparql': - canonical_uri = '/sparql/' - payload = query - - elif query_type == 'sparqlupdate': - canonical_uri = '/sparql/' - payload = query - - elif query_type == 'sparql/status': - canonical_uri = '/sparql/status/' - payload = query - - elif query_type == 'gremlin': - canonical_uri = '/gremlin' - payload = {} - - elif query_type == 'gremlin/status': - canonical_uri = '/gremlin/status/' - payload = query - - elif query_type == "loader": - canonical_uri = "/loader/" - payload = query - - elif query_type == "status": - canonical_uri = "/status/" - payload = {} - - elif query_type == "gremlin/explain": - canonical_uri = "/gremlin/explain/" - payload = query - - elif query_type == "gremlin/profile": - canonical_uri = "/gremlin/profile/" - payload = query - - elif query_type == "system": - canonical_uri = "/system/" - payload = query - - elif query_type.startswith("ml"): - canonical_uri = f'/{query_type}' - payload = query - - elif query_type.startswith("ml/dataprocessing"): - canonical_uri = f'/{query_type}' - payload = query - - elif query_type.startswith("ml/endpoints"): - canonical_uri = f'/{query_type}' - payload = query - - else: - raise ValueError('query_type %s is not valid' % query_type) - - return canonical_uri, payload - - -def normalize_query_string(query): - kv = (list(map(str.strip, s.split("="))) - for s in query.split('&') - if len(s) > 0) - - normalized = '&'.join('%s=%s' % (p[0], p[1] if len(p) > 1 else '') - for p in sorted(kv)) - return normalized - - -def make_signed_request(method, query_type, query, host, port, signing_access_key, signing_secret, signing_region, - use_ssl=False, signing_token='', additional_headers=None): - if additional_headers is None: - additional_headers = [] - - signing_region = signing_region.lower() - service = 'neptune-db' - - if use_ssl: - protocol = 'https' - else: - protocol = 'http' - - # this is always http right now - endpoint = f'{protocol}://{host}:{port}' - - # get canonical_uri and payload - canonical_uri, payload = get_canonical_uri_and_payload(query_type, query) - - if 'content-type' in additional_headers and additional_headers['content-type'] == 'application/json': - request_parameters = payload if type(payload) is str else json.dumps(payload) - else: - request_parameters = urllib.parse.urlencode(payload, quote_via=urllib.parse.quote) - request_parameters = request_parameters.replace('%27', '%22') - t = datetime.datetime.utcnow() - amz_date = t.strftime('%Y%m%dT%H%M%SZ') - date_stamp = t.strftime('%Y%m%d') # Date w/o time, used in credential scope - - method = method.upper() - if method == 'GET' or method == 'DELETE': - canonical_querystring = normalize_query_string(request_parameters) - elif method == 'POST': - canonical_querystring = '' - else: - raise ValueError('method %s is not valid when creating canonical request' % method) - - # Step 4: Create the canonical headers and signed headers. Header names - # must be trimmed and lowercase, and sorted in code point order from - # low to high. Note that there is a trailing \n. - canonical_headers = f'host:{host}:{port}\nx-amz-date:{amz_date}\n' - - # Step 5: Create the list of signed headers. This lists the headers - # in the canonical_headers list, delimited with ";" and in alpha order. - # Note: The request can include any headers; canonical_headers and - # signed_headers lists those that you want to be included in the - # hash of the request. "Host" and "x-amz-date" are always required. - signed_headers = 'host;x-amz-date' - - # Step 6: Create payload hash (hash of the request body content). For GET and DELETE - # requests, the payload is an empty string (""). - if method == 'GET' or method == 'DELETE': - post_payload = '' - elif method == 'POST': - post_payload = request_parameters - else: - raise ValueError('method %s is not supported' % method) - - payload_hash = hashlib.sha256(post_payload.encode('utf-8')).hexdigest() - - # Step 7: Combine elements to create canonical request. - canonical_request = method + '\n' + canonical_uri + '\n' + canonical_querystring + '\n' + canonical_headers + '\n' + signed_headers + '\n' + payload_hash - - # ************* TASK 2: CREATE THE STRING TO SIGN************* - # Match the algorithm to the hashing algorithm you use, either SHA-1 or - # SHA-256 (recommended) - algorithm = 'AWS4-HMAC-SHA256' - credential_scope = date_stamp + '/' + signing_region + '/' + service + '/' + 'aws4_request' - string_to_sign = algorithm + '\n' + amz_date + '\n' + credential_scope + '\n' + hashlib.sha256( - canonical_request.encode('utf-8')).hexdigest() - - # ************* TASK 3: CALCULATE THE SIGNATURE ************* - # Create the signing key using the function defined above. - signing_key = get_signature_key(signing_secret, date_stamp, signing_region, service) - - # Sign the string_to_sign using the signing_key - signature = hmac.new(signing_key, string_to_sign.encode('utf-8'), hashlib.sha256).hexdigest() - - # ************* TASK 4: ADD SIGNING INFORMATION TO THE REQUEST ************* - # The signing information can be either in a query string value or in - # a header named Authorization. This code shows how to use a header. - # Create authorization header and add to request headers - authorization_header = algorithm + ' ' + 'Credential=' + signing_access_key + '/' + credential_scope + ', ' + 'SignedHeaders=' + signed_headers + ', ' + 'Signature=' + signature - - # The request can include any headers, but MUST include "host", "x-amz-date", - # and (for this scenario) "Authorization". "host" and "x-amz-date" must - # be included in the canonical_headers and signed_headers, as noted - # earlier. Order here is not significant. - # Python note: The 'host' header is added automatically by the Python 'requests' library. - if method == 'GET' or method == 'DELETE': - headers = { - 'x-amz-date': amz_date, - 'Authorization': authorization_header - } - elif method == 'POST': - headers = { - 'content-type': 'application/x-www-form-urlencoded', - 'x-amz-date': amz_date, - 'Authorization': authorization_header, - } - else: - raise ValueError('method %s is not valid while creating request headers' % method) - - if additional_headers is not None: - for key in additional_headers: - headers[key] = additional_headers[key] - - if signing_token != '': - headers['X-Amz-Security-Token'] = signing_token - - # ************* SEND THE REQUEST ************* - request_url = endpoint + canonical_uri - - return { - 'url': request_url, - 'headers': headers, - 'params': request_parameters - } \ No newline at end of file diff --git a/src/graph_notebook/configuration/generate_config.py b/src/graph_notebook/configuration/generate_config.py index 496455f0..f57aaa11 100644 --- a/src/graph_notebook/configuration/generate_config.py +++ b/src/graph_notebook/configuration/generate_config.py @@ -8,10 +8,7 @@ import os from enum import Enum -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider -from graph_notebook.sparql.query import SPARQL_ACTION - -DEFAULT_IAM_CREDENTIALS_PROVIDER = IAMAuthCredentialsProvider.ROLE +from graph_notebook.neptune.client import SPARQL_ACTION DEFAULT_CONFIG_LOCATION = os.path.expanduser('~/graph_notebook_config.json') @@ -38,6 +35,9 @@ def __init__(self, path: str = SPARQL_ACTION, endpoint_prefix: str = ''): print('endpoint_prefix has been deprecated and will be removed in version 2.0.20 or greater.') if path == '': path = f'{endpoint_prefix}/sparql' + elif path == '': + path = SPARQL_ACTION + self.path = path def to_dict(self): @@ -47,13 +47,11 @@ def to_dict(self): class Configuration(object): def __init__(self, host: str, port: int, auth_mode: AuthModeEnum = AuthModeEnum.DEFAULT, - iam_credentials_provider_type: IAMAuthCredentialsProvider = DEFAULT_IAM_CREDENTIALS_PROVIDER, load_from_s3_arn='', ssl: bool = True, aws_region: str = 'us-east-1', sparql_section: SparqlSection = None): self.host = host self.port = port self.auth_mode = auth_mode - self.iam_credentials_provider_type = iam_credentials_provider_type self.load_from_s3_arn = load_from_s3_arn self.ssl = ssl self.aws_region = aws_region @@ -64,7 +62,6 @@ def to_dict(self) -> dict: 'host': self.host, 'port': self.port, 'auth_mode': self.auth_mode.value, - 'iam_credentials_provider_type': self.iam_credentials_provider_type.value, 'load_from_s3_arn': self.load_from_s3_arn, 'ssl': self.ssl, 'aws_region': self.aws_region, @@ -79,21 +76,15 @@ def write_to_file(self, file_path=DEFAULT_CONFIG_LOCATION): return -def generate_config(host, port, auth_mode, ssl, iam_credentials_provider_type, load_from_s3_arn, aws_region): +def generate_config(host, port, auth_mode, ssl, load_from_s3_arn, aws_region): use_ssl = False if ssl in [False, 'False', 'false', 'FALSE'] else True - - if iam_credentials_provider_type not in [IAMAuthCredentialsProvider.ENV, - IAMAuthCredentialsProvider.ROLE]: - iam_credentials_provider_type = DEFAULT_IAM_CREDENTIALS_PROVIDER - - config = Configuration(host, port, auth_mode, iam_credentials_provider_type, load_from_s3_arn, use_ssl, aws_region) - return config + c = Configuration(host, port, auth_mode, load_from_s3_arn, use_ssl, aws_region) + return c def generate_default_config(): - config = generate_config('change-me', 8182, AuthModeEnum.DEFAULT, True, DEFAULT_IAM_CREDENTIALS_PROVIDER, '', - 'us-east-1') - return config + c = generate_config('change-me', 8182, AuthModeEnum.DEFAULT, True, '', 'us-east-1') + return c if __name__ == "__main__": @@ -102,6 +93,8 @@ def generate_default_config(): parser.add_argument("--port", help="the port to use when creating a connection", default="8182") parser.add_argument("--auth_mode", default=AuthModeEnum.DEFAULT.value, help="type of authentication the cluster being connected to is using. Can be DEFAULT or IAM") + + # TODO: this can now be removed. parser.add_argument("--iam_credentials_provider", default='ROLE', help="The mode used to obtain credentials for IAM Authentication. Can be ROLE or ENV") parser.add_argument("--ssl", @@ -110,14 +103,11 @@ def generate_default_config(): parser.add_argument("--config_destination", help="location to put generated config", default=DEFAULT_CONFIG_LOCATION) parser.add_argument("--load_from_s3_arn", help="arn of role to use for bulk loader", default='') - parser.add_argument("--aws_region", help="aws region your neptune cluster is in.", default='us-east-1') + parser.add_argument("--aws_region", help="aws region your ml cluster is in.", default='us-east-1') args = parser.parse_args() auth_mode_arg = args.auth_mode if args.auth_mode != '' else AuthModeEnum.DEFAULT.value - iam_credentials_provider_arg = args.iam_credentials_provider if args.iam_credentials_provider != '' else IAMAuthCredentialsProvider.ROLE.value - - config = generate_config(args.host, int(args.port), AuthModeEnum(auth_mode_arg), args.ssl, - IAMAuthCredentialsProvider(iam_credentials_provider_arg), + config = generate_config(args.host, int(args.port), AuthModeEnum(auth_mode_arg), args.ssl , args.load_from_s3_arn, args.aws_region) config.write_to_file(args.config_destination) diff --git a/src/graph_notebook/configuration/get_config.py b/src/graph_notebook/configuration/get_config.py index 35698391..72ab829a 100644 --- a/src/graph_notebook/configuration/get_config.py +++ b/src/graph_notebook/configuration/get_config.py @@ -5,7 +5,6 @@ import json -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider from graph_notebook.configuration.generate_config import DEFAULT_CONFIG_LOCATION, Configuration, AuthModeEnum, \ SparqlSection @@ -13,8 +12,6 @@ def get_config_from_dict(data: dict) -> Configuration: sparql_section = SparqlSection(**data['sparql']) if 'sparql' in data else SparqlSection('') config = Configuration(host=data['host'], port=data['port'], auth_mode=AuthModeEnum(data['auth_mode']), - iam_credentials_provider_type=IAMAuthCredentialsProvider( - data['iam_credentials_provider_type']), ssl=data['ssl'], load_from_s3_arn=data['load_from_s3_arn'], aws_region=data['aws_region'], sparql_section=sparql_section) return config diff --git a/src/graph_notebook/gremlin/client_provider/default_client.py b/src/graph_notebook/gremlin/client_provider/default_client.py deleted file mode 100644 index 0d6bd4b8..00000000 --- a/src/graph_notebook/gremlin/client_provider/default_client.py +++ /dev/null @@ -1,21 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import graph_notebook.gremlin.client_provider.graphsonV3d0_MapType_objectify_patch # noqa F401 -import logging - -from gremlin_python.driver import client - -logging.basicConfig() -logger = logging.getLogger("default_client") - - -class ClientProvider(object): - @staticmethod - def get_client(host, port, use_ssl): - protocol = 'wss' if use_ssl else 'ws' - url = f'{protocol}://{host}:{port}/gremlin' - c = client.Client(url, 'g') - return c diff --git a/src/graph_notebook/gremlin/client_provider/factory.py b/src/graph_notebook/gremlin/client_provider/factory.py deleted file mode 100644 index 0f1a459e..00000000 --- a/src/graph_notebook/gremlin/client_provider/factory.py +++ /dev/null @@ -1,21 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.gremlin.client_provider.default_client import ClientProvider -from graph_notebook.gremlin.client_provider.iam_client import IamClientProvider -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import credentials_provider_factory, \ - IAMAuthCredentialsProvider - - -def create_client_provider(mode: AuthModeEnum, - credentials_provider_mode: IAMAuthCredentialsProvider = IAMAuthCredentialsProvider.ROLE): - if mode == AuthModeEnum.DEFAULT: - return ClientProvider() - elif mode == AuthModeEnum.IAM: - credentials_provider = credentials_provider_factory(credentials_provider_mode) - return IamClientProvider(credentials_provider) - else: - raise NotImplementedError(f"invalid client mode {mode} provided") diff --git a/src/graph_notebook/gremlin/client_provider/iam_client.py b/src/graph_notebook/gremlin/client_provider/iam_client.py deleted file mode 100644 index aca981ab..00000000 --- a/src/graph_notebook/gremlin/client_provider/iam_client.py +++ /dev/null @@ -1,52 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import graph_notebook.gremlin.client_provider.graphsonV3d0_MapType_objectify_patch # noqa F401 -import hashlib -import hmac -import logging - -from gremlin_python.driver import client -from gremlin_python.driver.client import Client -from tornado import httpclient - -from graph_notebook.authentication.iam_credentials_provider.credentials_provider import CredentialsProviderBase -from graph_notebook.authentication.iam_headers import make_signed_request - -logging.basicConfig() -logger = logging.getLogger("iam_client") - - -def sign(key, msg): - return hmac.new(key, msg.encode('utf-8'), hashlib.sha256).digest() - - -def get_signature_key(key, date_stamp, region_name, service_name): - k_date = sign(('AWS4' + key).encode('utf-8'), date_stamp) - k_region = sign(k_date, region_name) - k_service = sign(k_region, service_name) - k_signing = sign(k_service, 'aws4_request') - return k_signing - - -class IamClientProvider(object): - def __init__(self, credentials_provider: CredentialsProviderBase): - self.credentials_provider = credentials_provider - - def get_client(self, host, port, use_ssl) -> Client: - credentials = self.credentials_provider.get_iam_credentials() - request_params = make_signed_request('get', 'gremlin', '', host, port, credentials.key, - credentials.secret, credentials.region, use_ssl, - credentials.token) - ws_url = request_params['url'].strip('/').replace('http', 'ws') - signed_ws_request = httpclient.HTTPRequest(ws_url, headers=request_params['headers']) - - try: - c = client.Client(signed_ws_request, 'g') - return c - # TODO: handle exception explicitly - except Exception as e: - logger.error(f'error while creating client {e}') - raise e diff --git a/src/graph_notebook/gremlin/query.py b/src/graph_notebook/gremlin/query.py deleted file mode 100644 index 00e60010..00000000 --- a/src/graph_notebook/gremlin/query.py +++ /dev/null @@ -1,57 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import logging - -from graph_notebook.gremlin.client_provider.default_client import ClientProvider -from graph_notebook.request_param_generator.default_request_generator import DefaultRequestGenerator -from graph_notebook.request_param_generator.call_and_get_response import call_and_get_response - -logging.basicConfig() -logger = logging.getLogger("gremlin") - - -def do_gremlin_query(query_str, host, port, use_ssl, client_provider=ClientProvider()): - c = client_provider.get_client(host, port, use_ssl) - - try: - result = c.submit(query_str) - future_results = result.all() - results = future_results.result() - except Exception as e: - raise e # let the upstream decide what to do with this error. - finally: - c.close() # no matter the outcome we need to close the websocket connection - - return results - - -def do_gremlin_explain(query_str, host, port, use_ssl, request_param_generator=DefaultRequestGenerator()): - data = { - 'gremlin': query_str - } - - action = 'gremlin/explain' - res = call_and_get_response('get', action, host, port, request_param_generator, use_ssl, data) - content = res.content.decode('utf-8') - result = { - 'explain': content - } - return result - - -def do_gremlin_profile(query_str, host, port, use_ssl, request_param_generator=DefaultRequestGenerator()): - data = { - 'gremlin': query_str - } - - action = 'gremlin/profile' - res = call_and_get_response('get', action, host, port, request_param_generator, use_ssl, data) - content = res.content.decode('utf-8').strip(' ') - result = { - 'profile': content - } - - return result diff --git a/src/graph_notebook/gremlin/status.py b/src/graph_notebook/gremlin/status.py deleted file mode 100644 index 4817206d..00000000 --- a/src/graph_notebook/gremlin/status.py +++ /dev/null @@ -1,44 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.request_param_generator.call_and_get_response import call_and_get_response - -GREMLIN_STATUS_ACTION = 'gremlin/status' - - -def do_gremlin_status(host, port, use_ssl, mode, request_param_generator, query_id: str, include_waiting: bool): - data = {'includeWaiting': include_waiting} - if query_id != '': - data['queryId'] = query_id - - headers = {} - if mode == AuthModeEnum.DEFAULT: - """Add correct content-type header for the request. - This is needed because call_and_get_response requires custom headers to be set. - """ - headers['Content-Type'] = 'application/x-www-form-urlencoded' - res = call_and_get_response('post', GREMLIN_STATUS_ACTION, host, port, request_param_generator, use_ssl, data, - headers) - content = res.json() - return content - - -def do_gremlin_cancel(host, port, use_ssl, mode, request_param_generator, query_id): - if type(query_id) != str or query_id == '': - raise ValueError("query id must be a non-empty string") - - data = {'cancelQuery': True, 'queryId': query_id} - - headers = {} - if mode == AuthModeEnum.DEFAULT: - """Add correct content-type header for the request. - This is needed because call_and_get_response requires custom headers to be set. - """ - headers['Content-Type'] = 'application/x-www-form-urlencoded' - res = call_and_get_response('post', GREMLIN_STATUS_ACTION, host, port, request_param_generator, use_ssl, data, - headers) - content = res.json() - return content diff --git a/src/graph_notebook/loader/load.py b/src/graph_notebook/loader/load.py deleted file mode 100644 index 3eaa8dbe..00000000 --- a/src/graph_notebook/loader/load.py +++ /dev/null @@ -1,90 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import json -from graph_notebook.request_param_generator.call_and_get_response import call_and_get_response - -FORMAT_CSV = 'csv' -FORMAT_NTRIPLE = 'ntriples' -FORMAT_NQUADS = 'nquads' -FORMAT_RDFXML = 'rdfxml' -FORMAT_TURTLE = 'turtle' - -PARALLELISM_LOW = 'LOW' -PARALLELISM_MEDIUM = 'MEDIUM' -PARALLELISM_HIGH = 'HIGH' -PARALLELISM_OVERSUBSCRIBE = 'OVERSUBSCRIBE' - -MODE_RESUME = 'RESUME' -MODE_NEW = 'NEW' -MODE_AUTO = 'AUTO' - -LOAD_JOB_MODES = [MODE_RESUME, MODE_NEW, MODE_AUTO] -VALID_FORMATS = [FORMAT_CSV, FORMAT_NTRIPLE, FORMAT_NQUADS, FORMAT_RDFXML, FORMAT_TURTLE] -PARALLELISM_OPTIONS = [PARALLELISM_LOW, PARALLELISM_MEDIUM, PARALLELISM_HIGH, PARALLELISM_OVERSUBSCRIBE] -LOADER_ACTION = 'loader' - -FINAL_LOAD_STATUSES = ['LOAD_COMPLETED', - 'LOAD_COMMITTED_W_WRITE_CONFLICTS', - 'LOAD_CANCELLED_BY_USER', - 'LOAD_CANCELLED_DUE_TO_ERRORS', - 'LOAD_FAILED', - 'LOAD_UNEXPECTED_ERROR', - 'LOAD_DATA_DEADLOCK', - 'LOAD_DATA_FAILED_DUE_TO_FEED_MODIFIED_OR_DELETED', - 'LOAD_S3_READ_ERROR', - 'LOAD_S3_ACCESS_DENIED_ERROR', - 'LOAD_IN_QUEUE', - 'LOAD_FAILED_BECAUSE_DEPENDENCY_NOT_SATISFIED', - 'LOAD_FAILED_INVALID_REQUEST', ] - - -def do_load(host, port, load_format, use_ssl, source, region, arn, fail_on_error, request_param_generator, mode="AUTO", - parallelism="HIGH", update_single_cardinality="FALSE", queue_request="FALSE", dependencies=[]): - payload = { - 'source': source, - 'format': load_format, - 'mode': mode, - 'region': region, - 'failOnError': fail_on_error, - 'parallelism': parallelism, - 'updateSingleCardinalityProperties': update_single_cardinality, - 'queueRequest': queue_request - } - - if arn != '': - payload['iamRoleArn'] = arn - - if dependencies: - payload['dependencies'] = json.dumps(dependencies) - - res = call_and_get_response('post', LOADER_ACTION, host, port, request_param_generator, use_ssl, payload) - return res.json() - - -def get_loader_jobs(host, port, use_ssl, request_param_generator): - res = call_and_get_response('get', LOADER_ACTION, host, port, request_param_generator, use_ssl) - return res.json() - - -def get_load_status(host, port, use_ssl, request_param_generator, id, loader_details="FALSE", loader_errors="FALSE", loader_page=1, loader_epp=10): - payload = { - 'loadId': id, - 'details': loader_details, - 'errors': loader_errors, - 'page': loader_page, - 'errorsPerPage': loader_epp - } - res = call_and_get_response('get', LOADER_ACTION, host, port, request_param_generator, use_ssl, payload) - return res.json() - - -def cancel_load(host, port, use_ssl, request_param_generator, load_id): - payload = { - 'loadId': load_id - } - - res = call_and_get_response('get', LOADER_ACTION, host, port, request_param_generator, use_ssl, payload) - return res.status_code == 200 diff --git a/src/graph_notebook/magics/graph_magic.py b/src/graph_notebook/magics/graph_magic.py index 90b30a16..12b8a9be 100644 --- a/src/graph_notebook/magics/graph_magic.py +++ b/src/graph_notebook/magics/graph_magic.py @@ -14,6 +14,8 @@ from enum import Enum import ipywidgets as widgets +from SPARQLWrapper import SPARQLWrapper +from botocore.session import get_session from gremlin_python.driver.protocol import GremlinServerError from IPython.core.display import HTML, display_html, display from IPython.core.magic import (Magics, magics_class, cell_magic, line_magic, line_cell_magic, needs_local_scope) @@ -21,25 +23,18 @@ from requests import HTTPError import graph_notebook -from graph_notebook.configuration.generate_config import generate_default_config, DEFAULT_CONFIG_LOCATION +from graph_notebook.configuration.generate_config import generate_default_config, DEFAULT_CONFIG_LOCATION, AuthModeEnum, \ + Configuration from graph_notebook.decorators.decorators import display_exceptions from graph_notebook.magics.ml import neptune_ml_magic_handler, generate_neptune_ml_parser +from graph_notebook.neptune.client import ClientBuilder, Client, VALID_FORMATS, PARALLELISM_OPTIONS, PARALLELISM_HIGH, \ + LOAD_JOB_MODES, MODE_AUTO, FINAL_LOAD_STATUSES, SPARQL_ACTION from graph_notebook.network import SPARQLNetwork from graph_notebook.network.gremlin.GremlinNetwork import parse_pattern_list_str, GremlinNetwork -from graph_notebook.sparql.table import get_rows_and_columns -from graph_notebook.gremlin.query import do_gremlin_query, do_gremlin_explain, do_gremlin_profile -from graph_notebook.gremlin.status import do_gremlin_status, do_gremlin_cancel -from graph_notebook.sparql.query import get_query_type, do_sparql_query, do_sparql_explain, SPARQL_ACTION -from graph_notebook.sparql.status import do_sparql_status, do_sparql_cancel -from graph_notebook.system.database_reset import perform_database_reset, initiate_database_reset +from graph_notebook.visualization.sparql_rows_and_columns import get_rows_and_columns from graph_notebook.visualization.template_retriever import retrieve_template -from graph_notebook.gremlin.client_provider.factory import create_client_provider -from graph_notebook.request_param_generator.factory import create_request_generator -from graph_notebook.loader.load import do_load, get_loader_jobs, get_load_status, cancel_load, VALID_FORMATS, \ - PARALLELISM_OPTIONS, PARALLELISM_HIGH, FINAL_LOAD_STATUSES, LOAD_JOB_MODES, MODE_AUTO from graph_notebook.configuration.get_config import get_config, get_config_from_dict from graph_notebook.seed.load_query import get_data_sets, get_queries -from graph_notebook.status.get_status import get_status from graph_notebook.widgets import Force from graph_notebook.options import OPTIONS_DEFAULT_DIRECTED, vis_options_merge @@ -91,6 +86,28 @@ def str_to_query_mode(s: str) -> QueryMode: return QueryMode.DEFAULT +ACTION_TO_QUERY_TYPE = { + 'sparql': 'application/sparql-query', + 'sparqlupdate': 'application/sparql-update' +} + + +def get_query_type(query): + s = SPARQLWrapper('') + s.setQuery(query) + return s.queryType + + +def query_type_to_action(query_type): + query_type = query_type.upper() + if query_type in ['SELECT', 'CONSTRUCT', 'ASK', 'DESCRIBE']: + return 'sparql' + else: + # TODO: check explicitly for all query types, raise exception for invalid query + return 'sparqlupdate' + + +# TODO: refactor large magic commands into their own modules like what we do with %neptune_ml # noinspection PyTypeChecker @magics_class class Graph(Magics): @@ -98,17 +115,35 @@ def __init__(self, shell): # You must call the parent constructor super(Graph, self).__init__(shell) + self.graph_notebook_config = generate_default_config() try: self.config_location = os.getenv('GRAPH_NOTEBOOK_CONFIG', DEFAULT_CONFIG_LOCATION) + self.client: Client = None self.graph_notebook_config = get_config(self.config_location) except FileNotFoundError: - self.graph_notebook_config = generate_default_config() print( 'Could not find a valid configuration. Do not forgot to validate your settings using %graph_notebook_config') + self.max_results = DEFAULT_MAX_RESULTS self.graph_notebook_vis_options = OPTIONS_DEFAULT_DIRECTED + self._generate_client_from_config(self.graph_notebook_config) logger.setLevel(logging.ERROR) + def _generate_client_from_config(self, config: Configuration): + if self.client: + self.client.close() + + builder = ClientBuilder() \ + .with_host(config.host) \ + .with_port(config.port) \ + .with_region(config.aws_region) \ + .with_tls(config.ssl) \ + .with_sparql_path(config.sparql.path) + if config.auth_mode == AuthModeEnum.IAM: + builder = builder.with_iam(get_session()) + + self.client = builder.build() + @line_cell_magic @display_exceptions def graph_notebook_config(self, line='', cell=''): @@ -116,6 +151,7 @@ def graph_notebook_config(self, line='', cell=''): data = json.loads(cell) config = get_config_from_dict(data) self.graph_notebook_config = config + self._generate_client_from_config(config) print('set notebook config to:') print(json.dumps(self.graph_notebook_config.to_dict(), indent=2)) elif line == 'reset': @@ -143,6 +179,7 @@ def graph_notebook_host(self, line): # TODO: we should attempt to make a status call to this host before we set the config to this value. self.graph_notebook_config.host = line + self._generate_client_from_config(self.graph_notebook_config) print(f'set host to {line}') @cell_magic @@ -155,26 +192,24 @@ def sparql(self, line='', cell='', local_ns: dict = None): parser.add_argument('--path', '-p', default='', help='prefix path to sparql endpoint. For example, if "foo/bar" were specified, the endpoint called would be host:port/foo/bar') parser.add_argument('--expand-all', action='store_true') - - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, - self.graph_notebook_config.iam_credentials_provider_type, - command='sparql') + parser.add_argument('--explain-type', default='dynamic', + help='explain mode to use when using the explain query mode', + choices=['dynamic', 'static', 'details']) + parser.add_argument('--explain-format', default='text/html', help='response format for explain query mode', + choices=['text/csv', 'text/html']) parser.add_argument('--store-to', type=str, default='', help='store query result to this variable') args = parser.parse_args(line.split()) mode = str_to_query_mode(args.query_mode) tab = widgets.Tab() path = args.path if args.path != '' else self.graph_notebook_config.sparql.path - logger.debug(f'using mode={mode}') if mode == QueryMode.EXPLAIN: - res = do_sparql_explain(cell, self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator, path=path) - store_to_ns(args.store_to, res, local_ns) - if 'error' in res: - html = error_template.render(error=json.dumps(res['error'], indent=2)) - else: - html = sparql_explain_template.render(table=res) + res = self.client.sparql_explain(cell, args.explain_type, args.explain_format, path=path) + res.raise_for_status() + explain = res.content.decode('utf-8') + store_to_ns(args.store_to, explain, local_ns) + html = sparql_explain_template.render(table=explain) explain_output = widgets.Output(layout=DEFAULT_LAYOUT) with explain_output: display(HTML(html)) @@ -186,8 +221,10 @@ def sparql(self, line='', cell='', local_ns: dict = None): query_type = get_query_type(cell) headers = {} if query_type not in ['SELECT', 'CONSTRUCT', 'DESCRIBE'] else { 'Accept': 'application/sparql-results+json'} - res = do_sparql_query(cell, self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator, headers, path=path) + + query_res = self.client.sparql(cell, path=path, headers=headers) + query_res.raise_for_status() + res = query_res.json() store_to_ns(args.store_to, res, local_ns) titles = [] children = [] @@ -207,8 +244,7 @@ def sparql(self, line='', cell='', local_ns: dict = None): titles.append('Table') children.append(hbox) - expand_all = line == '--expand-all' - sn = SPARQLNetwork(expand_all=expand_all) + sn = SPARQLNetwork(expand_all=args.expand_all) sn.extract_prefix_declarations_from_query(cell) try: sn.add_results(res) @@ -268,20 +304,18 @@ def sparql_status(self, line='', local_ns: dict = None): parser.add_argument('--store-to', type=str, default='', help='store query result to this variable') args = parser.parse_args(line.split()) - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) - if not args.cancelQuery: - res = do_sparql_status(self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator, args.queryId) - + status_res = self.client.sparql_cancel(args.queryId) + status_res.raise_for_status() + res = status_res.json() else: if args.queryId == '': print(SPARQL_CANCEL_HINT_MSG) return else: - res = do_sparql_cancel(self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator, args.queryId, args.silent) + cancel_res = self.client.sparql_cancel(args.queryId, args.silent) + cancel_res.raise_for_status() + res = cancel_res.json() store_to_ns(args.store_to, res, local_ns) print(json.dumps(res, indent=2)) @@ -304,13 +338,11 @@ def gremlin(self, line, cell, local_ns: dict = None): tab = widgets.Tab() if mode == QueryMode.EXPLAIN: - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) - - query_res = do_gremlin_explain(cell, self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator) - if 'explain' in query_res: - html = pre_container_template.render(content=query_res['explain']) + res = self.client.gremlin_explain(cell) + res.raise_for_status() + query_res = res.content.decode('utf-8') + if 'Neptune Gremlin Explain' in query_res: + html = pre_container_template.render(content=query_res) else: html = pre_container_template.render(content='No explain found') @@ -321,13 +353,11 @@ def gremlin(self, line, cell, local_ns: dict = None): tab.set_title(0, 'Explain') display(tab) elif mode == QueryMode.PROFILE: - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) - - query_res = do_gremlin_profile(cell, self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator) - if 'profile' in query_res: - html = pre_container_template.render(content=query_res['profile']) + res = self.client.gremlin_profile(cell) + res.raise_for_status() + query_res = res.content.decode('utf-8') + if 'Neptune Gremlin Profile' in query_res: + html = pre_container_template.render(content=query_res) else: html = pre_container_template.render(content='No profile found') profile_output = widgets.Output(layout=DEFAULT_LAYOUT) @@ -337,10 +367,7 @@ def gremlin(self, line, cell, local_ns: dict = None): tab.set_title(0, 'Profile') display(tab) else: - client_provider = create_client_provider(self.graph_notebook_config.auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) - query_res = do_gremlin_query(cell, self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, client_provider) + query_res = self.client.gremlin_query(cell) children = [] titles = [] @@ -393,24 +420,18 @@ def gremlin_status(self, line='', local_ns: dict = None): parser.add_argument('--store-to', type=str, default='', help='store query result to this variable') args = parser.parse_args(line.split()) - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) - if not args.cancelQuery: - res = do_gremlin_status(self.graph_notebook_config.host, - self.graph_notebook_config.port, - self.graph_notebook_config.ssl, self.graph_notebook_config.auth_mode, - request_generator, args.queryId, args.includeWaiting) - + status_res = self.client.gremlin_status(args.queryId) + status_res.raise_for_status() + res = status_res.json() else: if args.queryId == '': print(GREMLIN_CANCEL_HINT_MSG) return else: - res = do_gremlin_cancel(self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, self.graph_notebook_config.auth_mode, - request_generator, args.queryId) - + cancel_res = self.client.gremlin_cancel(args.queryId) + cancel_res.raise_for_status() + res = cancel_res.json() print(json.dumps(res, indent=2)) store_to_ns(args.store_to, res, local_ns) @@ -418,41 +439,34 @@ def gremlin_status(self, line='', local_ns: dict = None): @display_exceptions def status(self, line): logger.info(f'calling for status on endpoint {self.graph_notebook_config.host}') - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) - logger.info( - f'used credentials_provider_mode={self.graph_notebook_config.iam_credentials_provider_type.name} and auth_mode={self.graph_notebook_config.auth_mode.name} to make status request') - - res = get_status(self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator) + status_res = self.client.status() + status_res.raise_for_status() + res = status_res.json() logger.info(f'got the response {res}') return res @line_magic @display_exceptions def db_reset(self, line): - host = self.graph_notebook_config.host - port = self.graph_notebook_config.port - ssl = self.graph_notebook_config.ssl - - logger.info(f'calling system endpoint {host}') + logger.info(f'calling system endpoint {self.client.host}') parser = argparse.ArgumentParser() parser.add_argument('-g', '--generate-token', action='store_true', help='generate token for database reset') - parser.add_argument('-t', '--token', nargs=1, default='', help='perform database reset with given token') + parser.add_argument('-t', '--token', default='', help='perform database reset with given token') parser.add_argument('-y', '--yes', action='store_true', help='skip the prompt and perform database reset') args = parser.parse_args(line.split()) generate_token = args.generate_token skip_prompt = args.yes - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) - logger.info( - f'used credentials_provider_mode={self.graph_notebook_config.iam_credentials_provider_type.name} and auth_mode={self.graph_notebook_config.auth_mode.name} to make system request') if generate_token is False and args.token == '': if skip_prompt: - res = initiate_database_reset(host, port, ssl, request_generator) + initiate_res = self.client.initiate_reset() + initiate_res.raise_for_status() + res = initiate_res.json() token = res['payload']['token'] - res = perform_database_reset(token, host, port, ssl, request_generator) + + perform_reset_res = self.client.perform_reset(token) + perform_reset_res.raise_for_status() logger.info(f'got the response {res}') + res = perform_reset_res.json() return res output = widgets.Output() @@ -473,7 +487,9 @@ def db_reset(self, line): display(text_hbox, check_box, button_hbox, output) def on_button_delete_clicked(b): - result = initiate_database_reset(host, port, ssl, request_generator) + initiate_res = self.client.initiate_reset() + initiate_res.raise_for_status() + result = initiate_res.json() text_hbox.close() check_box.close() @@ -492,7 +508,9 @@ def on_button_delete_clicked(b): print(result) return - result = perform_database_reset(token, host, port, ssl, request_generator) + perform_reset_res = self.client.perform_reset(token) + perform_reset_res.raise_for_status() + result = perform_reset_res.json() if 'status' not in result or result['status'] != '200 OK': with output: @@ -522,7 +540,9 @@ def on_button_delete_clicked(b): display_html(HTML(loading_wheel_html)) try: retry -= 1 - interval_check_response = get_status(host, port, ssl, request_generator) + status_res = self.client.status() + status_res.raise_for_status() + interval_check_response = status_res.json() except Exception as e: # Exception is expected when database is resetting, continue waiting with job_status_output: @@ -560,10 +580,14 @@ def on_button_cancel_clicked(b): button_cancel.on_click(on_button_cancel_clicked) return elif generate_token: - res = initiate_database_reset(host, port, ssl, request_generator) + initiate_res = self.client.initiate_reset() + initiate_res.raise_for_status() + res = initiate_res.json() else: # args.token is an array of a single string, e.g., args.token=['ade-23-c23'], use index 0 to take the string - res = perform_database_reset(args.token[0], host, port, ssl, request_generator) + perform_res = self.client.perform_reset(args.token) + perform_res.raise_for_status() + res = perform_res.json() logger.info(f'got the response {res}') return res @@ -572,6 +596,7 @@ def on_button_cancel_clicked(b): @needs_local_scope @display_exceptions def load(self, line='', local_ns: dict = None): + # TODO: change widgets to let any arbitrary inputs be added by users parser = argparse.ArgumentParser() parser.add_argument('-s', '--source', default='s3://') parser.add_argument('-l', '--loader-arn', default=self.graph_notebook_config.load_from_s3_arn) @@ -588,16 +613,7 @@ def load(self, line='', local_ns: dict = None): args = parser.parse_args(line.split()) - # since this can be a long-running task, freezing variables in the case - # that a user alters them in another command. - host = self.graph_notebook_config.host - port = self.graph_notebook_config.port - ssl = self.graph_notebook_config.ssl - - credentials_provider_mode = self.graph_notebook_config.iam_credentials_provider_type - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, credentials_provider_mode) region = self.graph_notebook_config.aws_region - button = widgets.Button(description="Submit") output = widgets.Output() source = widgets.Text( @@ -708,7 +724,8 @@ def on_button_clicked(b): if not len(dependencies_list) < 64: validated = False - dep_validation_label = widgets.HTML('
A maximum of 64 jobs may be queued at once.
') + dep_validation_label = widgets.HTML( + 'A maximum of 64 jobs may be queued at once.
') dep_hbox.children += (dep_validation_label,) if not validated: @@ -718,11 +735,19 @@ def on_button_clicked(b): source.value) # replace any env variables in source.value with their values, can use $foo or ${foo}. Particularly useful for ${AWS_REGION} logger.info(f'using source_exp: {source_exp}') try: - load_result = do_load(host, port, source_format.value, ssl, str(source_exp), region_box.value, - arn.value, fail_on_error.value, request_generator, mode=mode.value, - parallelism=parallelism.value, update_single_cardinality=update_single_cardinality.value, - queue_request=queue_request.value, dependencies=dependencies_list) - + kwargs = { + 'failOnError': fail_on_error.value, + 'parallelism': parallelism.value, + 'updateSingleCardinalityProperties': update_single_cardinality.value, + 'queueRequest': queue_request.value + } + + if dependencies: + kwargs['dependencies'] = dependencies_list + + load_res = self.client.load(source.value, source_format.value, arn.value, region_box.value, **kwargs) + load_res.raise_for_status() + load_result = load_res.json() store_to_ns(args.store_to, load_result, local_ns) source_hbox.close() @@ -767,8 +792,9 @@ def on_button_clicked(b): with job_status_output: display_html(HTML(loading_wheel_html)) try: - interval_check_response = get_load_status(host, port, ssl, request_generator, - load_result['payload']['loadId']) + load_status_res = self.client.load_status(load_result['payload']['loadId']) + load_status_res.raise_for_status() + interval_check_response = load_status_res.json() except Exception as e: logger.error(e) with job_status_output: @@ -806,10 +832,9 @@ def load_ids(self, line, local_ns: dict = None): parser.add_argument('--store-to', type=str, default='') args = parser.parse_args(line.split()) - credentials_provider_mode = self.graph_notebook_config.iam_credentials_provider_type - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, credentials_provider_mode) - res = get_loader_jobs(self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator) + load_status = self.client.load_status() + load_status.raise_for_status() + res = load_status.json() ids = [] if 'payload' in res and 'loadIds' in res['payload']: ids = res['payload']['loadIds'] @@ -834,15 +859,21 @@ def load_status(self, line, local_ns: dict = None): parser.add_argument('--store-to', type=str, default='') parser.add_argument('--details', action='store_true', default=False) parser.add_argument('--errors', action='store_true', default=False) - parser.add_argument('--page', '-p', default='1', help='The error page number. Only valid when the --errors option is set.') - parser.add_argument('--errorsPerPage', '-e', default='10', help='The number of errors per each page. Only valid when the --errors option is set.') + parser.add_argument('--page', '-p', default='1', + help='The error page number. Only valid when the --errors option is set.') + parser.add_argument('--errorsPerPage', '-e', default='10', + help='The number of errors per each page. Only valid when the --errors option is set.') args = parser.parse_args(line.split()) - credentials_provider_mode = self.graph_notebook_config.iam_credentials_provider_type - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, credentials_provider_mode) - res = get_load_status(self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator, args.load_id, args.details, args.errors, - args.page, args.errorsPerPage) + payload = { + 'details': args.details, + 'errors': args.errors, + 'page': args.page, + 'errorsPerPage': args.errorsPerPage + } + load_status_res = self.client.load_status(args.load_id, **payload) + load_status_res.raise_for_status() + res = load_status_res.json() print(json.dumps(res, indent=2)) if args.store_to != '' and local_ns is not None: @@ -857,10 +888,9 @@ def cancel_load(self, line, local_ns: dict = None): parser.add_argument('--store-to', type=str, default='') args = parser.parse_args(line.split()) - credentials_provider_mode = self.graph_notebook_config.iam_credentials_provider_type - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, credentials_provider_mode) - res = cancel_load(self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator, args.load_id) + cancel_res = self.client.cancel_load(args.load_id) + cancel_res.raise_for_status() + res = cancel_res.json() if res: print('Cancelled successfully.') else: @@ -875,7 +905,7 @@ def seed(self, line): parser = argparse.ArgumentParser() parser.add_argument('--language', type=str, default='', choices=SEED_LANGUAGE_OPTIONS) parser.add_argument('--dataset', type=str, default='') - # TODO: Gremlin paths are not yet supported. + # TODO: Gremlin api paths are not yet supported. parser.add_argument('--path', '-p', default=SPARQL_ACTION, help='prefix path to query endpoint. For example, "foo/bar". The queried path would then be host:port/foo/bar for sparql seed commands') parser.add_argument('--run', action='store_true') @@ -938,21 +968,11 @@ def on_button_clicked(b=None): for q in queries: with output: print(f'{progress.value}/{len(queries)}:\t{q["name"]}') - # Just like with the load command, seed is long-running - # as such, we want to obtain the values of host, port, etc. in case they - # change during execution. - host = self.graph_notebook_config.host - port = self.graph_notebook_config.port - auth_mode = self.graph_notebook_config.auth_mode - ssl = self.graph_notebook_config.ssl - if language == 'gremlin': - client_provider = create_client_provider(auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) # IMPORTANT: We treat each line as its own query! for line in q['content'].splitlines(): try: - do_gremlin_query(line, host, port, ssl, client_provider) + self.client.gremlin_query(line) except GremlinServerError as gremlinEx: try: error = json.loads(gremlinEx.args[0][5:]) # remove the leading error code. @@ -975,10 +995,8 @@ def on_button_clicked(b=None): progress.close() return else: - request_generator = create_request_generator(auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) try: - do_sparql_query(q['content'], host, port, ssl, request_generator, path=args.path) + self.client.sparql(q['content'], path=args.path) except HTTPError as httpEx: # attempt to turn response into json try: @@ -1053,11 +1071,9 @@ def neptune_ml(self, line, cell='', local_ns: dict = None): parser = generate_neptune_ml_parser() args = parser.parse_args(line.split()) logger.info(f'received call to neptune_ml with details: {args.__dict__}, cell={cell}, local_ns={local_ns}') - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) main_output = widgets.Output() display(main_output) - res = neptune_ml_magic_handler(args, request_generator, self.graph_notebook_config, main_output, cell, local_ns) + res = neptune_ml_magic_handler(args, self.client, main_output, cell, local_ns) message = json.dumps(res, indent=2) if type(res) is dict else res store_to_ns(args.store_to, res, local_ns) with main_output: diff --git a/src/graph_notebook/magics/ml.py b/src/graph_notebook/magics/ml.py index b40db040..91e42474 100644 --- a/src/graph_notebook/magics/ml.py +++ b/src/graph_notebook/magics/ml.py @@ -6,12 +6,8 @@ from IPython.core.display import display from ipywidgets import widgets -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import credentials_provider_factory -from graph_notebook.authentication.iam_credentials_provider.credentials_provider import Credentials -from graph_notebook.configuration.generate_config import Configuration, AuthModeEnum from graph_notebook.magics.parsing import str_to_namespace_var -from graph_notebook.ml.sagemaker import start_export, get_export_status, start_processing_job, get_processing_status, \ - start_training, get_training_status, start_create_endpoint, get_endpoint_status, EXPORT_SERVICE_NAME +from graph_notebook.neptune.client import Client logger = logging.getLogger("neptune_ml_magic_handler") @@ -146,17 +142,26 @@ def generate_neptune_ml_parser(): return parser -def neptune_ml_export_start(params, export_url: str, export_ssl: bool = True, creds: Credentials = None): +def neptune_ml_export_start(client: Client, params, export_url: str, export_ssl: bool = True): if type(params) is str: params = json.loads(params) - job = start_export(export_url, params, export_ssl, creds) + export_res = client.export(export_url, params, export_ssl) + export_res.raise_for_status() + job = export_res.json() return job -def wait_for_export(export_url: str, job_id: str, output: widgets.Output, +def neptune_ml_export_status(client: Client, export_url: str, job_id: str, export_ssl: bool = True): + res = client.export_status(export_url, job_id, export_ssl) + res.raise_for_status() + job = res.json() + return job + + +def wait_for_export(client: Client, export_url: str, job_id: str, output: widgets.Output, export_ssl: bool = True, wait_interval: int = DEFAULT_WAIT_INTERVAL, - wait_timeout: int = DEFAULT_WAIT_TIMEOUT, creds: Credentials = None): + wait_timeout: int = DEFAULT_WAIT_TIMEOUT): job_id_output = widgets.Output() update_widget_output = widgets.Output() with output: @@ -170,7 +175,9 @@ def wait_for_export(export_url: str, job_id: str, output: widgets.Output, while datetime.datetime.utcnow() - beginning_time < (datetime.timedelta(seconds=wait_timeout)): update_widget_output.clear_output() print('Checking for latest status...') - export_status = get_export_status(export_url, export_ssl, job_id, creds) + status_res = client.export_status(export_url, job_id, export_ssl) + status_res.raise_for_status() + export_status = status_res.json() if export_status['status'] in ['succeeded', 'failed']: print('Export is finished') return export_status @@ -180,32 +187,30 @@ def wait_for_export(export_url: str, job_id: str, output: widgets.Output, time.sleep(wait_interval) -def neptune_ml_export(args: argparse.Namespace, config: Configuration, output: widgets.Output, cell: str): - auth_mode = AuthModeEnum.IAM if args.export_iam else AuthModeEnum.DEFAULT - creds = None - if auth_mode == AuthModeEnum.IAM: - creds = credentials_provider_factory(config.iam_credentials_provider_type).get_iam_credentials() - +def neptune_ml_export(args: argparse.Namespace, client: Client, output: widgets.Output, + cell: str): export_ssl = not args.export_no_ssl if args.which_sub == 'start': if cell == '': return 'Cell body must have json payload or reference notebook variable using syntax ${payload_var}' - export_job = neptune_ml_export_start(cell, args.export_url, export_ssl, creds) + export_job = neptune_ml_export_start(client, cell, args.export_url, export_ssl) if args.wait: - return wait_for_export(args.export_url, export_job['jobId'], - output, export_ssl, args.wait_interval, args.wait_timeout, creds) + return wait_for_export(client, args.export_url, export_job['jobId'], + output, export_ssl, args.wait_interval, args.wait_timeout) else: return export_job elif args.which_sub == 'status': if args.wait: - status = wait_for_export(args.export_url, args.job_id, output, export_ssl, args.wait_interval, - args.wait_timeout, creds) + status = wait_for_export(client, args.export_url, args.job_id, output, export_ssl, args.wait_interval, + args.wait_timeout) else: - status = get_export_status(args.export_url, export_ssl, args.job_id, creds) + status_res = client.export_status(args.export_url, args.job_id, export_ssl) + status_res.raise_for_status() + status = status_res.json() return status -def wait_for_dataprocessing(job_id: str, config: Configuration, request_param_generator, output: widgets.Output, +def wait_for_dataprocessing(job_id: str, client: Client, output: widgets.Output, wait_interval: int = DEFAULT_WAIT_INTERVAL, wait_timeout: int = DEFAULT_WAIT_TIMEOUT): job_id_output = widgets.Output() update_status_output = widgets.Output() @@ -219,7 +224,9 @@ def wait_for_dataprocessing(job_id: str, config: Configuration, request_param_ge beginning_time = datetime.datetime.utcnow() while datetime.datetime.utcnow() - beginning_time < (datetime.timedelta(seconds=wait_timeout)): update_status_output.clear_output() - status = get_processing_status(config.host, str(config.port), config.ssl, request_param_generator, job_id) + status_res = client.dataprocessing_job_status(job_id) + status_res.raise_for_status() + status = status_res.json() if status['status'] in ['Completed', 'Failed']: print('Data processing is finished') return status @@ -229,37 +236,34 @@ def wait_for_dataprocessing(job_id: str, config: Configuration, request_param_ge time.sleep(wait_interval) -def neptune_ml_dataprocessing(args: argparse.Namespace, request_param_generator, output: widgets.Output, - config: Configuration, params: dict = None): +def neptune_ml_dataprocessing(args: argparse.Namespace, client, output: widgets.Output, params: dict = None): if args.which_sub == 'start': if params is None or params == '' or params == {}: params = { - 'inputDataS3Location': args.s3_input_uri, - 'processedDataS3Location': args.s3_processed_uri, 'id': args.job_id, 'configFileName': args.config_file_name } - processing_job = start_processing_job(config.host, str(config.port), config.ssl, - request_param_generator, params) + processing_job_res = client.dataprocessing_start(args.s3_input_uri, args.s3_processed_uri, **params) + processing_job_res.raise_for_status() + processing_job = processing_job_res.json() job_id = params['id'] if args.wait: - return wait_for_dataprocessing(job_id, config, request_param_generator, - output, args.wait_interval, args.wait_timeout) + return wait_for_dataprocessing(job_id, client, output, args.wait_interval, args.wait_timeout) else: return processing_job elif args.which_sub == 'status': if args.wait: - return wait_for_dataprocessing(args.job_id, config, request_param_generator, output, args.wait_interval, - args.wait_timeout) + return wait_for_dataprocessing(args.job_id, client, output, args.wait_interval, args.wait_timeout) else: - return get_processing_status(config.host, str(config.port), config.ssl, request_param_generator, - args.job_id) + processing_status = client.dataprocessing_job_status(args.job_id) + processing_status.raise_for_status() + return processing_status.json() else: return f'Sub parser "{args.which} {args.which_sub}" was not recognized' -def wait_for_training(job_id: str, config: Configuration, request_param_generator, output: widgets.Output, +def wait_for_training(job_id: str, client: Client, output: widgets.Output, wait_interval: int = DEFAULT_WAIT_INTERVAL, wait_timeout: int = DEFAULT_WAIT_TIMEOUT): job_id_output = widgets.Output() update_status_output = widgets.Output() @@ -273,7 +277,9 @@ def wait_for_training(job_id: str, config: Configuration, request_param_generato beginning_time = datetime.datetime.utcnow() while datetime.datetime.utcnow() - beginning_time < (datetime.timedelta(seconds=wait_timeout)): update_status_output.clear_output() - status = get_training_status(config.host, str(config.port), config.ssl, request_param_generator, job_id) + training_status_res = client.modeltraining_job_status(job_id) + training_status_res.raise_for_status() + status = training_status_res.json() if status['status'] in ['Completed', 'Failed']: print('Training is finished') return status @@ -283,35 +289,34 @@ def wait_for_training(job_id: str, config: Configuration, request_param_generato time.sleep(wait_interval) -def neptune_ml_training(args: argparse.Namespace, request_param_generator, config: Configuration, - output: widgets.Output, params): +def neptune_ml_training(args: argparse.Namespace, client: Client, output: widgets.Output, params): if args.which_sub == 'start': if params is None or params == '' or params == {}: params = { "id": args.job_id, "dataProcessingJobId": args.data_processing_id, "trainingInstanceType": args.instance_type, - "trainModelS3Location": args.s3_output_uri } - training_job = start_training(config.host, str(config.port), config.ssl, request_param_generator, params) + start_training_res = client.modeltraining_start(args.job_id, args.s3_output_uri, **params) + start_training_res.raise_for_status() + training_job = start_training_res.json() if args.wait: - return wait_for_training(training_job['id'], config, request_param_generator, output, args.wait_interval, - args.wait_timeout) + return wait_for_training(training_job['id'], client, output, args.wait_interval, args.wait_timeout) else: return training_job elif args.which_sub == 'status': if args.wait: - return wait_for_training(args.job_id, config, request_param_generator, output, args.wait_interval, - args.wait_timeout) + return wait_for_training(args.job_id, client, output, args.wait_interval, args.wait_timeout) else: - return get_training_status(config.host, str(config.port), config.ssl, request_param_generator, - args.job_id) + training_status_res = client.modeltraining_job_status(args.job_id) + training_status_res.raise_for_status() + return training_status_res.json() else: return f'Sub parser "{args.which} {args.which_sub}" was not recognized' -def wait_for_endpoint(job_id: str, config: Configuration, request_param_generator, output: widgets.Output, +def wait_for_endpoint(job_id: str, client: Client, output: widgets.Output, wait_interval: int = DEFAULT_WAIT_INTERVAL, wait_timeout: int = DEFAULT_WAIT_TIMEOUT): job_id_output = widgets.Output() update_status_output = widgets.Output() @@ -325,7 +330,9 @@ def wait_for_endpoint(job_id: str, config: Configuration, request_param_generato beginning_time = datetime.datetime.utcnow() while datetime.datetime.utcnow() - beginning_time < (datetime.timedelta(seconds=wait_timeout)): update_status_output.clear_output() - status = get_endpoint_status(config.host, str(config.port), config.ssl, request_param_generator, job_id) + endpoint_status_res = client.endpoints_status(job_id) + endpoint_status_res.raise_for_status() + status = endpoint_status_res.json() if status['status'] in ['InService', 'Failed']: print('Endpoint creation is finished') return status @@ -335,47 +342,44 @@ def wait_for_endpoint(job_id: str, config: Configuration, request_param_generato time.sleep(wait_interval) -def neptune_ml_endpoint(args: argparse.Namespace, request_param_generator, - config: Configuration, output: widgets.Output, params): +def neptune_ml_endpoint(args: argparse.Namespace, client: Client, output: widgets.Output, params): if args.which_sub == 'create': if params is None or params == '' or params == {}: params = { "id": args.job_id, - "mlModelTrainingJobId": args.model_job_id, 'instanceType': args.instance_type } - create_endpoint_job = start_create_endpoint(config.host, str(config.port), config.ssl, - request_param_generator, params) - + create_endpoint_res = client.endpoints_create(args.model_job_id, **params) + create_endpoint_res.raise_for_status() + create_endpoint_job = create_endpoint_res.json() if args.wait: - return wait_for_endpoint(create_endpoint_job['id'], config, request_param_generator, output, - args.wait_interval, args.wait_timeout) + return wait_for_endpoint(create_endpoint_job['id'], client, output, args.wait_interval, args.wait_timeout) else: return create_endpoint_job elif args.which_sub == 'status': if args.wait: - return wait_for_endpoint(args.job_id, config, request_param_generator, output, - args.wait_interval, args.wait_timeout) + return wait_for_endpoint(args.job_id, client, output, args.wait_interval, args.wait_timeout) else: - return get_endpoint_status(config.host, str(config.port), config.ssl, request_param_generator, args.job_id) + endpoint_status = client.endpoints_status(args.job_id) + endpoint_status.raise_for_status() + return endpoint_status.json() else: return f'Sub parser "{args.which} {args.which_sub}" was not recognized' -def neptune_ml_magic_handler(args, request_param_generator, config: Configuration, output: widgets.Output, - cell: str = '', local_ns: dict = None) -> any: +def neptune_ml_magic_handler(args, client: Client, output: widgets.Output, cell: str = '', local_ns: dict = None): if local_ns is None: local_ns = {} cell = str_to_namespace_var(cell, local_ns) if args.which == 'export': - return neptune_ml_export(args, config, output, cell) + return neptune_ml_export(args, client, output, cell) elif args.which == 'dataprocessing': - return neptune_ml_dataprocessing(args, request_param_generator, output, config, cell) + return neptune_ml_dataprocessing(args, client, output, cell) elif args.which == 'training': - return neptune_ml_training(args, request_param_generator, config, output, cell) + return neptune_ml_training(args, client, output, cell) elif args.which == 'endpoint': - return neptune_ml_endpoint(args, request_param_generator, config, output, cell) + return neptune_ml_endpoint(args, client, output, cell) else: return f'sub parser {args.which} was not recognized' diff --git a/src/graph_notebook/ml/sagemaker.py b/src/graph_notebook/ml/sagemaker.py deleted file mode 100644 index 71c4a59b..00000000 --- a/src/graph_notebook/ml/sagemaker.py +++ /dev/null @@ -1,86 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import json -import requests -from requests_aws4auth import AWS4Auth - -from graph_notebook.authentication.iam_credentials_provider.credentials_provider import Credentials -from graph_notebook.request_param_generator.call_and_get_response import call_and_get_response - -EXPORT_SERVICE_NAME = 'execute-api' -EXPORT_ACTION = 'neptune-export' -EXTRA_HEADERS = {'content-type': 'application/json'} -UPDATE_DELAY_SECONDS = 60 - - -def start_export(export_host: str, export_params: dict, use_ssl: bool, - creds: Credentials = None) -> dict: - auth = None - if creds is not None: - auth = AWS4Auth(creds.key, creds.secret, creds.region, EXPORT_SERVICE_NAME, - session_token=creds.token) - - protocol = 'https' if use_ssl else 'http' - url = f'{protocol}://{export_host}/{EXPORT_ACTION}' - res = requests.post(url, json=export_params, headers=EXTRA_HEADERS, auth=auth) - res.raise_for_status() - job = res.json() - return job - - -def get_export_status(export_host: str, use_ssl: bool, job_id: str, creds: Credentials = None): - auth = None - if creds is not None: - auth = AWS4Auth(creds.key, creds.secret, creds.region, EXPORT_SERVICE_NAME, - session_token=creds.token) - - protocol = 'https' if use_ssl else 'http' - url = f'{protocol}://{export_host}/{EXPORT_ACTION}/{job_id}' - res = requests.get(url, headers=EXTRA_HEADERS, auth=auth) - res.raise_for_status() - job = res.json() - return job - - -def get_processing_status(host: str, port: str, use_ssl: bool, request_param_generator, job_name: str): - res = call_and_get_response('get', f'ml/dataprocessing/{job_name}', host, port, request_param_generator, - use_ssl, extra_headers=EXTRA_HEADERS) - status = res.json() - return status - - -def start_processing_job(host: str, port: str, use_ssl: bool, request_param_generator, params: dict): - params_raw = json.dumps(params) if type(params) is dict else params - res = call_and_get_response('post', 'ml/dataprocessing', host, port, request_param_generator, use_ssl, params_raw, - EXTRA_HEADERS) - job = res.json() - return job - - -def start_training(host: str, port: str, use_ssl: bool, request_param_generator, params): - params_raw = json.dumps(params) if type(params) is dict else params - res = call_and_get_response('post', 'ml/modeltraining', host, port, request_param_generator, use_ssl, params_raw, - EXTRA_HEADERS) - return res.json() - - -def get_training_status(host: str, port: str, use_ssl: bool, request_param_generator, training_job_name: str): - res = call_and_get_response('get', f'ml/modeltraining/{training_job_name}', host, port, - request_param_generator, use_ssl, extra_headers=EXTRA_HEADERS) - return res.json() - - -def start_create_endpoint(host: str, port: str, use_ssl: bool, request_param_generator, params): - params_raw = json.dumps(params) if type(params) is dict else params - res = call_and_get_response('post', 'ml/endpoints', host, port, request_param_generator, use_ssl, params_raw, - EXTRA_HEADERS) - return res.json() - - -def get_endpoint_status(host: str, port: str, use_ssl: bool, request_param_generator, training_job_name: str): - res = call_and_get_response('get', f'ml/endpoints/{training_job_name}', host, port, request_param_generator, - use_ssl, extra_headers=EXTRA_HEADERS) - return res.json() diff --git a/src/graph_notebook/authentication/__init__.py b/src/graph_notebook/neptune/__init__.py similarity index 100% rename from src/graph_notebook/authentication/__init__.py rename to src/graph_notebook/neptune/__init__.py diff --git a/src/graph_notebook/neptune/client.py b/src/graph_notebook/neptune/client.py new file mode 100644 index 00000000..5bad2701 --- /dev/null +++ b/src/graph_notebook/neptune/client.py @@ -0,0 +1,497 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" + +import json + +import botocore +import requests +from SPARQLWrapper import SPARQLWrapper +from boto3 import Session +from botocore.auth import SigV4Auth +from botocore.awsrequest import AWSRequest +from gremlin_python.driver import client +from tornado import httpclient + +import graph_notebook.neptune.gremlin.graphsonV3d0_MapType_objectify_patch # noqa F401 + +DEFAULT_SPARQL_CONTENT_TYPE = 'application/x-www-form-urlencoded' +DEFAULT_PORT = 8182 +DEFAULT_REGION = 'us-east-1' + +NEPTUNE_SERVICE_NAME = 'neptune-db' + +# TODO: Constants for states of each long-running job +# TODO: add doc links to each command + +FORMAT_CSV = 'csv' +FORMAT_NTRIPLE = 'ntriples' +FORMAT_NQUADS = 'nquads' +FORMAT_RDFXML = 'rdfxml' +FORMAT_TURTLE = 'turtle' + +PARALLELISM_LOW = 'LOW' +PARALLELISM_MEDIUM = 'MEDIUM' +PARALLELISM_HIGH = 'HIGH' +PARALLELISM_OVERSUBSCRIBE = 'OVERSUBSCRIBE' + +MODE_RESUME = 'RESUME' +MODE_NEW = 'NEW' +MODE_AUTO = 'AUTO' + +LOAD_JOB_MODES = [MODE_RESUME, MODE_NEW, MODE_AUTO] +VALID_FORMATS = [FORMAT_CSV, FORMAT_NTRIPLE, FORMAT_NQUADS, FORMAT_RDFXML, FORMAT_TURTLE] +PARALLELISM_OPTIONS = [PARALLELISM_LOW, PARALLELISM_MEDIUM, PARALLELISM_HIGH, PARALLELISM_OVERSUBSCRIBE] +LOADER_ACTION = 'loader' + +FINAL_LOAD_STATUSES = ['LOAD_COMPLETED', + 'LOAD_COMMITTED_W_WRITE_CONFLICTS', + 'LOAD_CANCELLED_BY_USER', + 'LOAD_CANCELLED_DUE_TO_ERRORS', + 'LOAD_FAILED', + 'LOAD_UNEXPECTED_ERROR', + 'LOAD_DATA_DEADLOCK', + 'LOAD_DATA_FAILED_DUE_TO_FEED_MODIFIED_OR_DELETED', + 'LOAD_S3_READ_ERROR', + 'LOAD_S3_ACCESS_DENIED_ERROR', + 'LOAD_IN_QUEUE', + 'LOAD_FAILED_BECAUSE_DEPENDENCY_NOT_SATISFIED', + 'LOAD_FAILED_INVALID_REQUEST', ] + +EXPORT_SERVICE_NAME = 'execute-api' +EXPORT_ACTION = 'neptune-export' +EXTRA_HEADERS = {'content-type': 'application/json'} +SPARQL_ACTION = 'sparql' + + +class Client(object): + def __init__(self, host: str, port: int = DEFAULT_PORT, ssl: bool = True, region: str = DEFAULT_REGION, + sparql_path: str = '/sparql', auth=None, session: Session = None): + self.host = host + self.port = port + self.ssl = ssl + self.sparql_path = sparql_path + self.region = region + self._auth = auth + self._session = session + + self._http_protocol = 'https' if self.ssl else 'http' + self._ws_protocol = 'wss' if self.ssl else 'ws' + + self._http_session = None + + def sparql_query(self, query: str, headers=None, explain: str = '', path: str = '') -> requests.Response: + if headers is None: + headers = {} + + data = {'query': query} + return self.do_sparql_request(data, headers, explain, path=path) + + def sparql_update(self, update: str, headers=None, explain: str = '', path: str = '') -> requests.Response: + if headers is None: + headers = {} + + data = {'update': update} + return self.do_sparql_request(data, headers, explain, path=path) + + def do_sparql_request(self, data: dict, headers=None, explain: str = '', path: str = ''): + if 'content-type' not in headers: + headers['content-type'] = DEFAULT_SPARQL_CONTENT_TYPE + + explain = explain.lower() + if explain != '': + if explain not in ['static', 'dynamic', 'details']: + raise ValueError('explain mode not valid, must be one of "static", "dynamic", or "details"') + else: + data['explain'] = explain + + sparql_path = path if path != '' else self.sparql_path + uri = f'{self._http_protocol}://{self.host}:{self.port}/{sparql_path}' + req = self._prepare_request('POST', uri, data=data, headers=headers) + res = self._http_session.send(req) + return res + + def sparql(self, query: str, headers=None, explain: str = '', path: str = '') -> requests.Response: + if headers is None: + headers = {} + + s = SPARQLWrapper('') + s.setQuery(query) + query_type = s.queryType.upper() + if query_type in ['SELECT', 'CONSTRUCT', 'ASK', 'DESCRIBE']: + return self.sparql_query(query, headers, explain, path=path) + else: + return self.sparql_update(query, headers, explain, path=path) + + # TODO: enum/constants for supported types + def sparql_explain(self, query: str, explain: str = 'dynamic', output_format: str = 'text/html', + headers=None, path: str = '') -> requests.Response: + if headers is None: + headers = {} + + if 'Accept' not in headers: + headers['Accept'] = output_format + + return self.sparql(query, headers, explain, path=path) + + def sparql_status(self, query_id: str = ''): + return self._query_status('sparql', query_id=query_id) + + def sparql_cancel(self, query_id: str, silent: bool = False): + if type(query_id) is not str or query_id == '': + raise ValueError('query_id must be a non-empty string') + return self._query_status('sparql', query_id=query_id, silent=silent, cancelQuery=True) + + def get_gremlin_connection(self) -> client.Client: + uri = f'{self._http_protocol}://{self.host}:{self.port}/gremlin' + request = self._prepare_request('GET', uri) + + ws_url = f'{self._ws_protocol}://{self.host}:{self.port}/gremlin' + ws_request = httpclient.HTTPRequest(ws_url, headers=dict(request.headers)) + return client.Client(ws_request, 'g') + + def gremlin_query(self, query, bindings=None): + c = self.get_gremlin_connection() + try: + result = c.submit(query, bindings) + future_results = result.all() + results = future_results.result() + c.close() + return results + except Exception as e: + c.close() + raise e + + def gremlin_http_query(self, query, headers=None) -> requests.Response: + if headers is None: + headers = {} + + uri = f'{self._http_protocol}://{self.host}:{self.port}/gremlin' + data = {'gremlin': query} + req = self._prepare_request('POST', uri, data=json.dumps(data), headers=headers) + res = self._http_session.send(req) + return res + + def gremlin_status(self, query_id: str = '', include_waiting: bool = False): + kwargs = {} + if include_waiting: + kwargs['includeWaiting'] = True + return self._query_status('gremlin', query_id=query_id, **kwargs) + + def gremlin_cancel(self, query_id: str): + if type(query_id) is not str or query_id == '': + raise ValueError('query_id must be a non-empty string') + return self._query_status('gremlin', query_id=query_id, cancelQuery=True) + + def gremlin_explain(self, query: str) -> requests.Response: + return self._gremlin_query_plan(query, 'explain') + + def gremlin_profile(self, query: str) -> requests.Response: + return self._gremlin_query_plan(query, 'profile') + + def _gremlin_query_plan(self, query: str, plan_type: str, ) -> requests.Response: + url = f'{self._http_protocol}://{self.host}:{self.port}/gremlin/{plan_type}' + data = {'gremlin': query} + req = self._prepare_request('POST', url, data=json.dumps(data)) + res = self._http_session.send(req) + return res + + def status(self) -> requests.Response: + url = f'{self._http_protocol}://{self.host}:{self.port}/status' + req = self._prepare_request('GET', url, data='') + res = self._http_session.send(req) + return res + + def load(self, source: str, source_format: str, iam_role_arn: str, region: str, **kwargs) -> requests.Response: + """ + For a full list of allowed parameters, see aws documentation on the Neptune loader + endpoint: https://docs.aws.amazon.com/neptune/latest/userguide/load-api-reference-load.html + """ + payload = { + 'source': source, + 'format': source_format, + 'region': self.region, + 'iamRoleArn': iam_role_arn + } + + for key, value in kwargs.items(): + payload[key] = value + + url = f'{self._http_protocol}://{self.host}:{self.port}/loader' + raw = json.dumps(payload) + req = self._prepare_request('POST', url, data=raw, headers={'content-type': 'application/json'}) + res = self._http_session.send(req) + return res + + def load_status(self, load_id: str = '', **kwargs) -> requests.Response: + params = {} + for k, v in kwargs.items(): + params[k] = v + + if load_id != '': + params['loadId'] = load_id + + url = f'{self._http_protocol}://{self.host}:{self.port}/loader' + req = self._prepare_request('GET', url, params=params) + res = self._http_session.send(req) + return res + + def cancel_load(self, load_id: str) -> requests.Response: + url = f'{self._http_protocol}://{self.host}:{self.port}/loader' + params = {'loadId': load_id} + req = self._prepare_request('DELETE', url, params=params) + res = self._http_session.send(req) + return res + + def initiate_reset(self) -> requests.Response: + data = { + 'action': 'initiateDatabaseReset' + } + url = f'{self._http_protocol}://{self.host}:{self.port}/system' + req = self._prepare_request('POST', url, data=data) + res = self._http_session.send(req) + return res + + def perform_reset(self, token: str) -> requests.Response: + data = { + 'action': 'performDatabaseReset', + 'token': token + } + url = f'{self._http_protocol}://{self.host}:{self.port}/system' + req = self._prepare_request('POST', url, data=data) + res = self._http_session.send(req) + return res + + def dataprocessing_start(self, s3_input_uri: str, s3_output_uri: str, **kwargs) -> requests.Response: + data = { + 'inputDataS3Location': s3_input_uri, + 'processedDataS3Location': s3_output_uri, + } + + for k, v in kwargs.items(): + data[k] = v + + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/dataprocessing' + req = self._prepare_request('POST', url, data=json.dumps(data), headers={'content-type': 'application/json'}) + res = self._http_session.send(req) + return res + + def dataprocessing_job_status(self, job_id: str, neptune_iam_role_arn: str = '') -> requests.Response: + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/dataprocessing/{job_id}' + data = {} + if neptune_iam_role_arn != '': + data['neptuneIamRoleArn'] = neptune_iam_role_arn + req = self._prepare_request('GET', url, params=data) + res = self._http_session.send(req) + return res + + def dataprocessing_status(self, max_items: int = 10, neptune_iam_role_arn: str = '') -> requests.Response: + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/dataprocessing' + data = { + 'maxItems': max_items + } + + if neptune_iam_role_arn != '': + data['neptuneIamRoleArn'] = neptune_iam_role_arn + req = self._prepare_request('GET', url, params=data) + res = self._http_session.send(req) + return res + + def dataprocessing_stop(self, job_id: str, clean=False, neptune_iam_role_arn: str = '') -> requests.Response: + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/dataprocessing/{job_id}' + data = { + 'clean': clean + } + if neptune_iam_role_arn != '': + data['neptuneIamRoleArn'] = neptune_iam_role_arn + + req = self._prepare_request('DELETE', url, params=data) + res = self._http_session.send(req) + return res + + def modeltraining_start(self, data_processing_job_id: str, train_model_s3_location: str, + **kwargs) -> requests.Response: + """ + for a full list of supported parameters, see: + https://docs.aws.amazon.com/neptune/latest/userguide/machine-learning-api-modeltraining.html + """ + data = { + 'dataProcessingJobId': data_processing_job_id, + 'trainModelS3Location': train_model_s3_location + } + + for k, v in kwargs.items(): + data[k] = v + + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/modeltraining' + req = self._prepare_request('POST', url, data=json.dumps(data), headers={'content-type': 'application/json'}) + res = self._http_session.send(req) + return res + + def modeltraining_status(self, max_items: int = 10, neptune_iam_role_arn: str = '') -> requests.Response: + data = { + 'maxItems': max_items + } + + if neptune_iam_role_arn != '': + data['neptuneIamRoleArn'] = neptune_iam_role_arn + + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/modeltraining' + req = self._prepare_request('GET', url, params=data) + res = self._http_session.send(req) + return res + + def modeltraining_job_status(self, training_job_id: str, neptune_iam_role_arn: str = '') -> requests.Response: + data = {} if neptune_iam_role_arn == '' else {'neptuneIamRoleArn': neptune_iam_role_arn} + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/modeltraining/{training_job_id}' + req = self._prepare_request('GET', url, params=data) + res = self._http_session.send(req) + return res + + def modeltraining_stop(self, training_job_id: str, neptune_iam_role_arn: str = '', + clean: bool = False) -> requests.Response: + data = { + 'clean': "TRUE" if clean else "FALSE", + } + + if neptune_iam_role_arn != '': + data['neptuneIamRoleArn'] = neptune_iam_role_arn + + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/modeltraining/{training_job_id}' + req = self._prepare_request('DELETE', url, params=data) + res = self._http_session.send(req) + return res + + def endpoints_create(self, training_job_id: str, **kwargs) -> requests.Response: + data = { + 'mlModelTrainingJobId': training_job_id + } + + for k, v in kwargs.items(): + data[k] = v + + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/endpoints' + req = self._prepare_request('POST', url, data=json.dumps(data), headers={'content-type': 'application/json'}) + res = self._http_session.send(req) + return res + + def endpoints_status(self, endpoint_id: str, neptune_iam_role_arn: str = '') -> requests.Response: + data = {} if neptune_iam_role_arn == '' else {'neptuneIamRoleArn': neptune_iam_role_arn} + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/endpoints/{endpoint_id}' + req = self._prepare_request('GET', url, params=data) + res = self._http_session.send(req) + return res + + def endpoints_delete(self, endpoint_id: str, neptune_iam_role_arn: str = '') -> requests.Response: + data = {} if neptune_iam_role_arn == '' else {'neptuneIamRoleArn': neptune_iam_role_arn} + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/endpoints/{endpoint_id}' + req = self._prepare_request('DELETE', url, params=data) + res = self._http_session.send(req) + return res + + def endpoints(self, max_items: int = 10, neptune_iam_role_arn: str = ''): + data = { + 'maxItems': max_items + } + if neptune_iam_role_arn != '': + data['neptuneIamRoleArn'] = neptune_iam_role_arn + + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/endpoints' + req = self._prepare_request('GET', url, params=data) + res = self._http_session.send(req) + return res + + def export(self, host: str, params: dict, ssl: bool = True) -> requests.Response: + protocol = 'https' if ssl else 'http' + url = f'{protocol}://{host}/{EXPORT_ACTION}' + req = self._prepare_request('POST', url, data=json.dumps(params), service="execute-api") + res = self._http_session.send(req) + return res + + def export_status(self, host, job_id, ssl: bool = True) -> requests.Response: + protocol = 'https' if ssl else 'http' + url = f'{protocol}://{host}/{EXPORT_ACTION}/{job_id}' + req = self._prepare_request('GET', url, service="execute-api") + res = self._http_session.send(req) + return res + + def _query_status(self, language: str, *, query_id: str = '', **kwargs) -> requests.Response: + data = {} + if query_id != '': + data['queryId'] = query_id + + for k, v in kwargs.items(): + data[k] = v + + headers = { + 'Content-Type': 'application/x-www-form-urlencoded' + } + url = f'{self._http_protocol}://{self.host}:{self.port}/{language}/status' + req = self._prepare_request('POST', url, data=data, headers=headers) + res = self._http_session.send(req) + return res + + def _prepare_request(self, method, url, *, data=None, params=None, headers=None, service=NEPTUNE_SERVICE_NAME): + self._ensure_http_session() + request = requests.Request(method=method, url=url, data=data, params=params, headers=headers, auth=self._auth) + if self._session is not None: + credentials = self._session.get_credentials() + frozen_creds = credentials.get_frozen_credentials() + + req = AWSRequest(method=method, url=url, data=data, params=params, headers=headers) + SigV4Auth(frozen_creds, service, self.region).add_auth(req) + prepared_iam_req = req.prepare() + request.headers = dict(prepared_iam_req.headers) + + return request.prepare() + + def _ensure_http_session(self): + if not self._http_session: + self._http_session = requests.Session() + + def set_session(self, session: Session): + self._session = session + + def close(self): + if self._http_session: + self._http_session.close() + self._http_session = None + + @property + def iam_enabled(self): + return type(self._session) is botocore.session.Session + + +class ClientBuilder(object): + def __init__(self, args: dict = None): + if args is None: + args = {} + self.args = args + + def with_host(self, host: str): + self.args['host'] = host + return ClientBuilder(self.args) + + def with_port(self, port: int): + self.args['port'] = port + return ClientBuilder(self.args) + + def with_sparql_path(self, path: str): + self.args['sparql_path'] = path + return ClientBuilder(self.args) + + def with_tls(self, tls: bool): + self.args['ssl'] = tls + return ClientBuilder(self.args) + + def with_region(self, region: str): + self.args['region'] = region + return ClientBuilder(self.args) + + def with_iam(self, session: Session): + self.args['session'] = session + return ClientBuilder(self.args) + + def build(self) -> Client: + return Client(**self.args) diff --git a/src/graph_notebook/ml/__init__.py b/src/graph_notebook/neptune/gremlin/__init__.py similarity index 100% rename from src/graph_notebook/ml/__init__.py rename to src/graph_notebook/neptune/gremlin/__init__.py diff --git a/src/graph_notebook/gremlin/client_provider/graphsonV3d0_MapType_objectify_patch.py b/src/graph_notebook/neptune/gremlin/graphsonV3d0_MapType_objectify_patch.py similarity index 94% rename from src/graph_notebook/gremlin/client_provider/graphsonV3d0_MapType_objectify_patch.py rename to src/graph_notebook/neptune/gremlin/graphsonV3d0_MapType_objectify_patch.py index ab0b6896..a28b85a7 100644 --- a/src/graph_notebook/gremlin/client_provider/graphsonV3d0_MapType_objectify_patch.py +++ b/src/graph_notebook/neptune/gremlin/graphsonV3d0_MapType_objectify_patch.py @@ -4,7 +4,8 @@ """ from gremlin_python.structure.io.graphsonV3d0 import MapType -from graph_notebook.gremlin.client_provider.hashable_dict_patch import HashableDict +from graph_notebook.neptune.gremlin.hashable_dict_patch import HashableDict + # Original code from Tinkerpop 3.4.1 # diff --git a/src/graph_notebook/gremlin/client_provider/hashable_dict_patch.py b/src/graph_notebook/neptune/gremlin/hashable_dict_patch.py similarity index 100% rename from src/graph_notebook/gremlin/client_provider/hashable_dict_patch.py rename to src/graph_notebook/neptune/gremlin/hashable_dict_patch.py diff --git a/src/graph_notebook/notebooks/03-Sample-Applications/00-Sample-Applications-Overview.ipynb b/src/graph_notebook/notebooks/03-Sample-Applications/00-Sample-Applications-Overview.ipynb index efbd443e..208830e1 100644 --- a/src/graph_notebook/notebooks/03-Sample-Applications/00-Sample-Applications-Overview.ipynb +++ b/src/graph_notebook/notebooks/03-Sample-Applications/00-Sample-Applications-Overview.ipynb @@ -56,4 +56,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/src/graph_notebook/notebooks/04-Machine-Learning/neptune_ml_utils.py b/src/graph_notebook/notebooks/04-Machine-Learning/neptune_ml_utils.py index 7517ab6d..5541ca68 100644 --- a/src/graph_notebook/notebooks/04-Machine-Learning/neptune_ml_utils.py +++ b/src/graph_notebook/notebooks/04-Machine-Learning/neptune_ml_utils.py @@ -15,6 +15,8 @@ # How often to check the status UPDATE_DELAY_SECONDS = 15 +HOME_DIRECTORY = os.path.expanduser("~") + def signed_request(method, url, data=None, params=None, headers=None, service=None): creds = boto3.Session().get_credentials().get_frozen_credentials() @@ -24,7 +26,7 @@ def signed_request(method, url, data=None, params=None, headers=None, service=No def load_configuration(): - with open('/home/ec2-user/graph_notebook_config.json') as f: + with open(f'{HOME_DIRECTORY}/graph_notebook_config.json') as f: data = json.load(f) host = data['host'] port = data['port'] @@ -34,62 +36,69 @@ def load_configuration(): iam = False return host, port, iam + def get_host(): host, port, iam = load_configuration() return host + def get_iam(): host, port, iam = load_configuration() return iam + def get_training_job_name(prefix: str): return f'{prefix}-{int(time.time())}' + def check_ml_enabled(): - host, port, use_iam = load_configuration() + host, port, use_iam = load_configuration() response = signed_request("GET", url=f'https://{host}:{port}/ml/modeltraining', service='neptune-db') if response.status_code != 200: print('''This Neptune cluster \033[1mis not\033[0m configured to use Neptune ML. Please configure the cluster according to the Amazpnm Neptune ML documentation before proceeding.''') else: print("This Neptune cluster is configured to use Neptune ML") - + + def get_export_service_host(): - with open('/home/ec2-user/.bashrc') as f: + with open(f'{HOME_DIRECTORY}/.bashrc') as f: data = f.readlines() for d in data: if str.startswith(d, 'export NEPTUNE_EXPORT_API_URI'): parts = d.split('=') - if len(parts)==2: - path=urlparse(parts[1].rstrip()) + if len(parts) == 2: + path = urlparse(parts[1].rstrip()) return path.hostname + "/v1" - logging.error("Unable to determine the Neptune Export Service Endpoint. You will need to enter this assign this manually.") + logging.error( + "Unable to determine the Neptune Export Service Endpoint. You will need to enter this or assign it manually.") return None + def delete_pretrained_data(setup_node_classification: bool, setup_node_regression: bool, setup_link_prediction: bool): - host, port, use_iam = load_configuration() if setup_node_classification: response = signed_request("POST", service='neptune-db', - url=f'https://{host}:{port}/gremlin', - headers={'content-type': 'application/json'}, - data=json.dumps({'gremlin': "g.V('movie_1', 'movie_7', 'movie_15').properties('genre').drop()"})) - + url=f'https://{host}:{port}/gremlin', + headers={'content-type': 'application/json'}, + data=json.dumps( + {'gremlin': "g.V('movie_1', 'movie_7', 'movie_15').properties('genre').drop()"})) + if response.status_code != 200: print(response.content.decode('utf-8')) if setup_node_regression: response = signed_request("POST", service='neptune-db', - url=f'https://{host}:{port}/gremlin', - headers={'content-type': 'application/json'}, - data=json.dumps({'gremlin': "g.V('user_1').out('wrote').properties('score').drop()"})) + url=f'https://{host}:{port}/gremlin', + headers={'content-type': 'application/json'}, + data=json.dumps({'gremlin': "g.V('user_1').out('wrote').properties('score').drop()"})) if response.status_code != 200: print(response.content.decode('utf-8')) if setup_link_prediction: response = signed_request("POST", service='neptune-db', - url=f'https://{host}:{port}/gremlin', - headers={'content-type': 'application/json'}, - data=json.dumps({'gremlin': "g.V('user_1').outE('rated').drop()"})) + url=f'https://{host}:{port}/gremlin', + headers={'content-type': 'application/json'}, + data=json.dumps({'gremlin': "g.V('user_1').outE('rated').drop()"})) if response.status_code != 200: print(response.content.decode('utf-8')) @@ -114,7 +123,8 @@ def delete_endpoint(training_job_name: str, neptune_iam_role_arn=None): query_string = f'?neptuneIamRoleArn={neptune_iam_role_arn}' host, port, use_iam = load_configuration() response = signed_request("DELETE", service='neptune-db', - url=f'https://{host}:{port}/ml/endpoints/{training_job_name}{query_string}', headers={'content-type': 'application/json'}) + url=f'https://{host}:{port}/ml/endpoints/{training_job_name}{query_string}', + headers={'content-type': 'application/json'}) if response.status_code != 200: print(response.content.decode('utf-8')) else: @@ -129,28 +139,28 @@ def prepare_movielens_data(s3_bucket_uri: str): logging.error(e) - def setup_pretrained_endpoints(s3_bucket_uri: str, setup_node_classification: bool, setup_node_regression: bool, setup_link_prediction: bool): delete_pretrained_data(setup_node_classification, setup_node_regression, setup_link_prediction) try: - return PretrainedModels().setup_pretrained_endpoints(s3_bucket_uri, setup_node_classification, setup_node_regression, setup_link_prediction) + return PretrainedModels().setup_pretrained_endpoints(s3_bucket_uri, setup_node_classification, + setup_node_regression, setup_link_prediction) except Exception as e: logging.error(e) class MovieLensProcessor: - raw_directory = r'/home/ec2-user/data/raw' - formatted_directory = r'/home/ec2-user/data/formatted' + raw_directory = fr'{HOME_DIRECTORY}/data/raw' + formatted_directory = fr'{HOME_DIRECTORY}/data/formatted' def __download_and_unzip(self): - if not os.path.exists('/home/ec2-user/data'): - os.makedirs('/home/ec2-user/data') - if not os.path.exists('/home/ec2-user/data/raw'): - os.makedirs('/home/ec2-user/data/raw') - if not os.path.exists('/home/ec2-user/data/formatted'): - os.makedirs('/home/ec2-user/data/formatted') + if not os.path.exists(f'{HOME_DIRECTORY}/data'): + os.makedirs(f'{HOME_DIRECTORY}/data') + if not os.path.exists(f'{HOME_DIRECTORY}/data/raw'): + os.makedirs(f'{HOME_DIRECTORY}/data/raw') + if not os.path.exists(f'{HOME_DIRECTORY}/data/formatted'): + os.makedirs(f'{HOME_DIRECTORY}/data/formatted') # Download the MovieLens dataset url = 'http://files.grouplens.org/datasets/movielens/ml-100k.zip' r = requests.get(url, allow_redirects=True) @@ -163,7 +173,7 @@ def __process_movies_genres(self): # process the movies_vertex.csv print('Processing Movies', end='\r') movies_df = pd.read_csv(os.path.join( - self.raw_directory, 'ml-100k/u.item'), sep='|', encoding='ISO-8859-1', + self.raw_directory, 'ml-100k/u.item'), sep='|', encoding='ISO-8859-1', names=['~id', 'title', 'release_date', 'video_release_date', 'imdb_url', 'unknown', 'Action', 'Adventure', 'Animation', 'Childrens', 'Comedy', 'Crime', 'Documentary', 'Drama', 'Fantasy', 'Film-Noir', 'Horror', 'Musical', @@ -218,7 +228,7 @@ def __process_ratings_users(self): # Create ratings vertices and add edges on both sides print('Processing Ratings', end='\r') ratings_vertices = pd.read_csv(os.path.join( - self.raw_directory, 'ml-100k/u.data'), sep='\t', encoding='ISO-8859-1', + self.raw_directory, 'ml-100k/u.data'), sep='\t', encoding='ISO-8859-1', names=['~from', '~to', 'score:Int', 'timestamp']) ratings_vertices['~from'] = ratings_vertices['~from'].apply( lambda x: f'user_{x}') @@ -232,10 +242,10 @@ def __process_ratings_users(self): dict = {} for index, row in ratings_vertices.iterrows(): - dict[index*2] = {'~id': uuid.uuid4(), '~label': 'wrote', - '~from': row['~from'], '~to': row['~id']} - dict[index*2 + 1] = {'~id': uuid.uuid4(), '~label': 'about', - '~from': row['~id'], '~to': row['~to']} + dict[index * 2] = {'~id': uuid.uuid4(), '~label': 'wrote', + '~from': row['~from'], '~to': row['~id']} + dict[index * 2 + 1] = {'~id': uuid.uuid4(), '~label': 'about', + '~from': row['~id'], '~to': row['~to']} rating_edges_df = pd.DataFrame.from_dict(dict, "index") # Remove the from and to columns and write this out as a vertex now @@ -259,7 +269,7 @@ def __process_users(self): # User Vertices - Load, rename column with type, and save user_df = pd.read_csv(os.path.join( - self.raw_directory, 'ml-100k/u.user'), sep='|', encoding='ISO-8859-1', + self.raw_directory, 'ml-100k/u.user'), sep='|', encoding='ISO-8859-1', names=['~id', 'age:Int', 'gender', 'occupation', 'zip_code']) user_df['~id'] = user_df['~id'].apply( lambda x: f'user_{x}') @@ -372,12 +382,12 @@ def __create_model(self, name: str, model_s3_location: str): return name def __get_neptune_ml_role(self): - with open('/home/ec2-user/.bashrc') as f: + with open(f'{HOME_DIRECTORY}/.bashrc') as f: data = f.readlines() for d in data: if str.startswith(d, 'export NEPTUNE_ML_ROLE_ARN'): parts = d.split('=') - if len(parts)==2: + if len(parts) == 2: return parts[1].rstrip() logging.error("Unable to determine the Neptune ML IAM Role.") return None diff --git a/src/graph_notebook/request_param_generator/call_and_get_response.py b/src/graph_notebook/request_param_generator/call_and_get_response.py deleted file mode 100644 index f0bc3e84..00000000 --- a/src/graph_notebook/request_param_generator/call_and_get_response.py +++ /dev/null @@ -1,32 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import requests - - -def call_and_get_response(method: str, action: str, host: str, port: str, request_param_generator, use_ssl: bool, - query='', extra_headers=None): - if extra_headers is None: - extra_headers = {} - - method = method.upper() - protocol = 'https' if use_ssl else 'http' - - request_params = request_param_generator.generate_request_params(method=method, action=action, query=query, - host=host, port=port, protocol=protocol, - headers=extra_headers) - headers = request_params['headers'] if request_params['headers'] is not None else {} - - if method == 'GET': - res = requests.get(url=request_params['url'], params=request_params['params'], headers=headers) - elif method == 'DELETE': - res = requests.delete(url=request_params['url'], params=request_params['params'], headers=headers) - elif method == 'POST': - res = requests.post(url=request_params['url'], data=request_params['params'], headers=headers) - else: - raise NotImplementedError(f'Use of method {method} has not been implemented in call_and_get_response') - - res.raise_for_status() - return res diff --git a/src/graph_notebook/request_param_generator/default_request_generator.py b/src/graph_notebook/request_param_generator/default_request_generator.py deleted file mode 100644 index 0fb2b8cf..00000000 --- a/src/graph_notebook/request_param_generator/default_request_generator.py +++ /dev/null @@ -1,18 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - - -class DefaultRequestGenerator(object): - @staticmethod - def generate_request_params(method, action, query, host, port, protocol, headers=None): - url = f'{protocol}://{host}:{port}/{action}' if port != '' else f'{protocol}://{host}/{action}' - params = { - 'method': method, - 'url': url, - 'headers': headers, - 'params': query, - } - - return params diff --git a/src/graph_notebook/request_param_generator/factory.py b/src/graph_notebook/request_param_generator/factory.py deleted file mode 100644 index 2ad79dae..00000000 --- a/src/graph_notebook/request_param_generator/factory.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.request_param_generator.default_request_generator import DefaultRequestGenerator -from graph_notebook.request_param_generator.iam_request_generator import IamRequestGenerator -from graph_notebook.request_param_generator.sparql_request_generator import SPARQLRequestGenerator -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider, credentials_provider_factory - - -def create_request_generator(mode: AuthModeEnum, - credentials_provider_mode: IAMAuthCredentialsProvider = IAMAuthCredentialsProvider.ROLE, - command: str = ''): - - if mode == AuthModeEnum.DEFAULT and command == 'sparql': - return SPARQLRequestGenerator() - elif mode == AuthModeEnum.IAM: - credentials_provider_mode = credentials_provider_factory(credentials_provider_mode) - return IamRequestGenerator(credentials_provider_mode) - else: - return DefaultRequestGenerator() diff --git a/src/graph_notebook/request_param_generator/iam_request_generator.py b/src/graph_notebook/request_param_generator/iam_request_generator.py deleted file mode 100644 index fc88e809..00000000 --- a/src/graph_notebook/request_param_generator/iam_request_generator.py +++ /dev/null @@ -1,21 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from graph_notebook.authentication.iam_headers import make_signed_request - - -class IamRequestGenerator(object): - def __init__(self, credentials_provider): - self.credentials_provider = credentials_provider - - def generate_request_params(self, method, action, query, host, port, protocol, headers=None): - credentials = self.credentials_provider.get_iam_credentials() - if protocol in ['https', 'wss']: - use_ssl = True - else: - use_ssl = False - - return make_signed_request(method, action, query, host, port, credentials.key, credentials.secret, - credentials.region, use_ssl, credentials.token, additional_headers=headers) diff --git a/src/graph_notebook/request_param_generator/sparql_request_generator.py b/src/graph_notebook/request_param_generator/sparql_request_generator.py deleted file mode 100644 index f63ae479..00000000 --- a/src/graph_notebook/request_param_generator/sparql_request_generator.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - - -class SPARQLRequestGenerator(object): - @staticmethod - def generate_request_params(method, action, query, host, port, protocol, headers=None): - if headers is None: - headers = {} - - if 'Content-Type' not in headers: - headers['Content-Type'] = "application/x-www-form-urlencoded" - - url = f'{protocol}://{host}:{port}/{action}' - return { - 'method': method, - 'url': url, - 'headers': headers, - 'params': query, - } diff --git a/src/graph_notebook/sparql/query.py b/src/graph_notebook/sparql/query.py deleted file mode 100644 index 6d104600..00000000 --- a/src/graph_notebook/sparql/query.py +++ /dev/null @@ -1,82 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import logging - -from SPARQLWrapper import SPARQLWrapper -from graph_notebook.request_param_generator.call_and_get_response import call_and_get_response - -logging.basicConfig() -logger = logging.getLogger("sparql") - -ACTION_TO_QUERY_TYPE = { - 'sparql': 'application/sparql-query', - 'sparqlupdate': 'application/sparql-update' -} - -SPARQL_ACTION = 'sparql' - - -def get_query_type(query): - s = SPARQLWrapper('') - s.setQuery(query) - return s.queryType - - -def query_type_to_action(query_type): - query_type = query_type.upper() - if query_type in ['SELECT', 'CONSTRUCT', 'ASK', 'DESCRIBE']: - return 'sparql' - else: - # TODO: check explicitly for all query types, raise exception for invalid query - return 'sparqlupdate' - - -def do_sparql_query(query, host, port, use_ssl, request_param_generator, extra_headers=None, path: str = SPARQL_ACTION): - path = SPARQL_ACTION if path == '' else path - - if extra_headers is None: - extra_headers = {} - logger.debug(f'query={query}, endpoint={host}, port={port}') - query_type = get_query_type(query) - action = query_type_to_action(query_type) - - data = {} - if action == 'sparql': - data['query'] = query - elif action == 'sparqlupdate': - data['update'] = query - - res = call_and_get_response('post', path, host, port, request_param_generator, use_ssl, data, extra_headers) - try: - content = res.json() # attempt to return json, otherwise we will return the content string. - except Exception: - content = res.content.decode('utf-8') - return content - - -def do_sparql_explain(query: str, host: str, port: str, use_ssl: bool, request_param_generator, - accept_type='text/html', path: str = ''): - path = SPARQL_ACTION if path == '' else path - - query_type = get_query_type(query) - action = query_type_to_action(query_type) - - data = { - 'explain': 'dynamic', - } - - if action == 'sparql': - data['query'] = query - elif action == 'sparqlupdate': - data['update'] = query - - extra_headers = { - 'Accept': accept_type - } - - res = call_and_get_response('post', path, host, port, request_param_generator, use_ssl, data, - extra_headers) - return res.content.decode('utf-8') diff --git a/src/graph_notebook/sparql/status.py b/src/graph_notebook/sparql/status.py deleted file mode 100644 index 737d61b7..00000000 --- a/src/graph_notebook/sparql/status.py +++ /dev/null @@ -1,49 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from graph_notebook.request_param_generator.call_and_get_response import call_and_get_response - -SPARQL_STATUS_ACTION = 'sparql/status' - - -def do_sparql_status(host, port, use_ssl, request_param_generator, query_id=None): - data = {} - if query_id != '' and query_id is not None: - data['queryId'] = query_id - - headers = { - 'Content-Type': 'application/x-www-form-urlencoded' - } - res = call_and_get_response('post', SPARQL_STATUS_ACTION, host, port, request_param_generator, use_ssl, data, - headers) - try: - content = res.json() # attempt to return json, otherwise we will return the content string. - except Exception: - """When a invalid UUID is supplied, status servlet returns an empty string. - See https://sim.amazon.com/issues/NEPTUNE-16137 - """ - content = 'UUID is invalid.' - return content - - -def do_sparql_cancel(host, port, use_ssl, request_param_generator, query_id, silent=False): - if type(query_id) is not str or query_id == '': - raise ValueError("query id must be a non-empty string") - - data = {'cancelQuery': True, 'queryId': query_id, 'silent': silent} - - headers = { - 'Content-Type': 'application/x-www-form-urlencoded' - } - res = call_and_get_response('post', SPARQL_STATUS_ACTION, host, port, request_param_generator, use_ssl, data, - headers) - try: - content = res.json() - except Exception: - """When a invalid UUID is supplied, status servlet returns an empty string. - See https://sim.amazon.com/issues/NEPTUNE-16137 - """ - content = 'UUID is invalid.' - return content diff --git a/src/graph_notebook/status/get_status.py b/src/graph_notebook/status/get_status.py deleted file mode 100644 index bd86128c..00000000 --- a/src/graph_notebook/status/get_status.py +++ /dev/null @@ -1,17 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" -from json import JSONDecodeError - -from graph_notebook.request_param_generator.call_and_get_response import call_and_get_response -from graph_notebook.request_param_generator.default_request_generator import DefaultRequestGenerator - - -def get_status(host, port, use_ssl, request_param_generator=DefaultRequestGenerator()): - res = call_and_get_response('get', 'status', host, port, request_param_generator, use_ssl) - try: - js = res.json() - except JSONDecodeError: - js = res.content - return js diff --git a/src/graph_notebook/system/database_reset.py b/src/graph_notebook/system/database_reset.py deleted file mode 100644 index 7e6fb46e..00000000 --- a/src/graph_notebook/system/database_reset.py +++ /dev/null @@ -1,21 +0,0 @@ -from graph_notebook.request_param_generator.call_and_get_response import call_and_get_response -from graph_notebook.request_param_generator.default_request_generator import DefaultRequestGenerator - -SYSTEM_ACTION = 'system' - - -def initiate_database_reset(host, port, use_ssl, request_param_generator=DefaultRequestGenerator()): - data = { - 'action': 'initiateDatabaseReset' - } - res = call_and_get_response('post', SYSTEM_ACTION, host, port, request_param_generator, use_ssl, data) - return res.json() - - -def perform_database_reset(token, host, port, use_ssl, request_param_generator=DefaultRequestGenerator()): - data = { - 'action': 'performDatabaseReset', - 'token': token - } - res = call_and_get_response('post', SYSTEM_ACTION, host, port, request_param_generator, use_ssl, data) - return res.json() diff --git a/src/graph_notebook/sparql/table.py b/src/graph_notebook/visualization/sparql_rows_and_columns.py similarity index 89% rename from src/graph_notebook/sparql/table.py rename to src/graph_notebook/visualization/sparql_rows_and_columns.py index 0029a7ad..14ac1548 100644 --- a/src/graph_notebook/sparql/table.py +++ b/src/graph_notebook/visualization/sparql_rows_and_columns.py @@ -3,11 +3,13 @@ SPDX-License-Identifier: Apache-2.0 """ + def get_rows_and_columns(sparql_results): if type(sparql_results) is not dict: return None - if 'head' in sparql_results and 'vars' in sparql_results['head'] and 'results' in sparql_results and 'bindings' in sparql_results['results']: + if 'head' in sparql_results and 'vars' in sparql_results['head'] and 'results' in sparql_results and 'bindings' in \ + sparql_results['results']: columns = [] for v in sparql_results['head']['vars']: columns.append(v) diff --git a/test/integration/DataDrivenGremlinTest.py b/test/integration/DataDrivenGremlinTest.py index 2a794091..3ada56e5 100644 --- a/test/integration/DataDrivenGremlinTest.py +++ b/test/integration/DataDrivenGremlinTest.py @@ -5,9 +5,7 @@ import logging -from graph_notebook.gremlin.client_provider.factory import create_client_provider from graph_notebook.seed.load_query import get_queries -from graph_notebook.gremlin.query import do_gremlin_query from test.integration import IntegrationTest @@ -16,9 +14,9 @@ class DataDrivenGremlinTest(IntegrationTest): def setUp(self): super().setUp() - self.client_provider = create_client_provider(self.auth_mode, self.iam_credentials_provider_type) + self.client = self.client_builder.build() query_check_for_airports = "g.V('3684').outE().inV().has(id, '3444')" - res = do_gremlin_query(query_check_for_airports, self.host, self.port, self.ssl, self.client_provider) + res = self.client.gremlin_query(query_check_for_airports) if len(res) < 1: logging.info('did not find final airports edge, seeding database now...') airport_queries = get_queries('gremlin', 'airports') @@ -30,7 +28,7 @@ def setUp(self): # we are deciding to try except because we do not know if the database # we are connecting to has a partially complete set of airports data or not. try: - do_gremlin_query(line, self.host, self.port, self.ssl, self.client_provider) + self.client.gremlin_query(line) except Exception as e: logging.error(f'query {q} failed due to {e}') continue diff --git a/test/integration/DataDrivenSparqlTest.py b/test/integration/DataDrivenSparqlTest.py index 5ce77861..a013c01e 100644 --- a/test/integration/DataDrivenSparqlTest.py +++ b/test/integration/DataDrivenSparqlTest.py @@ -5,11 +5,7 @@ import logging -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider from graph_notebook.seed.load_query import get_queries -from graph_notebook.request_param_generator.factory import create_request_generator -from graph_notebook.sparql.query import do_sparql_query - from test.integration import IntegrationTest logger = logging.getLogger('DataDrivenSparqlTest') @@ -17,15 +13,14 @@ class DataDrivenSparqlTest(IntegrationTest): - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.request_generator = create_request_generator(cls.auth_mode, IAMAuthCredentialsProvider.ENV) + def setUp(self) -> None: + super().setUp() airport_queries = get_queries('sparql', 'epl') for q in airport_queries: try: # we are deciding to try except because we do not know if the database we are connecting to has a partially complete set of airports data or not. - do_sparql_query(q['content'], cls.host, cls.port, cls.ssl, cls.request_generator) + res = self.client.sparql(q['content'].strip()) + print(res) except Exception as e: logger.error(f'query {q["content"]} failed due to {e}') continue diff --git a/test/integration/notebook/GraphNotebookIntegrationTest.py b/test/integration/GraphNotebookIntegrationTest.py similarity index 100% rename from test/integration/notebook/GraphNotebookIntegrationTest.py rename to test/integration/GraphNotebookIntegrationTest.py diff --git a/test/integration/IntegrationTest.py b/test/integration/IntegrationTest.py index 49a87ad9..058c3bc7 100644 --- a/test/integration/IntegrationTest.py +++ b/test/integration/IntegrationTest.py @@ -5,19 +5,34 @@ import unittest +from botocore.session import get_session + +from graph_notebook.configuration.generate_config import Configuration, AuthModeEnum from graph_notebook.configuration.get_config import get_config +from graph_notebook.neptune.client import ClientBuilder from test.integration.NeptuneIntegrationWorkflowSteps import TEST_CONFIG_PATH +def setup_client_builder(config: Configuration) -> ClientBuilder: + builder = ClientBuilder() \ + .with_host(config.host) \ + .with_port(config.port) \ + .with_region(config.aws_region) \ + .with_tls(config.ssl) \ + .with_sparql_path(config.sparql.path) + + if config.auth_mode == AuthModeEnum.IAM: + builder = builder.with_iam(get_session()) + + return builder + + class IntegrationTest(unittest.TestCase): @classmethod def setUpClass(cls): super().setUpClass() - config = get_config(TEST_CONFIG_PATH) - cls.config = config - cls.host = config.host - cls.port = config.port - cls.auth_mode = config.auth_mode - cls.ssl = config.ssl - cls.iam_credentials_provider_type = config.iam_credentials_provider_type - cls.load_from_s3_arn = config.load_from_s3_arn + cls.config = get_config(TEST_CONFIG_PATH) + cls.client_builder = setup_client_builder(cls.config) + + def setUp(self) -> None: + self.client = self.client_builder.build() diff --git a/test/integration/NeptuneIntegrationWorkflowSteps.py b/test/integration/NeptuneIntegrationWorkflowSteps.py index 5149330a..36f4ce78 100644 --- a/test/integration/NeptuneIntegrationWorkflowSteps.py +++ b/test/integration/NeptuneIntegrationWorkflowSteps.py @@ -15,12 +15,12 @@ import boto3 as boto3 import requests -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider from graph_notebook.configuration.generate_config import AuthModeEnum, Configuration SUBPARSER_CREATE_CFN = 'create-cfn-stack' SUBPARSER_DELETE_CFN = 'delete-cfn-stack' SUBPARSER_RUN_TESTS = 'run-tests' +SUBPARSER_GENERATE_CONFIG = 'generate-config' SUBPARSER_ENABLE_IAM = 'toggle-cluster-iam' sys.path.insert(0, os.path.abspath('..')) @@ -177,8 +177,8 @@ def get_stack_details_to_run(stack: dict, region: str = 'us-east-1', timeout_min ip = network_interface['PrivateIpAddresses'][0]['Association']['PublicIp'] logging.info(f'checking if ip {ip} can be used ') + url = f'https://{ip}:80/status' try: - url = f'https://{ip}:80/status' logging.info(f'checking ip address {ip}, url={url}') # hard-coded to port 80 since that's what this CFN stack uses for its load balancer requests.get(url, verify=False, timeout=5) # an exception is thrown if the host cannot be reached. @@ -268,8 +268,7 @@ def generate_config_from_stack(stack: dict, region: str, iam: bool) -> Configura file.writelines(new_lines) auth = AuthModeEnum.IAM if iam else AuthModeEnum.DEFAULT - conf = Configuration(details['endpoint'], 80, auth, IAMAuthCredentialsProvider.ENV, details['loader_arn'], - ssl=True, aws_region=region) + conf = Configuration(details['endpoint'], 80, auth, details['loader_arn'], ssl=True, aws_region=region) logging.info(f'generated configuration for test run: {conf.to_dict()}') return conf @@ -323,15 +322,12 @@ def main(): delete_parser.add_argument('--cfn-stack-name', type=str, default='') delete_parser.add_argument('--aws-region', type=str, default='us-east-1') - # sub parser for running tests - parser_run_tests = subparsers.add_parser(SUBPARSER_RUN_TESTS, - help='run tests with the pattern *_without_iam.py') - parser_run_tests.add_argument('--pattern', type=str), - parser_run_tests.add_argument('--iam', action='store_true') - parser_run_tests.add_argument('--cfn-stack-name', type=str, default='') - parser_run_tests.add_argument('--aws-region', type=str, default='us-east-1') - parser_run_tests.add_argument('--skip-config-generation', action='store_true', - help=f'skips config generation for testing, using the one found under {TEST_CONFIG_PATH}') + # sub parser generate config + config_parser = subparsers.add_parser(SUBPARSER_GENERATE_CONFIG, + help='generate test configuration from supplied cfn stack') + config_parser.add_argument('--cfn-stack-name', type=str, default='') + config_parser.add_argument('--aws-region', type=str, default='us-east-1') + config_parser.add_argument('--iam', action='store_true') args = parser.parse_args() @@ -342,20 +338,17 @@ def main(): handle_create_cfn_stack(stack_name, args.cfn_template_url, args.cfn_s3_bucket, cfn_client, args.cfn_runner_role) elif args.which == SUBPARSER_DELETE_CFN: delete_stack(args.cfn_stack_name, cfn_client) - elif args.which == SUBPARSER_RUN_TESTS: - if not args.skip_config_generation: - loop_until_stack_is_complete(args.cfn_stack_name, cfn_client) - stack = get_cfn_stack_details(args.cfn_stack_name, cfn_client) - cluster_identifier = get_neptune_identifier_from_cfn(args.cfn_stack_name, cfn_client) - set_iam_auth_on_neptune_cluster(cluster_identifier, args.iam, neptune_client) - config = generate_config_from_stack(stack, args.aws_region, args.iam) - config.write_to_file(TEST_CONFIG_PATH) - run_integration_tests(args.pattern) elif args.which == SUBPARSER_ENABLE_IAM: cluster_identifier = get_neptune_identifier_from_cfn(args.cfn_stack_name, cfn_client) set_iam_auth_on_neptune_cluster(cluster_identifier, True, neptune_client) logging.info('waiting for one minute while change is applied...') time.sleep(60) + elif args.which == SUBPARSER_GENERATE_CONFIG: + stack = get_cfn_stack_details(args.cfn_stack_name, cfn_client) + cluster_identifier = get_neptune_identifier_from_cfn(args.cfn_stack_name, cfn_client) + set_iam_auth_on_neptune_cluster(cluster_identifier, args.iam, neptune_client) + config = generate_config_from_stack(stack, args.aws_region, args.iam) + config.write_to_file(TEST_CONFIG_PATH) if __name__ == '__main__': diff --git a/test/integration/__init__.py b/test/integration/__init__.py index d0f15c29..cf393a60 100644 --- a/test/integration/__init__.py +++ b/test/integration/__init__.py @@ -6,4 +6,5 @@ from .IntegrationTest import IntegrationTest # noqa F401 from .DataDrivenGremlinTest import DataDrivenGremlinTest # noqa F401 from .DataDrivenSparqlTest import DataDrivenSparqlTest # noqa F401 +from .GraphNotebookIntegrationTest import GraphNotebookIntegrationTest # noqa F401 from .NeptuneIntegrationWorkflowSteps import TEST_CONFIG_PATH # noqa F401 diff --git a/test/integration/gremlin/client_provider/__init__.py b/test/integration/gremlin/client_provider/__init__.py deleted file mode 100644 index 9049dd04..00000000 --- a/test/integration/gremlin/client_provider/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" diff --git a/test/integration/gremlin/client_provider/client_provider_factory.py b/test/integration/gremlin/client_provider/client_provider_factory.py deleted file mode 100644 index 5b3d6b31..00000000 --- a/test/integration/gremlin/client_provider/client_provider_factory.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import unittest - -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.gremlin.client_provider.default_client import ClientProvider -from graph_notebook.gremlin.client_provider.factory import create_client_provider -from graph_notebook.gremlin.client_provider.iam_client import IamClientProvider - - -class TestClientProviderFactory(unittest.TestCase): - def test_create_default_client(self): - client_provider = create_client_provider(AuthModeEnum.DEFAULT) - self.assertEqual(ClientProvider, type(client_provider)) - - def test_create_iam_client_from_env(self): - client_provider = create_client_provider(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - self.assertEqual(IamClientProvider, type(client_provider)) diff --git a/test/integration/gremlin/gremlin_query_with_iam.py b/test/integration/gremlin/gremlin_query_with_iam.py deleted file mode 100644 index a3cac15e..00000000 --- a/test/integration/gremlin/gremlin_query_with_iam.py +++ /dev/null @@ -1,37 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.gremlin.query import do_gremlin_query, do_gremlin_explain, do_gremlin_profile -from graph_notebook.gremlin.client_provider.factory import create_client_provider -from graph_notebook.request_param_generator.factory import create_request_generator - -from test.integration import IntegrationTest - - -class TestGremlinWithIam(IntegrationTest): - def test_do_gremlin_query_with_iam(self): - client_provider = create_client_provider(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - query = 'g.V().limit(1)' - results = do_gremlin_query(query, self.host, self.port, self.ssl, client_provider) - - self.assertEqual(type(results), list) - - def test_do_gremlin_explain_with_iam(self): - query = 'g.V().limit(1)' - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - results = do_gremlin_explain(query, self.host, self.port, self.ssl, request_generator) - - self.assertEqual(type(results), dict) - self.assertTrue('explain' in results) - - def test_do_gremlin_profile_with_iam(self): - query = 'g.V().limit(1)' - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - results = do_gremlin_profile(query, self.host, self.port, self.ssl, request_generator) - - self.assertEqual(type(results), dict) - self.assertTrue('profile' in results) diff --git a/test/integration/gremlin/gremlin_query_without_iam.py b/test/integration/gremlin/gremlin_query_without_iam.py deleted file mode 100644 index 7f61d020..00000000 --- a/test/integration/gremlin/gremlin_query_without_iam.py +++ /dev/null @@ -1,35 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from graph_notebook.gremlin.client_provider.default_client import ClientProvider -from graph_notebook.gremlin.query import do_gremlin_query, do_gremlin_explain, do_gremlin_profile -from graph_notebook.request_param_generator.default_request_generator import DefaultRequestGenerator - -from test.integration import IntegrationTest - - -class TestGremlin(IntegrationTest): - def test_do_gremlin_query(self): - client_provider = ClientProvider() - query = 'g.V().limit(1)' - results = do_gremlin_query(query, self.host, self.port, self.ssl, client_provider) - - self.assertEqual(type(results), list) - - def test_do_gremlin_explain(self): - query = 'g.V().limit(1)' - request_generator = DefaultRequestGenerator() - results = do_gremlin_explain(query, self.host, self.port, self.ssl, request_generator) - - self.assertEqual(type(results), dict) - self.assertTrue('explain' in results) - - def test_do_gremlin_profile(self): - query = 'g.V().limit(1)' - request_generator = DefaultRequestGenerator() - results = do_gremlin_profile(query, self.host, self.port, self.ssl, request_generator) - - self.assertEqual(type(results), dict) - self.assertTrue('profile' in results) diff --git a/test/integration/gremlin/gremlin_status_with_iam.py b/test/integration/gremlin/gremlin_status_with_iam.py deleted file mode 100644 index 0066db90..00000000 --- a/test/integration/gremlin/gremlin_status_with_iam.py +++ /dev/null @@ -1,134 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import threading -import time -import requests -from os import cpu_count - -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider -from graph_notebook.gremlin.query import do_gremlin_query -from graph_notebook.gremlin.status import do_gremlin_status, do_gremlin_cancel -from graph_notebook.gremlin.client_provider.factory import create_client_provider -from graph_notebook.request_param_generator.factory import create_request_generator -from gremlin_python.driver.protocol import GremlinServerError - -from test.integration import DataDrivenGremlinTest - - -class TestGremlinStatusWithIam(DataDrivenGremlinTest): - def do_gremlin_query_save_results(self, query, res): - client_provider = create_client_provider(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - try: - res['result'] = do_gremlin_query(query, self.host, self.port, self.ssl, client_provider) - except GremlinServerError as exception: - res['error'] = str(exception) - - def test_do_gremlin_status_nonexistent(self): - with self.assertRaises(requests.HTTPError): - query_id = "some-guid-here" - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - try: - do_gremlin_status(self.host, self.port, self.ssl, AuthModeEnum.IAM, request_generator, query_id, False) - except requests.HTTPError as exception: - content = exception.response.json() - self.assertTrue('requestId' in content) - self.assertTrue('code' in content) - self.assertTrue('detailedMessage' in content) - self.assertEqual('InvalidParameterException', content['code']) - raise exception - - def test_do_gremlin_cancel_nonexistent(self): - with self.assertRaises(requests.HTTPError): - query_id = "some-guid-here" - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - try: - do_gremlin_cancel(self.host, self.port, self.ssl, AuthModeEnum.IAM, request_generator, query_id) - except requests.HTTPError as exception: - content = exception.response.json() - self.assertTrue('requestId' in content) - self.assertTrue('code' in content) - self.assertTrue('detailedMessage' in content) - self.assertEqual('InvalidParameterException', content['code']) - raise exception - - def test_do_gremlin_cancel_empty_query_id(self): - with self.assertRaises(ValueError): - query_id = '' - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - do_gremlin_cancel(self.host, self.port, self.ssl, AuthModeEnum.IAM, request_generator, query_id) - - def test_do_gremlin_cancel_non_str_query_id(self): - with self.assertRaises(ValueError): - query_id = 42 - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - do_gremlin_cancel(self.host, self.port, self.ssl, AuthModeEnum.IAM, request_generator, query_id) - - def test_do_gremlin_status_and_cancel(self): - query = "g.V().out().out().out().out()" - query_res = {} - gremlin_query_thread = threading.Thread(target=self.do_gremlin_query_save_results, args=(query, query_res,)) - gremlin_query_thread.start() - time.sleep(3) - - query_id = '' - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - status_res = do_gremlin_status(self.host, self.port, self.ssl, AuthModeEnum.IAM, - request_generator, query_id, False) - self.assertEqual(type(status_res), dict) - self.assertTrue('acceptedQueryCount' in status_res) - self.assertTrue('runningQueryCount' in status_res) - self.assertTrue(status_res['runningQueryCount'] == 1) - self.assertTrue('queries' in status_res) - - query_id = '' - for q in status_res['queries']: - if query in q['queryString']: - query_id = q['queryId'] - - self.assertNotEqual(query_id, '') - - cancel_res = do_gremlin_cancel(self.host, self.port, self.ssl, AuthModeEnum.IAM, request_generator, query_id) - self.assertEqual(type(cancel_res), dict) - self.assertTrue('status' in cancel_res) - self.assertTrue('payload' in cancel_res) - self.assertEqual('200 OK', cancel_res['status']) - - gremlin_query_thread.join() - self.assertFalse('result' in query_res) - self.assertTrue('error' in query_res) - self.assertTrue('code' in query_res['error']) - self.assertTrue('requestId' in query_res['error']) - self.assertTrue('detailedMessage' in query_res['error']) - self.assertTrue('TimeLimitExceededException' in query_res['error']) - - def test_do_gremlin_status_include_waiting(self): - query = "g.V().out().out().out().out()" - num_threads = 4 * cpu_count() - threads = [] - query_results = [] - for x in range(0, num_threads): - query_res = {} - gremlin_query_thread = threading.Thread(target=self.do_gremlin_query_save_results, args=(query, query_res,)) - threads.append(gremlin_query_thread) - query_results.append(query_res) - gremlin_query_thread.start() - - time.sleep(5) - - query_id = '' - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - status_res = do_gremlin_status(self.host, self.port, self.ssl, AuthModeEnum.IAM, - request_generator, query_id, True) - - self.assertEqual(type(status_res), dict) - self.assertTrue('acceptedQueryCount' in status_res) - self.assertTrue('runningQueryCount' in status_res) - self.assertTrue('queries' in status_res) - self.assertEqual(status_res['acceptedQueryCount'], len(status_res['queries'])) - - for gremlin_query_thread in threads: - gremlin_query_thread.join() diff --git a/test/integration/gremlin/gremlin_status_without_iam.py b/test/integration/gremlin/gremlin_status_without_iam.py deleted file mode 100644 index 03f6d7b1..00000000 --- a/test/integration/gremlin/gremlin_status_without_iam.py +++ /dev/null @@ -1,129 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import threading -import time -import requests -from os import cpu_count - -from gremlin_python.driver.protocol import GremlinServerError - -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.gremlin.query import do_gremlin_query -from graph_notebook.gremlin.status import do_gremlin_status, do_gremlin_cancel -from graph_notebook.request_param_generator.factory import create_request_generator - -from test.integration import DataDrivenGremlinTest - - -class TestGremlinStatusWithoutIam(DataDrivenGremlinTest): - def do_gremlin_query_save_results(self, query, res): - try: - res['result'] = do_gremlin_query(query, self.host, self.port, self.ssl, self.client_provider) - except GremlinServerError as exception: - res['error'] = str(exception) - - def test_do_gremlin_status_nonexistent(self): - with self.assertRaises(requests.HTTPError): - query_id = "ac7d5a03-00cf-4280-b464-edbcbf51ffce" - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - try: - do_gremlin_status(self.host, self.port, self.ssl, self.auth_mode, request_generator, query_id, False) - except requests.HTTPError as exception: - content = exception.response.json() - self.assertTrue('requestId' in content) - self.assertTrue('code' in content) - self.assertTrue('detailedMessage' in content) - self.assertEqual('InvalidParameterException', content['code']) - raise exception - - def test_do_gremlin_cancel_nonexistent(self): - with self.assertRaises(requests.HTTPError): - query_id = "ac7d5a03-00cf-4280-b464-edbcbf51ffce" - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - try: - do_gremlin_cancel(self.host, self.port, self.ssl, self.auth_mode, request_generator, query_id) - except requests.HTTPError as exception: - content = exception.response.json() - self.assertTrue('requestId' in content) - self.assertTrue('code' in content) - self.assertTrue('detailedMessage' in content) - self.assertEqual('InvalidParameterException', content['code']) - raise exception - - def test_do_gremlin_cancel_empty_query_id(self): - with self.assertRaises(ValueError): - query_id = '' - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - do_gremlin_cancel(self.host, self.port, self.ssl, self.auth_mode, request_generator, query_id) - - def test_do_gremlin_cancel_non_str_query_id(self): - with self.assertRaises(ValueError): - query_id = 42 - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - do_gremlin_cancel(self.host, self.port, self.ssl, self.auth_mode, request_generator, query_id) - - def test_do_gremlin_status_and_cancel(self): - query = "g.V().out().out().out().out()" - query_res = {} - gremlin_query_thread = threading.Thread(target=self.do_gremlin_query_save_results, args=(query, query_res,)) - gremlin_query_thread.start() - time.sleep(3) - - query_id = '' - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - status_res = do_gremlin_status(self.host, self.port, self.ssl, self.auth_mode, - request_generator, query_id, False) - self.assertEqual(type(status_res), dict) - self.assertTrue('acceptedQueryCount' in status_res) - self.assertTrue('runningQueryCount' in status_res) - self.assertTrue(status_res['runningQueryCount'] == 1) - self.assertTrue('queries' in status_res) - - query_id = '' - for q in status_res['queries']: - if query in q['queryString']: - query_id = q['queryId'] - - self.assertNotEqual(query_id, '') - - cancel_res = do_gremlin_cancel(self.host, self.port, self.ssl, self.auth_mode, request_generator, query_id) - self.assertEqual(type(cancel_res), dict) - self.assertTrue('status' in cancel_res) - self.assertTrue('payload' in cancel_res) - self.assertEqual('200 OK', cancel_res['status']) - - gremlin_query_thread.join() - self.assertFalse('result' in query_res) - self.assertTrue('error' in query_res) - self.assertTrue('code' in query_res['error']) - self.assertTrue('requestId' in query_res['error']) - self.assertTrue('detailedMessage' in query_res['error']) - self.assertTrue('TimeLimitExceededException' in query_res['error']) - - def test_do_gremlin_status_include_waiting(self): - query = "g.V().out().out().out().out()" - num_threads = 4 * cpu_count() - threads = [] - for x in range(0, num_threads): - gremlin_query_thread = threading.Thread(target=self.do_gremlin_query_save_results, args=(query, {})) - threads.append(gremlin_query_thread) - gremlin_query_thread.start() - - time.sleep(5) - - query_id = '' - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - status_res = do_gremlin_status(self.host, self.port, self.ssl, self.auth_mode, - request_generator, query_id, True) - - self.assertEqual(type(status_res), dict) - self.assertTrue('acceptedQueryCount' in status_res) - self.assertTrue('runningQueryCount' in status_res) - self.assertTrue('queries' in status_res) - self.assertEqual(status_res['acceptedQueryCount'], len(status_res['queries'])) - - for gremlin_query_thread in threads: - gremlin_query_thread.join() diff --git a/src/graph_notebook/gremlin/__init__.py b/test/integration/iam/__init__.py similarity index 100% rename from src/graph_notebook/gremlin/__init__.py rename to test/integration/iam/__init__.py diff --git a/src/graph_notebook/gremlin/client_provider/__init__.py b/test/integration/iam/gremlin/__init__.py similarity index 100% rename from src/graph_notebook/gremlin/client_provider/__init__.py rename to test/integration/iam/gremlin/__init__.py diff --git a/test/integration/iam/gremlin/test_gremlin_status_with_iam.py b/test/integration/iam/gremlin/test_gremlin_status_with_iam.py new file mode 100644 index 00000000..502b7c9a --- /dev/null +++ b/test/integration/iam/gremlin/test_gremlin_status_with_iam.py @@ -0,0 +1,122 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" +import concurrent.futures +import threading +import time + +import pytest +from os import cpu_count + +from botocore.session import get_session +from gremlin_python.driver.protocol import GremlinServerError + +from graph_notebook.neptune.client import Client +from test.integration import DataDrivenGremlinTest + + +def long_running_gremlin_query(c: Client, query: str): + with pytest.raises(GremlinServerError): + c.gremlin_query(query) + return + + +class TestGremlinStatusWithIam(DataDrivenGremlinTest): + def setUp(self) -> None: + super().setUp() + if not self.client.iam_enabled: + self.client = self.client_builder.with_iam(get_session()).build() + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_gremlin_status_nonexistent(self): + query_id = "some-guid-here" + res = self.client.gremlin_status(query_id) + assert res.status_code == 400 + js = res.json() + assert js['code'] == 'InvalidParameterException' + assert js['detailedMessage'] == f'Supplied queryId {query_id} is invalid' + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_gremlin_cancel_nonexistent(self): + query_id = "some-guid-here" + res = self.client.gremlin_cancel(query_id) + assert res.status_code == 400 + js = res.json() + assert js['code'] == 'InvalidParameterException' + assert js['detailedMessage'] == f'Supplied queryId {query_id} is invalid' + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_gremlin_cancel_empty_query_id(self): + with self.assertRaises(ValueError): + self.client.gremlin_cancel('') + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_gremlin_cancel_non_str_query_id(self): + with self.assertRaises(ValueError): + self.client.gremlin_cancel(42) + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_gremlin_status_and_cancel(self): + long_running_query = "g.V().out().out().out().out().out().out().out().out()" + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(long_running_gremlin_query, self.client, long_running_query) + + time.sleep(1) + status_res = self.client.gremlin_status() + assert status_res.status_code == 200 + + status_js = status_res.json() + query_id = '' + for q in status_js['queries']: + if q['queryString'] == long_running_query: + query_id = q['queryId'] + + assert query_id != '' + + cancel_res = self.client.gremlin_cancel(query_id) + assert cancel_res.status_code == 200 + assert cancel_res.json()['status'] == '200 OK' + + time.sleep(1) + status_after_cancel = self.client.gremlin_status(query_id) + assert status_after_cancel.status_code == 400 # check that the query is no longer valid + assert status_after_cancel.json()['code'] == 'InvalidParameterException' + + future.result() + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_gremlin_status_include_waiting(self): + query = "g.V().out().out().out().out()" + num_threads = cpu_count() * 4 + threads = [] + for x in range(0, num_threads): + thread = threading.Thread(target=long_running_gremlin_query, args=(self.client, query)) + thread.start() + threads.append(thread) + + time.sleep(5) + + res = self.client.gremlin_status(include_waiting=True) + assert res.status_code == 200 + status_res = res.json() + + self.assertEqual(type(status_res), dict) + self.assertTrue('acceptedQueryCount' in status_res) + self.assertTrue('runningQueryCount' in status_res) + self.assertTrue('queries' in status_res) + self.assertEqual(status_res['acceptedQueryCount'], len(status_res['queries'])) + + for q in status_res['queries']: + # cancel all the queries we executed since they can take a very long time. + if q['queryString'] == query: + self.client.gremlin_cancel(q['queryId']) + + for t in threads: + t.join() diff --git a/test/integration/iam/gremlin/test_gremlin_with_iam.py b/test/integration/iam/gremlin/test_gremlin_with_iam.py new file mode 100644 index 00000000..375e2b7d --- /dev/null +++ b/test/integration/iam/gremlin/test_gremlin_with_iam.py @@ -0,0 +1,55 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" +import pytest +from botocore.session import get_session +from gremlin_python.structure.graph import Vertex + +from test.integration import IntegrationTest + + +class TestGremlinWithIam(IntegrationTest): + def setUp(self) -> None: + self.client = self.client_builder.with_iam(get_session()).build() + + @pytest.mark.iam + @pytest.mark.gremlin + def test_do_gremlin_query_with_iam(self): + query = 'g.V().limit(1)' + results = self.client.gremlin_query(query) + assert type(results) is list + for r in results: + assert type(r) is Vertex + + @pytest.mark.iam + @pytest.mark.gremlin + def test_do_gremlin_explain_with_iam(self): + query = 'g.V().limit(1)' + res = self.client.gremlin_explain(query) + assert res.status_code == 200 + results = res.content.decode('utf-8') + self.assertTrue('Explain' in results) + + @pytest.mark.iam + @pytest.mark.gremlin + def test_do_gremlin_profile_with_iam(self): + query = 'g.V().limit(1)' + res = self.client.gremlin_profile(query) + assert res.status_code == 200 + + results = res.content.decode('utf-8') + self.assertTrue('Profile' in results) + + @pytest.mark.iam + @pytest.mark.gremlin + def test_iam_gremlin_http_query(self): + query = 'g.V().limit(1)' + res = self.client.gremlin_http_query(query) + assert res.status_code == 200 + assert 'result' in res.json() + + def test_iam_gremlin_connection(self): + conn = self.client.get_gremlin_connection() + conn.submit('g.V().limit(1)') + print('here') diff --git a/src/graph_notebook/system/__init__.py b/test/integration/iam/load/__init__.py similarity index 100% rename from src/graph_notebook/system/__init__.py rename to test/integration/iam/load/__init__.py diff --git a/test/integration/iam/load/test_load_with_iam.py b/test/integration/iam/load/test_load_with_iam.py new file mode 100644 index 00000000..102cee66 --- /dev/null +++ b/test/integration/iam/load/test_load_with_iam.py @@ -0,0 +1,60 @@ +import time + +import pytest +import unittest + +from botocore.session import get_session + +from test.integration import IntegrationTest + +TEST_BULKLOAD_SOURCE = 's3://aws-ml-customer-samples-%s/bulkload-datasets/%s/airroutes/v01' + + +@unittest.skip +class TestLoadWithIAM(IntegrationTest): + def setUp(self) -> None: + assert self.config.load_from_s3_arn != '' + self.client = self.client_builder.with_iam(get_session()).build() + + @pytest.mark.neptune + def test_iam_load(self): + load_format = 'turtle' + source = TEST_BULKLOAD_SOURCE % (self.config.aws_region, 'turtle') + + # for a full list of options, see https://docs.aws.amazon.com/neptune/latest/userguide/bulk-load-data.html + kwargs = { + 'failOnError': "TRUE", + } + res = self.client.load(source, load_format, self.config.load_from_s3_arn, **kwargs) + assert res.status_code == 200 + + load_js = res.json() + assert 'loadId' in load_js['payload'] + load_id = load_js['payload']['loadId'] + + time.sleep(1) # brief wait to ensure the load job can be obtained + + res = self.client.load_status(load_id, details="TRUE") + assert res.status_code == 200 + + load_status = res.json() + assert 'overallStatus' in load_status['payload'] + status = load_status['payload']['overallStatus'] + assert status['fullUri'] == source + + res = self.client.cancel_load(load_id) + assert res.status_code == 200 + + time.sleep(5) + res = self.client.load_status(load_id, details="TRUE") + cancelled_status = res.json() + assert 'LOAD_CANCELLED_BY_USER' in cancelled_status['payload']['feedCount'][-1] + + @pytest.mark.neptune + def test_iam_load_status(self): + res = self.client.load_status() # This should only give a list of load ids + assert res.status_code == 200 + + js = res.json() + assert 'loadIds' in js['payload'] + assert len(js['payload'].keys()) == 1 diff --git a/test/integration/iam/ml/__init__.py b/test/integration/iam/ml/__init__.py new file mode 100644 index 00000000..70385e35 --- /dev/null +++ b/test/integration/iam/ml/__init__.py @@ -0,0 +1,27 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" + +from botocore.session import get_session + +from graph_notebook.configuration.generate_config import Configuration +from graph_notebook.neptune.client import Client, ClientBuilder + + +def setup_iam_client(config: Configuration) -> Client: + client = ClientBuilder() \ + .with_host(config.host) \ + .with_port(config.port) \ + .with_region(config.aws_region) \ + .with_tls(config.ssl) \ + .with_sparql_path(config.sparql.path) \ + .with_iam(get_session()) \ + .build() + + assert client.host == config.host + assert client.port == config.port + assert client.region == config.aws_region + assert client.sparql_path == config.sparql.path + assert client.ssl is config.ssl + return client diff --git a/test/integration/iam/ml/test_neptune_client_with_iam.py b/test/integration/iam/ml/test_neptune_client_with_iam.py new file mode 100644 index 00000000..d5af29ce --- /dev/null +++ b/test/integration/iam/ml/test_neptune_client_with_iam.py @@ -0,0 +1,14 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" + +from graph_notebook.configuration.generate_config import Configuration +from graph_notebook.neptune.client import Client + +client: Client +config: Configuration + +TEST_BULKLOAD_SOURCE = 's3://aws-ml-customer-samples-%s/bulkload-datasets/%s/airroutes/v01' +GREMLIN_TEST_LABEL = 'graph-notebook-test' +SPARQL_TEST_PREDICATE = '