diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..decbfee --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,25 @@ +{ + "terminal.integrated.env.osx": { + "PYTHONPATH": "${workspaceFolder}" + }, + "terminal.integrated.env.windows": { + "PYTHONPATH": "${workspaceFolder}" + }, + "files.insertFinalNewline": true, + "[javascript]": { + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.fixAll.eslint": "always" + }, + }, + "[python]": { + "editor.codeActionsOnSave": { + "source.fixAll": "explicit", + "source.organizeImports": "explicit" + }, + "editor.formatOnSave": true, + "editor.defaultFormatter": "charliermarsh.ruff", + }, + "python.languageServer": "Pylance", + "python.analysis.typeCheckingMode": "strict", +} diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..5a6752a --- /dev/null +++ b/mypy.ini @@ -0,0 +1,14 @@ +[mypy] +check_untyped_defs = True +disallow_untyped_calls = True +disallow_untyped_decorators = True +disallow_untyped_defs = True +ignore_missing_imports = True +show_column_numbers = True +strict = True +strict_bytes = True +strict_equality = True +warn_redundant_casts = True +warn_return_any = True +warn_unreachable = True +warn_unused_ignores = True diff --git a/pyproject.toml b/pyproject.toml index 07c6274..bb8deb2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,35 +1,52 @@ [build-system] -#https://python-poetry.org/docs/pyproject -requires = ["poetry-core>=1.0.0"] +requires = ["poetry-core>=2.0.0,<3.0.0"] build-backend = "poetry.core.masonry.api" -[tool.poetry] +[project] name = "pysafeguard" description = "One Identity Safeguard Python Package" version = "7.4.0" -readme = "README.md" -keywords = ["safeguard","oneidentity"] +readme = { file = "README.md", content-type = "text/markdown" } +keywords = ["safeguard", "oneidentity"] +license = "Apache" repository = "https://github.com/OneIdentity/PySafeguard" authors = [ - "Tania Engel ", + { name = "Tania Engel", email = "Tania.Engel@oneidentity.com" } ] maintainers = [ - "Stephanie Zinn " + { name = "Stephanie Zinn", email = "Stephanie.Zinn@oneidentity.com" } ] +requires-python = ">=3.10" classifiers = [ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", "Programming Language :: Python :: 3", - "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Programming Language :: Python :: 3 :: Only", "Operating System :: OS Independent", ] +dependencies = [ + "requests >= 2.32", + "truststore >= 0.10", + "typing_extensions >= 4.15; python_version < '3.11'", +] -[tool.poetry.urls] -"Bug Tracker" = "https://github.com/OneIdentity/PySafeguard/issues" - -[tool.poetry.dependencies] -python = "^3.7" -requests = "^2.28.1" -#signalrcore is optional because it is only imported for SignalR functionality -signalrcore = { version="^0.9.5", optional = true } +[project.optional-dependencies] +async = [ + "aiohttp >= 3.13", +] +signalr = [ + "signalrcore >= 0.9", +] +dev = [ + "types-requests >= 2.32", + "mypy >= 1.19", + "ruff >= 0.14", +] -[tool.poetry.extras] -signalr = ["signalrcore"] \ No newline at end of file +[project.urls] +"Bug Tracker" = "https://github.com/OneIdentity/PySafeguard/issues" diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..7a3f6a3 --- /dev/null +++ b/ruff.toml @@ -0,0 +1 @@ +line-length = 160 diff --git a/src/pysafeguard.py b/src/pysafeguard.py deleted file mode 100644 index ef2216b..0000000 --- a/src/pysafeguard.py +++ /dev/null @@ -1,260 +0,0 @@ -import requests -import os -import json -from requests.structures import CaseInsensitiveDict -from urllib.parse import urlunparse,urlencode -from enum import Enum - -class Services: - CORE = 'service/core' - APPLIANCE = 'service/appliance' - NOTIFICATION = 'service/notification' - A2A = 'service/a2a' - EVENT = 'service/event' - RSTS = 'RSTS' - -class HttpMethods: - GET = requests.get - POST = requests.post - PUT = requests.put - DELETE = requests.delete - -class A2ATypes: - PASSWORD = "password" - PRIVATEKEY = "privatekey" - APIKEYSECRET = "apikey" - -class SshKeyFormats: - OPENSSH = "openssh" - SSH2 = "ssh2" - PUTTY = "putty" - -class WebRequestError(Exception): - def __init__(self, req): - self.req = req - self.message = '{} {}: {} {}\n{}'.format(req.status_code,req.reason,req.request.method,req.url,req.text) - super().__init__(self.message) - -def _assemble_path(*args): - return '/'.join(map(lambda x: str(x).strip('/'), filter(None, args))) - -def _assemble_url(netloc='',path='',query={},fragment='',scheme='https'): - return urlunparse((scheme,netloc,path,'',urlencode(query,True),fragment)) - -def _create_merging_thing(cls): - def _inner_merge(*args,**kwargs): - return cls(sum(map(lambda x: list(x.items()), args+(kwargs,)),[])) - return _inner_merge - -_merge_dict = _create_merging_thing(dict) -_merge_idict = _create_merging_thing(CaseInsensitiveDict) - -class PySafeguardConnection: - - def __init__(self, host, verify=True, apiVersion='v4'): - """Initialize a Safeguard connection object - - Arguments: - host -- the appliance hostname - verify -- A path to a file with CA certificate information or - False to disable verification - apiVersion -- The version of the API with which to connect - """ - self.host = host - self.UserToken = None - self.apiVersion = apiVersion - self.req_globals = dict(verify=verify,cert=None) - self.headers = CaseInsensitiveDict({'Accept':'application/json'}) - - @staticmethod - def __execute_web_request(httpMethod, url, body, headers, verify, cert): - bodystyle = dict(data=body) - if body and httpMethod in [HttpMethods.POST, HttpMethods.PUT] and not headers.get('content-type'): - bodystyle = dict(json=body) - headers = _merge_idict(headers, {'Content-type':'application/json'}) - with httpMethod(url, headers=headers, cert=cert, verify=verify, **bodystyle) as req: - return req - - @staticmethod - def a2a_get_credential(host, apiKey, cert, key, verify=True, a2aType=A2ATypes.PASSWORD, keyFormat=SshKeyFormats.OPENSSH, apiVersion='v4'): - '''(Public) Retrieves an application to application credential. - - Keyword arguments: - host -- Name or ip of the safeguard appliance. - apiKey -- A2A api key. - cert -- Path to the user certificate in pem format. - key -- Path to the user certificate's key in key format. - verify -- A path to a file with CA certificate information or False to disable verification - a2aType -- Type of credential to retrieve (password, privatekey). Defaults to password. - keyFormat -- The privateKeyFormat to return (openssh, ssh2, putty). Defaults to openshh. - apiVersion -- API version to use. Defaults to v4. - ''' - if not apiKey: - raise Exception("apiKey may not be null or empty") - - if not cert and not key: - raise Exception("cert path and key path may not be null or empty") - - header = { - 'Authorization': f'A2A {apiKey}' - } - query = _merge_dict(dict(type=a2aType), dict(keyFormat=keyFormat) if a2aType == A2ATypes.PRIVATEKEY else {}) - credential = PySafeguardConnection.__execute_web_request(HttpMethods.GET, _assemble_url(host, _assemble_path(Services.A2A, apiVersion, "Credentials"), query), body={}, headers=header, verify=verify, cert=(cert, key)) - if credential.status_code != 200: - raise WebRequestError(credential) - return credential.json() - - def get_provider_id(self, name): - """Get an authentication provider by name to use when authenticating - - Arguments: - name -- The name of a configured provider - - Returns: - A string value which is the ID of a configured provider - """ - req = self.invoke(HttpMethods.GET, Services.CORE, 'AuthenticationProviders') - providers = req.json() - matches = list(filter(lambda x: name.upper() == x['Name'].upper(), providers)) - if matches: - return matches[0]['RstsProviderId'] - else: - raise Exception('Unable to find Provider with Name {} in\n{}'.format(name,json.dumps(providers,indent=2,sort_keys=True))) - - def __connect(self, body, *args, **kwargs): - req = self.invoke(HttpMethods.POST, Services.RSTS, 'oauth2/token', body=body, *args, **kwargs) - if req.status_code == 200 and 'application/json' in req.headers.get('Content-type',''): - data = req.json() - req = self.invoke(HttpMethods.POST, Services.CORE, 'Token/LoginResponse', body=dict(StsAccessToken=data.get('access_token'))) - if req.status_code == 200 and 'application/json' in req.headers.get('Content-type',''): - data = req.json() - self.connect_token(data.get('UserToken')) - else: - raise WebRequestError(req) - else: - raise WebRequestError(req) - - def connect_password(self, username, password, provider='local'): - """Obtain a token using username and password - used when connecting - - Arguments: - username -- The username of an authorized user - password -- The password for the user - provider -- An authentication provider ID associated with user - """ - body = { - 'scope': 'rsts:sts:primaryproviderid:{}'.format(provider), - 'grant_type': 'password', - 'username': username, - 'password': password - } - self.__connect(body) - - def connect_certificate(self, certFile, keyFile, provider='certificate'): - """Obtain a token using certificate and key file - used when connecting - - Arguments: - certFile -- path to the client certificate - keyFile -- path to the key for the certificate - provider -- An authentication provider ID associated with certificate - """ - body = { - 'scope': 'rsts:sts:primaryproviderid:{}'.format(provider), - 'grant_type': 'client_credentials' - } - self.__connect(body,cert=(certFile,keyFile)) - - def connect_token(self, token): - """Use an existing token"" - - Arguments: - token -- The user token - """ - self.UserToken = token - self.headers.update(Authorization='Bearer {}'.format(self.UserToken)) - - def invoke(self, httpMethod, httpService, endpoint=None, query={}, body=None, additionalHeaders={}, host=None, cert=None, apiVersion=None): - """Invoke a web request against the Safeguard API - - Arguments: - httpMethod -- One of the predefined HttpMethods - httpService -- One of the predefined Services - endpoint -- The path of an API endpoint to use (e.g. 'Users', 'Assets') - query -- A dictionary of query parameters that are added to endpoint - body -- The data that is sent in the request. Usually a dictionary. - headers -- Headers that are added to the request - host -- The host to which the request is made (useful for clusters) - cert -- A 2-tuple of the certificate and key - apiVersion -- Which version of the API to use in this request - - Returns: - Request Response object. - """ - url = _assemble_url(host or self.host, _assemble_path(httpService, (apiVersion or self.apiVersion) if httpService != Services.RSTS else '', endpoint), query) - merged_headers = _merge_idict(self.headers, additionalHeaders) - return PySafeguardConnection.__execute_web_request(httpMethod, url, body, merged_headers, **_merge_dict(self.req_globals, cert=cert)) - - def get_remaining_token_lifetime(self): - """Get the remaining time left on the access token - - Returns: - integer value in minutes - """ - req = self.invoke(HttpMethods.GET, Services.APPLIANCE, 'SystemTime') - return req.headers.get('X-tokenlifetimeremaining') - - @staticmethod - def __register_signalr(host, callback, options, verify): - """Register a SignalR callback and start listening.""" - from signalrcore.hub_connection_builder import HubConnectionBuilder - if not callback: - raise Exception("A callback must be specified to register for the SignalR events.") - options.update({'verify_ssl':verify}) - server_url = _assemble_url(host, _assemble_path(Services.EVENT, 'signalr')) - hub_connection = HubConnectionBuilder() \ - .with_url(server_url, options=options) \ - .with_automatic_reconnect({ - "type": "raw", - "keep_alive_interval": 10, - "reconnect_interval": 10, - "max_attempts": 5 - }).build() - - hub_connection.on("ReceiveMessage", callback) - hub_connection.on("NotifyEventAsync", callback) - hub_connection.on_open(lambda: print("connection opened and handshake received ready to send messages")) - hub_connection.on_close(lambda: print("connection closed")) - hub_connection.start() - - @staticmethod - def register_signalr_username(conn, callback, username, password): - """Wrapper to register a SignalR callback using username/password authentication. - - Arguments: - conn -- PySafeguardConnection instance object - callback -- Callback function to handle messages that come back - username -- Username for authentication - password -- Password for authentication - """ - def _token_factory_username(): - conn.connect_password(username, password) - return conn.UserToken - options = {"access_token_factory": _token_factory_username} - PySafeguardConnection.__register_signalr(conn.host, callback, options, bool(conn.req_globals.get('verify',True))) - - @staticmethod - def register_signalr_certificate(conn, callback, certfile, keyfile): - """Wrapper to register a SignalR callback using certificate authentication. - - Arguments: - conn -- PySafeguardConnection instance object - callback -- Callback function to handle messages that come back - certfile -- Path to the user certificate in pem format. - keyfile -- Path to the user certificate's key in key format. - """ - def _token_factory_certificate(): - conn.connect_certificate(certfile, keyfile, provider="certificate") - return conn.UserToken - options = options={"access_token_factory": _token_factory_certificate} - PySafeguardConnection.__register_signalr(conn.host, callback, options, bool(conn.req_globals.get('verify',True))) - diff --git a/src/pysafeguard/__init__.py b/src/pysafeguard/__init__.py new file mode 100644 index 0000000..d225b00 --- /dev/null +++ b/src/pysafeguard/__init__.py @@ -0,0 +1,70 @@ +# mypy: ignore-errors +# type: ignore + +from .connection import Connection +from .connection import WebRequestError as WebRequestError +from .data_types import A2ATypes as A2ATypes +from .data_types import HttpMethods as HttpMethods +from .data_types import Services +from .data_types import SshKeyFormats as SshKeyFormats +from .utility import assemble_path, assemble_url + + +class PySafeguardConnection(Connection): + @staticmethod + def __register_signalr(host, callback, options, verify): + """Register a SignalR callback and start listening.""" + from signalrcore.hub_connection_builder import HubConnectionBuilder + + if not callback: + raise Exception("A callback must be specified to register for the SignalR events.") + options.update({"verify_ssl": verify}) + server_url = assemble_url(host, assemble_path(Services.EVENT, "signalr")) + hub_connection = ( + HubConnectionBuilder() + .with_url(server_url, options=options) + .with_automatic_reconnect({"type": "raw", "keep_alive_interval": 10, "reconnect_interval": 10, "max_attempts": 5}) + .build() + ) + + hub_connection.on("ReceiveMessage", callback) + hub_connection.on("NotifyEventAsync", callback) + hub_connection.on_open(lambda: print("connection opened and handshake received ready to send messages")) + hub_connection.on_close(lambda: print("connection closed")) + hub_connection.start() + + @staticmethod + def register_signalr_username(conn, callback, username, password): + """Wrapper to register a SignalR callback using username/password authentication. + + Arguments: + conn -- PySafeguardConnection instance object + callback -- Callback function to handle messages that come back + username -- Username for authentication + password -- Password for authentication + """ + + def _token_factory_username(): + conn.connect_password(username, password) + return conn.UserToken + + options = {"access_token_factory": _token_factory_username} + PySafeguardConnection.__register_signalr(conn.host, callback, options, bool(conn.req_globals.get("verify", True))) + + @staticmethod + def register_signalr_certificate(conn, callback, certfile, keyfile): + """Wrapper to register a SignalR callback using certificate authentication. + + Arguments: + conn -- PySafeguardConnection instance object + callback -- Callback function to handle messages that come back + certfile -- Path to the user certificate in pem format. + keyfile -- Path to the user certificate's key in key format. + """ + + def _token_factory_certificate(): + conn.connect_certificate(certfile, keyfile, provider="certificate") + return conn.UserToken + + options = options = {"access_token_factory": _token_factory_certificate} + PySafeguardConnection.__register_signalr(conn.host, callback, options, bool(conn.req_globals.get("verify", True))) diff --git a/src/pysafeguard/async_connection.py b/src/pysafeguard/async_connection.py new file mode 100644 index 0000000..12a6af4 --- /dev/null +++ b/src/pysafeguard/async_connection.py @@ -0,0 +1,253 @@ +import json +import ssl +import typing +from collections.abc import Mapping + +from aiohttp import ClientResponse, ClientSession +from multidict import CIMultiDict +from truststore import SSLContext + +from .data_types import A2ATypes, HttpMethods, JsonType, Services, SshKeyFormats +from .utility import LiteralString, assemble_path, assemble_url + + +class AsyncWebRequestError(Exception): + def __init__(self, resp: ClientResponse) -> None: + self.req = resp + self.message = f"{resp.status} {resp.reason}: {resp.method} {resp.url}" + super().__init__(self.message) + + +class AsyncConnection: + host: str | None + UserToken: str | None + apiVersion: str + verify: bool | str + headers: CIMultiDict[str] + + def __init__(self, host: str | None, verify: bool | str = True, apiVersion: LiteralString = "v4") -> None: + """ + Initialize a Safeguard connection object + + :param host: The appliance hostname. + :param verify: A path to a file with CA certificate information or `False` to disable verification. + :param apiVersion: The version of the API with which to connect. + """ + + self.host = host + self.UserToken = None + self.apiVersion = apiVersion + self.verify = verify + self.headers = CIMultiDict({"accept": "application/json"}) + + @staticmethod + async def __execute_web_request( + httpMethod: HttpMethods, url: str, body: JsonType | str | None, headers: Mapping[str, str], verify: str | bool, cert: tuple[str, str] | None + ) -> ClientResponse: + data_body: str | None + json_body: JsonType | None + updated_headers = CIMultiDict(headers) + if body and httpMethod in [HttpMethods.POST, HttpMethods.PUT] and not headers.get("content-type"): + data_body = None + if not isinstance(body, dict): + raise TypeError("expected: body as a JSON object") + json_body = body + updated_headers["content-type"] = "application/json" + else: + if body is not None and not isinstance(body, str): + raise TypeError("expected: body as a string") + data_body = body + json_body = None + + ctx = SSLContext(ssl.PROTOCOL_TLS_CLIENT) + if cert is not None: + certfile, keyfile = cert + ctx.load_cert_chain(certfile, keyfile) + elif isinstance(verify, str): + ctx.load_cert_chain(verify) + + async with ClientSession() as session: + async with session.request(httpMethod, url, headers=updated_headers, ssl=ctx, data=data_body, json=json_body) as resp: + await resp.read() + return resp + + @classmethod + async def a2a_get_credential( + cls, + host: str, + apiKey: str, + cert: str, + key: str, + verify: str | bool = True, + a2aType: A2ATypes = A2ATypes.PASSWORD, + keyFormat: SshKeyFormats = SshKeyFormats.OPENSSH, + apiVersion: LiteralString = "v4", + ) -> JsonType: + """ + (Public) Retrieves an application to application credential. + + :param host: Name or ip of the safeguard appliance. + :param apiKey: A2A api key. + :param cert: Path to the user certificate in pem format. + :param key: Path to the user certificate's key in key format. + :param verify: A path to a file with CA certificate information or False to disable verification + :param a2aType: Type of credential to retrieve (password, privatekey). Defaults to password. + :param keyFormat: The privateKeyFormat to return (openssh, ssh2, putty). Defaults to openshh. + :param apiVersion: API version to use. Defaults to v4. + """ + + if not apiKey: + raise Exception("apiKey may not be null or empty") + + if not cert and not key: + raise Exception("cert path and key path may not be null or empty") + + headers = CIMultiDict({"authorization": f"A2A {apiKey}"}) + query: dict[str, str] = {"type": a2aType} + if a2aType == A2ATypes.PRIVATEKEY: + query["keyFormat"] = keyFormat + + resp = await cls.__execute_web_request( + HttpMethods.GET, + assemble_url(host, assemble_path(Services.A2A, apiVersion, "Credentials"), query), + body=None, + headers=headers, + verify=verify, + cert=(cert, key), + ) + if resp.status != 200: + raise AsyncWebRequestError(resp) + return typing.cast(JsonType, await resp.json()) + + async def get_provider_id(self, name: str) -> str: + """ + Get an authentication provider by name to use when authenticating. + + :param name: The name of a configured provider. + :returns: A string value which is the ID of a configured provider. + """ + + resp = await self.invoke(HttpMethods.GET, Services.CORE, "AuthenticationProviders") + providers = await resp.json() + matches = [provider for provider in providers if name.upper() == typing.cast(str, provider["Name"]).upper()] + if not matches: + raise Exception("Unable to find Provider with Name {} in\n{}".format(name, json.dumps(providers, indent=2, sort_keys=True))) + + return typing.cast(str, matches[0]["RstsProviderId"]) + + async def __connect(self, body: JsonType, cert: tuple[str, str] | None = None) -> None: + data: JsonType + resp = await self.invoke(HttpMethods.POST, Services.RSTS, "oauth2/token", body=body, cert=cert) + if resp.status == 200 and "application/json" in resp.headers.get("content-type", ""): + data = await resp.json() + if not isinstance(data, dict): + raise TypeError("expected: JSON object with field `access_token`") + access_token = data.get("access_token") + + resp = await self.invoke(HttpMethods.POST, Services.CORE, "Token/LoginResponse", body=dict(StsAccessToken=access_token)) + if resp.status == 200 and "application/json" in resp.headers.get("content-type", ""): + data = await resp.json() + if not isinstance(data, dict): + raise TypeError("expected: JSON object with field `UserToken`") + user_token = typing.cast(str, data.get("UserToken")) + self.connect_token(user_token) + else: + raise AsyncWebRequestError(resp) + else: + raise AsyncWebRequestError(resp) + + async def connect_password(self, username: str, password: str, provider: str = "local") -> None: + """ + Obtain a token using username and password - used when connecting. + + :param username: The username of an authorized user. + :param password: The password for the user. + :param provider: An authentication provider ID associated with user. + """ + + body: JsonType = { + "scope": f"rsts:sts:primaryproviderid:{provider}", + "grant_type": "password", + "username": username, + "password": password, + } + await self.__connect(body) + + async def connect_certificate(self, certFile: str, keyFile: str, provider: str = "certificate") -> None: + """ + Obtain a token using certificate and key file - used when connecting. + + :param certFile: Path to the client certificate. + :param keyFile: Path to the key for the certificate. + :param provider: An authentication provider ID associated with certificate. + """ + + body: JsonType = { + "scope": f"rsts:sts:primaryproviderid:{provider}", + "grant_type": "client_credentials", + } + await self.__connect(body, cert=(certFile, keyFile)) + + def connect_token(self, token: str | None) -> None: + """ + Use an existing token. + + :param token: The user token. + """ + + self.UserToken = token + self.headers.update(authorization="Bearer {}".format(self.UserToken)) + + async def invoke( + self, + httpMethod: HttpMethods, + httpService: Services, + endpoint: str | None = None, + query: Mapping[str, str] = {}, + body: JsonType | None = None, + additionalHeaders: Mapping[str, str] = {}, + host: str | None = None, + cert: tuple[str, str] | None = None, + apiVersion: str | None = None, + ) -> ClientResponse: + """ + Invoke a web request against the Safeguard API. + + :param httpMethod: One of the predefined `HttpMethods`. + :param httpService: One of the predefined `Services`. + :param endpoint: The path of an API endpoint to use (e.g. 'Users', 'Assets'). + :param query: A dictionary of query parameters that are added to endpoint. + :param body: The data that is sent in the request. Usually a dictionary. + :param headers: Headers that are added to the request. + :param host: The host to which the request is made (useful for clusters). + :param cert: A 2-tuple of the certificate and key. + :param apiVersion: Which version of the API to use in this request. + :returns: Request `Response` object. + """ + + url = assemble_url( + host or self.host or "", + assemble_path( + httpService, + (apiVersion or self.apiVersion) if httpService != Services.RSTS else "", + endpoint, + ), + query, + ) + headers = CIMultiDict(self.headers) + headers.update(additionalHeaders) + return await self.__execute_web_request(httpMethod, url, body, headers, verify=self.verify, cert=cert) + + async def get_remaining_token_lifetime(self) -> int | None: + """ + Get the remaining time left on the access token. + + :returns: An integer value in minutes. + """ + + resp = await self.invoke(HttpMethods.GET, Services.APPLIANCE, "SystemTime") + remaining = resp.headers.get("x-tokenlifetimeremaining") + if remaining is not None: + return int(remaining, base=10) + else: + return None diff --git a/src/pysafeguard/connection.py b/src/pysafeguard/connection.py new file mode 100644 index 0000000..76976bf --- /dev/null +++ b/src/pysafeguard/connection.py @@ -0,0 +1,242 @@ +import json +import typing +from collections.abc import Mapping + +from requests import Response, request +from requests.structures import CaseInsensitiveDict + +from .data_types import A2ATypes, HttpMethods, JsonType, Services, SshKeyFormats +from .utility import LiteralString, assemble_path, assemble_url + + +class WebRequestError(Exception): + def __init__(self, resp: Response) -> None: + self.req = resp + self.message = f"{resp.status_code} {resp.reason}: {resp.request.method} {resp.url}\n{resp.text}" + super().__init__(self.message) + + +class Connection: + host: str | None + UserToken: str | None + apiVersion: str + verify: bool | str + headers: CaseInsensitiveDict[str] + + def __init__(self, host: str | None, verify: bool | str = True, apiVersion: LiteralString = "v4") -> None: + """ + Initialize a Safeguard connection object + + :param host: The appliance hostname. + :param verify: A path to a file with CA certificate information or `False` to disable verification. + :param apiVersion: The version of the API with which to connect. + """ + + self.host = host + self.UserToken = None + self.apiVersion = apiVersion + self.verify = verify + self.headers = CaseInsensitiveDict({"accept": "application/json"}) + + @staticmethod + def __execute_web_request( + httpMethod: HttpMethods, url: str, body: JsonType | str | None, headers: Mapping[str, str], verify: str | bool, cert: tuple[str, str] | None + ) -> Response: + data_body: str | None + json_body: JsonType | None + updated_headers = CaseInsensitiveDict(headers) + if body and httpMethod in [HttpMethods.POST, HttpMethods.PUT] and not headers.get("content-type"): + data_body = None + if not isinstance(body, dict): + raise TypeError("expected: body as a JSON object") + json_body = body + updated_headers["content-type"] = "application/json" + else: + if body is not None and not isinstance(body, str): + raise TypeError("expected: body as a string") + data_body = body + json_body = None + + with request(httpMethod, url, headers=updated_headers, cert=cert, verify=verify, data=data_body, json=json_body) as resp: + return resp + + @classmethod + def a2a_get_credential( + cls, + host: str, + apiKey: str, + cert: str, + key: str, + verify: str | bool = True, + a2aType: A2ATypes = A2ATypes.PASSWORD, + keyFormat: SshKeyFormats = SshKeyFormats.OPENSSH, + apiVersion: LiteralString = "v4", + ) -> JsonType: + """ + (Public) Retrieves an application to application credential. + + :param host: Name or ip of the safeguard appliance. + :param apiKey: A2A api key. + :param cert: Path to the user certificate in pem format. + :param key: Path to the user certificate's key in key format. + :param verify: A path to a file with CA certificate information or False to disable verification + :param a2aType: Type of credential to retrieve (password, privatekey). Defaults to password. + :param keyFormat: The privateKeyFormat to return (openssh, ssh2, putty). Defaults to openshh. + :param apiVersion: API version to use. Defaults to v4. + """ + + if not apiKey: + raise Exception("apiKey may not be null or empty") + + if not cert and not key: + raise Exception("cert path and key path may not be null or empty") + + headers = CaseInsensitiveDict({"authorization": f"A2A {apiKey}"}) + query: dict[str, str] = {"type": a2aType} + if a2aType == A2ATypes.PRIVATEKEY: + query["keyFormat"] = keyFormat + + credential = cls.__execute_web_request( + HttpMethods.GET, + assemble_url(host, assemble_path(Services.A2A, apiVersion, "Credentials"), query), + body=None, + headers=headers, + verify=verify, + cert=(cert, key), + ) + if credential.status_code != 200: + raise WebRequestError(credential) + return typing.cast(JsonType, credential.json()) + + def get_provider_id(self, name: str) -> str: + """ + Get an authentication provider by name to use when authenticating. + + :param name: The name of a configured provider. + :returns: A string value which is the ID of a configured provider. + """ + + resp = self.invoke(HttpMethods.GET, Services.CORE, "AuthenticationProviders") + providers = resp.json() + matches = [provider for provider in providers if name.upper() == typing.cast(str, provider["Name"]).upper()] + if not matches: + raise Exception("Unable to find Provider with Name {} in\n{}".format(name, json.dumps(providers, indent=2, sort_keys=True))) + + return typing.cast(str, matches[0]["RstsProviderId"]) + + def __connect(self, body: JsonType, cert: tuple[str, str] | None = None) -> None: + data: JsonType + resp = self.invoke(HttpMethods.POST, Services.RSTS, "oauth2/token", body=body, cert=cert) + if resp.status_code == 200 and "application/json" in resp.headers.get("content-type", ""): + data = resp.json() + if not isinstance(data, dict): + raise TypeError("expected: JSON object with field `access_token`") + access_token = data.get("access_token") + + resp = self.invoke(HttpMethods.POST, Services.CORE, "Token/LoginResponse", body=dict(StsAccessToken=access_token)) + if resp.status_code == 200 and "application/json" in resp.headers.get("content-type", ""): + data = resp.json() + if not isinstance(data, dict): + raise TypeError("expected: JSON object with field `UserToken`") + user_token = typing.cast(str, data.get("UserToken")) + self.connect_token(user_token) + else: + raise WebRequestError(resp) + else: + raise WebRequestError(resp) + + def connect_password(self, username: str, password: str, provider: str = "local") -> None: + """ + Obtain a token using username and password - used when connecting. + + :param username: The username of an authorized user. + :param password: The password for the user. + :param provider: An authentication provider ID associated with user. + """ + + body: JsonType = { + "scope": f"rsts:sts:primaryproviderid:{provider}", + "grant_type": "password", + "username": username, + "password": password, + } + self.__connect(body) + + def connect_certificate(self, certFile: str, keyFile: str, provider: str = "certificate") -> None: + """ + Obtain a token using certificate and key file - used when connecting. + + :param certFile: Path to the client certificate. + :param keyFile: Path to the key for the certificate. + :param provider: An authentication provider ID associated with certificate. + """ + + body: JsonType = { + "scope": f"rsts:sts:primaryproviderid:{provider}", + "grant_type": "client_credentials", + } + self.__connect(body, cert=(certFile, keyFile)) + + def connect_token(self, token: str | None) -> None: + """ + Use an existing token. + + :param token: The user token. + """ + + self.UserToken = token + self.headers.update(authorization="Bearer {}".format(self.UserToken)) + + def invoke( + self, + httpMethod: HttpMethods, + httpService: Services, + endpoint: str | None = None, + query: Mapping[str, str] = {}, + body: JsonType | None = None, + additionalHeaders: Mapping[str, str] = {}, + host: str | None = None, + cert: tuple[str, str] | None = None, + apiVersion: str | None = None, + ) -> Response: + """ + Invoke a web request against the Safeguard API. + + :param httpMethod: One of the predefined `HttpMethods`. + :param httpService: One of the predefined `Services`. + :param endpoint: The path of an API endpoint to use (e.g. 'Users', 'Assets'). + :param query: A dictionary of query parameters that are added to endpoint. + :param body: The data that is sent in the request. Usually a dictionary. + :param headers: Headers that are added to the request. + :param host: The host to which the request is made (useful for clusters). + :param cert: A 2-tuple of the certificate and key. + :param apiVersion: Which version of the API to use in this request. + :returns: Request `Response` object. + """ + + url = assemble_url( + host or self.host or "", + assemble_path( + httpService, + (apiVersion or self.apiVersion) if httpService != Services.RSTS else "", + endpoint, + ), + query, + ) + headers = CaseInsensitiveDict(self.headers) + headers.update(additionalHeaders) + return self.__execute_web_request(httpMethod, url, body, headers, verify=self.verify, cert=cert) + + def get_remaining_token_lifetime(self) -> int | None: + """ + Get the remaining time left on the access token. + + :returns: An integer value in minutes. + """ + + resp = self.invoke(HttpMethods.GET, Services.APPLIANCE, "SystemTime") + remaining = resp.headers.get("x-tokenlifetimeremaining") + if remaining is not None: + return int(remaining, base=10) + else: + return None diff --git a/src/pysafeguard/data_types.py b/src/pysafeguard/data_types.py new file mode 100644 index 0000000..75e4ebb --- /dev/null +++ b/src/pysafeguard/data_types.py @@ -0,0 +1,40 @@ +import enum +import sys + +JsonType = None | bool | int | float | str | dict[str, "JsonType"] | list["JsonType"] + + +if sys.version_info < (3, 11): + + class StrEnum(str, enum.Enum): + pass +else: + from enum import StrEnum + + +class Services(StrEnum): + CORE = "service/core" + APPLIANCE = "service/appliance" + NOTIFICATION = "service/notification" + A2A = "service/a2a" + EVENT = "service/event" + RSTS = "RSTS" + + +class HttpMethods(StrEnum): + GET = "GET" + POST = "POST" + PUT = "PUT" + DELETE = "DELETE" + + +class A2ATypes(StrEnum): + PASSWORD = "password" + PRIVATEKEY = "privatekey" + APIKEYSECRET = "apikey" + + +class SshKeyFormats(StrEnum): + OPENSSH = "openssh" + SSH2 = "ssh2" + PUTTY = "putty" diff --git a/src/__init__.py b/src/pysafeguard/py.typed similarity index 100% rename from src/__init__.py rename to src/pysafeguard/py.typed diff --git a/src/pysafeguard/utility.py b/src/pysafeguard/utility.py new file mode 100644 index 0000000..4a19009 --- /dev/null +++ b/src/pysafeguard/utility.py @@ -0,0 +1,16 @@ +import sys +from collections.abc import Mapping +from urllib.parse import urlencode, urlunparse + +if sys.version_info < (3, 11): + from typing_extensions import LiteralString as LiteralString +else: + from typing import LiteralString as LiteralString + + +def assemble_path(*args: str | None) -> str: + return "/".join(arg.strip("/") for arg in args if arg is not None) + + +def assemble_url(netloc: str = "", path: str = "", query: Mapping[str, str] = {}, fragment: str = "", scheme: LiteralString = "https") -> str: + return urlunparse((scheme, netloc, path, "", urlencode(query, True), fragment))