diff --git a/scripts/confidential_compute.py b/scripts/confidential_compute.py index 3fd696f70..21160a171 100644 --- a/scripts/confidential_compute.py +++ b/scripts/confidential_compute.py @@ -51,6 +51,10 @@ class UID2ServicesUnreachableError(ConfidentialComputeStartupError): def __init__(self, cls, ip=None): super().__init__(error_name=f"E06: {self.__class__.__name__}", provider=cls, extra_message=ip) +class OperatorKeyRejectedError(ConfidentialComputeStartupError): + def __init__(self, cls): + super().__init__(error_name=f"E07: {self.__class__.__name__}", provider=cls) + class OperatorKeyPermissionError(ConfidentialComputeStartupError): def __init__(self, cls, message = None): super().__init__(error_name=f"E08: {self.__class__.__name__}", provider=cls, extra_message=message) @@ -97,6 +101,30 @@ def validate_connectivity() -> None: raise UID2ServicesUnreachableError(self.__class__.__name__, core_ip) except Exception as e: raise UID2ServicesUnreachableError(self.__class__.__name__) + + def validate_operator_key_with_core_service() -> None: + """Pre-flight check: verifies the operator key is accepted by the core service. + POSTs to /attest with only the Authorization header; core returns 401 for an + invalid key before it even inspects the attestation payload.""" + core_url = self.configs["core_base_url"] + operator_key = self.configs.get("operator_key") + try: + response = requests.post( + f"{core_url}/attest", + headers={"Authorization": f"Bearer {operator_key}"}, + json={}, + timeout=5 + ) + if response.status_code == 401: + logging.error(f"Operator key rejected by core service. Response: {response.text}") + raise OperatorKeyRejectedError(self.__class__.__name__) + logging.info(f"Operator key verified with core service (HTTP {response.status_code})") + except OperatorKeyRejectedError: + raise + except (requests.ConnectionError, requests.Timeout) as e: + logging.warning(f"Could not reach core service for key pre-verification: {e}") + except Exception as e: + logging.warning(f"Unexpected error during operator key pre-verification: {e}") type_hints = get_type_hints(ConfidentialComputeConfig, include_extras=True) required_keys = [field for field, hint in type_hints.items() if "NotRequired" not in str(hint)] @@ -115,6 +143,7 @@ def validate_connectivity() -> None: validate_url("optout_base_url", environment) validate_operator_key() validate_connectivity() + validate_operator_key_with_core_service() logging.info("Completed static validation of confidential compute config values") @abstractmethod diff --git a/scripts/tests/test_confidential_compute.py b/scripts/tests/test_confidential_compute.py new file mode 100644 index 000000000..30a02dd18 --- /dev/null +++ b/scripts/tests/test_confidential_compute.py @@ -0,0 +1,139 @@ +import pytest +import requests +import sys +import os +from unittest.mock import patch, MagicMock + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from confidential_compute import ( + ConfidentialCompute, + ConfidentialComputeConfig, + OperatorKeyRejectedError, + OperatorKeyValidationError, + UID2ServicesUnreachableError, + ConfigurationMissingError, + ConfigurationValueError, +) + + +class ConcreteConfidentialCompute(ConfidentialCompute): + """Minimal concrete implementation for testing the base class.""" + + def _set_confidential_config(self, secret_identifier): + pass + + def _setup_auxiliaries(self): + pass + + def _validate_auxiliaries(self): + pass + + def run_compute(self): + pass + + +VALID_CONFIG = { + "operator_key": "UID2-O-I-1-abcdefghijklmnop", + "core_base_url": "https://core-integ.uidapi.com", + "optout_base_url": "https://optout-integ.uidapi.com", + "environment": "integ", + "uid_instance_id_prefix": "ec2-abc123-ami-xyz", +} + + +def make_instance(config_overrides=None): + cc = ConcreteConfidentialCompute() + cc.configs = {**VALID_CONFIG, **(config_overrides or {})} + return cc + + +class TestValidateOperatorKeyWithService: + """Tests for the pre-flight operator key verification against the core service.""" + + def _run_validate(self, cc, mock_response): + with patch("confidential_compute.socket.gethostbyname", return_value="1.2.3.4"), \ + patch("confidential_compute.requests.get") as mock_get, \ + patch("confidential_compute.requests.post", return_value=mock_response) as mock_post: + mock_get.return_value = MagicMock(status_code=200) + cc.validate_configuration() + return mock_post + + def test_invalid_key_raises_operator_key_rejected_error(self): + cc = make_instance() + mock_resp = MagicMock() + mock_resp.status_code = 401 + mock_resp.text = '{"status":"Unauthorized"}' + + with pytest.raises(OperatorKeyRejectedError): + self._run_validate(cc, mock_resp) + + def test_valid_key_with_no_payload_passes(self): + cc = make_instance() + mock_resp = MagicMock() + mock_resp.status_code = 400 # valid key, missing attestation_request + + self._run_validate(cc, mock_resp) # should not raise + + def test_valid_key_200_response_passes(self): + cc = make_instance() + mock_resp = MagicMock() + mock_resp.status_code = 200 + + self._run_validate(cc, mock_resp) # should not raise + + def test_server_error_is_non_blocking(self): + cc = make_instance() + mock_resp = MagicMock() + mock_resp.status_code = 500 + + self._run_validate(cc, mock_resp) # should not raise + + def test_connection_error_is_non_blocking(self): + cc = make_instance() + + with patch("confidential_compute.socket.gethostbyname", return_value="1.2.3.4"), \ + patch("confidential_compute.requests.get") as mock_get, \ + patch("confidential_compute.requests.post", side_effect=requests.ConnectionError("refused")): + mock_get.return_value = MagicMock(status_code=200) + cc.validate_configuration() # should not raise + + def test_timeout_is_non_blocking(self): + cc = make_instance() + + with patch("confidential_compute.socket.gethostbyname", return_value="1.2.3.4"), \ + patch("confidential_compute.requests.get") as mock_get, \ + patch("confidential_compute.requests.post", side_effect=requests.Timeout("timed out")): + mock_get.return_value = MagicMock(status_code=200) + cc.validate_configuration() # should not raise + + def test_unexpected_exception_is_non_blocking(self): + cc = make_instance() + + with patch("confidential_compute.socket.gethostbyname", return_value="1.2.3.4"), \ + patch("confidential_compute.requests.get") as mock_get, \ + patch("confidential_compute.requests.post", side_effect=RuntimeError("unexpected")): + mock_get.return_value = MagicMock(status_code=200) + cc.validate_configuration() # should not raise + + def test_post_sent_to_correct_endpoint(self): + cc = make_instance() + mock_resp = MagicMock() + mock_resp.status_code = 400 + + mock_post = self._run_validate(cc, mock_resp) + mock_post.assert_called_once_with( + "https://core-integ.uidapi.com/attest", + headers={"Authorization": f"Bearer {VALID_CONFIG['operator_key']}"}, + json={}, + timeout=5, + ) + + def test_skip_validations_bypasses_key_service_check(self): + cc = make_instance({"skip_validations": True}) + + with patch("confidential_compute.requests.post") as mock_post: + # validate_configuration is not called when skip_validations is True + # (the cloud scripts check this flag before calling validate_configuration) + # But we can confirm the function itself is gated correctly: + # skip_validations=True means validate_configuration() is never called. + mock_post.assert_not_called()