Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions scripts/confidential_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ class UID2ServicesUnreachableError(ConfidentialComputeStartupError):
def __init__(self, cls, ip=None):
super().__init__(error_name=f"E06: {self.__class__.__name__}", provider=cls, extra_message=ip)

class OperatorKeyRejectedError(ConfidentialComputeStartupError):
def __init__(self, cls):
super().__init__(error_name=f"E07: {self.__class__.__name__}", provider=cls)

class OperatorKeyPermissionError(ConfidentialComputeStartupError):
def __init__(self, cls, message = None):
super().__init__(error_name=f"E08: {self.__class__.__name__}", provider=cls, extra_message=message)
Expand Down Expand Up @@ -97,6 +101,30 @@ def validate_connectivity() -> None:
raise UID2ServicesUnreachableError(self.__class__.__name__, core_ip)
except Exception as e:
raise UID2ServicesUnreachableError(self.__class__.__name__)

def validate_operator_key_with_core_service() -> None:
"""Pre-flight check: verifies the operator key is accepted by the core service.
POSTs to /attest with only the Authorization header; core returns 401 for an
invalid key before it even inspects the attestation payload."""
core_url = self.configs["core_base_url"]
operator_key = self.configs.get("operator_key")
try:
response = requests.post(
f"{core_url}/attest",
headers={"Authorization": f"Bearer {operator_key}"},
json={},
timeout=5
)
if response.status_code == 401:
logging.error(f"Operator key rejected by core service. Response: {response.text}")
raise OperatorKeyRejectedError(self.__class__.__name__)
logging.info(f"Operator key verified with core service (HTTP {response.status_code})")
except OperatorKeyRejectedError:
raise
except (requests.ConnectionError, requests.Timeout) as e:
logging.warning(f"Could not reach core service for key pre-verification: {e}")
except Exception as e:
logging.warning(f"Unexpected error during operator key pre-verification: {e}")

type_hints = get_type_hints(ConfidentialComputeConfig, include_extras=True)
required_keys = [field for field, hint in type_hints.items() if "NotRequired" not in str(hint)]
Expand All @@ -115,6 +143,7 @@ def validate_connectivity() -> None:
validate_url("optout_base_url", environment)
validate_operator_key()
validate_connectivity()
validate_operator_key_with_core_service()
logging.info("Completed static validation of confidential compute config values")

@abstractmethod
Expand Down
139 changes: 139 additions & 0 deletions scripts/tests/test_confidential_compute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import pytest
import requests
import sys
import os
from unittest.mock import patch, MagicMock

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from confidential_compute import (
ConfidentialCompute,
ConfidentialComputeConfig,
OperatorKeyRejectedError,
OperatorKeyValidationError,
UID2ServicesUnreachableError,
ConfigurationMissingError,
ConfigurationValueError,
)


class ConcreteConfidentialCompute(ConfidentialCompute):
"""Minimal concrete implementation for testing the base class."""

def _set_confidential_config(self, secret_identifier):
pass

def _setup_auxiliaries(self):
pass

def _validate_auxiliaries(self):
pass

def run_compute(self):
pass


VALID_CONFIG = {
"operator_key": "UID2-O-I-1-abcdefghijklmnop",
"core_base_url": "https://core-integ.uidapi.com",
"optout_base_url": "https://optout-integ.uidapi.com",
"environment": "integ",
"uid_instance_id_prefix": "ec2-abc123-ami-xyz",
}


def make_instance(config_overrides=None):
cc = ConcreteConfidentialCompute()
cc.configs = {**VALID_CONFIG, **(config_overrides or {})}
return cc


class TestValidateOperatorKeyWithService:
"""Tests for the pre-flight operator key verification against the core service."""

def _run_validate(self, cc, mock_response):
with patch("confidential_compute.socket.gethostbyname", return_value="1.2.3.4"), \
patch("confidential_compute.requests.get") as mock_get, \
patch("confidential_compute.requests.post", return_value=mock_response) as mock_post:
mock_get.return_value = MagicMock(status_code=200)
cc.validate_configuration()
return mock_post

def test_invalid_key_raises_operator_key_rejected_error(self):
cc = make_instance()
mock_resp = MagicMock()
mock_resp.status_code = 401
mock_resp.text = '{"status":"Unauthorized"}'

with pytest.raises(OperatorKeyRejectedError):
self._run_validate(cc, mock_resp)

def test_valid_key_with_no_payload_passes(self):
cc = make_instance()
mock_resp = MagicMock()
mock_resp.status_code = 400 # valid key, missing attestation_request

self._run_validate(cc, mock_resp) # should not raise

def test_valid_key_200_response_passes(self):
cc = make_instance()
mock_resp = MagicMock()
mock_resp.status_code = 200

self._run_validate(cc, mock_resp) # should not raise

def test_server_error_is_non_blocking(self):
cc = make_instance()
mock_resp = MagicMock()
mock_resp.status_code = 500

self._run_validate(cc, mock_resp) # should not raise

def test_connection_error_is_non_blocking(self):
cc = make_instance()

with patch("confidential_compute.socket.gethostbyname", return_value="1.2.3.4"), \
patch("confidential_compute.requests.get") as mock_get, \
patch("confidential_compute.requests.post", side_effect=requests.ConnectionError("refused")):
mock_get.return_value = MagicMock(status_code=200)
cc.validate_configuration() # should not raise

def test_timeout_is_non_blocking(self):
cc = make_instance()

with patch("confidential_compute.socket.gethostbyname", return_value="1.2.3.4"), \
patch("confidential_compute.requests.get") as mock_get, \
patch("confidential_compute.requests.post", side_effect=requests.Timeout("timed out")):
mock_get.return_value = MagicMock(status_code=200)
cc.validate_configuration() # should not raise

def test_unexpected_exception_is_non_blocking(self):
cc = make_instance()

with patch("confidential_compute.socket.gethostbyname", return_value="1.2.3.4"), \
patch("confidential_compute.requests.get") as mock_get, \
patch("confidential_compute.requests.post", side_effect=RuntimeError("unexpected")):
mock_get.return_value = MagicMock(status_code=200)
cc.validate_configuration() # should not raise

def test_post_sent_to_correct_endpoint(self):
cc = make_instance()
mock_resp = MagicMock()
mock_resp.status_code = 400

mock_post = self._run_validate(cc, mock_resp)
mock_post.assert_called_once_with(
"https://core-integ.uidapi.com/attest",
headers={"Authorization": f"Bearer {VALID_CONFIG['operator_key']}"},
json={},
timeout=5,
)

def test_skip_validations_bypasses_key_service_check(self):
cc = make_instance({"skip_validations": True})

with patch("confidential_compute.requests.post") as mock_post:
# validate_configuration is not called when skip_validations is True
# (the cloud scripts check this flag before calling validate_configuration)
# But we can confirm the function itself is gated correctly:
# skip_validations=True means validate_configuration() is never called.
mock_post.assert_not_called()
Loading