From 409547f6b939bebf97f781c72e99ea5b2e0ba199 Mon Sep 17 00:00:00 2001 From: Domen Date: Wed, 8 Mar 2023 13:12:22 +0100 Subject: [PATCH] Adding annotations to errors, rest_client and client --- plugins/module_utils/client.py | 96 +++++++++++++++++------ plugins/module_utils/errors.py | 27 ++++--- plugins/module_utils/hypercore_version.py | 2 +- plugins/module_utils/rest_client.py | 65 +++++++++------ plugins/module_utils/support_tunnel.py | 8 +- plugins/module_utils/typed_classes.py | 7 ++ plugins/modules/support_tunnel.py | 2 +- pyproject.toml | 5 -- 8 files changed, 138 insertions(+), 74 deletions(-) diff --git a/plugins/module_utils/client.py b/plugins/module_utils/client.py index a8128c6df..5a34b0f7d 100644 --- a/plugins/module_utils/client.py +++ b/plugins/module_utils/client.py @@ -4,14 +4,17 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function +from __future__ import annotations __metaclass__ = type import json +from typing import Any, Optional, Union from ansible.module_utils.urls import Request, basic_auth_header from .errors import AuthError, ScaleComputingError, UnexpectedAPIResponse +from ..module_utils.typed_classes import TypedClusterInstance from ansible.module_utils.six.moves.urllib.error import HTTPError, URLError from ansible.module_utils.six.moves.urllib.parse import urlencode, quote @@ -25,7 +28,9 @@ class Response: # Response(raw_resp) would be simpler. # How is this used in other projects? Jure? # Maybe we need/want both. - def __init__(self, status, data, headers=None): + def __init__( + self, status: int, data: Any, headers: Optional[dict[Any, Any]] = None + ): self.status = status self.data = data # [('h1', 'v1'), ('H2', 'V2')] -> {'h1': 'v1', 'h2': 'V2'} @@ -36,7 +41,7 @@ def __init__(self, status, data, headers=None): self._json = None @property - def json(self): + def json(self) -> Any: if self._json is None: try: self._json = json.loads(self.data) @@ -66,11 +71,11 @@ def __init__( self.password = password self.timeout = timeout - self._auth_header = None + self._auth_header: Optional[dict[str, bytes]] = None self._client = Request() @classmethod - def get_client(cls, cluster_instance: dict): + def get_client(cls, cluster_instance: TypedClusterInstance) -> Client: return cls( cluster_instance["host"], cluster_instance["username"], @@ -79,18 +84,25 @@ def get_client(cls, cluster_instance: dict): ) @property - def auth_header(self): + def auth_header(self) -> dict[str, bytes]: if not self._auth_header: self._auth_header = self._login() return self._auth_header - def _login(self): + def _login(self) -> dict[str, bytes]: return self._login_username_password() - def _login_username_password(self): + def _login_username_password(self) -> dict[str, bytes]: return dict(Authorization=basic_auth_header(self.username, self.password)) - def _request(self, method, path, data=None, headers=None, timeout=None): + def _request( + self, + method: str, + path: str, + data: Optional[Union[dict[Any, Any], bytes, str]] = None, + headers: Optional[dict[Any, Any]] = None, + timeout: Optional[float] = None, + ) -> Response: if ( timeout is None ): # If timeout from request is not specifically provided, take it from the Client. @@ -134,14 +146,14 @@ def _request(self, method, path, data=None, headers=None, timeout=None): def request( self, - method, - path, - query=None, - data=None, - headers=None, - binary_data=None, - timeout=None, - ): + method: str, + path: str, + query: Optional[dict[Any, Any]] = None, + data: Optional[dict[Any, Any]] = None, + headers: Optional[dict[Any, Any]] = None, + binary_data: Optional[bytes] = None, + timeout: Optional[float] = None, + ) -> Response: # Make sure we only have one kind of payload if data is not None and binary_data is not None: raise AssertionError( @@ -155,31 +167,64 @@ def request( url = "{0}?{1}".format(url, urlencode(query)) headers = dict(headers or DEFAULT_HEADERS, **self.auth_header) if data is not None: - data = json.dumps(data, separators=(",", ":")) headers["Content-type"] = "application/json" + return self._request( + method, + url, + data=json.dumps(data, separators=(",", ":")), + headers=headers, + timeout=timeout, + ) elif binary_data is not None: - data = binary_data + return self._request( + method, url, data=binary_data, headers=headers, timeout=timeout + ) return self._request(method, url, data=data, headers=headers, timeout=timeout) - def get(self, path, query=None, timeout=None): + def get( + self, + path: str, + query: Optional[dict[Any, Any]] = None, + timeout: Optional[float] = None, + ) -> Request: resp = self.request("GET", path, query=query, timeout=timeout) if resp.status in (200, 404): return resp raise UnexpectedAPIResponse(response=resp) - def post(self, path, data, query=None, timeout=None): + def post( + self, + path: str, + data: Optional[dict[Any, Any]], + query: Optional[dict[Any, Any]] = None, + timeout: Optional[float] = None, + ) -> Request: resp = self.request("POST", path, data=data, query=query, timeout=timeout) if resp.status == 201 or resp.status == 200: return resp raise UnexpectedAPIResponse(response=resp) - def patch(self, path, data, query=None, timeout=None): + def patch( + self, + path: str, + data: dict[Any, Any], + query: Optional[dict[Any, Any]] = None, + timeout: Optional[float] = None, + ) -> Request: resp = self.request("PATCH", path, data=data, query=query, timeout=timeout) if resp.status == 200: return resp raise UnexpectedAPIResponse(response=resp) - def put(self, path, data, query=None, timeout=None, binary_data=None, headers=None): + def put( + self, + path: str, + data: dict[Any, Any], + query: Optional[dict[Any, Any]] = None, + timeout: Optional[float] = None, + binary_data: Optional[bytes] = None, + headers: Optional[dict[Any, Any]] = None, + ) -> Request: resp = self.request( "PUT", path, @@ -193,7 +238,12 @@ def put(self, path, data, query=None, timeout=None, binary_data=None, headers=No return resp raise UnexpectedAPIResponse(response=resp) - def delete(self, path, query=None, timeout=None): + def delete( + self, + path: str, + query: Optional[dict[Any, Any]] = None, + timeout: Optional[float] = None, + ) -> Request: resp = self.request("DELETE", path, query=query, timeout=timeout) if resp.status == 204 or resp.status == 200: return resp diff --git a/plugins/module_utils/errors.py b/plugins/module_utils/errors.py index a7ac7e0b9..904e2f58f 100644 --- a/plugins/module_utils/errors.py +++ b/plugins/module_utils/errors.py @@ -7,7 +7,8 @@ __metaclass__ = type -from typing import Any +from typing import Union +from ansible.module_utils.urls import Request class ScaleComputingError(Exception): @@ -19,7 +20,7 @@ class AuthError(ScaleComputingError): class UnexpectedAPIResponse(ScaleComputingError): - def __init__(self, response): + def __init__(self, response: Request): self.message = "Unexpected response - {0} {1}".format( response.status, response.data ) @@ -28,46 +29,46 @@ def __init__(self, response): class InvalidUuidFormatError(ScaleComputingError): - def __init__(self, data): + def __init__(self, data: Union[str, Exception]): self.message = "Invalid UUID - {0}".format(data) super(InvalidUuidFormatError, self).__init__(self.message) # In-case function parameter is optional but required class MissingFunctionParameter(ScaleComputingError): - def __init__(self, data): + def __init__(self, data: Union[str, Exception]): self.message = "Missing parameter - {0}".format(data) super(MissingFunctionParameter, self).__init__(self.message) # In-case argument spec doesn't catch exception class MissingValueAnsible(ScaleComputingError): - def __init__(self, data): + def __init__(self, data: Union[str, Exception]): self.message = "Missing value - {0}".format(data) super(MissingValueAnsible, self).__init__(self.message) # In-case argument spec doesn't catch exception class MissingValueHypercore(ScaleComputingError): - def __init__(self, data: Any): + def __init__(self, data: Union[str, Exception]): self.message = "Missing values from hypercore API - {0}".format(data) super(MissingValueHypercore, self).__init__(self.message) class DeviceNotUnique(ScaleComputingError): - def __init__(self, data): + def __init__(self, data: Union[str, Exception]): self.message = "Device is not unique - {0} - already exists".format(data) super(DeviceNotUnique, self).__init__(self.message) class VMNotFound(ScaleComputingError): - def __init__(self, data): + def __init__(self, data: Union[str, Exception]): self.message = "Virtual machine - {0} - not found".format(data) super(VMNotFound, self).__init__(self.message) class ReplicationNotUnique(ScaleComputingError): - def __init__(self, data): + def __init__(self, data: Union[str, Exception]): self.message = ( "There is already a replication on - {0} - virtual machine".format(data) ) @@ -75,13 +76,13 @@ def __init__(self, data): class ClusterConnectionNotFound(ScaleComputingError): - def __init__(self, data): + def __init__(self, data: Union[str, Exception]): self.message = "No cluster connection found - {0}".format(data) super(ClusterConnectionNotFound, self).__init__(self.message) class SMBServerNotFound(ScaleComputingError): - def __init__(self, data): + def __init__(self, data: Union[str, Exception]): self.message = "SMB server is either not connected or not in the same network - {0}".format( data ) @@ -89,12 +90,12 @@ def __init__(self, data): class VMInvalidParams(ScaleComputingError): - def __init__(self): + def __init__(self) -> None: self.message = "Invalid set of parameters - strict affinity set to true and nodes not provided." super(VMInvalidParams, self).__init__(self.message) class SupportTunnelError(ScaleComputingError): - def __init__(self, data): + def __init__(self, data: Union[str, Exception]): self.message = "{0}".format(data) super(SupportTunnelError, self).__init__(self.message) diff --git a/plugins/module_utils/hypercore_version.py b/plugins/module_utils/hypercore_version.py index a3ddfb0ad..a3eff82d3 100644 --- a/plugins/module_utils/hypercore_version.py +++ b/plugins/module_utils/hypercore_version.py @@ -328,5 +328,5 @@ def __eq__(self, other: object) -> bool: @classmethod def get(cls, rest_client: RestClient, check_mode: bool = False) -> UpdateStatus: - update_status = rest_client.client.get("update/update_status.json").json # type: ignore + update_status = rest_client.client.get("update/update_status.json").json return cls.from_hypercore(update_status) diff --git a/plugins/module_utils/rest_client.py b/plugins/module_utils/rest_client.py index 20b2188a8..c9d60793a 100644 --- a/plugins/module_utils/rest_client.py +++ b/plugins/module_utils/rest_client.py @@ -13,10 +13,10 @@ __metaclass__ = type -from typing import Any, Union +from typing import Any, Optional -def _query(original=None): +def _query(original: Optional[dict[Any, Any]] = None) -> dict[Any, Any]: # Make sure the query isn't equal to None # If any default query values need to added in the future, they may be added here return dict(original or {}) @@ -27,7 +27,10 @@ def __init__(self, client: Client): self.client = client def list_records( - self, endpoint: str, query: dict[Any, Any] = None, timeout: float = None + self, + endpoint: str, + query: Optional[dict[Any, Any]] = None, + timeout: Optional[float] = None, ) -> list[Any]: """Results are obtained so that first off, all records are obtained and then filtered manually""" @@ -40,10 +43,10 @@ def list_records( def get_record( self, endpoint: str, - query: dict[Any, Any] = None, + query: Optional[dict[Any, Any]] = None, must_exist: bool = False, - timeout: float = None, - ) -> Union[dict[Any, Any], None]: + timeout: Optional[float] = None, + ) -> Optional[dict[Any, Any]]: records = self.list_records(endpoint=endpoint, query=query, timeout=timeout) if len(records) > 1: raise errors.ScaleComputingError( @@ -60,7 +63,11 @@ def get_record( return records[0] if records else None def create_record( - self, endpoint, payload, check_mode, timeout=None + self, + endpoint: str, + payload: Optional[dict[Any, Any]], + check_mode: bool, + timeout: Optional[float] = None, ) -> TypedTaskTag: if check_mode: return utils.MOCKED_TASK_TAG @@ -73,7 +80,11 @@ def create_record( return response def update_record( - self, endpoint, payload, check_mode, record=None, timeout=None + self, + endpoint: str, + payload: dict[Any, Any], + check_mode: bool, + timeout: Optional[float] = None, ) -> TypedTaskTag: # No action is possible when updating a record if check_mode: @@ -86,7 +97,9 @@ def update_record( raise errors.ScaleComputingError(f"Request timed out: {e}") return response - def delete_record(self, endpoint, check_mode, timeout=None) -> TypedTaskTag: + def delete_record( + self, endpoint: str, check_mode: bool, timeout: Optional[float] = None + ) -> TypedTaskTag: # No action is possible when deleting a record if check_mode: return utils.MOCKED_TASK_TAG @@ -98,33 +111,30 @@ def delete_record(self, endpoint, check_mode, timeout=None) -> TypedTaskTag: def put_record( self, - endpoint, - payload, - check_mode, - query=None, - timeout=None, - binary_data=None, - headers=None, - ): + endpoint: str, + payload: dict[Any, Any], + check_mode: bool, + query: Optional[dict[Any, Any]] = None, + timeout: Optional[float] = None, + binary_data: Optional[bytes] = None, + headers: Optional[dict[Any, Any]] = None, + ) -> TypedTaskTag: # Method put doesn't support check mode # IT ACTUALLY DOES if check_mode: - return None + return utils.MOCKED_TASK_TAG try: - response = self.client.put( + response: TypedTaskTag = self.client.put( endpoint, data=payload, query=query, timeout=timeout, binary_data=binary_data, headers=headers, - ) + ).json except TimeoutError as e: raise errors.ScaleComputingError(f"Request timed out: {e}") - try: - return response.json - except errors.ScaleComputingError: - return response + return response class CachedRestClient(RestClient): @@ -133,10 +143,13 @@ class CachedRestClient(RestClient): def __init__(self, client: Client): super().__init__(client) - self.cache = dict() + self.cache: dict[Any, Any] = dict() def list_records( - self, endpoint: str, query: dict[Any, Any] = None, timeout: float = None + self, + endpoint: str, + query: Optional[dict[Any, Any]] = None, + timeout: Optional[float] = None, ) -> list[Any]: if endpoint in self.cache: records = self.cache[endpoint] diff --git a/plugins/module_utils/support_tunnel.py b/plugins/module_utils/support_tunnel.py index 1b1a66cb3..4c5bea17a 100644 --- a/plugins/module_utils/support_tunnel.py +++ b/plugins/module_utils/support_tunnel.py @@ -64,15 +64,13 @@ def __eq__( @classmethod def check_tunnel_status(cls, client: Client) -> SupportTunnel: - response = client.get("/support-api/check") # type: ignore + response = client.get("/support-api/check") return cls.from_hypercore(response.json) @staticmethod def open_tunnel(module: AnsibleModule, client: Client) -> None: - client.get( - "/support-api/open", query={"code": module.params["code"]} - ) # type: ignore + client.get("/support-api/open", query={"code": module.params["code"]}) @staticmethod def close_tunnel(client: Client) -> None: - client.get("/support-api/close") # type: ignore + client.get("/support-api/close") diff --git a/plugins/module_utils/typed_classes.py b/plugins/module_utils/typed_classes.py index 5e31cc155..5d985b9c6 100644 --- a/plugins/module_utils/typed_classes.py +++ b/plugins/module_utils/typed_classes.py @@ -14,6 +14,13 @@ # Typed Classes use for Python hints. +class TypedClusterInstance(TypedDict): + host: str + username: str + password: str + timeout: float + + # Registration to ansible return dict. class TypedRegistrationToAnsible(TypedDict): company_name: Union[str, None] diff --git a/plugins/modules/support_tunnel.py b/plugins/modules/support_tunnel.py index 2f67f2d57..03f492d3b 100644 --- a/plugins/modules/support_tunnel.py +++ b/plugins/modules/support_tunnel.py @@ -89,7 +89,7 @@ def open_tunnel( SupportTunnel.open_tunnel(module, client) new_tunnel_status = SupportTunnel.check_tunnel_status(client) if new_tunnel_status.open is False: - raise errors.SupportTunnelError( # type: ignore + raise errors.SupportTunnelError( "Support tunnel can't be opened, probably the code is already in use." ) return ( diff --git a/pyproject.toml b/pyproject.toml index 802841bae..eee180217 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,6 @@ ignore_missing_imports = true module = [ "plugins.module_utils.arguments", "plugins.module_utils.disk", - "plugins.module_utils.errors", "plugins.module_utils.iso", "plugins.module_utils.nic", "plugins.module_utils.node", @@ -33,13 +32,9 @@ module = [ "plugins.module_utils.remote_cluster", "plugins.module_utils.replication", "plugins.module_utils.snapshot_schedule", - "plugins.module_utils.state", "plugins.module_utils.task_tag", - "plugins.module_utils.type", "plugins.module_utils.utils", "plugins.module_utils.vm", - "plugins.module_utils.rest_client", - "plugins.module_utils.client", "plugins.module_utils.time_server", "plugins.module_utils.time_zone", "plugins.modules.api",