diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index de3fb15e..b0b0f40d 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -81,37 +81,27 @@ jobs: - name: Give user /etc/hosts permission run: | sudo chmod 777 /etc/hosts - - name: Run Basic Tests + - name: Generate test configuration run: | - python test/integration/NeptuneIntegrationWorkflowSteps.py \ - run-tests \ - --pattern "*without_iam.py" \ + python test/integration/NeptuneIntegrationWorkflowSteps.py generate-config \ --cfn-stack-name ${{ needs.generate-stack-name.outputs.stack-name }} \ --aws-region ${{ secrets.AWS_REGION }} - - name: Run Networkx Tests + - name: Run Basic Tests + env: + GRAPH_NOTEBOK_CONFIG: /tmp/graph_notebook_config_integration_test.json run: | - python test/integration/NeptuneIntegrationWorkflowSteps.py \ - run-tests \ - --pattern "*network*.py" \ - --cfn-stack-name ${{ needs.generate-stack-name.outputs.stack-name }} \ - --aws-region ${{ secrets.AWS_REGION }} - - name: Run Notebook Tests + pytest test/integration/without_iam + - name: Generate iam test configuration run: | - python test/integration/NeptuneIntegrationWorkflowSteps.py \ - run-tests \ - --pattern "*graph_notebook.py" \ + python test/integration/NeptuneIntegrationWorkflowSteps.py generate-config \ --cfn-stack-name ${{ needs.generate-stack-name.outputs.stack-name }} \ - --aws-region ${{ secrets.AWS_REGION }} + --aws-region ${{ secrets.AWS_REGION }} \ + --iam - name: Run IAM Tests env: GRAPH_NOTEBOK_CONFIG: /tmp/graph_notebook_config_integration_test.json run: | - python test/integration/NeptuneIntegrationWorkflowSteps.py \ - run-tests \ - --pattern "*with_iam.py" \ - --iam \ - --cfn-stack-name ${{ needs.generate-stack-name.outputs.stack-name }} \ - --aws-region ${{ secrets.AWS_REGION }} + pytest test/integration/iam - name: Cleanup run: | python test/integration/NeptuneIntegrationWorkflowSteps.py \ diff --git a/.github/workflows/unit.yml b/.github/workflows/unit.yml index 8f7b5e7b..5f044413 100644 --- a/.github/workflows/unit.yml +++ b/.github/workflows/unit.yml @@ -35,4 +35,4 @@ jobs: python -m graph_notebook.notebooks.install - name: Test with pytest run: | - pytest \ No newline at end of file + pytest test/unit \ No newline at end of file diff --git a/ChangeLog.md b/ChangeLog.md index 524a81ad..68d89427 100644 --- a/ChangeLog.md +++ b/ChangeLog.md @@ -4,14 +4,15 @@ Starting with v1.31.6, this file will contain a record of major features and upd ## Upcoming -- Add support for Mode, queueRequest, and Dependencies parameters when running %load command -- Add support for list and dict as map keys in Python Gremlin +- Add support for Mode, queueRequest, and Dependencies parameters when running %load command ([Link to PR](https://github.com/aws/graph-notebook/pull/91)) +- Add support for list and dict as map keys in Python Gremlin ([Link to PR](https://github.com/aws/graph-notebook/pull/100)) +- Refactor modules that call to Neptune or other SPARQL/Gremlin endpoints to use a unified client object ([Link to PR](https://github.com/aws/graph-notebook/pull/104)) ## Release 2.0.12 (Mar 25, 2021) - - Add default parameters for `get_load_status` - - Add ipython as a dependency in `setup.py` ([Link to PT](https://github.com/aws/graph-notebook/pull/95)) - - Add parameters in `load_status` for `details`, `errors`, `page`, and `errorsPerPage` + - Add default parameters for `get_load_status` ([Link to PR](https://github.com/aws/graph-notebook/pull/96)) + - Add ipython as a dependency in `setup.py` ([Link to PR](https://github.com/aws/graph-notebook/pull/95)) + - Add parameters in `load_status` for `details`, `errors`, `page`, and `errorsPerPage` ([Link to PR](https://github.com/aws/graph-notebook/pull/88)) ## Release 2.0.10 (Mar 18, 2021) diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..4fee5c5f --- /dev/null +++ b/pytest.ini @@ -0,0 +1,11 @@ +[pytest] +markers = + neptune: tests which have to run against neptune + iam: tests which require iam authentication + gremlin: tests which run against a gremlin endpoint + sparql: tests which run against SPARQL1.1 endpoint + neptuneml: tests which run Neptune ML workloads + jupyter: tests which run against ipython/jupyter frameworks + reset: test which performs a fast reset against Neptune, running this will wipe your database! + + diff --git a/requirements.txt b/requirements.txt index 04e0b7e6..1852a511 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,8 +9,9 @@ jupyter-contrib-nbextensions widgetsnbextension gremlinpython requests==2.24.0 +ipython==7.16.1 # requirements for testing boto3==1.15.15 botocore==1.18.18 -ipython==7.16.1 \ No newline at end of file +pytest==6.2.2 \ No newline at end of file diff --git a/setup.py b/setup.py index f7de0895..bb08f9c5 100644 --- a/setup.py +++ b/setup.py @@ -93,4 +93,7 @@ def get_version(): 'License :: OSI Approved :: Apache Software License' ], keywords='jupyter neptune gremlin sparql', + tests_require=[ + 'pytest' + ] ) diff --git a/src/graph_notebook/authentication/iam_credentials_provider/__init__.py b/src/graph_notebook/authentication/iam_credentials_provider/__init__.py deleted file mode 100644 index fa84f3bc..00000000 --- a/src/graph_notebook/authentication/iam_credentials_provider/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" \ No newline at end of file diff --git a/src/graph_notebook/authentication/iam_credentials_provider/credentials_factory.py b/src/graph_notebook/authentication/iam_credentials_provider/credentials_factory.py deleted file mode 100644 index c849b595..00000000 --- a/src/graph_notebook/authentication/iam_credentials_provider/credentials_factory.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from enum import Enum - -from graph_notebook.authentication.iam_credentials_provider.credentials_provider import CredentialsProviderBase -from graph_notebook.authentication.iam_credentials_provider.env_credentials_provider import EnvCredentialsProvider -from graph_notebook.authentication.iam_credentials_provider.ec2_metadata_credentials_provider import MetadataCredentialsProvider - - -class IAMAuthCredentialsProvider(Enum): - ROLE = "ROLE" - ENV = "ENV" - - -def credentials_provider_factory(mode: IAMAuthCredentialsProvider) -> CredentialsProviderBase: - if mode == IAMAuthCredentialsProvider.ENV: - return EnvCredentialsProvider() - elif mode == IAMAuthCredentialsProvider.ROLE: - return MetadataCredentialsProvider() - else: - raise NotImplementedError(f'the provided mode of {mode} has not been implemented by credentials_provider_factory') diff --git a/src/graph_notebook/authentication/iam_credentials_provider/credentials_provider.py b/src/graph_notebook/authentication/iam_credentials_provider/credentials_provider.py deleted file mode 100644 index 8e6f34cf..00000000 --- a/src/graph_notebook/authentication/iam_credentials_provider/credentials_provider.py +++ /dev/null @@ -1,20 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from abc import ABC, abstractmethod - - -class Credentials(object): - def __init__(self, key, secret, region, token=''): - self.key = key - self.secret = secret - self.token = token - self.region = region - - -class CredentialsProviderBase(ABC): - @abstractmethod - def get_iam_credentials(self) -> Credentials: - pass diff --git a/src/graph_notebook/authentication/iam_credentials_provider/ec2_metadata_credentials_provider.py b/src/graph_notebook/authentication/iam_credentials_provider/ec2_metadata_credentials_provider.py deleted file mode 100644 index 8dd9bf58..00000000 --- a/src/graph_notebook/authentication/iam_credentials_provider/ec2_metadata_credentials_provider.py +++ /dev/null @@ -1,26 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import botocore.session -import requests - - -from graph_notebook.authentication.iam_credentials_provider.credentials_provider import CredentialsProviderBase, \ - Credentials - -region_url = 'http://169.254.169.254/latest/meta-data/placement/availability-zone' - - -class MetadataCredentialsProvider(CredentialsProviderBase): - def __init__(self): - res = requests.get(region_url) - zone = res.content.decode('utf-8') - region = zone[0:len(zone) - 1] - self.region = region - - def get_iam_credentials(self) -> Credentials: - session = botocore.session.get_session() - creds = session.get_credentials() - return Credentials(key=creds.access_key, secret=creds.secret_key, token=creds.token, region=self.region) diff --git a/src/graph_notebook/authentication/iam_credentials_provider/env_credentials_provider.py b/src/graph_notebook/authentication/iam_credentials_provider/env_credentials_provider.py deleted file mode 100644 index 0ba39e9d..00000000 --- a/src/graph_notebook/authentication/iam_credentials_provider/env_credentials_provider.py +++ /dev/null @@ -1,35 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import os - -from graph_notebook.authentication.iam_credentials_provider.credentials_provider import CredentialsProviderBase, \ - Credentials - -ACCESS_ENV_KEY = 'AWS_ACCESS_KEY_ID' -SECRET_ENV_KEY = 'AWS_SECRET_ACCESS_KEY' -REGION_ENV_KEY = 'AWS_REGION' -AWS_TOKEN_ENV_KEY = 'AWS_SESSION_TOKEN' - - -class EnvCredentialsProvider(CredentialsProviderBase): - def __init__(self): - self.creds = Credentials(key='', secret='', region='', token='') - self.loaded = False - - def load_iam_credentials(self): - access_key = os.environ.get(ACCESS_ENV_KEY, '') - secret_key = os.environ.get(SECRET_ENV_KEY, '') - region = os.environ.get(REGION_ENV_KEY, '') - token = os.environ.get(AWS_TOKEN_ENV_KEY, '') - self.creds = Credentials(access_key, secret_key, region, token) - self.loaded = True - return - - def get_iam_credentials(self, service=None) -> Credentials: - if not self.loaded: - self.load_iam_credentials() - - return self.creds diff --git a/src/graph_notebook/authentication/iam_headers.py b/src/graph_notebook/authentication/iam_headers.py deleted file mode 100644 index 6c39038a..00000000 --- a/src/graph_notebook/authentication/iam_headers.py +++ /dev/null @@ -1,217 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import datetime -import hashlib -import hmac -import json -import logging -import urllib - -logging.basicConfig() -logger = logging.getLogger("graph_magic") - - -# Key derivation functions. See: -# https://docs.aws.amazon.com/general/latest/gr/signature-v4-examples.html#signature-v4-examples-python -def sign(key, msg): - return hmac.new(key, msg.encode('utf-8'), hashlib.sha256).digest() - - -def get_signature_key(key, dateStamp, regionName, serviceName): - k_date = sign(('AWS4' + key).encode('utf-8'), dateStamp) - k_region = sign(k_date, regionName) - k_service = sign(k_region, serviceName) - k_signing = sign(k_service, 'aws4_request') - return k_signing - - -def get_canonical_uri_and_payload(query_type, query): - # Set the stack and payload depending on query_type. - if query_type == 'sparql': - canonical_uri = '/sparql/' - payload = query - - elif query_type == 'sparqlupdate': - canonical_uri = '/sparql/' - payload = query - - elif query_type == 'sparql/status': - canonical_uri = '/sparql/status/' - payload = query - - elif query_type == 'gremlin': - canonical_uri = '/gremlin' - payload = {} - - elif query_type == 'gremlin/status': - canonical_uri = '/gremlin/status/' - payload = query - - elif query_type == "loader": - canonical_uri = "/loader/" - payload = query - - elif query_type == "status": - canonical_uri = "/status/" - payload = {} - - elif query_type == "gremlin/explain": - canonical_uri = "/gremlin/explain/" - payload = query - - elif query_type == "gremlin/profile": - canonical_uri = "/gremlin/profile/" - payload = query - - elif query_type == "system": - canonical_uri = "/system/" - payload = query - - elif query_type.startswith("ml"): - canonical_uri = f'/{query_type}' - payload = query - - elif query_type.startswith("ml/dataprocessing"): - canonical_uri = f'/{query_type}' - payload = query - - elif query_type.startswith("ml/endpoints"): - canonical_uri = f'/{query_type}' - payload = query - - else: - raise ValueError('query_type %s is not valid' % query_type) - - return canonical_uri, payload - - -def normalize_query_string(query): - kv = (list(map(str.strip, s.split("="))) - for s in query.split('&') - if len(s) > 0) - - normalized = '&'.join('%s=%s' % (p[0], p[1] if len(p) > 1 else '') - for p in sorted(kv)) - return normalized - - -def make_signed_request(method, query_type, query, host, port, signing_access_key, signing_secret, signing_region, - use_ssl=False, signing_token='', additional_headers=None): - if additional_headers is None: - additional_headers = [] - - signing_region = signing_region.lower() - service = 'neptune-db' - - if use_ssl: - protocol = 'https' - else: - protocol = 'http' - - # this is always http right now - endpoint = f'{protocol}://{host}:{port}' - - # get canonical_uri and payload - canonical_uri, payload = get_canonical_uri_and_payload(query_type, query) - - if 'content-type' in additional_headers and additional_headers['content-type'] == 'application/json': - request_parameters = payload if type(payload) is str else json.dumps(payload) - else: - request_parameters = urllib.parse.urlencode(payload, quote_via=urllib.parse.quote) - request_parameters = request_parameters.replace('%27', '%22') - t = datetime.datetime.utcnow() - amz_date = t.strftime('%Y%m%dT%H%M%SZ') - date_stamp = t.strftime('%Y%m%d') # Date w/o time, used in credential scope - - method = method.upper() - if method == 'GET' or method == 'DELETE': - canonical_querystring = normalize_query_string(request_parameters) - elif method == 'POST': - canonical_querystring = '' - else: - raise ValueError('method %s is not valid when creating canonical request' % method) - - # Step 4: Create the canonical headers and signed headers. Header names - # must be trimmed and lowercase, and sorted in code point order from - # low to high. Note that there is a trailing \n. - canonical_headers = f'host:{host}:{port}\nx-amz-date:{amz_date}\n' - - # Step 5: Create the list of signed headers. This lists the headers - # in the canonical_headers list, delimited with ";" and in alpha order. - # Note: The request can include any headers; canonical_headers and - # signed_headers lists those that you want to be included in the - # hash of the request. "Host" and "x-amz-date" are always required. - signed_headers = 'host;x-amz-date' - - # Step 6: Create payload hash (hash of the request body content). For GET and DELETE - # requests, the payload is an empty string (""). - if method == 'GET' or method == 'DELETE': - post_payload = '' - elif method == 'POST': - post_payload = request_parameters - else: - raise ValueError('method %s is not supported' % method) - - payload_hash = hashlib.sha256(post_payload.encode('utf-8')).hexdigest() - - # Step 7: Combine elements to create canonical request. - canonical_request = method + '\n' + canonical_uri + '\n' + canonical_querystring + '\n' + canonical_headers + '\n' + signed_headers + '\n' + payload_hash - - # ************* TASK 2: CREATE THE STRING TO SIGN************* - # Match the algorithm to the hashing algorithm you use, either SHA-1 or - # SHA-256 (recommended) - algorithm = 'AWS4-HMAC-SHA256' - credential_scope = date_stamp + '/' + signing_region + '/' + service + '/' + 'aws4_request' - string_to_sign = algorithm + '\n' + amz_date + '\n' + credential_scope + '\n' + hashlib.sha256( - canonical_request.encode('utf-8')).hexdigest() - - # ************* TASK 3: CALCULATE THE SIGNATURE ************* - # Create the signing key using the function defined above. - signing_key = get_signature_key(signing_secret, date_stamp, signing_region, service) - - # Sign the string_to_sign using the signing_key - signature = hmac.new(signing_key, string_to_sign.encode('utf-8'), hashlib.sha256).hexdigest() - - # ************* TASK 4: ADD SIGNING INFORMATION TO THE REQUEST ************* - # The signing information can be either in a query string value or in - # a header named Authorization. This code shows how to use a header. - # Create authorization header and add to request headers - authorization_header = algorithm + ' ' + 'Credential=' + signing_access_key + '/' + credential_scope + ', ' + 'SignedHeaders=' + signed_headers + ', ' + 'Signature=' + signature - - # The request can include any headers, but MUST include "host", "x-amz-date", - # and (for this scenario) "Authorization". "host" and "x-amz-date" must - # be included in the canonical_headers and signed_headers, as noted - # earlier. Order here is not significant. - # Python note: The 'host' header is added automatically by the Python 'requests' library. - if method == 'GET' or method == 'DELETE': - headers = { - 'x-amz-date': amz_date, - 'Authorization': authorization_header - } - elif method == 'POST': - headers = { - 'content-type': 'application/x-www-form-urlencoded', - 'x-amz-date': amz_date, - 'Authorization': authorization_header, - } - else: - raise ValueError('method %s is not valid while creating request headers' % method) - - if additional_headers is not None: - for key in additional_headers: - headers[key] = additional_headers[key] - - if signing_token != '': - headers['X-Amz-Security-Token'] = signing_token - - # ************* SEND THE REQUEST ************* - request_url = endpoint + canonical_uri - - return { - 'url': request_url, - 'headers': headers, - 'params': request_parameters - } \ No newline at end of file diff --git a/src/graph_notebook/configuration/generate_config.py b/src/graph_notebook/configuration/generate_config.py index 496455f0..f57aaa11 100644 --- a/src/graph_notebook/configuration/generate_config.py +++ b/src/graph_notebook/configuration/generate_config.py @@ -8,10 +8,7 @@ import os from enum import Enum -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider -from graph_notebook.sparql.query import SPARQL_ACTION - -DEFAULT_IAM_CREDENTIALS_PROVIDER = IAMAuthCredentialsProvider.ROLE +from graph_notebook.neptune.client import SPARQL_ACTION DEFAULT_CONFIG_LOCATION = os.path.expanduser('~/graph_notebook_config.json') @@ -38,6 +35,9 @@ def __init__(self, path: str = SPARQL_ACTION, endpoint_prefix: str = ''): print('endpoint_prefix has been deprecated and will be removed in version 2.0.20 or greater.') if path == '': path = f'{endpoint_prefix}/sparql' + elif path == '': + path = SPARQL_ACTION + self.path = path def to_dict(self): @@ -47,13 +47,11 @@ def to_dict(self): class Configuration(object): def __init__(self, host: str, port: int, auth_mode: AuthModeEnum = AuthModeEnum.DEFAULT, - iam_credentials_provider_type: IAMAuthCredentialsProvider = DEFAULT_IAM_CREDENTIALS_PROVIDER, load_from_s3_arn='', ssl: bool = True, aws_region: str = 'us-east-1', sparql_section: SparqlSection = None): self.host = host self.port = port self.auth_mode = auth_mode - self.iam_credentials_provider_type = iam_credentials_provider_type self.load_from_s3_arn = load_from_s3_arn self.ssl = ssl self.aws_region = aws_region @@ -64,7 +62,6 @@ def to_dict(self) -> dict: 'host': self.host, 'port': self.port, 'auth_mode': self.auth_mode.value, - 'iam_credentials_provider_type': self.iam_credentials_provider_type.value, 'load_from_s3_arn': self.load_from_s3_arn, 'ssl': self.ssl, 'aws_region': self.aws_region, @@ -79,21 +76,15 @@ def write_to_file(self, file_path=DEFAULT_CONFIG_LOCATION): return -def generate_config(host, port, auth_mode, ssl, iam_credentials_provider_type, load_from_s3_arn, aws_region): +def generate_config(host, port, auth_mode, ssl, load_from_s3_arn, aws_region): use_ssl = False if ssl in [False, 'False', 'false', 'FALSE'] else True - - if iam_credentials_provider_type not in [IAMAuthCredentialsProvider.ENV, - IAMAuthCredentialsProvider.ROLE]: - iam_credentials_provider_type = DEFAULT_IAM_CREDENTIALS_PROVIDER - - config = Configuration(host, port, auth_mode, iam_credentials_provider_type, load_from_s3_arn, use_ssl, aws_region) - return config + c = Configuration(host, port, auth_mode, load_from_s3_arn, use_ssl, aws_region) + return c def generate_default_config(): - config = generate_config('change-me', 8182, AuthModeEnum.DEFAULT, True, DEFAULT_IAM_CREDENTIALS_PROVIDER, '', - 'us-east-1') - return config + c = generate_config('change-me', 8182, AuthModeEnum.DEFAULT, True, '', 'us-east-1') + return c if __name__ == "__main__": @@ -102,6 +93,8 @@ def generate_default_config(): parser.add_argument("--port", help="the port to use when creating a connection", default="8182") parser.add_argument("--auth_mode", default=AuthModeEnum.DEFAULT.value, help="type of authentication the cluster being connected to is using. Can be DEFAULT or IAM") + + # TODO: this can now be removed. parser.add_argument("--iam_credentials_provider", default='ROLE', help="The mode used to obtain credentials for IAM Authentication. Can be ROLE or ENV") parser.add_argument("--ssl", @@ -110,14 +103,11 @@ def generate_default_config(): parser.add_argument("--config_destination", help="location to put generated config", default=DEFAULT_CONFIG_LOCATION) parser.add_argument("--load_from_s3_arn", help="arn of role to use for bulk loader", default='') - parser.add_argument("--aws_region", help="aws region your neptune cluster is in.", default='us-east-1') + parser.add_argument("--aws_region", help="aws region your ml cluster is in.", default='us-east-1') args = parser.parse_args() auth_mode_arg = args.auth_mode if args.auth_mode != '' else AuthModeEnum.DEFAULT.value - iam_credentials_provider_arg = args.iam_credentials_provider if args.iam_credentials_provider != '' else IAMAuthCredentialsProvider.ROLE.value - - config = generate_config(args.host, int(args.port), AuthModeEnum(auth_mode_arg), args.ssl, - IAMAuthCredentialsProvider(iam_credentials_provider_arg), + config = generate_config(args.host, int(args.port), AuthModeEnum(auth_mode_arg), args.ssl , args.load_from_s3_arn, args.aws_region) config.write_to_file(args.config_destination) diff --git a/src/graph_notebook/configuration/get_config.py b/src/graph_notebook/configuration/get_config.py index 35698391..72ab829a 100644 --- a/src/graph_notebook/configuration/get_config.py +++ b/src/graph_notebook/configuration/get_config.py @@ -5,7 +5,6 @@ import json -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider from graph_notebook.configuration.generate_config import DEFAULT_CONFIG_LOCATION, Configuration, AuthModeEnum, \ SparqlSection @@ -13,8 +12,6 @@ def get_config_from_dict(data: dict) -> Configuration: sparql_section = SparqlSection(**data['sparql']) if 'sparql' in data else SparqlSection('') config = Configuration(host=data['host'], port=data['port'], auth_mode=AuthModeEnum(data['auth_mode']), - iam_credentials_provider_type=IAMAuthCredentialsProvider( - data['iam_credentials_provider_type']), ssl=data['ssl'], load_from_s3_arn=data['load_from_s3_arn'], aws_region=data['aws_region'], sparql_section=sparql_section) return config diff --git a/src/graph_notebook/gremlin/client_provider/default_client.py b/src/graph_notebook/gremlin/client_provider/default_client.py deleted file mode 100644 index 0d6bd4b8..00000000 --- a/src/graph_notebook/gremlin/client_provider/default_client.py +++ /dev/null @@ -1,21 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import graph_notebook.gremlin.client_provider.graphsonV3d0_MapType_objectify_patch # noqa F401 -import logging - -from gremlin_python.driver import client - -logging.basicConfig() -logger = logging.getLogger("default_client") - - -class ClientProvider(object): - @staticmethod - def get_client(host, port, use_ssl): - protocol = 'wss' if use_ssl else 'ws' - url = f'{protocol}://{host}:{port}/gremlin' - c = client.Client(url, 'g') - return c diff --git a/src/graph_notebook/gremlin/client_provider/factory.py b/src/graph_notebook/gremlin/client_provider/factory.py deleted file mode 100644 index 0f1a459e..00000000 --- a/src/graph_notebook/gremlin/client_provider/factory.py +++ /dev/null @@ -1,21 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.gremlin.client_provider.default_client import ClientProvider -from graph_notebook.gremlin.client_provider.iam_client import IamClientProvider -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import credentials_provider_factory, \ - IAMAuthCredentialsProvider - - -def create_client_provider(mode: AuthModeEnum, - credentials_provider_mode: IAMAuthCredentialsProvider = IAMAuthCredentialsProvider.ROLE): - if mode == AuthModeEnum.DEFAULT: - return ClientProvider() - elif mode == AuthModeEnum.IAM: - credentials_provider = credentials_provider_factory(credentials_provider_mode) - return IamClientProvider(credentials_provider) - else: - raise NotImplementedError(f"invalid client mode {mode} provided") diff --git a/src/graph_notebook/gremlin/client_provider/iam_client.py b/src/graph_notebook/gremlin/client_provider/iam_client.py deleted file mode 100644 index aca981ab..00000000 --- a/src/graph_notebook/gremlin/client_provider/iam_client.py +++ /dev/null @@ -1,52 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import graph_notebook.gremlin.client_provider.graphsonV3d0_MapType_objectify_patch # noqa F401 -import hashlib -import hmac -import logging - -from gremlin_python.driver import client -from gremlin_python.driver.client import Client -from tornado import httpclient - -from graph_notebook.authentication.iam_credentials_provider.credentials_provider import CredentialsProviderBase -from graph_notebook.authentication.iam_headers import make_signed_request - -logging.basicConfig() -logger = logging.getLogger("iam_client") - - -def sign(key, msg): - return hmac.new(key, msg.encode('utf-8'), hashlib.sha256).digest() - - -def get_signature_key(key, date_stamp, region_name, service_name): - k_date = sign(('AWS4' + key).encode('utf-8'), date_stamp) - k_region = sign(k_date, region_name) - k_service = sign(k_region, service_name) - k_signing = sign(k_service, 'aws4_request') - return k_signing - - -class IamClientProvider(object): - def __init__(self, credentials_provider: CredentialsProviderBase): - self.credentials_provider = credentials_provider - - def get_client(self, host, port, use_ssl) -> Client: - credentials = self.credentials_provider.get_iam_credentials() - request_params = make_signed_request('get', 'gremlin', '', host, port, credentials.key, - credentials.secret, credentials.region, use_ssl, - credentials.token) - ws_url = request_params['url'].strip('/').replace('http', 'ws') - signed_ws_request = httpclient.HTTPRequest(ws_url, headers=request_params['headers']) - - try: - c = client.Client(signed_ws_request, 'g') - return c - # TODO: handle exception explicitly - except Exception as e: - logger.error(f'error while creating client {e}') - raise e diff --git a/src/graph_notebook/gremlin/query.py b/src/graph_notebook/gremlin/query.py deleted file mode 100644 index 00e60010..00000000 --- a/src/graph_notebook/gremlin/query.py +++ /dev/null @@ -1,57 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import logging - -from graph_notebook.gremlin.client_provider.default_client import ClientProvider -from graph_notebook.request_param_generator.default_request_generator import DefaultRequestGenerator -from graph_notebook.request_param_generator.call_and_get_response import call_and_get_response - -logging.basicConfig() -logger = logging.getLogger("gremlin") - - -def do_gremlin_query(query_str, host, port, use_ssl, client_provider=ClientProvider()): - c = client_provider.get_client(host, port, use_ssl) - - try: - result = c.submit(query_str) - future_results = result.all() - results = future_results.result() - except Exception as e: - raise e # let the upstream decide what to do with this error. - finally: - c.close() # no matter the outcome we need to close the websocket connection - - return results - - -def do_gremlin_explain(query_str, host, port, use_ssl, request_param_generator=DefaultRequestGenerator()): - data = { - 'gremlin': query_str - } - - action = 'gremlin/explain' - res = call_and_get_response('get', action, host, port, request_param_generator, use_ssl, data) - content = res.content.decode('utf-8') - result = { - 'explain': content - } - return result - - -def do_gremlin_profile(query_str, host, port, use_ssl, request_param_generator=DefaultRequestGenerator()): - data = { - 'gremlin': query_str - } - - action = 'gremlin/profile' - res = call_and_get_response('get', action, host, port, request_param_generator, use_ssl, data) - content = res.content.decode('utf-8').strip(' ') - result = { - 'profile': content - } - - return result diff --git a/src/graph_notebook/gremlin/status.py b/src/graph_notebook/gremlin/status.py deleted file mode 100644 index 4817206d..00000000 --- a/src/graph_notebook/gremlin/status.py +++ /dev/null @@ -1,44 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.request_param_generator.call_and_get_response import call_and_get_response - -GREMLIN_STATUS_ACTION = 'gremlin/status' - - -def do_gremlin_status(host, port, use_ssl, mode, request_param_generator, query_id: str, include_waiting: bool): - data = {'includeWaiting': include_waiting} - if query_id != '': - data['queryId'] = query_id - - headers = {} - if mode == AuthModeEnum.DEFAULT: - """Add correct content-type header for the request. - This is needed because call_and_get_response requires custom headers to be set. - """ - headers['Content-Type'] = 'application/x-www-form-urlencoded' - res = call_and_get_response('post', GREMLIN_STATUS_ACTION, host, port, request_param_generator, use_ssl, data, - headers) - content = res.json() - return content - - -def do_gremlin_cancel(host, port, use_ssl, mode, request_param_generator, query_id): - if type(query_id) != str or query_id == '': - raise ValueError("query id must be a non-empty string") - - data = {'cancelQuery': True, 'queryId': query_id} - - headers = {} - if mode == AuthModeEnum.DEFAULT: - """Add correct content-type header for the request. - This is needed because call_and_get_response requires custom headers to be set. - """ - headers['Content-Type'] = 'application/x-www-form-urlencoded' - res = call_and_get_response('post', GREMLIN_STATUS_ACTION, host, port, request_param_generator, use_ssl, data, - headers) - content = res.json() - return content diff --git a/src/graph_notebook/loader/load.py b/src/graph_notebook/loader/load.py deleted file mode 100644 index 3eaa8dbe..00000000 --- a/src/graph_notebook/loader/load.py +++ /dev/null @@ -1,90 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import json -from graph_notebook.request_param_generator.call_and_get_response import call_and_get_response - -FORMAT_CSV = 'csv' -FORMAT_NTRIPLE = 'ntriples' -FORMAT_NQUADS = 'nquads' -FORMAT_RDFXML = 'rdfxml' -FORMAT_TURTLE = 'turtle' - -PARALLELISM_LOW = 'LOW' -PARALLELISM_MEDIUM = 'MEDIUM' -PARALLELISM_HIGH = 'HIGH' -PARALLELISM_OVERSUBSCRIBE = 'OVERSUBSCRIBE' - -MODE_RESUME = 'RESUME' -MODE_NEW = 'NEW' -MODE_AUTO = 'AUTO' - -LOAD_JOB_MODES = [MODE_RESUME, MODE_NEW, MODE_AUTO] -VALID_FORMATS = [FORMAT_CSV, FORMAT_NTRIPLE, FORMAT_NQUADS, FORMAT_RDFXML, FORMAT_TURTLE] -PARALLELISM_OPTIONS = [PARALLELISM_LOW, PARALLELISM_MEDIUM, PARALLELISM_HIGH, PARALLELISM_OVERSUBSCRIBE] -LOADER_ACTION = 'loader' - -FINAL_LOAD_STATUSES = ['LOAD_COMPLETED', - 'LOAD_COMMITTED_W_WRITE_CONFLICTS', - 'LOAD_CANCELLED_BY_USER', - 'LOAD_CANCELLED_DUE_TO_ERRORS', - 'LOAD_FAILED', - 'LOAD_UNEXPECTED_ERROR', - 'LOAD_DATA_DEADLOCK', - 'LOAD_DATA_FAILED_DUE_TO_FEED_MODIFIED_OR_DELETED', - 'LOAD_S3_READ_ERROR', - 'LOAD_S3_ACCESS_DENIED_ERROR', - 'LOAD_IN_QUEUE', - 'LOAD_FAILED_BECAUSE_DEPENDENCY_NOT_SATISFIED', - 'LOAD_FAILED_INVALID_REQUEST', ] - - -def do_load(host, port, load_format, use_ssl, source, region, arn, fail_on_error, request_param_generator, mode="AUTO", - parallelism="HIGH", update_single_cardinality="FALSE", queue_request="FALSE", dependencies=[]): - payload = { - 'source': source, - 'format': load_format, - 'mode': mode, - 'region': region, - 'failOnError': fail_on_error, - 'parallelism': parallelism, - 'updateSingleCardinalityProperties': update_single_cardinality, - 'queueRequest': queue_request - } - - if arn != '': - payload['iamRoleArn'] = arn - - if dependencies: - payload['dependencies'] = json.dumps(dependencies) - - res = call_and_get_response('post', LOADER_ACTION, host, port, request_param_generator, use_ssl, payload) - return res.json() - - -def get_loader_jobs(host, port, use_ssl, request_param_generator): - res = call_and_get_response('get', LOADER_ACTION, host, port, request_param_generator, use_ssl) - return res.json() - - -def get_load_status(host, port, use_ssl, request_param_generator, id, loader_details="FALSE", loader_errors="FALSE", loader_page=1, loader_epp=10): - payload = { - 'loadId': id, - 'details': loader_details, - 'errors': loader_errors, - 'page': loader_page, - 'errorsPerPage': loader_epp - } - res = call_and_get_response('get', LOADER_ACTION, host, port, request_param_generator, use_ssl, payload) - return res.json() - - -def cancel_load(host, port, use_ssl, request_param_generator, load_id): - payload = { - 'loadId': load_id - } - - res = call_and_get_response('get', LOADER_ACTION, host, port, request_param_generator, use_ssl, payload) - return res.status_code == 200 diff --git a/src/graph_notebook/magics/graph_magic.py b/src/graph_notebook/magics/graph_magic.py index 90b30a16..12b8a9be 100644 --- a/src/graph_notebook/magics/graph_magic.py +++ b/src/graph_notebook/magics/graph_magic.py @@ -14,6 +14,8 @@ from enum import Enum import ipywidgets as widgets +from SPARQLWrapper import SPARQLWrapper +from botocore.session import get_session from gremlin_python.driver.protocol import GremlinServerError from IPython.core.display import HTML, display_html, display from IPython.core.magic import (Magics, magics_class, cell_magic, line_magic, line_cell_magic, needs_local_scope) @@ -21,25 +23,18 @@ from requests import HTTPError import graph_notebook -from graph_notebook.configuration.generate_config import generate_default_config, DEFAULT_CONFIG_LOCATION +from graph_notebook.configuration.generate_config import generate_default_config, DEFAULT_CONFIG_LOCATION, AuthModeEnum, \ + Configuration from graph_notebook.decorators.decorators import display_exceptions from graph_notebook.magics.ml import neptune_ml_magic_handler, generate_neptune_ml_parser +from graph_notebook.neptune.client import ClientBuilder, Client, VALID_FORMATS, PARALLELISM_OPTIONS, PARALLELISM_HIGH, \ + LOAD_JOB_MODES, MODE_AUTO, FINAL_LOAD_STATUSES, SPARQL_ACTION from graph_notebook.network import SPARQLNetwork from graph_notebook.network.gremlin.GremlinNetwork import parse_pattern_list_str, GremlinNetwork -from graph_notebook.sparql.table import get_rows_and_columns -from graph_notebook.gremlin.query import do_gremlin_query, do_gremlin_explain, do_gremlin_profile -from graph_notebook.gremlin.status import do_gremlin_status, do_gremlin_cancel -from graph_notebook.sparql.query import get_query_type, do_sparql_query, do_sparql_explain, SPARQL_ACTION -from graph_notebook.sparql.status import do_sparql_status, do_sparql_cancel -from graph_notebook.system.database_reset import perform_database_reset, initiate_database_reset +from graph_notebook.visualization.sparql_rows_and_columns import get_rows_and_columns from graph_notebook.visualization.template_retriever import retrieve_template -from graph_notebook.gremlin.client_provider.factory import create_client_provider -from graph_notebook.request_param_generator.factory import create_request_generator -from graph_notebook.loader.load import do_load, get_loader_jobs, get_load_status, cancel_load, VALID_FORMATS, \ - PARALLELISM_OPTIONS, PARALLELISM_HIGH, FINAL_LOAD_STATUSES, LOAD_JOB_MODES, MODE_AUTO from graph_notebook.configuration.get_config import get_config, get_config_from_dict from graph_notebook.seed.load_query import get_data_sets, get_queries -from graph_notebook.status.get_status import get_status from graph_notebook.widgets import Force from graph_notebook.options import OPTIONS_DEFAULT_DIRECTED, vis_options_merge @@ -91,6 +86,28 @@ def str_to_query_mode(s: str) -> QueryMode: return QueryMode.DEFAULT +ACTION_TO_QUERY_TYPE = { + 'sparql': 'application/sparql-query', + 'sparqlupdate': 'application/sparql-update' +} + + +def get_query_type(query): + s = SPARQLWrapper('') + s.setQuery(query) + return s.queryType + + +def query_type_to_action(query_type): + query_type = query_type.upper() + if query_type in ['SELECT', 'CONSTRUCT', 'ASK', 'DESCRIBE']: + return 'sparql' + else: + # TODO: check explicitly for all query types, raise exception for invalid query + return 'sparqlupdate' + + +# TODO: refactor large magic commands into their own modules like what we do with %neptune_ml # noinspection PyTypeChecker @magics_class class Graph(Magics): @@ -98,17 +115,35 @@ def __init__(self, shell): # You must call the parent constructor super(Graph, self).__init__(shell) + self.graph_notebook_config = generate_default_config() try: self.config_location = os.getenv('GRAPH_NOTEBOOK_CONFIG', DEFAULT_CONFIG_LOCATION) + self.client: Client = None self.graph_notebook_config = get_config(self.config_location) except FileNotFoundError: - self.graph_notebook_config = generate_default_config() print( 'Could not find a valid configuration. Do not forgot to validate your settings using %graph_notebook_config') + self.max_results = DEFAULT_MAX_RESULTS self.graph_notebook_vis_options = OPTIONS_DEFAULT_DIRECTED + self._generate_client_from_config(self.graph_notebook_config) logger.setLevel(logging.ERROR) + def _generate_client_from_config(self, config: Configuration): + if self.client: + self.client.close() + + builder = ClientBuilder() \ + .with_host(config.host) \ + .with_port(config.port) \ + .with_region(config.aws_region) \ + .with_tls(config.ssl) \ + .with_sparql_path(config.sparql.path) + if config.auth_mode == AuthModeEnum.IAM: + builder = builder.with_iam(get_session()) + + self.client = builder.build() + @line_cell_magic @display_exceptions def graph_notebook_config(self, line='', cell=''): @@ -116,6 +151,7 @@ def graph_notebook_config(self, line='', cell=''): data = json.loads(cell) config = get_config_from_dict(data) self.graph_notebook_config = config + self._generate_client_from_config(config) print('set notebook config to:') print(json.dumps(self.graph_notebook_config.to_dict(), indent=2)) elif line == 'reset': @@ -143,6 +179,7 @@ def graph_notebook_host(self, line): # TODO: we should attempt to make a status call to this host before we set the config to this value. self.graph_notebook_config.host = line + self._generate_client_from_config(self.graph_notebook_config) print(f'set host to {line}') @cell_magic @@ -155,26 +192,24 @@ def sparql(self, line='', cell='', local_ns: dict = None): parser.add_argument('--path', '-p', default='', help='prefix path to sparql endpoint. For example, if "foo/bar" were specified, the endpoint called would be host:port/foo/bar') parser.add_argument('--expand-all', action='store_true') - - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, - self.graph_notebook_config.iam_credentials_provider_type, - command='sparql') + parser.add_argument('--explain-type', default='dynamic', + help='explain mode to use when using the explain query mode', + choices=['dynamic', 'static', 'details']) + parser.add_argument('--explain-format', default='text/html', help='response format for explain query mode', + choices=['text/csv', 'text/html']) parser.add_argument('--store-to', type=str, default='', help='store query result to this variable') args = parser.parse_args(line.split()) mode = str_to_query_mode(args.query_mode) tab = widgets.Tab() path = args.path if args.path != '' else self.graph_notebook_config.sparql.path - logger.debug(f'using mode={mode}') if mode == QueryMode.EXPLAIN: - res = do_sparql_explain(cell, self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator, path=path) - store_to_ns(args.store_to, res, local_ns) - if 'error' in res: - html = error_template.render(error=json.dumps(res['error'], indent=2)) - else: - html = sparql_explain_template.render(table=res) + res = self.client.sparql_explain(cell, args.explain_type, args.explain_format, path=path) + res.raise_for_status() + explain = res.content.decode('utf-8') + store_to_ns(args.store_to, explain, local_ns) + html = sparql_explain_template.render(table=explain) explain_output = widgets.Output(layout=DEFAULT_LAYOUT) with explain_output: display(HTML(html)) @@ -186,8 +221,10 @@ def sparql(self, line='', cell='', local_ns: dict = None): query_type = get_query_type(cell) headers = {} if query_type not in ['SELECT', 'CONSTRUCT', 'DESCRIBE'] else { 'Accept': 'application/sparql-results+json'} - res = do_sparql_query(cell, self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator, headers, path=path) + + query_res = self.client.sparql(cell, path=path, headers=headers) + query_res.raise_for_status() + res = query_res.json() store_to_ns(args.store_to, res, local_ns) titles = [] children = [] @@ -207,8 +244,7 @@ def sparql(self, line='', cell='', local_ns: dict = None): titles.append('Table') children.append(hbox) - expand_all = line == '--expand-all' - sn = SPARQLNetwork(expand_all=expand_all) + sn = SPARQLNetwork(expand_all=args.expand_all) sn.extract_prefix_declarations_from_query(cell) try: sn.add_results(res) @@ -268,20 +304,18 @@ def sparql_status(self, line='', local_ns: dict = None): parser.add_argument('--store-to', type=str, default='', help='store query result to this variable') args = parser.parse_args(line.split()) - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) - if not args.cancelQuery: - res = do_sparql_status(self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator, args.queryId) - + status_res = self.client.sparql_cancel(args.queryId) + status_res.raise_for_status() + res = status_res.json() else: if args.queryId == '': print(SPARQL_CANCEL_HINT_MSG) return else: - res = do_sparql_cancel(self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator, args.queryId, args.silent) + cancel_res = self.client.sparql_cancel(args.queryId, args.silent) + cancel_res.raise_for_status() + res = cancel_res.json() store_to_ns(args.store_to, res, local_ns) print(json.dumps(res, indent=2)) @@ -304,13 +338,11 @@ def gremlin(self, line, cell, local_ns: dict = None): tab = widgets.Tab() if mode == QueryMode.EXPLAIN: - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) - - query_res = do_gremlin_explain(cell, self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator) - if 'explain' in query_res: - html = pre_container_template.render(content=query_res['explain']) + res = self.client.gremlin_explain(cell) + res.raise_for_status() + query_res = res.content.decode('utf-8') + if 'Neptune Gremlin Explain' in query_res: + html = pre_container_template.render(content=query_res) else: html = pre_container_template.render(content='No explain found') @@ -321,13 +353,11 @@ def gremlin(self, line, cell, local_ns: dict = None): tab.set_title(0, 'Explain') display(tab) elif mode == QueryMode.PROFILE: - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) - - query_res = do_gremlin_profile(cell, self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator) - if 'profile' in query_res: - html = pre_container_template.render(content=query_res['profile']) + res = self.client.gremlin_profile(cell) + res.raise_for_status() + query_res = res.content.decode('utf-8') + if 'Neptune Gremlin Profile' in query_res: + html = pre_container_template.render(content=query_res) else: html = pre_container_template.render(content='No profile found') profile_output = widgets.Output(layout=DEFAULT_LAYOUT) @@ -337,10 +367,7 @@ def gremlin(self, line, cell, local_ns: dict = None): tab.set_title(0, 'Profile') display(tab) else: - client_provider = create_client_provider(self.graph_notebook_config.auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) - query_res = do_gremlin_query(cell, self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, client_provider) + query_res = self.client.gremlin_query(cell) children = [] titles = [] @@ -393,24 +420,18 @@ def gremlin_status(self, line='', local_ns: dict = None): parser.add_argument('--store-to', type=str, default='', help='store query result to this variable') args = parser.parse_args(line.split()) - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) - if not args.cancelQuery: - res = do_gremlin_status(self.graph_notebook_config.host, - self.graph_notebook_config.port, - self.graph_notebook_config.ssl, self.graph_notebook_config.auth_mode, - request_generator, args.queryId, args.includeWaiting) - + status_res = self.client.gremlin_status(args.queryId) + status_res.raise_for_status() + res = status_res.json() else: if args.queryId == '': print(GREMLIN_CANCEL_HINT_MSG) return else: - res = do_gremlin_cancel(self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, self.graph_notebook_config.auth_mode, - request_generator, args.queryId) - + cancel_res = self.client.gremlin_cancel(args.queryId) + cancel_res.raise_for_status() + res = cancel_res.json() print(json.dumps(res, indent=2)) store_to_ns(args.store_to, res, local_ns) @@ -418,41 +439,34 @@ def gremlin_status(self, line='', local_ns: dict = None): @display_exceptions def status(self, line): logger.info(f'calling for status on endpoint {self.graph_notebook_config.host}') - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) - logger.info( - f'used credentials_provider_mode={self.graph_notebook_config.iam_credentials_provider_type.name} and auth_mode={self.graph_notebook_config.auth_mode.name} to make status request') - - res = get_status(self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator) + status_res = self.client.status() + status_res.raise_for_status() + res = status_res.json() logger.info(f'got the response {res}') return res @line_magic @display_exceptions def db_reset(self, line): - host = self.graph_notebook_config.host - port = self.graph_notebook_config.port - ssl = self.graph_notebook_config.ssl - - logger.info(f'calling system endpoint {host}') + logger.info(f'calling system endpoint {self.client.host}') parser = argparse.ArgumentParser() parser.add_argument('-g', '--generate-token', action='store_true', help='generate token for database reset') - parser.add_argument('-t', '--token', nargs=1, default='', help='perform database reset with given token') + parser.add_argument('-t', '--token', default='', help='perform database reset with given token') parser.add_argument('-y', '--yes', action='store_true', help='skip the prompt and perform database reset') args = parser.parse_args(line.split()) generate_token = args.generate_token skip_prompt = args.yes - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) - logger.info( - f'used credentials_provider_mode={self.graph_notebook_config.iam_credentials_provider_type.name} and auth_mode={self.graph_notebook_config.auth_mode.name} to make system request') if generate_token is False and args.token == '': if skip_prompt: - res = initiate_database_reset(host, port, ssl, request_generator) + initiate_res = self.client.initiate_reset() + initiate_res.raise_for_status() + res = initiate_res.json() token = res['payload']['token'] - res = perform_database_reset(token, host, port, ssl, request_generator) + + perform_reset_res = self.client.perform_reset(token) + perform_reset_res.raise_for_status() logger.info(f'got the response {res}') + res = perform_reset_res.json() return res output = widgets.Output() @@ -473,7 +487,9 @@ def db_reset(self, line): display(text_hbox, check_box, button_hbox, output) def on_button_delete_clicked(b): - result = initiate_database_reset(host, port, ssl, request_generator) + initiate_res = self.client.initiate_reset() + initiate_res.raise_for_status() + result = initiate_res.json() text_hbox.close() check_box.close() @@ -492,7 +508,9 @@ def on_button_delete_clicked(b): print(result) return - result = perform_database_reset(token, host, port, ssl, request_generator) + perform_reset_res = self.client.perform_reset(token) + perform_reset_res.raise_for_status() + result = perform_reset_res.json() if 'status' not in result or result['status'] != '200 OK': with output: @@ -522,7 +540,9 @@ def on_button_delete_clicked(b): display_html(HTML(loading_wheel_html)) try: retry -= 1 - interval_check_response = get_status(host, port, ssl, request_generator) + status_res = self.client.status() + status_res.raise_for_status() + interval_check_response = status_res.json() except Exception as e: # Exception is expected when database is resetting, continue waiting with job_status_output: @@ -560,10 +580,14 @@ def on_button_cancel_clicked(b): button_cancel.on_click(on_button_cancel_clicked) return elif generate_token: - res = initiate_database_reset(host, port, ssl, request_generator) + initiate_res = self.client.initiate_reset() + initiate_res.raise_for_status() + res = initiate_res.json() else: # args.token is an array of a single string, e.g., args.token=['ade-23-c23'], use index 0 to take the string - res = perform_database_reset(args.token[0], host, port, ssl, request_generator) + perform_res = self.client.perform_reset(args.token) + perform_res.raise_for_status() + res = perform_res.json() logger.info(f'got the response {res}') return res @@ -572,6 +596,7 @@ def on_button_cancel_clicked(b): @needs_local_scope @display_exceptions def load(self, line='', local_ns: dict = None): + # TODO: change widgets to let any arbitrary inputs be added by users parser = argparse.ArgumentParser() parser.add_argument('-s', '--source', default='s3://') parser.add_argument('-l', '--loader-arn', default=self.graph_notebook_config.load_from_s3_arn) @@ -588,16 +613,7 @@ def load(self, line='', local_ns: dict = None): args = parser.parse_args(line.split()) - # since this can be a long-running task, freezing variables in the case - # that a user alters them in another command. - host = self.graph_notebook_config.host - port = self.graph_notebook_config.port - ssl = self.graph_notebook_config.ssl - - credentials_provider_mode = self.graph_notebook_config.iam_credentials_provider_type - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, credentials_provider_mode) region = self.graph_notebook_config.aws_region - button = widgets.Button(description="Submit") output = widgets.Output() source = widgets.Text( @@ -708,7 +724,8 @@ def on_button_clicked(b): if not len(dependencies_list) < 64: validated = False - dep_validation_label = widgets.HTML('

A maximum of 64 jobs may be queued at once.

') + dep_validation_label = widgets.HTML( + '

A maximum of 64 jobs may be queued at once.

') dep_hbox.children += (dep_validation_label,) if not validated: @@ -718,11 +735,19 @@ def on_button_clicked(b): source.value) # replace any env variables in source.value with their values, can use $foo or ${foo}. Particularly useful for ${AWS_REGION} logger.info(f'using source_exp: {source_exp}') try: - load_result = do_load(host, port, source_format.value, ssl, str(source_exp), region_box.value, - arn.value, fail_on_error.value, request_generator, mode=mode.value, - parallelism=parallelism.value, update_single_cardinality=update_single_cardinality.value, - queue_request=queue_request.value, dependencies=dependencies_list) - + kwargs = { + 'failOnError': fail_on_error.value, + 'parallelism': parallelism.value, + 'updateSingleCardinalityProperties': update_single_cardinality.value, + 'queueRequest': queue_request.value + } + + if dependencies: + kwargs['dependencies'] = dependencies_list + + load_res = self.client.load(source.value, source_format.value, arn.value, region_box.value, **kwargs) + load_res.raise_for_status() + load_result = load_res.json() store_to_ns(args.store_to, load_result, local_ns) source_hbox.close() @@ -767,8 +792,9 @@ def on_button_clicked(b): with job_status_output: display_html(HTML(loading_wheel_html)) try: - interval_check_response = get_load_status(host, port, ssl, request_generator, - load_result['payload']['loadId']) + load_status_res = self.client.load_status(load_result['payload']['loadId']) + load_status_res.raise_for_status() + interval_check_response = load_status_res.json() except Exception as e: logger.error(e) with job_status_output: @@ -806,10 +832,9 @@ def load_ids(self, line, local_ns: dict = None): parser.add_argument('--store-to', type=str, default='') args = parser.parse_args(line.split()) - credentials_provider_mode = self.graph_notebook_config.iam_credentials_provider_type - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, credentials_provider_mode) - res = get_loader_jobs(self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator) + load_status = self.client.load_status() + load_status.raise_for_status() + res = load_status.json() ids = [] if 'payload' in res and 'loadIds' in res['payload']: ids = res['payload']['loadIds'] @@ -834,15 +859,21 @@ def load_status(self, line, local_ns: dict = None): parser.add_argument('--store-to', type=str, default='') parser.add_argument('--details', action='store_true', default=False) parser.add_argument('--errors', action='store_true', default=False) - parser.add_argument('--page', '-p', default='1', help='The error page number. Only valid when the --errors option is set.') - parser.add_argument('--errorsPerPage', '-e', default='10', help='The number of errors per each page. Only valid when the --errors option is set.') + parser.add_argument('--page', '-p', default='1', + help='The error page number. Only valid when the --errors option is set.') + parser.add_argument('--errorsPerPage', '-e', default='10', + help='The number of errors per each page. Only valid when the --errors option is set.') args = parser.parse_args(line.split()) - credentials_provider_mode = self.graph_notebook_config.iam_credentials_provider_type - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, credentials_provider_mode) - res = get_load_status(self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator, args.load_id, args.details, args.errors, - args.page, args.errorsPerPage) + payload = { + 'details': args.details, + 'errors': args.errors, + 'page': args.page, + 'errorsPerPage': args.errorsPerPage + } + load_status_res = self.client.load_status(args.load_id, **payload) + load_status_res.raise_for_status() + res = load_status_res.json() print(json.dumps(res, indent=2)) if args.store_to != '' and local_ns is not None: @@ -857,10 +888,9 @@ def cancel_load(self, line, local_ns: dict = None): parser.add_argument('--store-to', type=str, default='') args = parser.parse_args(line.split()) - credentials_provider_mode = self.graph_notebook_config.iam_credentials_provider_type - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, credentials_provider_mode) - res = cancel_load(self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator, args.load_id) + cancel_res = self.client.cancel_load(args.load_id) + cancel_res.raise_for_status() + res = cancel_res.json() if res: print('Cancelled successfully.') else: @@ -875,7 +905,7 @@ def seed(self, line): parser = argparse.ArgumentParser() parser.add_argument('--language', type=str, default='', choices=SEED_LANGUAGE_OPTIONS) parser.add_argument('--dataset', type=str, default='') - # TODO: Gremlin paths are not yet supported. + # TODO: Gremlin api paths are not yet supported. parser.add_argument('--path', '-p', default=SPARQL_ACTION, help='prefix path to query endpoint. For example, "foo/bar". The queried path would then be host:port/foo/bar for sparql seed commands') parser.add_argument('--run', action='store_true') @@ -938,21 +968,11 @@ def on_button_clicked(b=None): for q in queries: with output: print(f'{progress.value}/{len(queries)}:\t{q["name"]}') - # Just like with the load command, seed is long-running - # as such, we want to obtain the values of host, port, etc. in case they - # change during execution. - host = self.graph_notebook_config.host - port = self.graph_notebook_config.port - auth_mode = self.graph_notebook_config.auth_mode - ssl = self.graph_notebook_config.ssl - if language == 'gremlin': - client_provider = create_client_provider(auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) # IMPORTANT: We treat each line as its own query! for line in q['content'].splitlines(): try: - do_gremlin_query(line, host, port, ssl, client_provider) + self.client.gremlin_query(line) except GremlinServerError as gremlinEx: try: error = json.loads(gremlinEx.args[0][5:]) # remove the leading error code. @@ -975,10 +995,8 @@ def on_button_clicked(b=None): progress.close() return else: - request_generator = create_request_generator(auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) try: - do_sparql_query(q['content'], host, port, ssl, request_generator, path=args.path) + self.client.sparql(q['content'], path=args.path) except HTTPError as httpEx: # attempt to turn response into json try: @@ -1053,11 +1071,9 @@ def neptune_ml(self, line, cell='', local_ns: dict = None): parser = generate_neptune_ml_parser() args = parser.parse_args(line.split()) logger.info(f'received call to neptune_ml with details: {args.__dict__}, cell={cell}, local_ns={local_ns}') - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) main_output = widgets.Output() display(main_output) - res = neptune_ml_magic_handler(args, request_generator, self.graph_notebook_config, main_output, cell, local_ns) + res = neptune_ml_magic_handler(args, self.client, main_output, cell, local_ns) message = json.dumps(res, indent=2) if type(res) is dict else res store_to_ns(args.store_to, res, local_ns) with main_output: diff --git a/src/graph_notebook/magics/ml.py b/src/graph_notebook/magics/ml.py index b40db040..91e42474 100644 --- a/src/graph_notebook/magics/ml.py +++ b/src/graph_notebook/magics/ml.py @@ -6,12 +6,8 @@ from IPython.core.display import display from ipywidgets import widgets -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import credentials_provider_factory -from graph_notebook.authentication.iam_credentials_provider.credentials_provider import Credentials -from graph_notebook.configuration.generate_config import Configuration, AuthModeEnum from graph_notebook.magics.parsing import str_to_namespace_var -from graph_notebook.ml.sagemaker import start_export, get_export_status, start_processing_job, get_processing_status, \ - start_training, get_training_status, start_create_endpoint, get_endpoint_status, EXPORT_SERVICE_NAME +from graph_notebook.neptune.client import Client logger = logging.getLogger("neptune_ml_magic_handler") @@ -146,17 +142,26 @@ def generate_neptune_ml_parser(): return parser -def neptune_ml_export_start(params, export_url: str, export_ssl: bool = True, creds: Credentials = None): +def neptune_ml_export_start(client: Client, params, export_url: str, export_ssl: bool = True): if type(params) is str: params = json.loads(params) - job = start_export(export_url, params, export_ssl, creds) + export_res = client.export(export_url, params, export_ssl) + export_res.raise_for_status() + job = export_res.json() return job -def wait_for_export(export_url: str, job_id: str, output: widgets.Output, +def neptune_ml_export_status(client: Client, export_url: str, job_id: str, export_ssl: bool = True): + res = client.export_status(export_url, job_id, export_ssl) + res.raise_for_status() + job = res.json() + return job + + +def wait_for_export(client: Client, export_url: str, job_id: str, output: widgets.Output, export_ssl: bool = True, wait_interval: int = DEFAULT_WAIT_INTERVAL, - wait_timeout: int = DEFAULT_WAIT_TIMEOUT, creds: Credentials = None): + wait_timeout: int = DEFAULT_WAIT_TIMEOUT): job_id_output = widgets.Output() update_widget_output = widgets.Output() with output: @@ -170,7 +175,9 @@ def wait_for_export(export_url: str, job_id: str, output: widgets.Output, while datetime.datetime.utcnow() - beginning_time < (datetime.timedelta(seconds=wait_timeout)): update_widget_output.clear_output() print('Checking for latest status...') - export_status = get_export_status(export_url, export_ssl, job_id, creds) + status_res = client.export_status(export_url, job_id, export_ssl) + status_res.raise_for_status() + export_status = status_res.json() if export_status['status'] in ['succeeded', 'failed']: print('Export is finished') return export_status @@ -180,32 +187,30 @@ def wait_for_export(export_url: str, job_id: str, output: widgets.Output, time.sleep(wait_interval) -def neptune_ml_export(args: argparse.Namespace, config: Configuration, output: widgets.Output, cell: str): - auth_mode = AuthModeEnum.IAM if args.export_iam else AuthModeEnum.DEFAULT - creds = None - if auth_mode == AuthModeEnum.IAM: - creds = credentials_provider_factory(config.iam_credentials_provider_type).get_iam_credentials() - +def neptune_ml_export(args: argparse.Namespace, client: Client, output: widgets.Output, + cell: str): export_ssl = not args.export_no_ssl if args.which_sub == 'start': if cell == '': return 'Cell body must have json payload or reference notebook variable using syntax ${payload_var}' - export_job = neptune_ml_export_start(cell, args.export_url, export_ssl, creds) + export_job = neptune_ml_export_start(client, cell, args.export_url, export_ssl) if args.wait: - return wait_for_export(args.export_url, export_job['jobId'], - output, export_ssl, args.wait_interval, args.wait_timeout, creds) + return wait_for_export(client, args.export_url, export_job['jobId'], + output, export_ssl, args.wait_interval, args.wait_timeout) else: return export_job elif args.which_sub == 'status': if args.wait: - status = wait_for_export(args.export_url, args.job_id, output, export_ssl, args.wait_interval, - args.wait_timeout, creds) + status = wait_for_export(client, args.export_url, args.job_id, output, export_ssl, args.wait_interval, + args.wait_timeout) else: - status = get_export_status(args.export_url, export_ssl, args.job_id, creds) + status_res = client.export_status(args.export_url, args.job_id, export_ssl) + status_res.raise_for_status() + status = status_res.json() return status -def wait_for_dataprocessing(job_id: str, config: Configuration, request_param_generator, output: widgets.Output, +def wait_for_dataprocessing(job_id: str, client: Client, output: widgets.Output, wait_interval: int = DEFAULT_WAIT_INTERVAL, wait_timeout: int = DEFAULT_WAIT_TIMEOUT): job_id_output = widgets.Output() update_status_output = widgets.Output() @@ -219,7 +224,9 @@ def wait_for_dataprocessing(job_id: str, config: Configuration, request_param_ge beginning_time = datetime.datetime.utcnow() while datetime.datetime.utcnow() - beginning_time < (datetime.timedelta(seconds=wait_timeout)): update_status_output.clear_output() - status = get_processing_status(config.host, str(config.port), config.ssl, request_param_generator, job_id) + status_res = client.dataprocessing_job_status(job_id) + status_res.raise_for_status() + status = status_res.json() if status['status'] in ['Completed', 'Failed']: print('Data processing is finished') return status @@ -229,37 +236,34 @@ def wait_for_dataprocessing(job_id: str, config: Configuration, request_param_ge time.sleep(wait_interval) -def neptune_ml_dataprocessing(args: argparse.Namespace, request_param_generator, output: widgets.Output, - config: Configuration, params: dict = None): +def neptune_ml_dataprocessing(args: argparse.Namespace, client, output: widgets.Output, params: dict = None): if args.which_sub == 'start': if params is None or params == '' or params == {}: params = { - 'inputDataS3Location': args.s3_input_uri, - 'processedDataS3Location': args.s3_processed_uri, 'id': args.job_id, 'configFileName': args.config_file_name } - processing_job = start_processing_job(config.host, str(config.port), config.ssl, - request_param_generator, params) + processing_job_res = client.dataprocessing_start(args.s3_input_uri, args.s3_processed_uri, **params) + processing_job_res.raise_for_status() + processing_job = processing_job_res.json() job_id = params['id'] if args.wait: - return wait_for_dataprocessing(job_id, config, request_param_generator, - output, args.wait_interval, args.wait_timeout) + return wait_for_dataprocessing(job_id, client, output, args.wait_interval, args.wait_timeout) else: return processing_job elif args.which_sub == 'status': if args.wait: - return wait_for_dataprocessing(args.job_id, config, request_param_generator, output, args.wait_interval, - args.wait_timeout) + return wait_for_dataprocessing(args.job_id, client, output, args.wait_interval, args.wait_timeout) else: - return get_processing_status(config.host, str(config.port), config.ssl, request_param_generator, - args.job_id) + processing_status = client.dataprocessing_job_status(args.job_id) + processing_status.raise_for_status() + return processing_status.json() else: return f'Sub parser "{args.which} {args.which_sub}" was not recognized' -def wait_for_training(job_id: str, config: Configuration, request_param_generator, output: widgets.Output, +def wait_for_training(job_id: str, client: Client, output: widgets.Output, wait_interval: int = DEFAULT_WAIT_INTERVAL, wait_timeout: int = DEFAULT_WAIT_TIMEOUT): job_id_output = widgets.Output() update_status_output = widgets.Output() @@ -273,7 +277,9 @@ def wait_for_training(job_id: str, config: Configuration, request_param_generato beginning_time = datetime.datetime.utcnow() while datetime.datetime.utcnow() - beginning_time < (datetime.timedelta(seconds=wait_timeout)): update_status_output.clear_output() - status = get_training_status(config.host, str(config.port), config.ssl, request_param_generator, job_id) + training_status_res = client.modeltraining_job_status(job_id) + training_status_res.raise_for_status() + status = training_status_res.json() if status['status'] in ['Completed', 'Failed']: print('Training is finished') return status @@ -283,35 +289,34 @@ def wait_for_training(job_id: str, config: Configuration, request_param_generato time.sleep(wait_interval) -def neptune_ml_training(args: argparse.Namespace, request_param_generator, config: Configuration, - output: widgets.Output, params): +def neptune_ml_training(args: argparse.Namespace, client: Client, output: widgets.Output, params): if args.which_sub == 'start': if params is None or params == '' or params == {}: params = { "id": args.job_id, "dataProcessingJobId": args.data_processing_id, "trainingInstanceType": args.instance_type, - "trainModelS3Location": args.s3_output_uri } - training_job = start_training(config.host, str(config.port), config.ssl, request_param_generator, params) + start_training_res = client.modeltraining_start(args.job_id, args.s3_output_uri, **params) + start_training_res.raise_for_status() + training_job = start_training_res.json() if args.wait: - return wait_for_training(training_job['id'], config, request_param_generator, output, args.wait_interval, - args.wait_timeout) + return wait_for_training(training_job['id'], client, output, args.wait_interval, args.wait_timeout) else: return training_job elif args.which_sub == 'status': if args.wait: - return wait_for_training(args.job_id, config, request_param_generator, output, args.wait_interval, - args.wait_timeout) + return wait_for_training(args.job_id, client, output, args.wait_interval, args.wait_timeout) else: - return get_training_status(config.host, str(config.port), config.ssl, request_param_generator, - args.job_id) + training_status_res = client.modeltraining_job_status(args.job_id) + training_status_res.raise_for_status() + return training_status_res.json() else: return f'Sub parser "{args.which} {args.which_sub}" was not recognized' -def wait_for_endpoint(job_id: str, config: Configuration, request_param_generator, output: widgets.Output, +def wait_for_endpoint(job_id: str, client: Client, output: widgets.Output, wait_interval: int = DEFAULT_WAIT_INTERVAL, wait_timeout: int = DEFAULT_WAIT_TIMEOUT): job_id_output = widgets.Output() update_status_output = widgets.Output() @@ -325,7 +330,9 @@ def wait_for_endpoint(job_id: str, config: Configuration, request_param_generato beginning_time = datetime.datetime.utcnow() while datetime.datetime.utcnow() - beginning_time < (datetime.timedelta(seconds=wait_timeout)): update_status_output.clear_output() - status = get_endpoint_status(config.host, str(config.port), config.ssl, request_param_generator, job_id) + endpoint_status_res = client.endpoints_status(job_id) + endpoint_status_res.raise_for_status() + status = endpoint_status_res.json() if status['status'] in ['InService', 'Failed']: print('Endpoint creation is finished') return status @@ -335,47 +342,44 @@ def wait_for_endpoint(job_id: str, config: Configuration, request_param_generato time.sleep(wait_interval) -def neptune_ml_endpoint(args: argparse.Namespace, request_param_generator, - config: Configuration, output: widgets.Output, params): +def neptune_ml_endpoint(args: argparse.Namespace, client: Client, output: widgets.Output, params): if args.which_sub == 'create': if params is None or params == '' or params == {}: params = { "id": args.job_id, - "mlModelTrainingJobId": args.model_job_id, 'instanceType': args.instance_type } - create_endpoint_job = start_create_endpoint(config.host, str(config.port), config.ssl, - request_param_generator, params) - + create_endpoint_res = client.endpoints_create(args.model_job_id, **params) + create_endpoint_res.raise_for_status() + create_endpoint_job = create_endpoint_res.json() if args.wait: - return wait_for_endpoint(create_endpoint_job['id'], config, request_param_generator, output, - args.wait_interval, args.wait_timeout) + return wait_for_endpoint(create_endpoint_job['id'], client, output, args.wait_interval, args.wait_timeout) else: return create_endpoint_job elif args.which_sub == 'status': if args.wait: - return wait_for_endpoint(args.job_id, config, request_param_generator, output, - args.wait_interval, args.wait_timeout) + return wait_for_endpoint(args.job_id, client, output, args.wait_interval, args.wait_timeout) else: - return get_endpoint_status(config.host, str(config.port), config.ssl, request_param_generator, args.job_id) + endpoint_status = client.endpoints_status(args.job_id) + endpoint_status.raise_for_status() + return endpoint_status.json() else: return f'Sub parser "{args.which} {args.which_sub}" was not recognized' -def neptune_ml_magic_handler(args, request_param_generator, config: Configuration, output: widgets.Output, - cell: str = '', local_ns: dict = None) -> any: +def neptune_ml_magic_handler(args, client: Client, output: widgets.Output, cell: str = '', local_ns: dict = None): if local_ns is None: local_ns = {} cell = str_to_namespace_var(cell, local_ns) if args.which == 'export': - return neptune_ml_export(args, config, output, cell) + return neptune_ml_export(args, client, output, cell) elif args.which == 'dataprocessing': - return neptune_ml_dataprocessing(args, request_param_generator, output, config, cell) + return neptune_ml_dataprocessing(args, client, output, cell) elif args.which == 'training': - return neptune_ml_training(args, request_param_generator, config, output, cell) + return neptune_ml_training(args, client, output, cell) elif args.which == 'endpoint': - return neptune_ml_endpoint(args, request_param_generator, config, output, cell) + return neptune_ml_endpoint(args, client, output, cell) else: return f'sub parser {args.which} was not recognized' diff --git a/src/graph_notebook/ml/sagemaker.py b/src/graph_notebook/ml/sagemaker.py deleted file mode 100644 index 71c4a59b..00000000 --- a/src/graph_notebook/ml/sagemaker.py +++ /dev/null @@ -1,86 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import json -import requests -from requests_aws4auth import AWS4Auth - -from graph_notebook.authentication.iam_credentials_provider.credentials_provider import Credentials -from graph_notebook.request_param_generator.call_and_get_response import call_and_get_response - -EXPORT_SERVICE_NAME = 'execute-api' -EXPORT_ACTION = 'neptune-export' -EXTRA_HEADERS = {'content-type': 'application/json'} -UPDATE_DELAY_SECONDS = 60 - - -def start_export(export_host: str, export_params: dict, use_ssl: bool, - creds: Credentials = None) -> dict: - auth = None - if creds is not None: - auth = AWS4Auth(creds.key, creds.secret, creds.region, EXPORT_SERVICE_NAME, - session_token=creds.token) - - protocol = 'https' if use_ssl else 'http' - url = f'{protocol}://{export_host}/{EXPORT_ACTION}' - res = requests.post(url, json=export_params, headers=EXTRA_HEADERS, auth=auth) - res.raise_for_status() - job = res.json() - return job - - -def get_export_status(export_host: str, use_ssl: bool, job_id: str, creds: Credentials = None): - auth = None - if creds is not None: - auth = AWS4Auth(creds.key, creds.secret, creds.region, EXPORT_SERVICE_NAME, - session_token=creds.token) - - protocol = 'https' if use_ssl else 'http' - url = f'{protocol}://{export_host}/{EXPORT_ACTION}/{job_id}' - res = requests.get(url, headers=EXTRA_HEADERS, auth=auth) - res.raise_for_status() - job = res.json() - return job - - -def get_processing_status(host: str, port: str, use_ssl: bool, request_param_generator, job_name: str): - res = call_and_get_response('get', f'ml/dataprocessing/{job_name}', host, port, request_param_generator, - use_ssl, extra_headers=EXTRA_HEADERS) - status = res.json() - return status - - -def start_processing_job(host: str, port: str, use_ssl: bool, request_param_generator, params: dict): - params_raw = json.dumps(params) if type(params) is dict else params - res = call_and_get_response('post', 'ml/dataprocessing', host, port, request_param_generator, use_ssl, params_raw, - EXTRA_HEADERS) - job = res.json() - return job - - -def start_training(host: str, port: str, use_ssl: bool, request_param_generator, params): - params_raw = json.dumps(params) if type(params) is dict else params - res = call_and_get_response('post', 'ml/modeltraining', host, port, request_param_generator, use_ssl, params_raw, - EXTRA_HEADERS) - return res.json() - - -def get_training_status(host: str, port: str, use_ssl: bool, request_param_generator, training_job_name: str): - res = call_and_get_response('get', f'ml/modeltraining/{training_job_name}', host, port, - request_param_generator, use_ssl, extra_headers=EXTRA_HEADERS) - return res.json() - - -def start_create_endpoint(host: str, port: str, use_ssl: bool, request_param_generator, params): - params_raw = json.dumps(params) if type(params) is dict else params - res = call_and_get_response('post', 'ml/endpoints', host, port, request_param_generator, use_ssl, params_raw, - EXTRA_HEADERS) - return res.json() - - -def get_endpoint_status(host: str, port: str, use_ssl: bool, request_param_generator, training_job_name: str): - res = call_and_get_response('get', f'ml/endpoints/{training_job_name}', host, port, request_param_generator, - use_ssl, extra_headers=EXTRA_HEADERS) - return res.json() diff --git a/src/graph_notebook/authentication/__init__.py b/src/graph_notebook/neptune/__init__.py similarity index 100% rename from src/graph_notebook/authentication/__init__.py rename to src/graph_notebook/neptune/__init__.py diff --git a/src/graph_notebook/neptune/client.py b/src/graph_notebook/neptune/client.py new file mode 100644 index 00000000..5bad2701 --- /dev/null +++ b/src/graph_notebook/neptune/client.py @@ -0,0 +1,497 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" + +import json + +import botocore +import requests +from SPARQLWrapper import SPARQLWrapper +from boto3 import Session +from botocore.auth import SigV4Auth +from botocore.awsrequest import AWSRequest +from gremlin_python.driver import client +from tornado import httpclient + +import graph_notebook.neptune.gremlin.graphsonV3d0_MapType_objectify_patch # noqa F401 + +DEFAULT_SPARQL_CONTENT_TYPE = 'application/x-www-form-urlencoded' +DEFAULT_PORT = 8182 +DEFAULT_REGION = 'us-east-1' + +NEPTUNE_SERVICE_NAME = 'neptune-db' + +# TODO: Constants for states of each long-running job +# TODO: add doc links to each command + +FORMAT_CSV = 'csv' +FORMAT_NTRIPLE = 'ntriples' +FORMAT_NQUADS = 'nquads' +FORMAT_RDFXML = 'rdfxml' +FORMAT_TURTLE = 'turtle' + +PARALLELISM_LOW = 'LOW' +PARALLELISM_MEDIUM = 'MEDIUM' +PARALLELISM_HIGH = 'HIGH' +PARALLELISM_OVERSUBSCRIBE = 'OVERSUBSCRIBE' + +MODE_RESUME = 'RESUME' +MODE_NEW = 'NEW' +MODE_AUTO = 'AUTO' + +LOAD_JOB_MODES = [MODE_RESUME, MODE_NEW, MODE_AUTO] +VALID_FORMATS = [FORMAT_CSV, FORMAT_NTRIPLE, FORMAT_NQUADS, FORMAT_RDFXML, FORMAT_TURTLE] +PARALLELISM_OPTIONS = [PARALLELISM_LOW, PARALLELISM_MEDIUM, PARALLELISM_HIGH, PARALLELISM_OVERSUBSCRIBE] +LOADER_ACTION = 'loader' + +FINAL_LOAD_STATUSES = ['LOAD_COMPLETED', + 'LOAD_COMMITTED_W_WRITE_CONFLICTS', + 'LOAD_CANCELLED_BY_USER', + 'LOAD_CANCELLED_DUE_TO_ERRORS', + 'LOAD_FAILED', + 'LOAD_UNEXPECTED_ERROR', + 'LOAD_DATA_DEADLOCK', + 'LOAD_DATA_FAILED_DUE_TO_FEED_MODIFIED_OR_DELETED', + 'LOAD_S3_READ_ERROR', + 'LOAD_S3_ACCESS_DENIED_ERROR', + 'LOAD_IN_QUEUE', + 'LOAD_FAILED_BECAUSE_DEPENDENCY_NOT_SATISFIED', + 'LOAD_FAILED_INVALID_REQUEST', ] + +EXPORT_SERVICE_NAME = 'execute-api' +EXPORT_ACTION = 'neptune-export' +EXTRA_HEADERS = {'content-type': 'application/json'} +SPARQL_ACTION = 'sparql' + + +class Client(object): + def __init__(self, host: str, port: int = DEFAULT_PORT, ssl: bool = True, region: str = DEFAULT_REGION, + sparql_path: str = '/sparql', auth=None, session: Session = None): + self.host = host + self.port = port + self.ssl = ssl + self.sparql_path = sparql_path + self.region = region + self._auth = auth + self._session = session + + self._http_protocol = 'https' if self.ssl else 'http' + self._ws_protocol = 'wss' if self.ssl else 'ws' + + self._http_session = None + + def sparql_query(self, query: str, headers=None, explain: str = '', path: str = '') -> requests.Response: + if headers is None: + headers = {} + + data = {'query': query} + return self.do_sparql_request(data, headers, explain, path=path) + + def sparql_update(self, update: str, headers=None, explain: str = '', path: str = '') -> requests.Response: + if headers is None: + headers = {} + + data = {'update': update} + return self.do_sparql_request(data, headers, explain, path=path) + + def do_sparql_request(self, data: dict, headers=None, explain: str = '', path: str = ''): + if 'content-type' not in headers: + headers['content-type'] = DEFAULT_SPARQL_CONTENT_TYPE + + explain = explain.lower() + if explain != '': + if explain not in ['static', 'dynamic', 'details']: + raise ValueError('explain mode not valid, must be one of "static", "dynamic", or "details"') + else: + data['explain'] = explain + + sparql_path = path if path != '' else self.sparql_path + uri = f'{self._http_protocol}://{self.host}:{self.port}/{sparql_path}' + req = self._prepare_request('POST', uri, data=data, headers=headers) + res = self._http_session.send(req) + return res + + def sparql(self, query: str, headers=None, explain: str = '', path: str = '') -> requests.Response: + if headers is None: + headers = {} + + s = SPARQLWrapper('') + s.setQuery(query) + query_type = s.queryType.upper() + if query_type in ['SELECT', 'CONSTRUCT', 'ASK', 'DESCRIBE']: + return self.sparql_query(query, headers, explain, path=path) + else: + return self.sparql_update(query, headers, explain, path=path) + + # TODO: enum/constants for supported types + def sparql_explain(self, query: str, explain: str = 'dynamic', output_format: str = 'text/html', + headers=None, path: str = '') -> requests.Response: + if headers is None: + headers = {} + + if 'Accept' not in headers: + headers['Accept'] = output_format + + return self.sparql(query, headers, explain, path=path) + + def sparql_status(self, query_id: str = ''): + return self._query_status('sparql', query_id=query_id) + + def sparql_cancel(self, query_id: str, silent: bool = False): + if type(query_id) is not str or query_id == '': + raise ValueError('query_id must be a non-empty string') + return self._query_status('sparql', query_id=query_id, silent=silent, cancelQuery=True) + + def get_gremlin_connection(self) -> client.Client: + uri = f'{self._http_protocol}://{self.host}:{self.port}/gremlin' + request = self._prepare_request('GET', uri) + + ws_url = f'{self._ws_protocol}://{self.host}:{self.port}/gremlin' + ws_request = httpclient.HTTPRequest(ws_url, headers=dict(request.headers)) + return client.Client(ws_request, 'g') + + def gremlin_query(self, query, bindings=None): + c = self.get_gremlin_connection() + try: + result = c.submit(query, bindings) + future_results = result.all() + results = future_results.result() + c.close() + return results + except Exception as e: + c.close() + raise e + + def gremlin_http_query(self, query, headers=None) -> requests.Response: + if headers is None: + headers = {} + + uri = f'{self._http_protocol}://{self.host}:{self.port}/gremlin' + data = {'gremlin': query} + req = self._prepare_request('POST', uri, data=json.dumps(data), headers=headers) + res = self._http_session.send(req) + return res + + def gremlin_status(self, query_id: str = '', include_waiting: bool = False): + kwargs = {} + if include_waiting: + kwargs['includeWaiting'] = True + return self._query_status('gremlin', query_id=query_id, **kwargs) + + def gremlin_cancel(self, query_id: str): + if type(query_id) is not str or query_id == '': + raise ValueError('query_id must be a non-empty string') + return self._query_status('gremlin', query_id=query_id, cancelQuery=True) + + def gremlin_explain(self, query: str) -> requests.Response: + return self._gremlin_query_plan(query, 'explain') + + def gremlin_profile(self, query: str) -> requests.Response: + return self._gremlin_query_plan(query, 'profile') + + def _gremlin_query_plan(self, query: str, plan_type: str, ) -> requests.Response: + url = f'{self._http_protocol}://{self.host}:{self.port}/gremlin/{plan_type}' + data = {'gremlin': query} + req = self._prepare_request('POST', url, data=json.dumps(data)) + res = self._http_session.send(req) + return res + + def status(self) -> requests.Response: + url = f'{self._http_protocol}://{self.host}:{self.port}/status' + req = self._prepare_request('GET', url, data='') + res = self._http_session.send(req) + return res + + def load(self, source: str, source_format: str, iam_role_arn: str, region: str, **kwargs) -> requests.Response: + """ + For a full list of allowed parameters, see aws documentation on the Neptune loader + endpoint: https://docs.aws.amazon.com/neptune/latest/userguide/load-api-reference-load.html + """ + payload = { + 'source': source, + 'format': source_format, + 'region': self.region, + 'iamRoleArn': iam_role_arn + } + + for key, value in kwargs.items(): + payload[key] = value + + url = f'{self._http_protocol}://{self.host}:{self.port}/loader' + raw = json.dumps(payload) + req = self._prepare_request('POST', url, data=raw, headers={'content-type': 'application/json'}) + res = self._http_session.send(req) + return res + + def load_status(self, load_id: str = '', **kwargs) -> requests.Response: + params = {} + for k, v in kwargs.items(): + params[k] = v + + if load_id != '': + params['loadId'] = load_id + + url = f'{self._http_protocol}://{self.host}:{self.port}/loader' + req = self._prepare_request('GET', url, params=params) + res = self._http_session.send(req) + return res + + def cancel_load(self, load_id: str) -> requests.Response: + url = f'{self._http_protocol}://{self.host}:{self.port}/loader' + params = {'loadId': load_id} + req = self._prepare_request('DELETE', url, params=params) + res = self._http_session.send(req) + return res + + def initiate_reset(self) -> requests.Response: + data = { + 'action': 'initiateDatabaseReset' + } + url = f'{self._http_protocol}://{self.host}:{self.port}/system' + req = self._prepare_request('POST', url, data=data) + res = self._http_session.send(req) + return res + + def perform_reset(self, token: str) -> requests.Response: + data = { + 'action': 'performDatabaseReset', + 'token': token + } + url = f'{self._http_protocol}://{self.host}:{self.port}/system' + req = self._prepare_request('POST', url, data=data) + res = self._http_session.send(req) + return res + + def dataprocessing_start(self, s3_input_uri: str, s3_output_uri: str, **kwargs) -> requests.Response: + data = { + 'inputDataS3Location': s3_input_uri, + 'processedDataS3Location': s3_output_uri, + } + + for k, v in kwargs.items(): + data[k] = v + + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/dataprocessing' + req = self._prepare_request('POST', url, data=json.dumps(data), headers={'content-type': 'application/json'}) + res = self._http_session.send(req) + return res + + def dataprocessing_job_status(self, job_id: str, neptune_iam_role_arn: str = '') -> requests.Response: + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/dataprocessing/{job_id}' + data = {} + if neptune_iam_role_arn != '': + data['neptuneIamRoleArn'] = neptune_iam_role_arn + req = self._prepare_request('GET', url, params=data) + res = self._http_session.send(req) + return res + + def dataprocessing_status(self, max_items: int = 10, neptune_iam_role_arn: str = '') -> requests.Response: + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/dataprocessing' + data = { + 'maxItems': max_items + } + + if neptune_iam_role_arn != '': + data['neptuneIamRoleArn'] = neptune_iam_role_arn + req = self._prepare_request('GET', url, params=data) + res = self._http_session.send(req) + return res + + def dataprocessing_stop(self, job_id: str, clean=False, neptune_iam_role_arn: str = '') -> requests.Response: + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/dataprocessing/{job_id}' + data = { + 'clean': clean + } + if neptune_iam_role_arn != '': + data['neptuneIamRoleArn'] = neptune_iam_role_arn + + req = self._prepare_request('DELETE', url, params=data) + res = self._http_session.send(req) + return res + + def modeltraining_start(self, data_processing_job_id: str, train_model_s3_location: str, + **kwargs) -> requests.Response: + """ + for a full list of supported parameters, see: + https://docs.aws.amazon.com/neptune/latest/userguide/machine-learning-api-modeltraining.html + """ + data = { + 'dataProcessingJobId': data_processing_job_id, + 'trainModelS3Location': train_model_s3_location + } + + for k, v in kwargs.items(): + data[k] = v + + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/modeltraining' + req = self._prepare_request('POST', url, data=json.dumps(data), headers={'content-type': 'application/json'}) + res = self._http_session.send(req) + return res + + def modeltraining_status(self, max_items: int = 10, neptune_iam_role_arn: str = '') -> requests.Response: + data = { + 'maxItems': max_items + } + + if neptune_iam_role_arn != '': + data['neptuneIamRoleArn'] = neptune_iam_role_arn + + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/modeltraining' + req = self._prepare_request('GET', url, params=data) + res = self._http_session.send(req) + return res + + def modeltraining_job_status(self, training_job_id: str, neptune_iam_role_arn: str = '') -> requests.Response: + data = {} if neptune_iam_role_arn == '' else {'neptuneIamRoleArn': neptune_iam_role_arn} + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/modeltraining/{training_job_id}' + req = self._prepare_request('GET', url, params=data) + res = self._http_session.send(req) + return res + + def modeltraining_stop(self, training_job_id: str, neptune_iam_role_arn: str = '', + clean: bool = False) -> requests.Response: + data = { + 'clean': "TRUE" if clean else "FALSE", + } + + if neptune_iam_role_arn != '': + data['neptuneIamRoleArn'] = neptune_iam_role_arn + + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/modeltraining/{training_job_id}' + req = self._prepare_request('DELETE', url, params=data) + res = self._http_session.send(req) + return res + + def endpoints_create(self, training_job_id: str, **kwargs) -> requests.Response: + data = { + 'mlModelTrainingJobId': training_job_id + } + + for k, v in kwargs.items(): + data[k] = v + + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/endpoints' + req = self._prepare_request('POST', url, data=json.dumps(data), headers={'content-type': 'application/json'}) + res = self._http_session.send(req) + return res + + def endpoints_status(self, endpoint_id: str, neptune_iam_role_arn: str = '') -> requests.Response: + data = {} if neptune_iam_role_arn == '' else {'neptuneIamRoleArn': neptune_iam_role_arn} + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/endpoints/{endpoint_id}' + req = self._prepare_request('GET', url, params=data) + res = self._http_session.send(req) + return res + + def endpoints_delete(self, endpoint_id: str, neptune_iam_role_arn: str = '') -> requests.Response: + data = {} if neptune_iam_role_arn == '' else {'neptuneIamRoleArn': neptune_iam_role_arn} + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/endpoints/{endpoint_id}' + req = self._prepare_request('DELETE', url, params=data) + res = self._http_session.send(req) + return res + + def endpoints(self, max_items: int = 10, neptune_iam_role_arn: str = ''): + data = { + 'maxItems': max_items + } + if neptune_iam_role_arn != '': + data['neptuneIamRoleArn'] = neptune_iam_role_arn + + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/endpoints' + req = self._prepare_request('GET', url, params=data) + res = self._http_session.send(req) + return res + + def export(self, host: str, params: dict, ssl: bool = True) -> requests.Response: + protocol = 'https' if ssl else 'http' + url = f'{protocol}://{host}/{EXPORT_ACTION}' + req = self._prepare_request('POST', url, data=json.dumps(params), service="execute-api") + res = self._http_session.send(req) + return res + + def export_status(self, host, job_id, ssl: bool = True) -> requests.Response: + protocol = 'https' if ssl else 'http' + url = f'{protocol}://{host}/{EXPORT_ACTION}/{job_id}' + req = self._prepare_request('GET', url, service="execute-api") + res = self._http_session.send(req) + return res + + def _query_status(self, language: str, *, query_id: str = '', **kwargs) -> requests.Response: + data = {} + if query_id != '': + data['queryId'] = query_id + + for k, v in kwargs.items(): + data[k] = v + + headers = { + 'Content-Type': 'application/x-www-form-urlencoded' + } + url = f'{self._http_protocol}://{self.host}:{self.port}/{language}/status' + req = self._prepare_request('POST', url, data=data, headers=headers) + res = self._http_session.send(req) + return res + + def _prepare_request(self, method, url, *, data=None, params=None, headers=None, service=NEPTUNE_SERVICE_NAME): + self._ensure_http_session() + request = requests.Request(method=method, url=url, data=data, params=params, headers=headers, auth=self._auth) + if self._session is not None: + credentials = self._session.get_credentials() + frozen_creds = credentials.get_frozen_credentials() + + req = AWSRequest(method=method, url=url, data=data, params=params, headers=headers) + SigV4Auth(frozen_creds, service, self.region).add_auth(req) + prepared_iam_req = req.prepare() + request.headers = dict(prepared_iam_req.headers) + + return request.prepare() + + def _ensure_http_session(self): + if not self._http_session: + self._http_session = requests.Session() + + def set_session(self, session: Session): + self._session = session + + def close(self): + if self._http_session: + self._http_session.close() + self._http_session = None + + @property + def iam_enabled(self): + return type(self._session) is botocore.session.Session + + +class ClientBuilder(object): + def __init__(self, args: dict = None): + if args is None: + args = {} + self.args = args + + def with_host(self, host: str): + self.args['host'] = host + return ClientBuilder(self.args) + + def with_port(self, port: int): + self.args['port'] = port + return ClientBuilder(self.args) + + def with_sparql_path(self, path: str): + self.args['sparql_path'] = path + return ClientBuilder(self.args) + + def with_tls(self, tls: bool): + self.args['ssl'] = tls + return ClientBuilder(self.args) + + def with_region(self, region: str): + self.args['region'] = region + return ClientBuilder(self.args) + + def with_iam(self, session: Session): + self.args['session'] = session + return ClientBuilder(self.args) + + def build(self) -> Client: + return Client(**self.args) diff --git a/src/graph_notebook/ml/__init__.py b/src/graph_notebook/neptune/gremlin/__init__.py similarity index 100% rename from src/graph_notebook/ml/__init__.py rename to src/graph_notebook/neptune/gremlin/__init__.py diff --git a/src/graph_notebook/gremlin/client_provider/graphsonV3d0_MapType_objectify_patch.py b/src/graph_notebook/neptune/gremlin/graphsonV3d0_MapType_objectify_patch.py similarity index 94% rename from src/graph_notebook/gremlin/client_provider/graphsonV3d0_MapType_objectify_patch.py rename to src/graph_notebook/neptune/gremlin/graphsonV3d0_MapType_objectify_patch.py index ab0b6896..a28b85a7 100644 --- a/src/graph_notebook/gremlin/client_provider/graphsonV3d0_MapType_objectify_patch.py +++ b/src/graph_notebook/neptune/gremlin/graphsonV3d0_MapType_objectify_patch.py @@ -4,7 +4,8 @@ """ from gremlin_python.structure.io.graphsonV3d0 import MapType -from graph_notebook.gremlin.client_provider.hashable_dict_patch import HashableDict +from graph_notebook.neptune.gremlin.hashable_dict_patch import HashableDict + # Original code from Tinkerpop 3.4.1 # diff --git a/src/graph_notebook/gremlin/client_provider/hashable_dict_patch.py b/src/graph_notebook/neptune/gremlin/hashable_dict_patch.py similarity index 100% rename from src/graph_notebook/gremlin/client_provider/hashable_dict_patch.py rename to src/graph_notebook/neptune/gremlin/hashable_dict_patch.py diff --git a/src/graph_notebook/notebooks/03-Sample-Applications/00-Sample-Applications-Overview.ipynb b/src/graph_notebook/notebooks/03-Sample-Applications/00-Sample-Applications-Overview.ipynb index efbd443e..208830e1 100644 --- a/src/graph_notebook/notebooks/03-Sample-Applications/00-Sample-Applications-Overview.ipynb +++ b/src/graph_notebook/notebooks/03-Sample-Applications/00-Sample-Applications-Overview.ipynb @@ -56,4 +56,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/src/graph_notebook/notebooks/04-Machine-Learning/neptune_ml_utils.py b/src/graph_notebook/notebooks/04-Machine-Learning/neptune_ml_utils.py index 7517ab6d..5541ca68 100644 --- a/src/graph_notebook/notebooks/04-Machine-Learning/neptune_ml_utils.py +++ b/src/graph_notebook/notebooks/04-Machine-Learning/neptune_ml_utils.py @@ -15,6 +15,8 @@ # How often to check the status UPDATE_DELAY_SECONDS = 15 +HOME_DIRECTORY = os.path.expanduser("~") + def signed_request(method, url, data=None, params=None, headers=None, service=None): creds = boto3.Session().get_credentials().get_frozen_credentials() @@ -24,7 +26,7 @@ def signed_request(method, url, data=None, params=None, headers=None, service=No def load_configuration(): - with open('/home/ec2-user/graph_notebook_config.json') as f: + with open(f'{HOME_DIRECTORY}/graph_notebook_config.json') as f: data = json.load(f) host = data['host'] port = data['port'] @@ -34,62 +36,69 @@ def load_configuration(): iam = False return host, port, iam + def get_host(): host, port, iam = load_configuration() return host + def get_iam(): host, port, iam = load_configuration() return iam + def get_training_job_name(prefix: str): return f'{prefix}-{int(time.time())}' + def check_ml_enabled(): - host, port, use_iam = load_configuration() + host, port, use_iam = load_configuration() response = signed_request("GET", url=f'https://{host}:{port}/ml/modeltraining', service='neptune-db') if response.status_code != 200: print('''This Neptune cluster \033[1mis not\033[0m configured to use Neptune ML. Please configure the cluster according to the Amazpnm Neptune ML documentation before proceeding.''') else: print("This Neptune cluster is configured to use Neptune ML") - + + def get_export_service_host(): - with open('/home/ec2-user/.bashrc') as f: + with open(f'{HOME_DIRECTORY}/.bashrc') as f: data = f.readlines() for d in data: if str.startswith(d, 'export NEPTUNE_EXPORT_API_URI'): parts = d.split('=') - if len(parts)==2: - path=urlparse(parts[1].rstrip()) + if len(parts) == 2: + path = urlparse(parts[1].rstrip()) return path.hostname + "/v1" - logging.error("Unable to determine the Neptune Export Service Endpoint. You will need to enter this assign this manually.") + logging.error( + "Unable to determine the Neptune Export Service Endpoint. You will need to enter this or assign it manually.") return None + def delete_pretrained_data(setup_node_classification: bool, setup_node_regression: bool, setup_link_prediction: bool): - host, port, use_iam = load_configuration() if setup_node_classification: response = signed_request("POST", service='neptune-db', - url=f'https://{host}:{port}/gremlin', - headers={'content-type': 'application/json'}, - data=json.dumps({'gremlin': "g.V('movie_1', 'movie_7', 'movie_15').properties('genre').drop()"})) - + url=f'https://{host}:{port}/gremlin', + headers={'content-type': 'application/json'}, + data=json.dumps( + {'gremlin': "g.V('movie_1', 'movie_7', 'movie_15').properties('genre').drop()"})) + if response.status_code != 200: print(response.content.decode('utf-8')) if setup_node_regression: response = signed_request("POST", service='neptune-db', - url=f'https://{host}:{port}/gremlin', - headers={'content-type': 'application/json'}, - data=json.dumps({'gremlin': "g.V('user_1').out('wrote').properties('score').drop()"})) + url=f'https://{host}:{port}/gremlin', + headers={'content-type': 'application/json'}, + data=json.dumps({'gremlin': "g.V('user_1').out('wrote').properties('score').drop()"})) if response.status_code != 200: print(response.content.decode('utf-8')) if setup_link_prediction: response = signed_request("POST", service='neptune-db', - url=f'https://{host}:{port}/gremlin', - headers={'content-type': 'application/json'}, - data=json.dumps({'gremlin': "g.V('user_1').outE('rated').drop()"})) + url=f'https://{host}:{port}/gremlin', + headers={'content-type': 'application/json'}, + data=json.dumps({'gremlin': "g.V('user_1').outE('rated').drop()"})) if response.status_code != 200: print(response.content.decode('utf-8')) @@ -114,7 +123,8 @@ def delete_endpoint(training_job_name: str, neptune_iam_role_arn=None): query_string = f'?neptuneIamRoleArn={neptune_iam_role_arn}' host, port, use_iam = load_configuration() response = signed_request("DELETE", service='neptune-db', - url=f'https://{host}:{port}/ml/endpoints/{training_job_name}{query_string}', headers={'content-type': 'application/json'}) + url=f'https://{host}:{port}/ml/endpoints/{training_job_name}{query_string}', + headers={'content-type': 'application/json'}) if response.status_code != 200: print(response.content.decode('utf-8')) else: @@ -129,28 +139,28 @@ def prepare_movielens_data(s3_bucket_uri: str): logging.error(e) - def setup_pretrained_endpoints(s3_bucket_uri: str, setup_node_classification: bool, setup_node_regression: bool, setup_link_prediction: bool): delete_pretrained_data(setup_node_classification, setup_node_regression, setup_link_prediction) try: - return PretrainedModels().setup_pretrained_endpoints(s3_bucket_uri, setup_node_classification, setup_node_regression, setup_link_prediction) + return PretrainedModels().setup_pretrained_endpoints(s3_bucket_uri, setup_node_classification, + setup_node_regression, setup_link_prediction) except Exception as e: logging.error(e) class MovieLensProcessor: - raw_directory = r'/home/ec2-user/data/raw' - formatted_directory = r'/home/ec2-user/data/formatted' + raw_directory = fr'{HOME_DIRECTORY}/data/raw' + formatted_directory = fr'{HOME_DIRECTORY}/data/formatted' def __download_and_unzip(self): - if not os.path.exists('/home/ec2-user/data'): - os.makedirs('/home/ec2-user/data') - if not os.path.exists('/home/ec2-user/data/raw'): - os.makedirs('/home/ec2-user/data/raw') - if not os.path.exists('/home/ec2-user/data/formatted'): - os.makedirs('/home/ec2-user/data/formatted') + if not os.path.exists(f'{HOME_DIRECTORY}/data'): + os.makedirs(f'{HOME_DIRECTORY}/data') + if not os.path.exists(f'{HOME_DIRECTORY}/data/raw'): + os.makedirs(f'{HOME_DIRECTORY}/data/raw') + if not os.path.exists(f'{HOME_DIRECTORY}/data/formatted'): + os.makedirs(f'{HOME_DIRECTORY}/data/formatted') # Download the MovieLens dataset url = 'http://files.grouplens.org/datasets/movielens/ml-100k.zip' r = requests.get(url, allow_redirects=True) @@ -163,7 +173,7 @@ def __process_movies_genres(self): # process the movies_vertex.csv print('Processing Movies', end='\r') movies_df = pd.read_csv(os.path.join( - self.raw_directory, 'ml-100k/u.item'), sep='|', encoding='ISO-8859-1', + self.raw_directory, 'ml-100k/u.item'), sep='|', encoding='ISO-8859-1', names=['~id', 'title', 'release_date', 'video_release_date', 'imdb_url', 'unknown', 'Action', 'Adventure', 'Animation', 'Childrens', 'Comedy', 'Crime', 'Documentary', 'Drama', 'Fantasy', 'Film-Noir', 'Horror', 'Musical', @@ -218,7 +228,7 @@ def __process_ratings_users(self): # Create ratings vertices and add edges on both sides print('Processing Ratings', end='\r') ratings_vertices = pd.read_csv(os.path.join( - self.raw_directory, 'ml-100k/u.data'), sep='\t', encoding='ISO-8859-1', + self.raw_directory, 'ml-100k/u.data'), sep='\t', encoding='ISO-8859-1', names=['~from', '~to', 'score:Int', 'timestamp']) ratings_vertices['~from'] = ratings_vertices['~from'].apply( lambda x: f'user_{x}') @@ -232,10 +242,10 @@ def __process_ratings_users(self): dict = {} for index, row in ratings_vertices.iterrows(): - dict[index*2] = {'~id': uuid.uuid4(), '~label': 'wrote', - '~from': row['~from'], '~to': row['~id']} - dict[index*2 + 1] = {'~id': uuid.uuid4(), '~label': 'about', - '~from': row['~id'], '~to': row['~to']} + dict[index * 2] = {'~id': uuid.uuid4(), '~label': 'wrote', + '~from': row['~from'], '~to': row['~id']} + dict[index * 2 + 1] = {'~id': uuid.uuid4(), '~label': 'about', + '~from': row['~id'], '~to': row['~to']} rating_edges_df = pd.DataFrame.from_dict(dict, "index") # Remove the from and to columns and write this out as a vertex now @@ -259,7 +269,7 @@ def __process_users(self): # User Vertices - Load, rename column with type, and save user_df = pd.read_csv(os.path.join( - self.raw_directory, 'ml-100k/u.user'), sep='|', encoding='ISO-8859-1', + self.raw_directory, 'ml-100k/u.user'), sep='|', encoding='ISO-8859-1', names=['~id', 'age:Int', 'gender', 'occupation', 'zip_code']) user_df['~id'] = user_df['~id'].apply( lambda x: f'user_{x}') @@ -372,12 +382,12 @@ def __create_model(self, name: str, model_s3_location: str): return name def __get_neptune_ml_role(self): - with open('/home/ec2-user/.bashrc') as f: + with open(f'{HOME_DIRECTORY}/.bashrc') as f: data = f.readlines() for d in data: if str.startswith(d, 'export NEPTUNE_ML_ROLE_ARN'): parts = d.split('=') - if len(parts)==2: + if len(parts) == 2: return parts[1].rstrip() logging.error("Unable to determine the Neptune ML IAM Role.") return None diff --git a/src/graph_notebook/request_param_generator/call_and_get_response.py b/src/graph_notebook/request_param_generator/call_and_get_response.py deleted file mode 100644 index f0bc3e84..00000000 --- a/src/graph_notebook/request_param_generator/call_and_get_response.py +++ /dev/null @@ -1,32 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import requests - - -def call_and_get_response(method: str, action: str, host: str, port: str, request_param_generator, use_ssl: bool, - query='', extra_headers=None): - if extra_headers is None: - extra_headers = {} - - method = method.upper() - protocol = 'https' if use_ssl else 'http' - - request_params = request_param_generator.generate_request_params(method=method, action=action, query=query, - host=host, port=port, protocol=protocol, - headers=extra_headers) - headers = request_params['headers'] if request_params['headers'] is not None else {} - - if method == 'GET': - res = requests.get(url=request_params['url'], params=request_params['params'], headers=headers) - elif method == 'DELETE': - res = requests.delete(url=request_params['url'], params=request_params['params'], headers=headers) - elif method == 'POST': - res = requests.post(url=request_params['url'], data=request_params['params'], headers=headers) - else: - raise NotImplementedError(f'Use of method {method} has not been implemented in call_and_get_response') - - res.raise_for_status() - return res diff --git a/src/graph_notebook/request_param_generator/default_request_generator.py b/src/graph_notebook/request_param_generator/default_request_generator.py deleted file mode 100644 index 0fb2b8cf..00000000 --- a/src/graph_notebook/request_param_generator/default_request_generator.py +++ /dev/null @@ -1,18 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - - -class DefaultRequestGenerator(object): - @staticmethod - def generate_request_params(method, action, query, host, port, protocol, headers=None): - url = f'{protocol}://{host}:{port}/{action}' if port != '' else f'{protocol}://{host}/{action}' - params = { - 'method': method, - 'url': url, - 'headers': headers, - 'params': query, - } - - return params diff --git a/src/graph_notebook/request_param_generator/factory.py b/src/graph_notebook/request_param_generator/factory.py deleted file mode 100644 index 2ad79dae..00000000 --- a/src/graph_notebook/request_param_generator/factory.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.request_param_generator.default_request_generator import DefaultRequestGenerator -from graph_notebook.request_param_generator.iam_request_generator import IamRequestGenerator -from graph_notebook.request_param_generator.sparql_request_generator import SPARQLRequestGenerator -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider, credentials_provider_factory - - -def create_request_generator(mode: AuthModeEnum, - credentials_provider_mode: IAMAuthCredentialsProvider = IAMAuthCredentialsProvider.ROLE, - command: str = ''): - - if mode == AuthModeEnum.DEFAULT and command == 'sparql': - return SPARQLRequestGenerator() - elif mode == AuthModeEnum.IAM: - credentials_provider_mode = credentials_provider_factory(credentials_provider_mode) - return IamRequestGenerator(credentials_provider_mode) - else: - return DefaultRequestGenerator() diff --git a/src/graph_notebook/request_param_generator/iam_request_generator.py b/src/graph_notebook/request_param_generator/iam_request_generator.py deleted file mode 100644 index fc88e809..00000000 --- a/src/graph_notebook/request_param_generator/iam_request_generator.py +++ /dev/null @@ -1,21 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from graph_notebook.authentication.iam_headers import make_signed_request - - -class IamRequestGenerator(object): - def __init__(self, credentials_provider): - self.credentials_provider = credentials_provider - - def generate_request_params(self, method, action, query, host, port, protocol, headers=None): - credentials = self.credentials_provider.get_iam_credentials() - if protocol in ['https', 'wss']: - use_ssl = True - else: - use_ssl = False - - return make_signed_request(method, action, query, host, port, credentials.key, credentials.secret, - credentials.region, use_ssl, credentials.token, additional_headers=headers) diff --git a/src/graph_notebook/request_param_generator/sparql_request_generator.py b/src/graph_notebook/request_param_generator/sparql_request_generator.py deleted file mode 100644 index f63ae479..00000000 --- a/src/graph_notebook/request_param_generator/sparql_request_generator.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - - -class SPARQLRequestGenerator(object): - @staticmethod - def generate_request_params(method, action, query, host, port, protocol, headers=None): - if headers is None: - headers = {} - - if 'Content-Type' not in headers: - headers['Content-Type'] = "application/x-www-form-urlencoded" - - url = f'{protocol}://{host}:{port}/{action}' - return { - 'method': method, - 'url': url, - 'headers': headers, - 'params': query, - } diff --git a/src/graph_notebook/sparql/query.py b/src/graph_notebook/sparql/query.py deleted file mode 100644 index 6d104600..00000000 --- a/src/graph_notebook/sparql/query.py +++ /dev/null @@ -1,82 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import logging - -from SPARQLWrapper import SPARQLWrapper -from graph_notebook.request_param_generator.call_and_get_response import call_and_get_response - -logging.basicConfig() -logger = logging.getLogger("sparql") - -ACTION_TO_QUERY_TYPE = { - 'sparql': 'application/sparql-query', - 'sparqlupdate': 'application/sparql-update' -} - -SPARQL_ACTION = 'sparql' - - -def get_query_type(query): - s = SPARQLWrapper('') - s.setQuery(query) - return s.queryType - - -def query_type_to_action(query_type): - query_type = query_type.upper() - if query_type in ['SELECT', 'CONSTRUCT', 'ASK', 'DESCRIBE']: - return 'sparql' - else: - # TODO: check explicitly for all query types, raise exception for invalid query - return 'sparqlupdate' - - -def do_sparql_query(query, host, port, use_ssl, request_param_generator, extra_headers=None, path: str = SPARQL_ACTION): - path = SPARQL_ACTION if path == '' else path - - if extra_headers is None: - extra_headers = {} - logger.debug(f'query={query}, endpoint={host}, port={port}') - query_type = get_query_type(query) - action = query_type_to_action(query_type) - - data = {} - if action == 'sparql': - data['query'] = query - elif action == 'sparqlupdate': - data['update'] = query - - res = call_and_get_response('post', path, host, port, request_param_generator, use_ssl, data, extra_headers) - try: - content = res.json() # attempt to return json, otherwise we will return the content string. - except Exception: - content = res.content.decode('utf-8') - return content - - -def do_sparql_explain(query: str, host: str, port: str, use_ssl: bool, request_param_generator, - accept_type='text/html', path: str = ''): - path = SPARQL_ACTION if path == '' else path - - query_type = get_query_type(query) - action = query_type_to_action(query_type) - - data = { - 'explain': 'dynamic', - } - - if action == 'sparql': - data['query'] = query - elif action == 'sparqlupdate': - data['update'] = query - - extra_headers = { - 'Accept': accept_type - } - - res = call_and_get_response('post', path, host, port, request_param_generator, use_ssl, data, - extra_headers) - return res.content.decode('utf-8') diff --git a/src/graph_notebook/sparql/status.py b/src/graph_notebook/sparql/status.py deleted file mode 100644 index 737d61b7..00000000 --- a/src/graph_notebook/sparql/status.py +++ /dev/null @@ -1,49 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from graph_notebook.request_param_generator.call_and_get_response import call_and_get_response - -SPARQL_STATUS_ACTION = 'sparql/status' - - -def do_sparql_status(host, port, use_ssl, request_param_generator, query_id=None): - data = {} - if query_id != '' and query_id is not None: - data['queryId'] = query_id - - headers = { - 'Content-Type': 'application/x-www-form-urlencoded' - } - res = call_and_get_response('post', SPARQL_STATUS_ACTION, host, port, request_param_generator, use_ssl, data, - headers) - try: - content = res.json() # attempt to return json, otherwise we will return the content string. - except Exception: - """When a invalid UUID is supplied, status servlet returns an empty string. - See https://sim.amazon.com/issues/NEPTUNE-16137 - """ - content = 'UUID is invalid.' - return content - - -def do_sparql_cancel(host, port, use_ssl, request_param_generator, query_id, silent=False): - if type(query_id) is not str or query_id == '': - raise ValueError("query id must be a non-empty string") - - data = {'cancelQuery': True, 'queryId': query_id, 'silent': silent} - - headers = { - 'Content-Type': 'application/x-www-form-urlencoded' - } - res = call_and_get_response('post', SPARQL_STATUS_ACTION, host, port, request_param_generator, use_ssl, data, - headers) - try: - content = res.json() - except Exception: - """When a invalid UUID is supplied, status servlet returns an empty string. - See https://sim.amazon.com/issues/NEPTUNE-16137 - """ - content = 'UUID is invalid.' - return content diff --git a/src/graph_notebook/status/get_status.py b/src/graph_notebook/status/get_status.py deleted file mode 100644 index bd86128c..00000000 --- a/src/graph_notebook/status/get_status.py +++ /dev/null @@ -1,17 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" -from json import JSONDecodeError - -from graph_notebook.request_param_generator.call_and_get_response import call_and_get_response -from graph_notebook.request_param_generator.default_request_generator import DefaultRequestGenerator - - -def get_status(host, port, use_ssl, request_param_generator=DefaultRequestGenerator()): - res = call_and_get_response('get', 'status', host, port, request_param_generator, use_ssl) - try: - js = res.json() - except JSONDecodeError: - js = res.content - return js diff --git a/src/graph_notebook/system/database_reset.py b/src/graph_notebook/system/database_reset.py deleted file mode 100644 index 7e6fb46e..00000000 --- a/src/graph_notebook/system/database_reset.py +++ /dev/null @@ -1,21 +0,0 @@ -from graph_notebook.request_param_generator.call_and_get_response import call_and_get_response -from graph_notebook.request_param_generator.default_request_generator import DefaultRequestGenerator - -SYSTEM_ACTION = 'system' - - -def initiate_database_reset(host, port, use_ssl, request_param_generator=DefaultRequestGenerator()): - data = { - 'action': 'initiateDatabaseReset' - } - res = call_and_get_response('post', SYSTEM_ACTION, host, port, request_param_generator, use_ssl, data) - return res.json() - - -def perform_database_reset(token, host, port, use_ssl, request_param_generator=DefaultRequestGenerator()): - data = { - 'action': 'performDatabaseReset', - 'token': token - } - res = call_and_get_response('post', SYSTEM_ACTION, host, port, request_param_generator, use_ssl, data) - return res.json() diff --git a/src/graph_notebook/sparql/table.py b/src/graph_notebook/visualization/sparql_rows_and_columns.py similarity index 89% rename from src/graph_notebook/sparql/table.py rename to src/graph_notebook/visualization/sparql_rows_and_columns.py index 0029a7ad..14ac1548 100644 --- a/src/graph_notebook/sparql/table.py +++ b/src/graph_notebook/visualization/sparql_rows_and_columns.py @@ -3,11 +3,13 @@ SPDX-License-Identifier: Apache-2.0 """ + def get_rows_and_columns(sparql_results): if type(sparql_results) is not dict: return None - if 'head' in sparql_results and 'vars' in sparql_results['head'] and 'results' in sparql_results and 'bindings' in sparql_results['results']: + if 'head' in sparql_results and 'vars' in sparql_results['head'] and 'results' in sparql_results and 'bindings' in \ + sparql_results['results']: columns = [] for v in sparql_results['head']['vars']: columns.append(v) diff --git a/test/integration/DataDrivenGremlinTest.py b/test/integration/DataDrivenGremlinTest.py index 2a794091..3ada56e5 100644 --- a/test/integration/DataDrivenGremlinTest.py +++ b/test/integration/DataDrivenGremlinTest.py @@ -5,9 +5,7 @@ import logging -from graph_notebook.gremlin.client_provider.factory import create_client_provider from graph_notebook.seed.load_query import get_queries -from graph_notebook.gremlin.query import do_gremlin_query from test.integration import IntegrationTest @@ -16,9 +14,9 @@ class DataDrivenGremlinTest(IntegrationTest): def setUp(self): super().setUp() - self.client_provider = create_client_provider(self.auth_mode, self.iam_credentials_provider_type) + self.client = self.client_builder.build() query_check_for_airports = "g.V('3684').outE().inV().has(id, '3444')" - res = do_gremlin_query(query_check_for_airports, self.host, self.port, self.ssl, self.client_provider) + res = self.client.gremlin_query(query_check_for_airports) if len(res) < 1: logging.info('did not find final airports edge, seeding database now...') airport_queries = get_queries('gremlin', 'airports') @@ -30,7 +28,7 @@ def setUp(self): # we are deciding to try except because we do not know if the database # we are connecting to has a partially complete set of airports data or not. try: - do_gremlin_query(line, self.host, self.port, self.ssl, self.client_provider) + self.client.gremlin_query(line) except Exception as e: logging.error(f'query {q} failed due to {e}') continue diff --git a/test/integration/DataDrivenSparqlTest.py b/test/integration/DataDrivenSparqlTest.py index 5ce77861..a013c01e 100644 --- a/test/integration/DataDrivenSparqlTest.py +++ b/test/integration/DataDrivenSparqlTest.py @@ -5,11 +5,7 @@ import logging -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider from graph_notebook.seed.load_query import get_queries -from graph_notebook.request_param_generator.factory import create_request_generator -from graph_notebook.sparql.query import do_sparql_query - from test.integration import IntegrationTest logger = logging.getLogger('DataDrivenSparqlTest') @@ -17,15 +13,14 @@ class DataDrivenSparqlTest(IntegrationTest): - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.request_generator = create_request_generator(cls.auth_mode, IAMAuthCredentialsProvider.ENV) + def setUp(self) -> None: + super().setUp() airport_queries = get_queries('sparql', 'epl') for q in airport_queries: try: # we are deciding to try except because we do not know if the database we are connecting to has a partially complete set of airports data or not. - do_sparql_query(q['content'], cls.host, cls.port, cls.ssl, cls.request_generator) + res = self.client.sparql(q['content'].strip()) + print(res) except Exception as e: logger.error(f'query {q["content"]} failed due to {e}') continue diff --git a/test/integration/notebook/GraphNotebookIntegrationTest.py b/test/integration/GraphNotebookIntegrationTest.py similarity index 100% rename from test/integration/notebook/GraphNotebookIntegrationTest.py rename to test/integration/GraphNotebookIntegrationTest.py diff --git a/test/integration/IntegrationTest.py b/test/integration/IntegrationTest.py index 49a87ad9..058c3bc7 100644 --- a/test/integration/IntegrationTest.py +++ b/test/integration/IntegrationTest.py @@ -5,19 +5,34 @@ import unittest +from botocore.session import get_session + +from graph_notebook.configuration.generate_config import Configuration, AuthModeEnum from graph_notebook.configuration.get_config import get_config +from graph_notebook.neptune.client import ClientBuilder from test.integration.NeptuneIntegrationWorkflowSteps import TEST_CONFIG_PATH +def setup_client_builder(config: Configuration) -> ClientBuilder: + builder = ClientBuilder() \ + .with_host(config.host) \ + .with_port(config.port) \ + .with_region(config.aws_region) \ + .with_tls(config.ssl) \ + .with_sparql_path(config.sparql.path) + + if config.auth_mode == AuthModeEnum.IAM: + builder = builder.with_iam(get_session()) + + return builder + + class IntegrationTest(unittest.TestCase): @classmethod def setUpClass(cls): super().setUpClass() - config = get_config(TEST_CONFIG_PATH) - cls.config = config - cls.host = config.host - cls.port = config.port - cls.auth_mode = config.auth_mode - cls.ssl = config.ssl - cls.iam_credentials_provider_type = config.iam_credentials_provider_type - cls.load_from_s3_arn = config.load_from_s3_arn + cls.config = get_config(TEST_CONFIG_PATH) + cls.client_builder = setup_client_builder(cls.config) + + def setUp(self) -> None: + self.client = self.client_builder.build() diff --git a/test/integration/NeptuneIntegrationWorkflowSteps.py b/test/integration/NeptuneIntegrationWorkflowSteps.py index 5149330a..36f4ce78 100644 --- a/test/integration/NeptuneIntegrationWorkflowSteps.py +++ b/test/integration/NeptuneIntegrationWorkflowSteps.py @@ -15,12 +15,12 @@ import boto3 as boto3 import requests -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider from graph_notebook.configuration.generate_config import AuthModeEnum, Configuration SUBPARSER_CREATE_CFN = 'create-cfn-stack' SUBPARSER_DELETE_CFN = 'delete-cfn-stack' SUBPARSER_RUN_TESTS = 'run-tests' +SUBPARSER_GENERATE_CONFIG = 'generate-config' SUBPARSER_ENABLE_IAM = 'toggle-cluster-iam' sys.path.insert(0, os.path.abspath('..')) @@ -177,8 +177,8 @@ def get_stack_details_to_run(stack: dict, region: str = 'us-east-1', timeout_min ip = network_interface['PrivateIpAddresses'][0]['Association']['PublicIp'] logging.info(f'checking if ip {ip} can be used ') + url = f'https://{ip}:80/status' try: - url = f'https://{ip}:80/status' logging.info(f'checking ip address {ip}, url={url}') # hard-coded to port 80 since that's what this CFN stack uses for its load balancer requests.get(url, verify=False, timeout=5) # an exception is thrown if the host cannot be reached. @@ -268,8 +268,7 @@ def generate_config_from_stack(stack: dict, region: str, iam: bool) -> Configura file.writelines(new_lines) auth = AuthModeEnum.IAM if iam else AuthModeEnum.DEFAULT - conf = Configuration(details['endpoint'], 80, auth, IAMAuthCredentialsProvider.ENV, details['loader_arn'], - ssl=True, aws_region=region) + conf = Configuration(details['endpoint'], 80, auth, details['loader_arn'], ssl=True, aws_region=region) logging.info(f'generated configuration for test run: {conf.to_dict()}') return conf @@ -323,15 +322,12 @@ def main(): delete_parser.add_argument('--cfn-stack-name', type=str, default='') delete_parser.add_argument('--aws-region', type=str, default='us-east-1') - # sub parser for running tests - parser_run_tests = subparsers.add_parser(SUBPARSER_RUN_TESTS, - help='run tests with the pattern *_without_iam.py') - parser_run_tests.add_argument('--pattern', type=str), - parser_run_tests.add_argument('--iam', action='store_true') - parser_run_tests.add_argument('--cfn-stack-name', type=str, default='') - parser_run_tests.add_argument('--aws-region', type=str, default='us-east-1') - parser_run_tests.add_argument('--skip-config-generation', action='store_true', - help=f'skips config generation for testing, using the one found under {TEST_CONFIG_PATH}') + # sub parser generate config + config_parser = subparsers.add_parser(SUBPARSER_GENERATE_CONFIG, + help='generate test configuration from supplied cfn stack') + config_parser.add_argument('--cfn-stack-name', type=str, default='') + config_parser.add_argument('--aws-region', type=str, default='us-east-1') + config_parser.add_argument('--iam', action='store_true') args = parser.parse_args() @@ -342,20 +338,17 @@ def main(): handle_create_cfn_stack(stack_name, args.cfn_template_url, args.cfn_s3_bucket, cfn_client, args.cfn_runner_role) elif args.which == SUBPARSER_DELETE_CFN: delete_stack(args.cfn_stack_name, cfn_client) - elif args.which == SUBPARSER_RUN_TESTS: - if not args.skip_config_generation: - loop_until_stack_is_complete(args.cfn_stack_name, cfn_client) - stack = get_cfn_stack_details(args.cfn_stack_name, cfn_client) - cluster_identifier = get_neptune_identifier_from_cfn(args.cfn_stack_name, cfn_client) - set_iam_auth_on_neptune_cluster(cluster_identifier, args.iam, neptune_client) - config = generate_config_from_stack(stack, args.aws_region, args.iam) - config.write_to_file(TEST_CONFIG_PATH) - run_integration_tests(args.pattern) elif args.which == SUBPARSER_ENABLE_IAM: cluster_identifier = get_neptune_identifier_from_cfn(args.cfn_stack_name, cfn_client) set_iam_auth_on_neptune_cluster(cluster_identifier, True, neptune_client) logging.info('waiting for one minute while change is applied...') time.sleep(60) + elif args.which == SUBPARSER_GENERATE_CONFIG: + stack = get_cfn_stack_details(args.cfn_stack_name, cfn_client) + cluster_identifier = get_neptune_identifier_from_cfn(args.cfn_stack_name, cfn_client) + set_iam_auth_on_neptune_cluster(cluster_identifier, args.iam, neptune_client) + config = generate_config_from_stack(stack, args.aws_region, args.iam) + config.write_to_file(TEST_CONFIG_PATH) if __name__ == '__main__': diff --git a/test/integration/__init__.py b/test/integration/__init__.py index d0f15c29..cf393a60 100644 --- a/test/integration/__init__.py +++ b/test/integration/__init__.py @@ -6,4 +6,5 @@ from .IntegrationTest import IntegrationTest # noqa F401 from .DataDrivenGremlinTest import DataDrivenGremlinTest # noqa F401 from .DataDrivenSparqlTest import DataDrivenSparqlTest # noqa F401 +from .GraphNotebookIntegrationTest import GraphNotebookIntegrationTest # noqa F401 from .NeptuneIntegrationWorkflowSteps import TEST_CONFIG_PATH # noqa F401 diff --git a/test/integration/gremlin/client_provider/__init__.py b/test/integration/gremlin/client_provider/__init__.py deleted file mode 100644 index 9049dd04..00000000 --- a/test/integration/gremlin/client_provider/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" diff --git a/test/integration/gremlin/client_provider/client_provider_factory.py b/test/integration/gremlin/client_provider/client_provider_factory.py deleted file mode 100644 index 5b3d6b31..00000000 --- a/test/integration/gremlin/client_provider/client_provider_factory.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import unittest - -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.gremlin.client_provider.default_client import ClientProvider -from graph_notebook.gremlin.client_provider.factory import create_client_provider -from graph_notebook.gremlin.client_provider.iam_client import IamClientProvider - - -class TestClientProviderFactory(unittest.TestCase): - def test_create_default_client(self): - client_provider = create_client_provider(AuthModeEnum.DEFAULT) - self.assertEqual(ClientProvider, type(client_provider)) - - def test_create_iam_client_from_env(self): - client_provider = create_client_provider(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - self.assertEqual(IamClientProvider, type(client_provider)) diff --git a/test/integration/gremlin/gremlin_query_with_iam.py b/test/integration/gremlin/gremlin_query_with_iam.py deleted file mode 100644 index a3cac15e..00000000 --- a/test/integration/gremlin/gremlin_query_with_iam.py +++ /dev/null @@ -1,37 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.gremlin.query import do_gremlin_query, do_gremlin_explain, do_gremlin_profile -from graph_notebook.gremlin.client_provider.factory import create_client_provider -from graph_notebook.request_param_generator.factory import create_request_generator - -from test.integration import IntegrationTest - - -class TestGremlinWithIam(IntegrationTest): - def test_do_gremlin_query_with_iam(self): - client_provider = create_client_provider(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - query = 'g.V().limit(1)' - results = do_gremlin_query(query, self.host, self.port, self.ssl, client_provider) - - self.assertEqual(type(results), list) - - def test_do_gremlin_explain_with_iam(self): - query = 'g.V().limit(1)' - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - results = do_gremlin_explain(query, self.host, self.port, self.ssl, request_generator) - - self.assertEqual(type(results), dict) - self.assertTrue('explain' in results) - - def test_do_gremlin_profile_with_iam(self): - query = 'g.V().limit(1)' - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - results = do_gremlin_profile(query, self.host, self.port, self.ssl, request_generator) - - self.assertEqual(type(results), dict) - self.assertTrue('profile' in results) diff --git a/test/integration/gremlin/gremlin_query_without_iam.py b/test/integration/gremlin/gremlin_query_without_iam.py deleted file mode 100644 index 7f61d020..00000000 --- a/test/integration/gremlin/gremlin_query_without_iam.py +++ /dev/null @@ -1,35 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from graph_notebook.gremlin.client_provider.default_client import ClientProvider -from graph_notebook.gremlin.query import do_gremlin_query, do_gremlin_explain, do_gremlin_profile -from graph_notebook.request_param_generator.default_request_generator import DefaultRequestGenerator - -from test.integration import IntegrationTest - - -class TestGremlin(IntegrationTest): - def test_do_gremlin_query(self): - client_provider = ClientProvider() - query = 'g.V().limit(1)' - results = do_gremlin_query(query, self.host, self.port, self.ssl, client_provider) - - self.assertEqual(type(results), list) - - def test_do_gremlin_explain(self): - query = 'g.V().limit(1)' - request_generator = DefaultRequestGenerator() - results = do_gremlin_explain(query, self.host, self.port, self.ssl, request_generator) - - self.assertEqual(type(results), dict) - self.assertTrue('explain' in results) - - def test_do_gremlin_profile(self): - query = 'g.V().limit(1)' - request_generator = DefaultRequestGenerator() - results = do_gremlin_profile(query, self.host, self.port, self.ssl, request_generator) - - self.assertEqual(type(results), dict) - self.assertTrue('profile' in results) diff --git a/test/integration/gremlin/gremlin_status_with_iam.py b/test/integration/gremlin/gremlin_status_with_iam.py deleted file mode 100644 index 0066db90..00000000 --- a/test/integration/gremlin/gremlin_status_with_iam.py +++ /dev/null @@ -1,134 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import threading -import time -import requests -from os import cpu_count - -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider -from graph_notebook.gremlin.query import do_gremlin_query -from graph_notebook.gremlin.status import do_gremlin_status, do_gremlin_cancel -from graph_notebook.gremlin.client_provider.factory import create_client_provider -from graph_notebook.request_param_generator.factory import create_request_generator -from gremlin_python.driver.protocol import GremlinServerError - -from test.integration import DataDrivenGremlinTest - - -class TestGremlinStatusWithIam(DataDrivenGremlinTest): - def do_gremlin_query_save_results(self, query, res): - client_provider = create_client_provider(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - try: - res['result'] = do_gremlin_query(query, self.host, self.port, self.ssl, client_provider) - except GremlinServerError as exception: - res['error'] = str(exception) - - def test_do_gremlin_status_nonexistent(self): - with self.assertRaises(requests.HTTPError): - query_id = "some-guid-here" - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - try: - do_gremlin_status(self.host, self.port, self.ssl, AuthModeEnum.IAM, request_generator, query_id, False) - except requests.HTTPError as exception: - content = exception.response.json() - self.assertTrue('requestId' in content) - self.assertTrue('code' in content) - self.assertTrue('detailedMessage' in content) - self.assertEqual('InvalidParameterException', content['code']) - raise exception - - def test_do_gremlin_cancel_nonexistent(self): - with self.assertRaises(requests.HTTPError): - query_id = "some-guid-here" - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - try: - do_gremlin_cancel(self.host, self.port, self.ssl, AuthModeEnum.IAM, request_generator, query_id) - except requests.HTTPError as exception: - content = exception.response.json() - self.assertTrue('requestId' in content) - self.assertTrue('code' in content) - self.assertTrue('detailedMessage' in content) - self.assertEqual('InvalidParameterException', content['code']) - raise exception - - def test_do_gremlin_cancel_empty_query_id(self): - with self.assertRaises(ValueError): - query_id = '' - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - do_gremlin_cancel(self.host, self.port, self.ssl, AuthModeEnum.IAM, request_generator, query_id) - - def test_do_gremlin_cancel_non_str_query_id(self): - with self.assertRaises(ValueError): - query_id = 42 - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - do_gremlin_cancel(self.host, self.port, self.ssl, AuthModeEnum.IAM, request_generator, query_id) - - def test_do_gremlin_status_and_cancel(self): - query = "g.V().out().out().out().out()" - query_res = {} - gremlin_query_thread = threading.Thread(target=self.do_gremlin_query_save_results, args=(query, query_res,)) - gremlin_query_thread.start() - time.sleep(3) - - query_id = '' - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - status_res = do_gremlin_status(self.host, self.port, self.ssl, AuthModeEnum.IAM, - request_generator, query_id, False) - self.assertEqual(type(status_res), dict) - self.assertTrue('acceptedQueryCount' in status_res) - self.assertTrue('runningQueryCount' in status_res) - self.assertTrue(status_res['runningQueryCount'] == 1) - self.assertTrue('queries' in status_res) - - query_id = '' - for q in status_res['queries']: - if query in q['queryString']: - query_id = q['queryId'] - - self.assertNotEqual(query_id, '') - - cancel_res = do_gremlin_cancel(self.host, self.port, self.ssl, AuthModeEnum.IAM, request_generator, query_id) - self.assertEqual(type(cancel_res), dict) - self.assertTrue('status' in cancel_res) - self.assertTrue('payload' in cancel_res) - self.assertEqual('200 OK', cancel_res['status']) - - gremlin_query_thread.join() - self.assertFalse('result' in query_res) - self.assertTrue('error' in query_res) - self.assertTrue('code' in query_res['error']) - self.assertTrue('requestId' in query_res['error']) - self.assertTrue('detailedMessage' in query_res['error']) - self.assertTrue('TimeLimitExceededException' in query_res['error']) - - def test_do_gremlin_status_include_waiting(self): - query = "g.V().out().out().out().out()" - num_threads = 4 * cpu_count() - threads = [] - query_results = [] - for x in range(0, num_threads): - query_res = {} - gremlin_query_thread = threading.Thread(target=self.do_gremlin_query_save_results, args=(query, query_res,)) - threads.append(gremlin_query_thread) - query_results.append(query_res) - gremlin_query_thread.start() - - time.sleep(5) - - query_id = '' - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - status_res = do_gremlin_status(self.host, self.port, self.ssl, AuthModeEnum.IAM, - request_generator, query_id, True) - - self.assertEqual(type(status_res), dict) - self.assertTrue('acceptedQueryCount' in status_res) - self.assertTrue('runningQueryCount' in status_res) - self.assertTrue('queries' in status_res) - self.assertEqual(status_res['acceptedQueryCount'], len(status_res['queries'])) - - for gremlin_query_thread in threads: - gremlin_query_thread.join() diff --git a/test/integration/gremlin/gremlin_status_without_iam.py b/test/integration/gremlin/gremlin_status_without_iam.py deleted file mode 100644 index 03f6d7b1..00000000 --- a/test/integration/gremlin/gremlin_status_without_iam.py +++ /dev/null @@ -1,129 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import threading -import time -import requests -from os import cpu_count - -from gremlin_python.driver.protocol import GremlinServerError - -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.gremlin.query import do_gremlin_query -from graph_notebook.gremlin.status import do_gremlin_status, do_gremlin_cancel -from graph_notebook.request_param_generator.factory import create_request_generator - -from test.integration import DataDrivenGremlinTest - - -class TestGremlinStatusWithoutIam(DataDrivenGremlinTest): - def do_gremlin_query_save_results(self, query, res): - try: - res['result'] = do_gremlin_query(query, self.host, self.port, self.ssl, self.client_provider) - except GremlinServerError as exception: - res['error'] = str(exception) - - def test_do_gremlin_status_nonexistent(self): - with self.assertRaises(requests.HTTPError): - query_id = "ac7d5a03-00cf-4280-b464-edbcbf51ffce" - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - try: - do_gremlin_status(self.host, self.port, self.ssl, self.auth_mode, request_generator, query_id, False) - except requests.HTTPError as exception: - content = exception.response.json() - self.assertTrue('requestId' in content) - self.assertTrue('code' in content) - self.assertTrue('detailedMessage' in content) - self.assertEqual('InvalidParameterException', content['code']) - raise exception - - def test_do_gremlin_cancel_nonexistent(self): - with self.assertRaises(requests.HTTPError): - query_id = "ac7d5a03-00cf-4280-b464-edbcbf51ffce" - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - try: - do_gremlin_cancel(self.host, self.port, self.ssl, self.auth_mode, request_generator, query_id) - except requests.HTTPError as exception: - content = exception.response.json() - self.assertTrue('requestId' in content) - self.assertTrue('code' in content) - self.assertTrue('detailedMessage' in content) - self.assertEqual('InvalidParameterException', content['code']) - raise exception - - def test_do_gremlin_cancel_empty_query_id(self): - with self.assertRaises(ValueError): - query_id = '' - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - do_gremlin_cancel(self.host, self.port, self.ssl, self.auth_mode, request_generator, query_id) - - def test_do_gremlin_cancel_non_str_query_id(self): - with self.assertRaises(ValueError): - query_id = 42 - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - do_gremlin_cancel(self.host, self.port, self.ssl, self.auth_mode, request_generator, query_id) - - def test_do_gremlin_status_and_cancel(self): - query = "g.V().out().out().out().out()" - query_res = {} - gremlin_query_thread = threading.Thread(target=self.do_gremlin_query_save_results, args=(query, query_res,)) - gremlin_query_thread.start() - time.sleep(3) - - query_id = '' - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - status_res = do_gremlin_status(self.host, self.port, self.ssl, self.auth_mode, - request_generator, query_id, False) - self.assertEqual(type(status_res), dict) - self.assertTrue('acceptedQueryCount' in status_res) - self.assertTrue('runningQueryCount' in status_res) - self.assertTrue(status_res['runningQueryCount'] == 1) - self.assertTrue('queries' in status_res) - - query_id = '' - for q in status_res['queries']: - if query in q['queryString']: - query_id = q['queryId'] - - self.assertNotEqual(query_id, '') - - cancel_res = do_gremlin_cancel(self.host, self.port, self.ssl, self.auth_mode, request_generator, query_id) - self.assertEqual(type(cancel_res), dict) - self.assertTrue('status' in cancel_res) - self.assertTrue('payload' in cancel_res) - self.assertEqual('200 OK', cancel_res['status']) - - gremlin_query_thread.join() - self.assertFalse('result' in query_res) - self.assertTrue('error' in query_res) - self.assertTrue('code' in query_res['error']) - self.assertTrue('requestId' in query_res['error']) - self.assertTrue('detailedMessage' in query_res['error']) - self.assertTrue('TimeLimitExceededException' in query_res['error']) - - def test_do_gremlin_status_include_waiting(self): - query = "g.V().out().out().out().out()" - num_threads = 4 * cpu_count() - threads = [] - for x in range(0, num_threads): - gremlin_query_thread = threading.Thread(target=self.do_gremlin_query_save_results, args=(query, {})) - threads.append(gremlin_query_thread) - gremlin_query_thread.start() - - time.sleep(5) - - query_id = '' - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - status_res = do_gremlin_status(self.host, self.port, self.ssl, self.auth_mode, - request_generator, query_id, True) - - self.assertEqual(type(status_res), dict) - self.assertTrue('acceptedQueryCount' in status_res) - self.assertTrue('runningQueryCount' in status_res) - self.assertTrue('queries' in status_res) - self.assertEqual(status_res['acceptedQueryCount'], len(status_res['queries'])) - - for gremlin_query_thread in threads: - gremlin_query_thread.join() diff --git a/src/graph_notebook/gremlin/__init__.py b/test/integration/iam/__init__.py similarity index 100% rename from src/graph_notebook/gremlin/__init__.py rename to test/integration/iam/__init__.py diff --git a/src/graph_notebook/gremlin/client_provider/__init__.py b/test/integration/iam/gremlin/__init__.py similarity index 100% rename from src/graph_notebook/gremlin/client_provider/__init__.py rename to test/integration/iam/gremlin/__init__.py diff --git a/test/integration/iam/gremlin/test_gremlin_status_with_iam.py b/test/integration/iam/gremlin/test_gremlin_status_with_iam.py new file mode 100644 index 00000000..502b7c9a --- /dev/null +++ b/test/integration/iam/gremlin/test_gremlin_status_with_iam.py @@ -0,0 +1,122 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" +import concurrent.futures +import threading +import time + +import pytest +from os import cpu_count + +from botocore.session import get_session +from gremlin_python.driver.protocol import GremlinServerError + +from graph_notebook.neptune.client import Client +from test.integration import DataDrivenGremlinTest + + +def long_running_gremlin_query(c: Client, query: str): + with pytest.raises(GremlinServerError): + c.gremlin_query(query) + return + + +class TestGremlinStatusWithIam(DataDrivenGremlinTest): + def setUp(self) -> None: + super().setUp() + if not self.client.iam_enabled: + self.client = self.client_builder.with_iam(get_session()).build() + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_gremlin_status_nonexistent(self): + query_id = "some-guid-here" + res = self.client.gremlin_status(query_id) + assert res.status_code == 400 + js = res.json() + assert js['code'] == 'InvalidParameterException' + assert js['detailedMessage'] == f'Supplied queryId {query_id} is invalid' + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_gremlin_cancel_nonexistent(self): + query_id = "some-guid-here" + res = self.client.gremlin_cancel(query_id) + assert res.status_code == 400 + js = res.json() + assert js['code'] == 'InvalidParameterException' + assert js['detailedMessage'] == f'Supplied queryId {query_id} is invalid' + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_gremlin_cancel_empty_query_id(self): + with self.assertRaises(ValueError): + self.client.gremlin_cancel('') + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_gremlin_cancel_non_str_query_id(self): + with self.assertRaises(ValueError): + self.client.gremlin_cancel(42) + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_gremlin_status_and_cancel(self): + long_running_query = "g.V().out().out().out().out().out().out().out().out()" + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(long_running_gremlin_query, self.client, long_running_query) + + time.sleep(1) + status_res = self.client.gremlin_status() + assert status_res.status_code == 200 + + status_js = status_res.json() + query_id = '' + for q in status_js['queries']: + if q['queryString'] == long_running_query: + query_id = q['queryId'] + + assert query_id != '' + + cancel_res = self.client.gremlin_cancel(query_id) + assert cancel_res.status_code == 200 + assert cancel_res.json()['status'] == '200 OK' + + time.sleep(1) + status_after_cancel = self.client.gremlin_status(query_id) + assert status_after_cancel.status_code == 400 # check that the query is no longer valid + assert status_after_cancel.json()['code'] == 'InvalidParameterException' + + future.result() + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_gremlin_status_include_waiting(self): + query = "g.V().out().out().out().out()" + num_threads = cpu_count() * 4 + threads = [] + for x in range(0, num_threads): + thread = threading.Thread(target=long_running_gremlin_query, args=(self.client, query)) + thread.start() + threads.append(thread) + + time.sleep(5) + + res = self.client.gremlin_status(include_waiting=True) + assert res.status_code == 200 + status_res = res.json() + + self.assertEqual(type(status_res), dict) + self.assertTrue('acceptedQueryCount' in status_res) + self.assertTrue('runningQueryCount' in status_res) + self.assertTrue('queries' in status_res) + self.assertEqual(status_res['acceptedQueryCount'], len(status_res['queries'])) + + for q in status_res['queries']: + # cancel all the queries we executed since they can take a very long time. + if q['queryString'] == query: + self.client.gremlin_cancel(q['queryId']) + + for t in threads: + t.join() diff --git a/test/integration/iam/gremlin/test_gremlin_with_iam.py b/test/integration/iam/gremlin/test_gremlin_with_iam.py new file mode 100644 index 00000000..375e2b7d --- /dev/null +++ b/test/integration/iam/gremlin/test_gremlin_with_iam.py @@ -0,0 +1,55 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" +import pytest +from botocore.session import get_session +from gremlin_python.structure.graph import Vertex + +from test.integration import IntegrationTest + + +class TestGremlinWithIam(IntegrationTest): + def setUp(self) -> None: + self.client = self.client_builder.with_iam(get_session()).build() + + @pytest.mark.iam + @pytest.mark.gremlin + def test_do_gremlin_query_with_iam(self): + query = 'g.V().limit(1)' + results = self.client.gremlin_query(query) + assert type(results) is list + for r in results: + assert type(r) is Vertex + + @pytest.mark.iam + @pytest.mark.gremlin + def test_do_gremlin_explain_with_iam(self): + query = 'g.V().limit(1)' + res = self.client.gremlin_explain(query) + assert res.status_code == 200 + results = res.content.decode('utf-8') + self.assertTrue('Explain' in results) + + @pytest.mark.iam + @pytest.mark.gremlin + def test_do_gremlin_profile_with_iam(self): + query = 'g.V().limit(1)' + res = self.client.gremlin_profile(query) + assert res.status_code == 200 + + results = res.content.decode('utf-8') + self.assertTrue('Profile' in results) + + @pytest.mark.iam + @pytest.mark.gremlin + def test_iam_gremlin_http_query(self): + query = 'g.V().limit(1)' + res = self.client.gremlin_http_query(query) + assert res.status_code == 200 + assert 'result' in res.json() + + def test_iam_gremlin_connection(self): + conn = self.client.get_gremlin_connection() + conn.submit('g.V().limit(1)') + print('here') diff --git a/src/graph_notebook/system/__init__.py b/test/integration/iam/load/__init__.py similarity index 100% rename from src/graph_notebook/system/__init__.py rename to test/integration/iam/load/__init__.py diff --git a/test/integration/iam/load/test_load_with_iam.py b/test/integration/iam/load/test_load_with_iam.py new file mode 100644 index 00000000..102cee66 --- /dev/null +++ b/test/integration/iam/load/test_load_with_iam.py @@ -0,0 +1,60 @@ +import time + +import pytest +import unittest + +from botocore.session import get_session + +from test.integration import IntegrationTest + +TEST_BULKLOAD_SOURCE = 's3://aws-ml-customer-samples-%s/bulkload-datasets/%s/airroutes/v01' + + +@unittest.skip +class TestLoadWithIAM(IntegrationTest): + def setUp(self) -> None: + assert self.config.load_from_s3_arn != '' + self.client = self.client_builder.with_iam(get_session()).build() + + @pytest.mark.neptune + def test_iam_load(self): + load_format = 'turtle' + source = TEST_BULKLOAD_SOURCE % (self.config.aws_region, 'turtle') + + # for a full list of options, see https://docs.aws.amazon.com/neptune/latest/userguide/bulk-load-data.html + kwargs = { + 'failOnError': "TRUE", + } + res = self.client.load(source, load_format, self.config.load_from_s3_arn, **kwargs) + assert res.status_code == 200 + + load_js = res.json() + assert 'loadId' in load_js['payload'] + load_id = load_js['payload']['loadId'] + + time.sleep(1) # brief wait to ensure the load job can be obtained + + res = self.client.load_status(load_id, details="TRUE") + assert res.status_code == 200 + + load_status = res.json() + assert 'overallStatus' in load_status['payload'] + status = load_status['payload']['overallStatus'] + assert status['fullUri'] == source + + res = self.client.cancel_load(load_id) + assert res.status_code == 200 + + time.sleep(5) + res = self.client.load_status(load_id, details="TRUE") + cancelled_status = res.json() + assert 'LOAD_CANCELLED_BY_USER' in cancelled_status['payload']['feedCount'][-1] + + @pytest.mark.neptune + def test_iam_load_status(self): + res = self.client.load_status() # This should only give a list of load ids + assert res.status_code == 200 + + js = res.json() + assert 'loadIds' in js['payload'] + assert len(js['payload'].keys()) == 1 diff --git a/test/integration/iam/ml/__init__.py b/test/integration/iam/ml/__init__.py new file mode 100644 index 00000000..70385e35 --- /dev/null +++ b/test/integration/iam/ml/__init__.py @@ -0,0 +1,27 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" + +from botocore.session import get_session + +from graph_notebook.configuration.generate_config import Configuration +from graph_notebook.neptune.client import Client, ClientBuilder + + +def setup_iam_client(config: Configuration) -> Client: + client = ClientBuilder() \ + .with_host(config.host) \ + .with_port(config.port) \ + .with_region(config.aws_region) \ + .with_tls(config.ssl) \ + .with_sparql_path(config.sparql.path) \ + .with_iam(get_session()) \ + .build() + + assert client.host == config.host + assert client.port == config.port + assert client.region == config.aws_region + assert client.sparql_path == config.sparql.path + assert client.ssl is config.ssl + return client diff --git a/test/integration/iam/ml/test_neptune_client_with_iam.py b/test/integration/iam/ml/test_neptune_client_with_iam.py new file mode 100644 index 00000000..d5af29ce --- /dev/null +++ b/test/integration/iam/ml/test_neptune_client_with_iam.py @@ -0,0 +1,14 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" + +from graph_notebook.configuration.generate_config import Configuration +from graph_notebook.neptune.client import Client + +client: Client +config: Configuration + +TEST_BULKLOAD_SOURCE = 's3://aws-ml-customer-samples-%s/bulkload-datasets/%s/airroutes/v01' +GREMLIN_TEST_LABEL = 'graph-notebook-test' +SPARQL_TEST_PREDICATE = '' diff --git a/test/integration/iam/ml/test_neptune_ml_with_iam.py b/test/integration/iam/ml/test_neptune_ml_with_iam.py new file mode 100644 index 00000000..36d1c05e --- /dev/null +++ b/test/integration/iam/ml/test_neptune_ml_with_iam.py @@ -0,0 +1,179 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" + +import logging +import os +import threading +import time +import unittest + +import pytest +from botocore.session import get_session + +from graph_notebook.configuration.get_config import get_config +from test.integration import IntegrationTest +from test.integration.iam.ml import setup_iam_client + +logger = logging.getLogger() + + +@unittest.skip +class TestNeptuneMLWithIAM(IntegrationTest): + def setUp(self) -> None: + self.client = self.client_builder.with_iam(get_session()).build() + + def tearDown(self) -> None: + endpoint_ids = client.endpoints().json()['ids'] + for endpoint_id in endpoint_ids: + self.client.endpoints_delete(endpoint_id) + + client.close() + + @pytest.mark.neptuneml + @pytest.mark.iam + def test_neptune_ml_e2e(self): + s3_input_uri = os.getenv('NEPTUNE_ML_DATAPROCESSING_S3_INPUT', '') + s3_processed_uri = os.getenv('NEPTUNE_ML_DATAPROCESSING_S3_PROCESSED', '') + train_model_s3_location = os.getenv('NEPTUNE_ML_TRAINING_S3_LOCATION', '') + + assert s3_input_uri != '' + assert s3_processed_uri != '' + assert train_model_s3_location != '' + + logger.info("dataprocessing...") + dataprocessing_job = do_dataprocessing(s3_input_uri, s3_processed_uri) + dataprocessing_id = dataprocessing_job['id'] + + p = threading.Thread(target=wait_for_dataprocessing_complete, args=(dataprocessing_id,)) + p.start() + p.join(3600) + + logger.info("model training...") + training_job = do_modeltraining(dataprocessing_id, train_model_s3_location) + training_job_id = training_job['id'] + + p = threading.Thread(target=wait_for_modeltraining_complete, args=(training_job_id,)) + p.start() + p.join(3600) + + logger.info("endpoint...") + endpoint_job = do_create_endpoint(training_job_id) + endpoint_job_id = endpoint_job['id'] + p = threading.Thread(target=wait_for_endpoint_complete, args=(endpoint_job_id,)) + p.start() + p.join(3600) + + @pytest.mark.neptuneml + @pytest.mark.iam + def test_neptune_ml_dataprocessing_status(self): + status = client.dataprocessing_status() + + assert status.status_code == 200 + assert 'ids' in status.json() + + @pytest.mark.neptuneml + @pytest.mark.iam + def test_neptune_ml_modeltraining_status(self): + status = client.modeltraining_status() + assert status.status_code == 200 + assert 'ids' in status.json() + + @pytest.mark.neptuneml + @pytest.mark.iam + def test_neptune_ml_training(self): + dataprocessing_id = os.getenv('NEPTUNE_ML_DATAPROCESSING_ID', '') + train_model_s3_location = os.getenv('NEPTUNE_ML_TRAINING_S3_LOCATION', '') + + assert dataprocessing_id != '' + assert train_model_s3_location != '' + + dataprocessing_status = client.dataprocessing_job_status(dataprocessing_id) + assert dataprocessing_status.status_code == 200 + + job_start_res = client.modeltraining_start(dataprocessing_id, train_model_s3_location) + assert job_start_res.status_code == 200 + + job_id = job_start_res.json()['id'] + training_status_res = client.modeltraining_job_status(job_id) + assert training_status_res.status_code == 200 + + job_stop_res = client.modeltraining_stop(job_id, clean=True) + assert job_stop_res.status_code == 200 + + +def setup_module(): + global client + client = setup_iam_client(get_config()) + + +def teardown_module(): + endpoint_ids = client.endpoints().json()['ids'] + for endpoint_id in endpoint_ids: + client.endpoints_delete(endpoint_id) + + client.close() + + +def do_dataprocessing(s3_input, s3_processed) -> dict: + logger.info(f"starting dataprocessing job with input={s3_input} and processed={s3_processed}") + dataprocessing_res = client.dataprocessing_start(s3_input, s3_processed) + assert dataprocessing_res.status_code == 200 + return dataprocessing_res.json() + + +def wait_for_dataprocessing_complete(dataprocessing_id: str): + logger.info(f"waiting for dataprocessing job {dataprocessing_id} to complete") + while True: + status = client.dataprocessing_job_status(dataprocessing_id) + assert status.status_code == 200 + raw = status.json() + logger.info(f"status is {raw['status']}") + if raw['status'] != 'InProgress': + assert raw['status'] == 'Completed' + return raw + logger.info("waiting for 10 seconds then checking again") + time.sleep(10) + + +def do_modeltraining(dataprocessing_id, train_model_s3_location): + logger.info( + f"starting training job from dataprocessing_job_id={dataprocessing_id} and training_model_s3_location={train_model_s3_location}") + training_start = client.modeltraining_start(dataprocessing_id, train_model_s3_location) + assert training_start.status_code == 200 + return training_start.json() + + +def wait_for_modeltraining_complete(training_job: str) -> dict: + logger.info(f"waiting for modeltraining job {training_job} to complete") + while True: + status = client.modeltraining_job_status(training_job) + assert status.status_code == 200 + raw = status.json() + logger.info(f"status is {raw['status']}") + if raw['status'] != 'InProgress': + assert raw['status'] == 'Completed' + return raw + logger.info("waiting for 10 seconds then checking again") + time.sleep(10) + + +def do_create_endpoint(training_job_id: str) -> dict: + endpoint_res = client.endpoints_create(training_job_id) + assert endpoint_res.status_code == 200 + return endpoint_res.json() + + +def wait_for_endpoint_complete(endpoint_job_id): + logger.info(f"waiting for endpoint creation job {endpoint_job_id} to complete") + while True: + endpoint_status = client.endpoints_status(endpoint_job_id) + assert endpoint_status.status_code == 200 + raw = endpoint_status.json() + logger.info(f"status is {raw['status']}") + if raw['status'] != 'Creating': + assert raw['status'] == 'InService' + return raw + logger.info("waiting for 10 seconds then checking again") + time.sleep(10) diff --git a/test/integration/system/__init__.py b/test/integration/iam/sparql/__init__.py similarity index 100% rename from test/integration/system/__init__.py rename to test/integration/iam/sparql/__init__.py diff --git a/test/integration/iam/sparql/test_sparql_query_with_iam.py b/test/integration/iam/sparql/test_sparql_query_with_iam.py new file mode 100644 index 00000000..6b4d9c68 --- /dev/null +++ b/test/integration/iam/sparql/test_sparql_query_with_iam.py @@ -0,0 +1,54 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" +from json import JSONDecodeError + +import pytest +from botocore.session import get_session + +from test.integration import DataDrivenSparqlTest + + +class TestSparqlQueryWithIam(DataDrivenSparqlTest): + def setUp(self) -> None: + super().setUp() + self.client = self.client_builder.with_iam(get_session()).build() + + @pytest.mark.iam + @pytest.mark.sparql + def test_do_sparql_query(self): + query = "SELECT * WHERE {?s ?p ?o} LIMIT 1" + query_res = self.client.sparql(query) + assert query_res.status_code == 200 + res = query_res.json() + + self.assertEqual(type(res), dict) + self.assertTrue('s' in res['head']['vars']) + self.assertTrue('p' in res['head']['vars']) + self.assertTrue('o' in res['head']['vars']) + + @pytest.mark.iam + @pytest.mark.sparql + def test_do_sparql_explain(self): + query = "SELECT * WHERE {?s ?p ?o} LIMIT 1" + query_res = self.client.sparql_explain(query) + assert query_res.status_code == 200 + res = query_res.content.decode('utf-8') + self.assertEqual(type(res), str) + self.assertTrue(res.startswith('')) + + @pytest.mark.iam + @pytest.mark.sparql + def test_iam_describe(self): + query = '''PREFIX soccer: + DESCRIBE soccer:Arsenal''' + res = self.client.sparql(query) + assert res.status_code == 200 + + # test that we do not get back json + with pytest.raises(JSONDecodeError): + res.json() + + content = res.content.decode('utf-8') + assert len(content.splitlines()) == 6 diff --git a/test/integration/iam/sparql/test_sparql_status_with_iam.py b/test/integration/iam/sparql/test_sparql_status_with_iam.py new file mode 100644 index 00000000..c6038099 --- /dev/null +++ b/test/integration/iam/sparql/test_sparql_status_with_iam.py @@ -0,0 +1,128 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" + +import concurrent.futures +import time + +import pytest +from botocore.session import get_session + +from graph_notebook.neptune.client import Client + +from test.integration import DataDrivenSparqlTest + + +def long_running_sparql_query(c: Client, query: str): + res = c.sparql(query) + return res + + +class TestSparqlStatusWithIam(DataDrivenSparqlTest): + def setUp(self) -> None: + super().setUp() + self.client = self.client_builder.with_iam(get_session()).build() + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_sparql_status_nonexistent(self): + query_id = "invalid-guid" + status_res = self.client.sparql_status(query_id) + assert status_res.status_code == 200 + assert status_res.content == b'' + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_sparql_cancel_nonexistent(self): + query_id = "invalid-guid" + cancel_res = self.client.sparql_cancel(query_id) + assert cancel_res.status_code == 200 + assert cancel_res.content == b'' + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_sparql_cancel_empty_query_id(self): + with pytest.raises(ValueError): + self.client.sparql_cancel('') + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_sparql_cancel_non_str_query_id(self): + with pytest.raises(ValueError): + self.client.sparql_cancel(42) + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_sparql_status_and_cancel(self): + query = "SELECT * WHERE { ?s ?p ?o . ?s2 ?p2 ?o2 .?s3 ?p3 ?o3 . ?s4 ?s5 ?s6 .} ORDER BY DESC(?s)" + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(long_running_sparql_query, self.client, query) + time.sleep(1) + + status = self.client.sparql_status() + status_res = status.json() + assert 'acceptedQueryCount' in status_res + assert 'runningQueryCount' in status_res + assert 'queries' in status_res + + time.sleep(1) + + query_id = '' + for q in status_res['queries']: + if query in q['queryString']: + query_id = q['queryId'] + + self.assertNotEqual(query_id, '') + + cancel = self.client.sparql_cancel(query_id, False) + cancel_res = cancel.json() + + assert 'acceptedQueryCount' in cancel_res + assert 'acceptedQueryCount' in cancel_res + assert 'runningQueryCount' in cancel_res + assert 'queries' in cancel_res + + res = future.result() + assert res.status_code == 500 + raw = res.json() + assert raw['code'] == 'CancelledByUserException' + assert raw['detailedMessage'] == 'Operation terminated (cancelled by user)' + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_sparql_status_and_cancel_silently(self): + + query = "SELECT * WHERE { ?s ?p ?o . ?s2 ?p2 ?o2 .?s3 ?p3 ?o3 . ?s4 ?s5 ?s6 .} ORDER BY DESC(?s)" + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(long_running_sparql_query, self.client, query) + time.sleep(1) + + status = self.client.sparql_status() + status_res = status.json() + assert 'acceptedQueryCount' in status_res + assert 'runningQueryCount' in status_res + assert 'queries' in status_res + + time.sleep(1) + + query_id = '' + for q in status_res['queries']: + if query in q['queryString']: + query_id = q['queryId'] + + assert query_id != '' + + cancel = self.client.sparql_cancel(query_id, True) + cancel_res = cancel.json() + assert 'acceptedQueryCount' in cancel_res + assert 'runningQueryCount' in cancel_res + assert 'queries' in cancel_res + + res = future.result() + query_res = res.json() + assert type(query_res) is dict + assert 's3' in query_res['head']['vars'] + assert 'p3' in query_res['head']['vars'] + assert 'o3' in query_res['head']['vars'] + assert [] == query_res['results']['bindings'] diff --git a/test/integration/iam/status/__init__.py b/test/integration/iam/status/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/integration/iam/status/test_status_with_iam.py b/test/integration/iam/status/test_status_with_iam.py new file mode 100644 index 00000000..0cd83446 --- /dev/null +++ b/test/integration/iam/status/test_status_with_iam.py @@ -0,0 +1,29 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" +import pytest +from botocore.session import get_session + +from test.integration import IntegrationTest + + +class TestStatusWithIAM(IntegrationTest): + def setUp(self) -> None: + super().setUp() + self.client = self.client_builder.with_iam(get_session()).build() + + @pytest.mark.neptune + @pytest.mark.iam + def test_do_status_with_iam_credentials(self): + res = self.client.status() + assert res.status_code == 200 + status = res.json() + self.assertEqual(status['status'], 'healthy') + + @pytest.mark.neptune + @pytest.mark.iam + def test_do_status_without_iam_credentials(self): + client = self.client_builder.with_iam(None).build() + res = client.status() + assert res.status_code != 200 diff --git a/test/integration/iam/system/__init__.py b/test/integration/iam/system/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/integration/iam/system/test_system_with_iam.py b/test/integration/iam/system/test_system_with_iam.py new file mode 100644 index 00000000..a176f404 --- /dev/null +++ b/test/integration/iam/system/test_system_with_iam.py @@ -0,0 +1,60 @@ +import datetime +import time + +import pytest +from botocore.session import get_session +from test.integration import IntegrationTest + + +class TestStatusWithIAM(IntegrationTest): + def setUp(self) -> None: + self.client = self.client_builder.with_iam(get_session()).build() + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_db_reset_initiate_with_iam_credentials(self): + token = self.client.initiate_reset() + result = token.json() + self.assertNotEqual(result['payload']['token'], '') + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_db_reset_perform_with_wrong_token_with_iam_credentials(self): + res = self.client.perform_reset('invalid') + assert res.status_code == 400 + + expected_message = "System command parameter 'token' : 'invalid' does not match database reset token" + assert expected_message == res.json()['detailedMessage'] + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_db_reset_initiate_without_iam_credentials(self): + client = self.client_builder.with_iam(None).build() + res = client.initiate_reset() + assert res.status_code == 403 + + @pytest.mark.iam + @pytest.mark.neptune + @pytest.mark.reset + def test_iam_fast_reset(self): + initiate_reset_res = self.client.initiate_reset() + assert initiate_reset_res.status_code == 200 + + token = initiate_reset_res.json()['payload']['token'] + reset_res = self.client.perform_reset(token) + assert reset_res.json()['status'] == '200 OK' + + # check for status for 5 minutes while reset is performed + end_time = datetime.datetime.now() + datetime.timedelta(minutes=5) + status = None + while end_time >= datetime.datetime.now(): + try: + status = self.client.status() + if status.status_code != 200: + time.sleep(5) # wait momentarily until we obtain the status again + else: + break + except Exception: + time.sleep(5) + + assert status.status_code == 200 diff --git a/test/integration/network/__init__.py b/test/integration/network/__init__.py deleted file mode 100644 index 9049dd04..00000000 --- a/test/integration/network/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" diff --git a/test/integration/network/gremlin/__init__.py b/test/integration/network/gremlin/__init__.py deleted file mode 100644 index 9049dd04..00000000 --- a/test/integration/network/gremlin/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" diff --git a/test/integration/notebook/__init__.py b/test/integration/notebook/__init__.py deleted file mode 100644 index 9049dd04..00000000 --- a/test/integration/notebook/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" diff --git a/test/integration/sparql/__init__.py b/test/integration/sparql/__init__.py deleted file mode 100644 index 9049dd04..00000000 --- a/test/integration/sparql/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" diff --git a/test/integration/sparql/sparql_query_with_iam.py b/test/integration/sparql/sparql_query_with_iam.py deleted file mode 100644 index 8764b182..00000000 --- a/test/integration/sparql/sparql_query_with_iam.py +++ /dev/null @@ -1,30 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider -from graph_notebook.sparql.query import do_sparql_query, do_sparql_explain -from graph_notebook.request_param_generator.factory import create_request_generator - -from test.integration import IntegrationTest - - -class TestSparqlQueryWithIam(IntegrationTest): - def test_do_sparql_query(self): - query = "SELECT * WHERE {?s ?p ?o} LIMIT 1" - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - res = do_sparql_query(query, self.host, self.port, self.ssl, request_generator) - self.assertEqual(type(res), dict) - self.assertTrue('s' in res['head']['vars']) - self.assertTrue('p' in res['head']['vars']) - self.assertTrue('o' in res['head']['vars']) - - def test_do_sparql_explain(self): - query = "SELECT * WHERE {?s ?p ?o} LIMIT 1" - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - - res = do_sparql_explain(query, self.host, self.port, self.ssl, request_generator) - self.assertEqual(type(res), str) - self.assertTrue(res.startswith('')) diff --git a/test/integration/sparql/sparql_status_with_iam.py b/test/integration/sparql/sparql_status_with_iam.py deleted file mode 100644 index 60ee3fb3..00000000 --- a/test/integration/sparql/sparql_status_with_iam.py +++ /dev/null @@ -1,134 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import threading -import time -import requests - -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider -from graph_notebook.sparql.query import do_sparql_query -from graph_notebook.sparql.status import do_sparql_status, do_sparql_cancel -from graph_notebook.request_param_generator.factory import create_request_generator - -from test.integration import DataDrivenSparqlTest - - -class TestSparqlStatusWithIam(DataDrivenSparqlTest): - def do_sparql_query_save_result(self, query, res): - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - try: - res['result'] = do_sparql_query(query, self.host, self.port, self.ssl, request_generator) - except requests.HTTPError as exception: - res['error'] = exception.response.json() - - def setUp(self) -> None: - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - res = do_sparql_status(self.host, self.port, self.ssl, request_generator) - for q in res['queries']: - do_sparql_cancel(self.host, self.port, self.ssl, request_generator, q['queryId'], False) - - def test_do_sparql_status_nonexistent(self): - query_id = "ac7d5a03-00cf-4280-b464-edbcbf51ffce" - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - res = do_sparql_status(self.host, self.port, self.ssl, request_generator, query_id) - self.assertEqual(type(res), dict) - self.assertTrue('acceptedQueryCount' in res) - self.assertTrue('runningQueryCount' in res) - self.assertTrue('queries' in res) - - def test_do_sparql_cancel_nonexistent(self): - query_id = "ac7d5a03-00cf-4280-b464-edbcbf51ffce" - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - res = do_sparql_cancel(self.host, self.port, self.ssl, request_generator, query_id, False) - self.assertEqual(type(res), dict) - self.assertTrue('acceptedQueryCount' in res) - self.assertTrue('runningQueryCount' in res) - self.assertTrue('queries' in res) - - def test_do_sparql_cancel_empty_query_id(self): - with self.assertRaises(ValueError): - query_id = '' - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - do_sparql_cancel(self.host, self.port, self.ssl, request_generator, query_id, False) - - def test_do_sparql_cancel_non_str_query_id(self): - with self.assertRaises(ValueError): - query_id = 42 - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - do_sparql_cancel(self.host, self.port, self.ssl, request_generator, query_id, False) - - def test_do_sparql_status_and_cancel(self): - query = "SELECT * WHERE { ?s ?p ?o . ?s2 ?p2 ?o2 .?s3 ?p3 ?o3 .} ORDER BY DESC(?s) LIMIT 100" - query_res = {} - sparql_query_thread = threading.Thread(target=self.do_sparql_query_save_result, args=(query, query_res,)) - sparql_query_thread.start() - time.sleep(1) - - query_id = '' - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - status_res = do_sparql_status(self.host, self.port, self.ssl, self.request_generator, query_id) - self.assertEqual(type(status_res), dict) - self.assertTrue('acceptedQueryCount' in status_res) - self.assertTrue('runningQueryCount' in status_res) - self.assertTrue('queries' in status_res) - - time.sleep(1) - - query_id = '' - for q in status_res['queries']: - if query in q['queryString']: - query_id = q['queryId'] - - self.assertNotEqual(query_id, '') - - cancel_res = do_sparql_cancel(self.host, self.port, self.ssl, request_generator, query_id, False) - self.assertEqual(type(cancel_res), dict) - self.assertTrue('acceptedQueryCount' in cancel_res) - self.assertTrue('runningQueryCount' in cancel_res) - self.assertTrue('queries' in cancel_res) - - sparql_query_thread.join() - self.assertFalse('result' in query_res) - self.assertTrue('error' in query_res) - self.assertTrue('code' in query_res['error']) - self.assertTrue('requestId' in query_res['error']) - self.assertTrue('detailedMessage' in query_res['error']) - self.assertEqual('CancelledByUserException', query_res['error']['code']) - - def test_do_sparql_status_and_cancel_silently(self): - query = "SELECT * WHERE { ?s ?p ?o . ?s2 ?p2 ?o2 .?s3 ?p3 ?o3 .} ORDER BY DESC(?s) LIMIT 100" - query_res = {} - sparql_query_thread = threading.Thread(target=self.do_sparql_query_save_result, args=(query, query_res,)) - sparql_query_thread.start() - time.sleep(1) - - query_id = '' - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - status_res = do_sparql_status(self.host, self.port, self.ssl, request_generator, query_id) - self.assertEqual(type(status_res), dict) - self.assertTrue('acceptedQueryCount' in status_res) - self.assertTrue('runningQueryCount' in status_res) - self.assertTrue('queries' in status_res) - - query_id = '' - for q in status_res['queries']: - if query in q['queryString']: - query_id = q['queryId'] - - self.assertNotEqual(query_id, '') - - cancel_res = do_sparql_cancel(self.host, self.port, self.ssl, request_generator, query_id, True) - self.assertEqual(type(cancel_res), dict) - self.assertTrue('acceptedQueryCount' in cancel_res) - self.assertTrue('runningQueryCount' in cancel_res) - self.assertTrue('queries' in cancel_res) - - sparql_query_thread.join() - self.assertEqual(type(query_res['result']), dict) - self.assertTrue('s3' in query_res['result']['head']['vars']) - self.assertTrue('p3' in query_res['result']['head']['vars']) - self.assertTrue('o3' in query_res['result']['head']['vars']) - self.assertEqual([], query_res['result']['results']['bindings']) diff --git a/test/integration/sparql/sparql_status_without_iam.py b/test/integration/sparql/sparql_status_without_iam.py deleted file mode 100644 index 1d600b34..00000000 --- a/test/integration/sparql/sparql_status_without_iam.py +++ /dev/null @@ -1,137 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import threading - -import logging -import time -import requests - -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.sparql.query import do_sparql_query -from graph_notebook.sparql.status import do_sparql_status, do_sparql_cancel -from graph_notebook.request_param_generator.factory import create_request_generator -from graph_notebook.request_param_generator.sparql_request_generator import SPARQLRequestGenerator - -from test.integration import DataDrivenSparqlTest - -logger = logging.getLogger('TestSparqlStatusWithoutIam') - - -class TestSparqlStatusWithoutIam(DataDrivenSparqlTest): - def do_sparql_query_save_result(self, query, res): - try: - res['result'] = do_sparql_query(query, self.host, self.port, self.ssl, SPARQLRequestGenerator()) - except requests.HTTPError as exception: - res['error'] = exception.response.json() - - def setUp(self) -> None: - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - res = do_sparql_status(self.host, self.port, self.ssl, request_generator) - for q in res['queries']: - do_sparql_cancel(self.host, self.port, self.ssl, request_generator, q['queryId'], False) - - def test_do_sparql_status_nonexistent(self): - query_id = "ac7d5a03-00cf-4280-b464-edbcbf51ffce" - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - res = do_sparql_status(self.host, self.port, self.ssl, request_generator, query_id) - self.assertEqual(type(res), dict) - self.assertTrue('acceptedQueryCount' in res) - self.assertTrue('runningQueryCount' in res) - self.assertTrue('queries' in res) - - def test_do_sparql_cancel_nonexistent(self): - query_id = "ac7d5a03-00cf-4280-b464-edbcbf51ffce" - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - res = do_sparql_cancel(self.host, self.port, self.ssl, request_generator, query_id, False) - self.assertEqual(type(res), dict) - self.assertTrue('acceptedQueryCount' in res) - self.assertTrue('runningQueryCount' in res) - self.assertTrue('queries' in res) - - def test_do_sparql_cancel_empty_query_id(self): - with self.assertRaises(ValueError): - query_id = '' - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - do_sparql_cancel(query_id, False, self.host, self.port, self.ssl, request_generator) - - def test_do_sparql_cancel_non_str_query_id(self): - with self.assertRaises(ValueError): - query_id = 42 - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - do_sparql_cancel(query_id, False, self.host, self.port, self.ssl, request_generator) - - def test_do_sparql_status_and_cancel(self): - query = "SELECT * WHERE { ?s ?p ?o . ?s2 ?p2 ?o2 .?s3 ?p3 ?o3 .} ORDER BY DESC(?s) LIMIT 100" - query_res = {} - sparql_query_thread = threading.Thread(target=self.do_sparql_query_save_result, args=(query, query_res,)) - sparql_query_thread.start() - time.sleep(3) - - query_id = '' - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - status_res = do_sparql_status(self.host, self.port, self.ssl, request_generator, query_id) - self.assertEqual(type(status_res), dict) - self.assertTrue('acceptedQueryCount' in status_res) - self.assertTrue('runningQueryCount' in status_res) - self.assertEqual(1, status_res['runningQueryCount']) - self.assertTrue('queries' in status_res) - - query_id = '' - for q in status_res['queries']: - if query in q['queryString']: - query_id = q['queryId'] - - self.assertNotEqual(query_id, '') - - cancel_res = do_sparql_cancel(self.host, self.port, self.ssl, request_generator, query_id, False) - self.assertEqual(type(cancel_res), dict) - self.assertTrue('acceptedQueryCount' in cancel_res) - self.assertTrue('runningQueryCount' in cancel_res) - self.assertTrue('queries' in cancel_res) - - sparql_query_thread.join() - self.assertFalse('result' in query_res) - self.assertTrue('error' in query_res) - self.assertTrue('code' in query_res['error']) - self.assertTrue('requestId' in query_res['error']) - self.assertTrue('detailedMessage' in query_res['error']) - self.assertEqual('CancelledByUserException', query_res['error']['code']) - - def test_do_sparql_status_and_cancel_silently(self): - query = "SELECT * WHERE { ?s ?p ?o . ?s2 ?p2 ?o2 .?s3 ?p3 ?o3 .} ORDER BY DESC(?s) LIMIT 100" - query_res = {} - sparql_query_thread = threading.Thread(target=self.do_sparql_query_save_result, args=(query, query_res,)) - sparql_query_thread.start() - time.sleep(3) - - query_id = '' - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - status_res = do_sparql_status(self.host, self.port, self.ssl, request_generator, query_id) - self.assertEqual(type(status_res), dict) - self.assertTrue('acceptedQueryCount' in status_res) - self.assertTrue('runningQueryCount' in status_res) - self.assertEqual(1, status_res['runningQueryCount']) - self.assertTrue('queries' in status_res) - - query_id = '' - for q in status_res['queries']: - if query in q['queryString']: - query_id = q['queryId'] - - self.assertNotEqual(query_id, '') - - cancel_res = do_sparql_cancel(self.host, self.port, self.ssl, request_generator, query_id, True) - self.assertEqual(type(cancel_res), dict) - self.assertTrue('acceptedQueryCount' in cancel_res) - self.assertTrue('runningQueryCount' in cancel_res) - self.assertTrue('queries' in cancel_res) - - sparql_query_thread.join() - self.assertEqual(type(query_res['result']), dict) - self.assertTrue('s3' in query_res['result']['head']['vars']) - self.assertTrue('p3' in query_res['result']['head']['vars']) - self.assertTrue('o3' in query_res['result']['head']['vars']) - self.assertEqual([], query_res['result']['results']['bindings']) diff --git a/test/integration/status/__init__.py b/test/integration/status/__init__.py deleted file mode 100644 index 9049dd04..00000000 --- a/test/integration/status/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" diff --git a/test/integration/status/status_with_iam.py b/test/integration/status/status_with_iam.py deleted file mode 100644 index 5878cede..00000000 --- a/test/integration/status/status_with_iam.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from requests.exceptions import HTTPError - -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.request_param_generator.factory import create_request_generator -from graph_notebook.status.get_status import get_status - -from test.integration import IntegrationTest - - -class TestStatusWithIAM(IntegrationTest): - def test_do_status_with_iam_credentials(self): - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - status = get_status(self.host, self.port, self.ssl, request_generator) - self.assertEqual(status['status'], 'healthy') - - def test_do_status_without_iam_credentials(self): - with self.assertRaises(HTTPError): - get_status(self.host, self.port, self.ssl) diff --git a/test/integration/system/system_with_iam.py b/test/integration/system/system_with_iam.py deleted file mode 100644 index a3298630..00000000 --- a/test/integration/system/system_with_iam.py +++ /dev/null @@ -1,25 +0,0 @@ -from requests.exceptions import HTTPError - -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.request_param_generator.factory import create_request_generator -from graph_notebook.system.database_reset import initiate_database_reset, perform_database_reset -from test.integration import IntegrationTest - - -class TestStatusWithIAM(IntegrationTest): - def test_do_db_reset_initiate_with_iam_credentials(self): - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - result = initiate_database_reset(self.host, self.port, self.ssl, request_generator) - self.assertNotEqual(result['payload']['token'], '') - - def test_do_db_reset_perform_with_wrong_token_with_iam_credentials(self): - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - with self.assertRaises(HTTPError) as cm: - perform_database_reset('x', self.host, self.port, self.ssl, request_generator) - expected_message = "System command parameter 'token' : 'x' does not match database reset token" - self.assertEqual(expected_message, str(cm.exception.response.json()['detailedMessage'])) - - def test_do_db_reset_initiate_without_iam_credentials(self): - with self.assertRaises(HTTPError): - initiate_database_reset(self.host, self.port, self.ssl) diff --git a/test/integration/system/system_without_iam.py b/test/integration/system/system_without_iam.py deleted file mode 100644 index 566f0a12..00000000 --- a/test/integration/system/system_without_iam.py +++ /dev/null @@ -1,15 +0,0 @@ -from requests.exceptions import HTTPError -from graph_notebook.system.database_reset import initiate_database_reset, perform_database_reset -from test.integration import IntegrationTest - - -class TestStatusWithoutIAM(IntegrationTest): - def test_do_database_reset_initiate(self): - result = initiate_database_reset(self.host, self.port, self.ssl) - self.assertNotEqual(result['payload']['token'], '') - - def test_do_database_reset_perform_with_wrong_token(self): - with self.assertRaises(HTTPError) as cm: - perform_database_reset('x', self.host, self.port, self.ssl) - expected_message = "System command parameter 'token' : 'x' does not match database reset token" - self.assertEqual(expected_message, str(cm.exception.response.json()['detailedMessage'])) diff --git a/test/integration/without_iam/__init__.py b/test/integration/without_iam/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/integration/without_iam/gremlin/__init__.py b/test/integration/without_iam/gremlin/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/integration/gremlin/bug_fix_tests_without_iam.py b/test/integration/without_iam/gremlin/test_gremlin_patching.py similarity index 71% rename from test/integration/gremlin/bug_fix_tests_without_iam.py rename to test/integration/without_iam/gremlin/test_gremlin_patching.py index 5a7c63d4..37557faa 100644 --- a/test/integration/gremlin/bug_fix_tests_without_iam.py +++ b/test/integration/without_iam/gremlin/test_gremlin_patching.py @@ -5,13 +5,11 @@ import logging -from graph_notebook.gremlin.client_provider.default_client import ClientProvider -from graph_notebook.gremlin.query import do_gremlin_query +import pytest from test.integration import IntegrationTest - -logger = logging.getLogger('TestUnhashableTypeDict') +logger = logging.getLogger('test_bug_fixes') class TestBugFixes(IntegrationTest): @@ -21,12 +19,8 @@ class TestBugFixes(IntegrationTest): is not. We ran into this in a Data Lab a while back but we worked around it there by monkey patching the Gremlin Python client. We may want to do the same for the version of Gremlin Python used by the workbench.""" - @classmethod - def setUpClass(cls): - super(TestBugFixes, cls).setUpClass() - - cls.client_provider = ClientProvider() - + def setUp(self) -> None: + self.client = self.client_builder.build() queries = [ "g.addV('Interest').property(id,'i1').property('value', 4)", "g.addV('Priority').property(id, 'p1').property('name', 'P1')", @@ -34,25 +28,19 @@ def setUpClass(cls): "g.V('m1').addE('interested').to(g.V('i1'))", "g.V('m1').addE('prioritized').to(g.V('p1'))" ] - cls.runQueries(queries) + for q in queries: + self.client.gremlin_query(q) - @classmethod - def tearDownClass(cls): + def tearDown(self) -> None: queries = [ "g.V('i1').drop()", "g.V('p1').drop()", "g.V('m1').drop()" ] - cls.runQueries(queries) - - @classmethod - def runQueries(cls, queries): - for query in queries: - try: - do_gremlin_query(query, cls.host, cls.port, cls.ssl, cls.client_provider) - except Exception as e: - logger.error(f'query {query} failed due to {e}') + for q in queries: + self.client.gremlin_query(q) + @pytest.mark.gremlin def test_do_gremlin_query_with_map_as_key(self): query = """ g.V().hasLabel("Interest").as("int") @@ -63,7 +51,7 @@ def test_do_gremlin_query_with_map_as_key(self): .by("name") .groupCount().unfold() """ - results = do_gremlin_query(query, self.host, self.port, self.ssl, self.client_provider) + results = self.client.gremlin_query(query) keys_are_hashable = True for key in results[0].keys(): try: @@ -73,13 +61,14 @@ def test_do_gremlin_query_with_map_as_key(self): break self.assertEqual(keys_are_hashable, True) + @pytest.mark.gremlin def test_do_gremlin_query_with_list_as_key(self): query = """ g.V('m1').group() .by(out().fold()) .by(out().count()) """ - results = do_gremlin_query(query, self.host, self.port, self.ssl, self.client_provider) + results = self.client.gremlin_query(query) keys_are_hashable = True for key in results[0].keys(): try: diff --git a/test/integration/without_iam/gremlin/test_gremlin_query.py b/test/integration/without_iam/gremlin/test_gremlin_query.py new file mode 100644 index 00000000..79ee575b --- /dev/null +++ b/test/integration/without_iam/gremlin/test_gremlin_query.py @@ -0,0 +1,37 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" +import pytest +from gremlin_python.structure.graph import Vertex + +from test.integration import IntegrationTest + + +class TestGremlin(IntegrationTest): + @pytest.mark.gremlin + def test_do_gremlin_query(self): + query = 'g.V().limit(1)' + results = self.client.gremlin_query(query) + assert type(results) is list + for r in results: + assert type(r) is Vertex + + self.assertEqual(type(results), list) + + @pytest.mark.gremlin + def test_do_gremlin_explain(self): + query = 'g.V().limit(1)' + res = self.client.gremlin_explain(query) + assert res.status_code == 200 + results = res.content.decode('utf-8') + self.assertTrue('Explain' in results) + + @pytest.mark.gremlin + def test_do_gremlin_profile(self): + query = 'g.V().limit(1)' + res = self.client.gremlin_profile(query) + assert res.status_code == 200 + + results = res.content.decode('utf-8') + self.assertTrue('Profile' in results) diff --git a/test/integration/without_iam/gremlin/test_gremlin_status_without_iam.py b/test/integration/without_iam/gremlin/test_gremlin_status_without_iam.py new file mode 100644 index 00000000..37337ee0 --- /dev/null +++ b/test/integration/without_iam/gremlin/test_gremlin_status_without_iam.py @@ -0,0 +1,113 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" + +import threading +import time + +import pytest +import concurrent.futures +from os import cpu_count + +from gremlin_python.driver.protocol import GremlinServerError + +from graph_notebook.neptune.client import Client +from test.integration import DataDrivenGremlinTest + + +def long_running_gremlin_query(c: Client, query: str): + res = c.gremlin_query(query) + return res + + +class TestGremlinStatusWithoutIam(DataDrivenGremlinTest): + @pytest.mark.neptune + def test_do_gremlin_status_nonexistent(self): + query_id = "some-guid-here" + res = self.client.gremlin_status(query_id) + assert res.status_code == 400 + js = res.json() + assert js['code'] == 'InvalidParameterException' + assert js['detailedMessage'] == f'Supplied queryId {query_id} is invalid' + + @pytest.mark.neptune + def test_do_gremlin_cancel_nonexistent(self): + query_id = "some-guid-here" + res = self.client.gremlin_cancel(query_id) + assert res.status_code == 400 + js = res.json() + assert js['code'] == 'InvalidParameterException' + assert js['detailedMessage'] == f'Supplied queryId {query_id} is invalid' + + @pytest.mark.neptune + def test_do_gremlin_cancel_empty_query_id(self): + with self.assertRaises(ValueError): + self.client.gremlin_cancel('') + + @pytest.mark.neptune + def test_do_gremlin_cancel_non_str_query_id(self): + with self.assertRaises(ValueError): + self.client.gremlin_cancel(42) + + @pytest.mark.neptune + def test_do_gremlin_status_and_cancel(self): + long_running_query = "g.V().out().out().out().out().out().out().out().out()" + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(long_running_gremlin_query, self.client, long_running_query) + + time.sleep(1) + status_res = self.client.gremlin_status() + assert status_res.status_code == 200 + + status_js = status_res.json() + query_id = '' + for q in status_js['queries']: + if q['queryString'] == long_running_query: + query_id = q['queryId'] + + assert query_id != '' + + cancel_res = self.client.gremlin_cancel(query_id) + assert cancel_res.status_code == 200 + assert cancel_res.json()['status'] == '200 OK' + + time.sleep(1) + status_after_cancel = self.client.gremlin_status(query_id) + assert status_after_cancel.status_code == 400 # check that the query is no longer valid + assert status_after_cancel.json()['code'] == 'InvalidParameterException' + + with pytest.raises(GremlinServerError): + # this result corresponds to the cancel query, so our gremlin client will raise an exception + future.result() + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_gremlin_status_include_waiting(self): + query = "g.V().out().out().out().out()" + num_threads = cpu_count() * 4 + threads = [] + for x in range(0, num_threads): + thread = threading.Thread(target=long_running_gremlin_query, args=(self.client, query)) + thread.start() + threads.append(thread) + + time.sleep(5) + + res = self.client.gremlin_status(include_waiting=True) + assert res.status_code == 200 + status_res = res.json() + + self.assertEqual(type(status_res), dict) + self.assertTrue('acceptedQueryCount' in status_res) + self.assertTrue('runningQueryCount' in status_res) + self.assertTrue('queries' in status_res) + self.assertEqual(status_res['acceptedQueryCount'], len(status_res['queries'])) + + for q in status_res['queries']: + # cancel all the queries we executed since they can take a very long time. + if q['queryString'] == query: + self.client.gremlin_cancel(q['queryId']) + + for t in threads: + t.join() diff --git a/src/graph_notebook/loader/__init__.py b/test/integration/without_iam/network/__init__.py similarity index 100% rename from src/graph_notebook/loader/__init__.py rename to test/integration/without_iam/network/__init__.py diff --git a/src/graph_notebook/request_param_generator/__init__.py b/test/integration/without_iam/network/gremlin/__init__.py similarity index 100% rename from src/graph_notebook/request_param_generator/__init__.py rename to test/integration/without_iam/network/gremlin/__init__.py diff --git a/test/integration/network/gremlin/gremlin_network_from_queries.py b/test/integration/without_iam/network/gremlin/test_gremlin_network_from_queries.py similarity index 81% rename from test/integration/network/gremlin/gremlin_network_from_queries.py rename to test/integration/without_iam/network/gremlin/test_gremlin_network_from_queries.py index e96f08f9..51b81748 100644 --- a/test/integration/network/gremlin/gremlin_network_from_queries.py +++ b/test/integration/without_iam/network/gremlin/test_gremlin_network_from_queries.py @@ -2,18 +2,19 @@ Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. SPDX-License-Identifier: Apache-2.0 """ +import pytest -from graph_notebook.gremlin.query import do_gremlin_query from graph_notebook.network.gremlin.GremlinNetwork import GremlinNetwork from test.integration import DataDrivenGremlinTest class TestGremlinNetwork(DataDrivenGremlinTest): + + @pytest.mark.gremlin def test_add_paths_to_network(self): airports_path_query = "g.V().has('code', 'SEA').outE().inV().path()" - results = do_gremlin_query(airports_path_query, self.host, self.port, self.ssl, self.client_provider) - + results = self.client.gremlin_query(airports_path_query) gremlin_network = GremlinNetwork() gremlin_network.add_results(results) sea_code = '22' @@ -23,10 +24,10 @@ def test_add_paths_to_network(self): actual_label = gremlin_network.graph[sea_code][aus_code][edge_id]['label'] self.assertEqual(expected_label, actual_label) + @pytest.mark.gremlin def test_add_value_map_to_network(self): airports_path_query = "g.V().has('code', 'SEA').outE().inV().path().by(valueMap(true))" - results = do_gremlin_query(airports_path_query, self.host, self.port, self.ssl, self.client_provider) - + results = self.client.gremlin_query(airports_path_query) gremlin_network = GremlinNetwork() gremlin_network.add_results(results) edge_id = '4406' @@ -34,9 +35,10 @@ def test_add_value_map_to_network(self): actual_label = gremlin_network.graph.nodes.get(edge_id)['label'] self.assertEqual(expected_label, actual_label) + @pytest.mark.gremlin def test_add_entire_path(self): sea_to_bmi = "g.V().has('code', 'SEA').outE().inV().has('code', 'ORD').outE().inV().has('code', 'BMI').path()" - results = do_gremlin_query(sea_to_bmi, self.host, self.port, self.ssl, self.client_provider) + results = self.client.gremlin_query(sea_to_bmi) gremlin_network = GremlinNetwork() gremlin_network.add_results(results) @@ -48,9 +50,10 @@ def test_add_entire_path(self): self.assertTrue(gremlin_network.graph.has_edge('22', '18', '4420')) self.assertTrue(gremlin_network.graph.has_edge('18', '359', '7126')) + @pytest.mark.gremlin def test_add_paths_with_bad_pattern(self): query = "g.V().out().out().path().limit(10)" - results = do_gremlin_query(query, self.host, self.port, self.ssl, self.client_provider) + results = self.client.gremlin_query(query) gremlin_network = GremlinNetwork() gremlin_network.add_results(results) @@ -61,17 +64,19 @@ def test_add_paths_with_bad_pattern(self): self.assertEqual('', edge['label']) self.assertFalse(edge['arrows']['to']['enabled']) + @pytest.mark.gremlin def test_add_path_with_repeat(self): query = "g.V().has('airport', 'code', 'ANC').repeat(outE().inV().simplePath()).times(2).path().by('code').by()" - results = do_gremlin_query(query, self.host, self.port, self.ssl, self.client_provider) + results = self.client.gremlin_query(query) gremlin_network = GremlinNetwork() gremlin_network.add_results(results) self.assertEqual('route', gremlin_network.graph.edges[('ANC', 'BLI', '5276')]['label']) + @pytest.mark.gremlin def test_valuemap_without_ids(self): query = "g.V().has('code', 'ANC').out().path().by(valueMap()).limit(10)" - results = do_gremlin_query(query, self.host, self.port, self.ssl, self.client_provider) + results = self.client.gremlin_query(query) gremlin_network = GremlinNetwork() gremlin_network.add_results(results) @@ -79,25 +84,28 @@ def test_valuemap_without_ids(self): node = gremlin_network.graph.nodes.get(n) self.assertEqual(gremlin_network.label_max_length, len(node['label'])) + @pytest.mark.gremlin def test_path_without_by_nodes_have_ids(self): query = "g.V().has('code', 'AUS').outE().inV().outE().inV().has('code', 'SEA').path()" - results = do_gremlin_query(query, self.host, self.port, self.ssl, self.client_provider) + results = self.client.gremlin_query(query) gremlin_network = GremlinNetwork() gremlin_network.add_results(results) node = gremlin_network.graph.nodes.get('9') self.assertIsNotNone(node) + @pytest.mark.gremlin def test_path_without_by_oute_has_arrows(self): query = "g.V().hasLabel('airport').has('code', 'SEA').outE().inV().path()" - results = do_gremlin_query(query, self.host, self.port, self.ssl, self.client_provider) + results = self.client.gremlin_query(query) gremlin_network = GremlinNetwork() gremlin_network.add_results(results) edge = gremlin_network.graph.edges[('22', '151', '7389')] self.assertTrue('arrows' not in edge) + @pytest.mark.gremlin def test_path_without_by_ine_has_arrows(self): query = "g.V().hasLabel('airport').has('code', 'SEA').inE().outV().path()" - results = do_gremlin_query(query, self.host, self.port, self.ssl, self.client_provider) + results = self.client.gremlin_query(query) gremlin_network = GremlinNetwork() gremlin_network.add_results(results) edge = gremlin_network.graph.edges[('3670', '22', '53637')] diff --git a/test/integration/network/gremlin/gremlin_network_with_pattern.py b/test/integration/without_iam/network/gremlin/test_gremlin_network_with_pattern.py similarity index 83% rename from test/integration/network/gremlin/gremlin_network_with_pattern.py rename to test/integration/without_iam/network/gremlin/test_gremlin_network_with_pattern.py index 43f3fce0..1a154f88 100644 --- a/test/integration/network/gremlin/gremlin_network_with_pattern.py +++ b/test/integration/without_iam/network/gremlin/test_gremlin_network_with_pattern.py @@ -3,7 +3,6 @@ SPDX-License-Identifier: Apache-2.0 """ -from graph_notebook.gremlin.query import do_gremlin_query from graph_notebook.network.gremlin.GremlinNetwork import GremlinNetwork, PathPattern from test.integration import DataDrivenGremlinTest @@ -12,7 +11,7 @@ class TestGremlinNetwork(DataDrivenGremlinTest): def test_add_path_with_edge_object(self): query = "g.V().has('airport','code','AUS').outE().inV().path().by('code').by().limit(10)" - results = do_gremlin_query(query, self.host, self.port, self.ssl, self.client_provider) + results = self.client.gremlin_query(query) gn = GremlinNetwork() pattern = [PathPattern.V, PathPattern.OUT_E, PathPattern.IN_V] gn.add_results_with_pattern(results, pattern) @@ -27,7 +26,7 @@ def test_add_path_by_dist(self): path(). by('code'). by('dist')""" - results = do_gremlin_query(query, self.host, self.port, self.ssl, self.client_provider) + results = self.client.gremlin_query(query) gn = GremlinNetwork() pattern = [PathPattern.V, PathPattern.OUT_E, PathPattern.IN_V, PathPattern.OUT_E] gn.add_results_with_pattern(results, pattern) @@ -41,7 +40,7 @@ def test_path_with_dict(self): by(valueMap('code','city','region','desc','lat','lon'). order(local). by(keys))""" - results = do_gremlin_query(query, self.host, self.port, self.ssl, self.client_provider) + results = self.client.gremlin_query(query) gn = GremlinNetwork() pattern = [PathPattern.V, PathPattern.IN_V] gn.add_results_with_pattern(results, pattern) @@ -55,7 +54,7 @@ def test_out_v_unhashable_dict(self): out(). path(). by(valueMap())""" - results = do_gremlin_query(query, self.host, self.port, self.ssl, self.client_provider) + results = self.client.gremlin_query(query) gn = GremlinNetwork() pattern = [PathPattern.V, PathPattern.OUT_V] gn.add_results_with_pattern(results, pattern) diff --git a/src/graph_notebook/sparql/__init__.py b/test/integration/without_iam/notebook/__init__.py similarity index 100% rename from src/graph_notebook/sparql/__init__.py rename to test/integration/without_iam/notebook/__init__.py diff --git a/test/integration/notebook/test_gremlin_graph_notebook.py b/test/integration/without_iam/notebook/test_gremlin_graph_notebook.py similarity index 86% rename from test/integration/notebook/test_gremlin_graph_notebook.py rename to test/integration/without_iam/notebook/test_gremlin_graph_notebook.py index c8ed1f92..c456c1b3 100644 --- a/test/integration/notebook/test_gremlin_graph_notebook.py +++ b/test/integration/without_iam/notebook/test_gremlin_graph_notebook.py @@ -2,8 +2,9 @@ Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. SPDX-License-Identifier: Apache-2.0 """ +import pytest -from test.integration.notebook.GraphNotebookIntegrationTest import GraphNotebookIntegrationTest +from test.integration import GraphNotebookIntegrationTest class TestGraphMagicGremlin(GraphNotebookIntegrationTest): @@ -11,6 +12,8 @@ def tearDown(self) -> None: delete_query = "g.V('graph-notebook-test').drop()" self.ip.run_cell_magic('gremlin', 'query', delete_query) + @pytest.mark.jupyter + @pytest.mark.gremlin def test_gremlin_query(self): label = 'graph-notebook-test' query = f"g.addV('{label}')" diff --git a/test/integration/notebook/test_sparql_graph_notebook.py b/test/integration/without_iam/notebook/test_sparql_graph_notebook.py similarity index 64% rename from test/integration/notebook/test_sparql_graph_notebook.py rename to test/integration/without_iam/notebook/test_sparql_graph_notebook.py index a71c1935..69abd498 100644 --- a/test/integration/notebook/test_sparql_graph_notebook.py +++ b/test/integration/without_iam/notebook/test_sparql_graph_notebook.py @@ -2,12 +2,26 @@ Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. SPDX-License-Identifier: Apache-2.0 """ +import pytest -from test.integration.notebook.GraphNotebookIntegrationTest import GraphNotebookIntegrationTest +from test.integration import GraphNotebookIntegrationTest class TestGraphMagicGremlin(GraphNotebookIntegrationTest): + + @pytest.mark.jupyter + @pytest.mark.sparql def test_sparql_query(self): + query = 'SELECT * WHERE {?s ?o ?p } LIMIT 1' + store_to_var = 'sparql_res' + self.ip.run_cell_magic('sparql', f'--store-to {store_to_var}', query) + self.assertFalse('graph_notebook_error' in self.ip.user_ns) + sparql_res = self.ip.user_ns[store_to_var] + self.assertEqual(['s', 'o', 'p'], sparql_res['head']['vars']) + + @pytest.mark.jupyter + @pytest.mark.sparql + def test_sparql_query_explain(self): query = 'SELECT * WHERE {?s ?o ?p } LIMIT 1' store_to_var = 'sparql_res' self.ip.run_cell_magic('sparql', f'explain --store-to {store_to_var}', query) @@ -16,6 +30,7 @@ def test_sparql_query(self): self.assertTrue(sparql_res.startswith('')) self.assertTrue('' in sparql_res) + @pytest.mark.jupyter def test_load_sparql_config(self): config = '''{ "host": "localhost", diff --git a/test/integration/notebook/test_status_graph_notebook.py b/test/integration/without_iam/notebook/test_status_graph_notebook.py similarity index 71% rename from test/integration/notebook/test_status_graph_notebook.py rename to test/integration/without_iam/notebook/test_status_graph_notebook.py index 1b280f64..23387820 100644 --- a/test/integration/notebook/test_status_graph_notebook.py +++ b/test/integration/without_iam/notebook/test_status_graph_notebook.py @@ -2,11 +2,14 @@ Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. SPDX-License-Identifier: Apache-2.0 """ +import pytest -from test.integration.notebook.GraphNotebookIntegrationTest import GraphNotebookIntegrationTest +from test.integration import GraphNotebookIntegrationTest class TestGraphMagicStatus(GraphNotebookIntegrationTest): + @pytest.mark.jupyter + @pytest.mark.neptune def test_status(self): res = self.ip.run_line_magic('status', '') self.assertEqual('healthy', res['status']) diff --git a/src/graph_notebook/status/__init__.py b/test/integration/without_iam/sparql/__init__.py similarity index 100% rename from src/graph_notebook/status/__init__.py rename to test/integration/without_iam/sparql/__init__.py diff --git a/test/integration/sparql/sparql_query_without_iam.py b/test/integration/without_iam/sparql/test_sparql_query_without_iam.py similarity index 59% rename from test/integration/sparql/sparql_query_without_iam.py rename to test/integration/without_iam/sparql/test_sparql_query_without_iam.py index c1bccc27..58cf3045 100644 --- a/test/integration/sparql/sparql_query_without_iam.py +++ b/test/integration/without_iam/sparql/test_sparql_query_without_iam.py @@ -2,27 +2,30 @@ Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. SPDX-License-Identifier: Apache-2.0 """ - -from graph_notebook.request_param_generator.sparql_request_generator import SPARQLRequestGenerator -from graph_notebook.sparql.query import do_sparql_query, do_sparql_explain +import pytest from test.integration import IntegrationTest class TestSparqlQuery(IntegrationTest): + @pytest.mark.sparql def test_do_sparql_query(self): query = "SELECT * WHERE {?s ?p ?o} LIMIT 1" - request_generator = SPARQLRequestGenerator() - res = do_sparql_query(query, self.host, self.port, self.ssl, request_generator) + sparql_res = self.client.sparql(query) + assert sparql_res.status_code == 200 + res = sparql_res.json() + self.assertEqual(type(res), dict) self.assertTrue('s' in res['head']['vars']) self.assertTrue('p' in res['head']['vars']) self.assertTrue('o' in res['head']['vars']) + @pytest.mark.sparql def test_do_sparql_explain(self): query = "SELECT * WHERE {?s ?p ?o} LIMIT 1" - request_generator = SPARQLRequestGenerator() - res = do_sparql_explain(query, self.host, self.port, self.ssl, request_generator) + query_res = self.client.sparql_explain(query) + assert query_res.status_code == 200 + res = query_res.content.decode('utf-8') self.assertEqual(type(res), str) self.assertTrue(res.startswith('')) diff --git a/test/integration/without_iam/sparql/test_sparql_status_without_iam.py b/test/integration/without_iam/sparql/test_sparql_status_without_iam.py new file mode 100644 index 00000000..9ee39c99 --- /dev/null +++ b/test/integration/without_iam/sparql/test_sparql_status_without_iam.py @@ -0,0 +1,118 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" + +import concurrent.futures + +import logging +import time +import pytest + +from graph_notebook.neptune.client import Client +from test.integration import DataDrivenSparqlTest + +logger = logging.getLogger('TestSparqlStatusWithoutIam') + + +def long_running_sparql_query(c: Client, query: str): + res = c.sparql(query) + return res + + +class TestSparqlStatusWithoutIam(DataDrivenSparqlTest): + @pytest.mark.neptune + def test_do_sparql_status_nonexistent(self): + query_id = "invalid-guid" + status_res = self.client.sparql_status(query_id) + assert status_res.status_code == 200 + assert status_res.content == b'' + + @pytest.mark.neptune + def test_do_sparql_cancel_nonexistent(self): + query_id = "invalid-guid" + cancel_res = self.client.sparql_cancel(query_id) + assert cancel_res.status_code == 200 + assert cancel_res.content == b'' + + @pytest.mark.neptune + def test_do_sparql_cancel_empty_query_id(self): + with pytest.raises(ValueError): + self.client.sparql_cancel('') + + @pytest.mark.neptune + def test_do_sparql_cancel_non_str_query_id(self): + with pytest.raises(ValueError): + self.client.sparql_cancel(42) + + @pytest.mark.neptune + def test_do_sparql_status_and_cancel(self): + query = "SELECT * WHERE { ?s ?p ?o . ?s2 ?p2 ?o2 .?s3 ?p3 ?o3 . ?s4 ?s5 ?s6 .} ORDER BY DESC(?s)" + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(long_running_sparql_query, self.client, query) + time.sleep(1) + + status = self.client.sparql_status() + status_res = status.json() + assert 'acceptedQueryCount' in status_res + assert 'runningQueryCount' in status_res + assert 'queries' in status_res + + time.sleep(1) + + query_id = '' + for q in status_res['queries']: + if query in q['queryString']: + query_id = q['queryId'] + + self.assertNotEqual(query_id, '') + + cancel = self.client.sparql_cancel(query_id, False) + cancel_res = cancel.json() + + assert 'acceptedQueryCount' in cancel_res + assert 'acceptedQueryCount' in cancel_res + assert 'runningQueryCount' in cancel_res + assert 'queries' in cancel_res + + res = future.result() + assert res.status_code == 500 + raw = res.json() + assert raw['code'] == 'CancelledByUserException' + assert raw['detailedMessage'] == 'Operation terminated (cancelled by user)' + + @pytest.mark.neptune + def test_do_sparql_status_and_cancel_silently(self): + query = "SELECT * WHERE { ?s ?p ?o . ?s2 ?p2 ?o2 .?s3 ?p3 ?o3 . ?s4 ?s5 ?s6 .} ORDER BY DESC(?s)" + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(long_running_sparql_query, self.client, query) + time.sleep(1) + + status = self.client.sparql_status() + status_res = status.json() + assert 'acceptedQueryCount' in status_res + assert 'runningQueryCount' in status_res + assert 'queries' in status_res + + time.sleep(1) + + query_id = '' + for q in status_res['queries']: + if query in q['queryString']: + query_id = q['queryId'] + + assert query_id != '' + + cancel = self.client.sparql_cancel(query_id, True) + cancel_res = cancel.json() + assert 'acceptedQueryCount' in cancel_res + assert 'runningQueryCount' in cancel_res + assert 'queries' in cancel_res + + res = future.result() + query_res = res.json() + assert type(query_res) is dict + assert 's3' in query_res['head']['vars'] + assert 'p3' in query_res['head']['vars'] + assert 'o3' in query_res['head']['vars'] + assert [] == query_res['results']['bindings'] diff --git a/test/integration/gremlin/__init__.py b/test/integration/without_iam/status/__init__.py similarity index 100% rename from test/integration/gremlin/__init__.py rename to test/integration/without_iam/status/__init__.py diff --git a/test/integration/status/status_without_iam.py b/test/integration/without_iam/status/test_status_without_iam.py similarity index 71% rename from test/integration/status/status_without_iam.py rename to test/integration/without_iam/status/test_status_without_iam.py index d72ce1d3..2b20e6de 100644 --- a/test/integration/status/status_without_iam.py +++ b/test/integration/without_iam/status/test_status_without_iam.py @@ -2,13 +2,15 @@ Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. SPDX-License-Identifier: Apache-2.0 """ - -from graph_notebook.status.get_status import get_status +import pytest from test.integration import IntegrationTest class TestStatusWithoutIAM(IntegrationTest): + + @pytest.mark.neptune def test_do_status(self): - status = get_status(self.host, self.port, self.ssl) + res = self.client.status() + status = res.json() self.assertEqual(status['status'], 'healthy') diff --git a/test/integration/without_iam/system/__init__.py b/test/integration/without_iam/system/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/integration/without_iam/system/test_system_without_iam.py b/test/integration/without_iam/system/test_system_without_iam.py new file mode 100644 index 00000000..c171e96a --- /dev/null +++ b/test/integration/without_iam/system/test_system_without_iam.py @@ -0,0 +1,18 @@ +import pytest +from test.integration import IntegrationTest + + +class TestStatusWithoutIAM(IntegrationTest): + + @pytest.mark.neptune + def test_do_database_reset_initiate(self): + res = self.client.initiate_reset() + result = res.json() + self.assertNotEqual(result['payload']['token'], '') + + @pytest.mark.neptune + def test_do_database_reset_perform_with_wrong_token(self): + res = self.client.perform_reset('invalid') + assert res.status_code == 400 + expected_message = "System command parameter 'token' : 'invalid' does not match database reset token" + assert expected_message == res.json()['detailedMessage'] diff --git a/test/unit/configuration/test_configuration.py b/test/unit/configuration/test_configuration.py index a845ce4a..0bbb4d50 100644 --- a/test/unit/configuration/test_configuration.py +++ b/test/unit/configuration/test_configuration.py @@ -7,9 +7,7 @@ import unittest from graph_notebook.configuration.get_config import get_config -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider -from graph_notebook.configuration.generate_config import Configuration, DEFAULT_AUTH_MODE, \ - DEFAULT_IAM_CREDENTIALS_PROVIDER, AuthModeEnum, generate_config +from graph_notebook.configuration.generate_config import Configuration, DEFAULT_AUTH_MODE, AuthModeEnum, generate_config class TestGenerateConfiguration(unittest.TestCase): @@ -28,25 +26,21 @@ def test_configuration_default_auth_defaults(self): self.assertEqual(self.host, config.host) self.assertEqual(self.port, config.port) self.assertEqual(DEFAULT_AUTH_MODE, config.auth_mode) - self.assertEqual(DEFAULT_IAM_CREDENTIALS_PROVIDER, config.iam_credentials_provider_type) self.assertEqual(True, config.ssl) self.assertEqual('', config.load_from_s3_arn) def test_configuration_override_defaults(self): auth_mode = AuthModeEnum.IAM - credentials_provider = IAMAuthCredentialsProvider.ENV ssl = False loader_arn = 'foo' - config = Configuration(self.host, self.port, auth_mode, credentials_provider, loader_arn, ssl) + config = Configuration(self.host, self.port, auth_mode, loader_arn, ssl) self.assertEqual(auth_mode, config.auth_mode) - self.assertEqual(credentials_provider, config.iam_credentials_provider_type) self.assertEqual(ssl, config.ssl) self.assertEqual(loader_arn, config.load_from_s3_arn) def test_generate_configuration_with_defaults(self): config = Configuration(self.host, self.port) c = generate_config(config.host, config.port, config.auth_mode, config.ssl, - config.iam_credentials_provider_type, config.load_from_s3_arn, config.aws_region) c.write_to_file(self.test_file_path) config_from_file = get_config(self.test_file_path) @@ -54,14 +48,12 @@ def test_generate_configuration_with_defaults(self): def test_generate_configuration_override_defaults(self): auth_mode = AuthModeEnum.IAM - credentials_provider = IAMAuthCredentialsProvider.ENV ssl = False loader_arn = 'foo' aws_region = 'us-west-2' - config = Configuration(self.host, self.port, auth_mode, credentials_provider, loader_arn, ssl, aws_region) + config = Configuration(self.host, self.port, auth_mode, loader_arn, ssl, aws_region) c = generate_config(config.host, config.port, config.auth_mode, config.ssl, - config.iam_credentials_provider_type, config.load_from_s3_arn, config.aws_region) c.write_to_file(self.test_file_path) config_from_file = get_config(self.test_file_path) diff --git a/test/unit/configuration/test_configuration_from_main.py b/test/unit/configuration/test_configuration_from_main.py index 5b2d6452..ebfd604c 100644 --- a/test/unit/configuration/test_configuration_from_main.py +++ b/test/unit/configuration/test_configuration_from_main.py @@ -8,7 +8,6 @@ from graph_notebook.configuration.generate_config import AuthModeEnum, Configuration from graph_notebook.configuration.get_config import get_config -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider class TestGenerateConfigurationMain(unittest.TestCase): @@ -24,11 +23,11 @@ def tearDown(self) -> None: os.remove(self.test_file_path) def test_generate_configuration_main_defaults(self): - expected_config = Configuration(self.host, self.port, AuthModeEnum.DEFAULT, IAMAuthCredentialsProvider.ROLE, '', True) + expected_config = Configuration(self.host, self.port, AuthModeEnum.DEFAULT, '', True) self.generate_config_from_main_and_test(expected_config) def test_generate_configuration_main_override_defaults(self): - expected_config = Configuration(self.host, self.port, AuthModeEnum.IAM, IAMAuthCredentialsProvider.ROLE, 'loader_arn', False) + expected_config = Configuration(self.host, self.port, AuthModeEnum.IAM, 'loader_arn', False) self.generate_config_from_main_and_test(expected_config) def test_generate_configuration_main_empty_args(self): @@ -42,7 +41,7 @@ def generate_config_from_main_and_test(self, source_config: Configuration): # This will run the main method that our install script runs on a Sagemaker notebook. # The return code should be 0, but more importantly, we need to assert that the # Configuration object we get from the resulting file is what we expect. - result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config --host "{source_config.host}" --port "{source_config.port}" --auth_mode "{source_config.auth_mode.value}" --ssl "{source_config.ssl}" --iam_credentials_provider "{source_config.iam_credentials_provider_type.value}" --load_from_s3_arn "{source_config.load_from_s3_arn}" --config_destination="{self.test_file_path}" ') + result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config --host "{source_config.host}" --port "{source_config.port}" --auth_mode "{source_config.auth_mode.value}" --ssl "{source_config.ssl}" --load_from_s3_arn "{source_config.load_from_s3_arn}" --config_destination="{self.test_file_path}" ') self.assertEqual(result, 0) config = get_config(self.test_file_path) self.assertEqual(source_config.to_dict(), config.to_dict()) diff --git a/test/unit/gremlin/__init__.py b/test/unit/gremlin/__init__.py deleted file mode 100644 index 9049dd04..00000000 --- a/test/unit/gremlin/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" diff --git a/test/unit/request_param_generator/__init__.py b/test/unit/request_param_generator/__init__.py deleted file mode 100644 index 9049dd04..00000000 --- a/test/unit/request_param_generator/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" diff --git a/test/unit/request_param_generator/test_default_request_generator.py b/test/unit/request_param_generator/test_default_request_generator.py deleted file mode 100644 index e948d2a6..00000000 --- a/test/unit/request_param_generator/test_default_request_generator.py +++ /dev/null @@ -1,32 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import unittest - -from graph_notebook.request_param_generator.default_request_generator import DefaultRequestGenerator - - -class TestDefaultRequestGenerator(unittest.TestCase): - def test_generate_request_params(self): - method = 'post' - action = 'foo' - query = { - 'bar': 'baz' - } - host = 'host_endpoint' - port = 8182 - protocol = 'https' - headers = { - 'header1': 'header_value_1' - } - - rpg = DefaultRequestGenerator() - request_params = rpg.generate_request_params(method, action, query, host, port, protocol, headers) - - expected_url = f'{protocol}://{host}:{port}/{action}' - self.assertEqual(request_params['method'], method) - self.assertEqual(request_params['url'], expected_url) - self.assertEqual(request_params['headers'], headers) - self.assertEqual(request_params['params'], query) diff --git a/test/unit/request_param_generator/test_factory_generator.py b/test/unit/request_param_generator/test_factory_generator.py deleted file mode 100644 index 6878cb6d..00000000 --- a/test/unit/request_param_generator/test_factory_generator.py +++ /dev/null @@ -1,33 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import unittest - -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider -from graph_notebook.authentication.iam_credentials_provider.env_credentials_provider import EnvCredentialsProvider -from graph_notebook.request_param_generator.default_request_generator import DefaultRequestGenerator -from graph_notebook.request_param_generator.factory import create_request_generator -from graph_notebook.request_param_generator.iam_request_generator import IamRequestGenerator -from graph_notebook.request_param_generator.sparql_request_generator import SPARQLRequestGenerator - - -class TestRequestParamGeneratorFactory(unittest.TestCase): - def test_create_request_generator_sparql(self): - mode = AuthModeEnum.DEFAULT - command = 'sparql' - rpg = create_request_generator(mode, command=command) - self.assertEqual(SPARQLRequestGenerator, type(rpg)) - - def test_create_request_generator_default(self): - mode = AuthModeEnum.DEFAULT - rpg = create_request_generator(mode) - self.assertEqual(DefaultRequestGenerator, type(rpg)) - - def test_create_request_generator_iam_env(self): - mode = AuthModeEnum.IAM - rpg = create_request_generator(mode, IAMAuthCredentialsProvider.ENV) - self.assertEqual(IamRequestGenerator, type(rpg)) - self.assertEqual(EnvCredentialsProvider, type(rpg.credentials_provider)) diff --git a/test/unit/request_param_generator/test_sparql_request_generator.py b/test/unit/request_param_generator/test_sparql_request_generator.py deleted file mode 100644 index 3584a7c5..00000000 --- a/test/unit/request_param_generator/test_sparql_request_generator.py +++ /dev/null @@ -1,58 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import unittest - -from graph_notebook.request_param_generator.sparql_request_generator import SPARQLRequestGenerator - - -class TestSparqlRequestGenerator(unittest.TestCase): - def test_generate_request_params(self): - method = 'post' - action = 'foo' # action is a no-op since we know it is sparql - query = { - 'bar': 'baz' - } - host = 'host_endpoint' - port = 8182 - protocol = 'https' - headers = { - 'header1': 'header_value_1' - } - - rpg = SPARQLRequestGenerator() - request_params = rpg.generate_request_params(method, action, query, host, port, protocol, headers) - expected_headers = { - 'header1': 'header_value_1', - 'Content-Type': 'application/x-www-form-urlencoded' - } - - expected_url = f'{protocol}://{host}:{port}/{action}' - self.assertEqual(request_params['method'], method) - self.assertEqual(request_params['url'], expected_url) - self.assertEqual(request_params['headers'], expected_headers) - self.assertEqual(request_params['params'], query) - - def test_generate_request_params_no_headers(self): - method = 'post' - action = 'foo' # action is a no-op since we know it is sparql - query = { - 'bar': 'baz' - } - host = 'host_endpoint' - port = 8182 - protocol = 'https' - - rpg = SPARQLRequestGenerator() - request_params = rpg.generate_request_params(method, action, query, host, port, protocol, headers=None) - expected_headers = { - 'Content-Type': 'application/x-www-form-urlencoded' - } - - expected_url = f'{protocol}://{host}:{port}/{action}' - self.assertEqual(request_params['method'], method) - self.assertEqual(request_params['url'], expected_url) - self.assertEqual(request_params['headers'], expected_headers) - self.assertEqual(request_params['params'], query) diff --git a/test/unit/sparql/test_sparql.py b/test/unit/sparql/test_sparql.py index 03d2a4b2..a4bb97cc 100644 --- a/test/unit/sparql/test_sparql.py +++ b/test/unit/sparql/test_sparql.py @@ -5,7 +5,7 @@ import unittest -from graph_notebook.sparql.query import get_query_type, query_type_to_action +from graph_notebook.magics.graph_magic import get_query_type, query_type_to_action class TestSparql(unittest.TestCase):