From fd457f087537331a732541a09acf290af017d904 Mon Sep 17 00:00:00 2001 From: Roberto Gazia Date: Fri, 28 Nov 2025 09:30:26 +0100 Subject: [PATCH 01/12] feat: introduce a centralized TLS certificate manager --- src/arduino/app_utils/tls_cert_manager.py | 223 ++++++++++ .../tls_cert_manager/test_tls_cert_manager.py | 408 ++++++++++++++++++ 2 files changed, 631 insertions(+) create mode 100644 src/arduino/app_utils/tls_cert_manager.py create mode 100644 tests/arduino/app_utils/tls_cert_manager/test_tls_cert_manager.py diff --git a/src/arduino/app_utils/tls_cert_manager.py b/src/arduino/app_utils/tls_cert_manager.py new file mode 100644 index 00000000..ef387723 --- /dev/null +++ b/src/arduino/app_utils/tls_cert_manager.py @@ -0,0 +1,223 @@ +# SPDX-FileCopyrightText: Copyright (C) ARDUINO SRL (http://www.arduino.cc) +# +# SPDX-License-Identifier: MPL-2.0 + +import os +import threading +from pathlib import Path +from cryptography import x509 +from cryptography.x509.oid import NameOID +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives import serialization +from datetime import datetime, timedelta, UTC + + +DEFAULT_CERTS_DIR = "/app/certs" +DEFAULT_CERTS_PARAMS = { + "country_name": "IT", + "state_or_province_name": "Piedmont", + "locality_name": "Turin", + "organization_name": "Arduino", + "common_name": "0.0.0.0", + "validity_days": 365 +} + +class TLSCertificateManager: + """Certificate manager for TLS certificates. + + This class handles certificate generation and retrieval on a Brick basis. By default, all bricks + share certificates from the default directory (/app/certs). + Components can use their own certificates by providing a different certs_dir path. + """ + + _locks = {} + _locks_lock = threading.Lock() + + + @classmethod + def get_or_create_certificates( + cls, + certs_dir: str = DEFAULT_CERTS_DIR, + country_name: str = DEFAULT_CERTS_PARAMS["country_name"], + state_or_province_name: str = DEFAULT_CERTS_PARAMS["state_or_province_name"], + locality_name: str = DEFAULT_CERTS_PARAMS["locality_name"], + organization_name: str = DEFAULT_CERTS_PARAMS["organization_name"], + common_name: str = DEFAULT_CERTS_PARAMS["common_name"], + validity_days: int = DEFAULT_CERTS_PARAMS["validity_days"] + ) -> tuple[str, str]: + """Get or create TLS certificates at the specified path. + + By default, uses shared certificates in /app/certs. If a different certs_dir is provided, + uses certificates specific to that directory (useful for brick-specific certificates). + + Concurrent access is managed to prevent race conditions when multiple bricks attempt to + access certificates simultaneously. + + Args: + certs_dir (str, optional): Directory for certificates. Defaults to /app/certs (shared + by all bricks). Provide a different path for brick-specific certificates. + country_name (str, optional): Country name for the certificate. Defaults to "IT". + state_or_province_name (str, optional): State or province name for the certificate. + Defaults to "Piedmont". + locality_name (str, optional): Locality name for the certificate. Defaults to "Turin". + organization_name (str, optional): Organization name for the certificate. Defaults to "Arduino". + common_name (str, optional): Common name for the certificate. Defaults to "0.0.0.0". + validity_days (int, optional): Certificate validity period in days. Defaults to 365. + + Returns: + tuple[str, str]: Paths to (certificate_file, private_key_file) + + Raises: + RuntimeError: If certificate generation fails. + """ + target_dir = certs_dir or DEFAULT_CERTS_DIR + cert_path = os.path.join(target_dir, "cert.pem") + key_path = os.path.join(target_dir, "key.pem") + + if cls.certificates_exist(target_dir): + return cert_path, key_path + + dir_lock = cls._get_dir_lock(target_dir) + with dir_lock: + if cls.certificates_exist(target_dir): + return cert_path, key_path + + try: + cls._generate_self_signed_cert( + target_dir, + country_name, + state_or_province_name, + locality_name, + organization_name, + common_name, + validity_days + ) + return cert_path, key_path + except Exception as e: + raise RuntimeError(f"Failed to generate TLS certificates in {target_dir}: {e}") from e + + @classmethod + def certificates_exist(cls, certs_dir: str = DEFAULT_CERTS_DIR) -> bool: + """Check if TLS certificates exist in the given directory. + + Args: + certs_dir (str, optional): Directory for certificates. + Defaults to /app/certs. + + Returns: + bool: True if both certificate and key files exist, False otherwise. + """ + target_dir = certs_dir or DEFAULT_CERTS_DIR + cert_path = os.path.join(target_dir, "cert.pem") + key_path = os.path.join(target_dir, "key.pem") + return os.path.exists(cert_path) and os.path.exists(key_path) + + @classmethod + def get_certificates_paths(cls, certs_dir: str = DEFAULT_CERTS_DIR) -> tuple[str, str]: + """Get the paths to the TLS certificate and private key files. + + Args: + certs_dir (str, optional): Directory for certificates. Defaults to /app/certs. + Returns: + tuple[str, str]: Paths to certificate_file and private_key_file + """ + return cls.get_certificate_path(certs_dir), cls.get_private_key_path(certs_dir) + + @classmethod + def get_certificate_path(cls, certs_dir: str = DEFAULT_CERTS_DIR) -> str: + """Get the path to the TLS certificate file. + + Args: + certs_dir (str, optional): Directory for certificates. Defaults to /app/certs. + + Returns: + str: Path to the certificate file. + """ + return os.path.join(certs_dir or DEFAULT_CERTS_DIR, "cert.pem") + + @classmethod + def get_private_key_path(cls, certs_dir: str = DEFAULT_CERTS_DIR) -> str: + """Get the path to the TLS private key file. + + Args: + certs_dir (str, optional): Directory for certificates. Defaults to /app/certs. + + Returns: + str: Path to the private key file. + """ + return os.path.join(certs_dir or DEFAULT_CERTS_DIR, "key.pem") + + @classmethod + def _get_dir_lock(cls, target_dir: str) -> threading.Lock: + """Get or create a lock for a specific directory. + + This ensures that only operations on the same directory block each other, + while operations on different directories can proceed concurrently. + + Args: + target_dir (str): The normalized absolute path to the directory. + + Returns: + threading.Lock: A lock specific to this directory. + """ + with cls._locks_lock: + if target_dir not in cls._locks: + cls._locks[target_dir] = threading.Lock() + return cls._locks[target_dir] + + @staticmethod + def _generate_self_signed_cert( + target_dir: str, + country_name: str, + state_or_province_name: str, + locality_name: str, + organization_name: str, + common_name: str, + validity_days: int + ): + # Generate a private key + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + + # Generate a self-signed certificate + subject = issuer = x509.Name([ + x509.NameAttribute(NameOID.COUNTRY_NAME, country_name), + x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, state_or_province_name), + x509.NameAttribute(NameOID.LOCALITY_NAME, locality_name), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, organization_name), + x509.NameAttribute(NameOID.COMMON_NAME, common_name), + ]) + + cert = x509.CertificateBuilder() + cert = cert.subject_name(subject) + cert = cert.issuer_name(issuer) + cert = cert.public_key(private_key.public_key()) + cert = cert.serial_number(x509.random_serial_number()) + cert = cert.not_valid_before(datetime.now(UTC)) + cert = cert.not_valid_after(datetime.now(UTC) + timedelta(days=validity_days)) + cert = cert.add_extension( + x509.SubjectAlternativeName([x509.DNSName(common_name)]), + critical=False + ) + cert = cert.sign(private_key, hashes.SHA256()) + + Path(target_dir).mkdir(parents=True, exist_ok=True) + + # Write the certificate to a PEM file + cert_path = os.path.join(target_dir, "cert.pem") + with open(cert_path, "wb") as cert_file: + cert_file.write(cert.public_bytes(serialization.Encoding.PEM)) + + # Write the private key to a PEM file + key_path = os.path.join(target_dir, "key.pem") + with open(key_path, "wb") as key_file: + key_file.write( + private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + ) diff --git a/tests/arduino/app_utils/tls_cert_manager/test_tls_cert_manager.py b/tests/arduino/app_utils/tls_cert_manager/test_tls_cert_manager.py new file mode 100644 index 00000000..55e2e914 --- /dev/null +++ b/tests/arduino/app_utils/tls_cert_manager/test_tls_cert_manager.py @@ -0,0 +1,408 @@ +# SPDX-FileCopyrightText: Copyright (C) ARDUINO SRL (http://www.arduino.cc) +# +# SPDX-License-Identifier: MPL-2.0 + +import os +import shutil +import tempfile +import threading +import time +import pytest +from cryptography import x509 +from cryptography.hazmat.backends import default_backend + +from arduino.app_utils.tls_cert_manager import TLSCertificateManager + + +@pytest.fixture +def temp_certs_dir(): + """Create a temporary directory for certificates and clean up after tests.""" + temp_dir = tempfile.mkdtemp() + + yield temp_dir + + # Cleanup + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) + + +@pytest.fixture +def reset_manager(): + """Reset the TLSCertificateManager state between tests.""" + yield + + TLSCertificateManager._locks.clear() # Reset state + + +class TestBasicFunctionality: + """Test basic certificate creation and retrieval.""" + + def test_create_certificates_in_custom_dir(self, temp_certs_dir, reset_manager): + """Test creating certificates in a custom directory.""" + cert_path, key_path = TLSCertificateManager.get_or_create_certificates( + certs_dir=temp_certs_dir + ) + + # Verify paths are correct + assert cert_path == os.path.join(temp_certs_dir, "cert.pem") + assert key_path == os.path.join(temp_certs_dir, "key.pem") + + # Verify files exist + assert os.path.exists(cert_path) + assert os.path.exists(key_path) + + def test_certificates_are_valid(self, temp_certs_dir, reset_manager): + """Test that generated certificates are valid X.509 certificates.""" + cert_path, key_path = TLSCertificateManager.get_or_create_certificates( + certs_dir=temp_certs_dir, + common_name="test.local" + ) + + # Load and verify certificate + with open(cert_path, "rb") as f: + cert = x509.load_pem_x509_certificate(f.read(), default_backend()) + + # Check common name + common_name = cert.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME)[0].value + assert common_name == "test.local" + + # Check organization + org = cert.subject.get_attributes_for_oid(x509.NameOID.ORGANIZATION_NAME)[0].value + assert org == "Arduino" + + def test_reuse_existing_certificates(self, temp_certs_dir, reset_manager): + """Test that existing certificates are reused instead of regenerated.""" + cert_path1, key_path1 = TLSCertificateManager.get_or_create_certificates( + certs_dir=temp_certs_dir + ) + + # Get modification time + mtime1 = os.path.getmtime(cert_path1) + + # Get certificates again + cert_path2, key_path2 = TLSCertificateManager.get_or_create_certificates( + certs_dir=temp_certs_dir + ) + + assert cert_path1 == cert_path2 + assert key_path1 == key_path2 + + # Check modification time is unchanged + mtime2 = os.path.getmtime(cert_path2) + assert mtime1 == mtime2 + + def test_custom_validity_period(self, temp_certs_dir, reset_manager): + """Test creating certificates with custom validity period.""" + cert_path, _ = TLSCertificateManager.get_or_create_certificates( + certs_dir=temp_certs_dir, + validity_days=1 + ) + + with open(cert_path, "rb") as f: + cert = x509.load_pem_x509_certificate(f.read(), default_backend()) + + validity_days = (cert.not_valid_after_utc - cert.not_valid_before_utc).days + assert validity_days == 1 + + +class TestHelperMethods: + """Test helper methods for checking and retrieving certificate paths.""" + + def test_certificates_exist_returns_false_for_missing(self, temp_certs_dir, reset_manager): + """Test certificates_exist returns False when certificates don't exist.""" + assert not TLSCertificateManager.certificates_exist(certs_dir=temp_certs_dir) + + def test_certificates_exist_returns_true_after_creation(self, temp_certs_dir, reset_manager): + """Test certificates_exist returns True after certificates are created.""" + TLSCertificateManager.get_or_create_certificates(certs_dir=temp_certs_dir) + assert TLSCertificateManager.certificates_exist(certs_dir=temp_certs_dir) + + def test_get_certificate_path(self, temp_certs_dir, reset_manager): + """Test get_certificate_path returns correct path.""" + expected_path = os.path.join(temp_certs_dir, "cert.pem") + actual_path = TLSCertificateManager.get_certificate_path(certs_dir=temp_certs_dir) + assert actual_path == expected_path + + def test_get_private_key_path(self, temp_certs_dir, reset_manager): + """Test get_private_key_path returns correct path.""" + expected_path = os.path.join(temp_certs_dir, "key.pem") + actual_path = TLSCertificateManager.get_private_key_path(certs_dir=temp_certs_dir) + assert actual_path == expected_path + + +class TestConcurrentAccess: + """Test concurrent access and race condition handling.""" + + def test_concurrent_access_same_directory(self, temp_certs_dir, reset_manager): + """Test multiple threads accessing the same directory concurrently.""" + results = [] + errors = [] + + def create_certs(thread_id): + try: + start_time = time.time() + cert_path, key_path = TLSCertificateManager.get_or_create_certificates( + certs_dir=temp_certs_dir + ) + elapsed = time.time() - start_time + results.append({ + 'thread_id': thread_id, + 'cert_path': cert_path, + 'key_path': key_path, + 'elapsed': elapsed + }) + except Exception as e: + errors.append({'thread_id': thread_id, 'error': str(e)}) + + # Start 10 threads simultaneously + threads = [] + for i in range(10): + thread = threading.Thread(target=create_certs, args=(i,)) + threads.append(thread) + + # Start all threads at once + for thread in threads: + thread.start() + + # Wait for all to complete + for thread in threads: + thread.join() + + # Verify no errors occurred + assert len(errors) == 0, f"Errors occurred: {errors}" + + # Verify all threads got the same paths + assert len(results) == 10 + cert_paths = set(r['cert_path'] for r in results) + key_paths = set(r['key_path'] for r in results) + assert len(cert_paths) == 1, "All threads should get the same certificate path" + assert len(key_paths) == 1, "All threads should get the same key path" + + # Verify certificates exist and are valid + cert_path = results[0]['cert_path'] + assert os.path.exists(cert_path) + with open(cert_path, "rb") as f: + cert = x509.load_pem_x509_certificate(f.read(), default_backend()) + assert cert is not None + + def test_concurrent_access_different_directories(self, temp_certs_dir, reset_manager): + """Test multiple threads accessing different directories concurrently.""" + results = [] + errors = [] + + def create_certs(component_name): + try: + start_time = time.time() + component_dir = os.path.join(temp_certs_dir, component_name) + cert_path, key_path = TLSCertificateManager.get_or_create_certificates( + certs_dir=component_dir + ) + elapsed = time.time() - start_time + results.append({ + 'component': component_name, + 'cert_path': cert_path, + 'key_path': key_path, + 'elapsed': elapsed + }) + except Exception as e: + errors.append({'component': component_name, 'error': str(e)}) + + # Simulate multiple components starting simultaneously + components = ['webui', 'api', 'mqtt', 'scanner', 'processor'] + threads = [] + + for component in components: + thread = threading.Thread(target=create_certs, args=(component,)) + threads.append(thread) + + # Start all threads + for thread in threads: + thread.start() + + # Wait for completion + for thread in threads: + thread.join() + + # Verify no errors + assert len(errors) == 0, f"Errors occurred: {errors}" + + # Verify all components succeeded + assert len(results) == len(components) + + # Verify each component has its own certificates + cert_dirs = set(os.path.dirname(r['cert_path']) for r in results) + assert len(cert_dirs) == len(components), "Each component should have its own directory" + + # Verify all certificates exist and are in correct directories + for result in results: + component = result['component'] + expected_dir = os.path.join(temp_certs_dir, component) + assert expected_dir in result['cert_path'] + assert os.path.exists(result['cert_path']) + assert os.path.exists(result['key_path']) + + def test_concurrent_mixed_access(self, temp_certs_dir, reset_manager): + """Test concurrent access with both shared and component-specific directories.""" + results = [] + errors = [] + lock = threading.Lock() + + def create_certs(name, use_custom_dir): + try: + start_time = time.time() + if use_custom_dir: + certs_dir = os.path.join(temp_certs_dir, name) + else: + certs_dir = temp_certs_dir + + cert_path, key_path = TLSCertificateManager.get_or_create_certificates( + certs_dir=certs_dir + ) + elapsed = time.time() - start_time + + with lock: + results.append({ + 'name': name, + 'use_custom': use_custom_dir, + 'cert_path': cert_path, + 'elapsed': elapsed + }) + except Exception as e: + with lock: + errors.append({'name': name, 'error': str(e)}) + + # Mix of shared and custom directory access + configs = [ + ('webui', False), # Shared + ('api', False), # Shared + ('mqtt', True), # Custom + ('scanner', True), # Custom + ('backup', False), # Shared + ('processor', True), # Custom + ] + + threads = [] + for name, use_custom in configs: + thread = threading.Thread(target=create_certs, args=(name, use_custom)) + threads.append(thread) + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # Verify no errors + assert len(errors) == 0, f"Errors occurred: {errors}" + assert len(results) == len(configs) + + # Verify shared components use the same certificates + shared_certs = [r for r in results if not r['use_custom']] + shared_paths = set(r['cert_path'] for r in shared_certs) + assert len(shared_paths) == 1, "Shared components should use same certificates" + + # Verify custom components have unique certificates + custom_certs = [r for r in results if r['use_custom']] + custom_paths = set(r['cert_path'] for r in custom_certs) + assert len(custom_paths) == len(custom_certs), "Custom components should have unique certificates" + + +class TestDirectoryCreation: + """Test automatic directory creation.""" + + def test_creates_missing_directory(self, temp_certs_dir, reset_manager): + """Test that missing directories are created automatically.""" + nested_dir = os.path.join(temp_certs_dir, "deeply", "nested", "path") + assert not os.path.exists(nested_dir) + + cert_path, key_path = TLSCertificateManager.get_or_create_certificates( + certs_dir=nested_dir + ) + + assert os.path.exists(nested_dir) + assert os.path.exists(cert_path) + assert os.path.exists(key_path) + + def test_handles_existing_directory(self, temp_certs_dir, reset_manager): + """Test that existing directories are handled correctly.""" + # Pre-create the directory + os.makedirs(temp_certs_dir, exist_ok=True) + + cert_path, key_path = TLSCertificateManager.get_or_create_certificates( + certs_dir=temp_certs_dir + ) + + assert os.path.exists(cert_path) + assert os.path.exists(key_path) + + +class TestErrorHandling: + """Test error handling scenarios.""" + + def test_invalid_directory_permissions(self, reset_manager): + """Test handling of directories with invalid permissions.""" + # This test is platform-specific and may need adjustment + if os.name != 'posix': + pytest.skip("Permission test only applicable on POSIX systems") + + temp_dir = tempfile.mkdtemp() + try: + # Make directory read-only + os.chmod(temp_dir, 0o444) + + with pytest.raises(RuntimeError) as exc_info: + TLSCertificateManager.get_or_create_certificates(certs_dir=temp_dir) + + assert "Failed to generate TLS certificates" in str(exc_info.value) + finally: + # Restore permissions for cleanup + os.chmod(temp_dir, 0o755) + shutil.rmtree(temp_dir) + + +class TestPerformance: + """Test performance characteristics.""" + + def test_fast_path_no_lock_overhead(self, temp_certs_dir, reset_manager): + """Test that retrieving existing certificates is fast (no lock acquisition).""" + # Create certificates first + TLSCertificateManager.get_or_create_certificates(certs_dir=temp_certs_dir) + + # Measure retrieval time + iterations = 100 + start = time.time() + for _ in range(iterations): + TLSCertificateManager.get_or_create_certificates(certs_dir=temp_certs_dir) + elapsed = time.time() - start + + # Should be very fast (< 1ms per call on average) + avg_time = elapsed / iterations + assert avg_time < 0.001, f"Average retrieval time too slow: {avg_time:.6f}s" + + def test_concurrent_different_dirs_no_blocking(self, temp_certs_dir, reset_manager): + """Test that different directories don't block each other significantly.""" + total_elapsed_lock = threading.Lock() + elapsed_times = [] + + def create_certs(brick_name): + brick_dir = os.path.join(temp_certs_dir, brick_name) + + start = time.time() + TLSCertificateManager.get_or_create_certificates(certs_dir=brick_dir) + elapsed = time.time() - start + + with total_elapsed_lock: + elapsed_times.append(elapsed) + + bricks = ['brick1', 'brick2', 'brick3', 'brick4'] + threads = [threading.Thread(target=create_certs, args=(c,)) for c in bricks] + + start = time.time() + for thread in threads: + thread.start() + for thread in threads: + thread.join() + overall_run_time = time.time() - start + + # If bricks truly run in parallel, total time should be lower than + # the cumulative total run times by all threads + assert overall_run_time < sum(elapsed_times), f"Bricks blocked each other: {overall_run_time:.3f}s should be lower than {sum(elapsed_times):.3f}s" From 9ac55b5443dec66281488a027fbc1b3e4c6889e1 Mon Sep 17 00:00:00 2001 From: Roberto Gazia Date: Fri, 28 Nov 2025 09:30:47 +0100 Subject: [PATCH 02/12] refactor: move web_ui cert management to TLSCertificateManager --- src/arduino/app_bricks/web_ui/certs.py | 95 ------------------------- src/arduino/app_bricks/web_ui/web_ui.py | 71 ++++++++++-------- 2 files changed, 40 insertions(+), 126 deletions(-) delete mode 100644 src/arduino/app_bricks/web_ui/certs.py diff --git a/src/arduino/app_bricks/web_ui/certs.py b/src/arduino/app_bricks/web_ui/certs.py deleted file mode 100644 index 4ae3e1f9..00000000 --- a/src/arduino/app_bricks/web_ui/certs.py +++ /dev/null @@ -1,95 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (C) ARDUINO SRL (http://www.arduino.cc) -# -# SPDX-License-Identifier: MPL-2.0 - -import os -from cryptography import x509 -from cryptography.x509.oid import NameOID -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography.hazmat.primitives import serialization -from datetime import datetime, timedelta, UTC - - -def cert_exists(root_dir: str) -> bool: - """Check if the SSL certificate and private key files exist. - - Args: - root_dir (str): The root directory where the SSL files are stored. - - Returns: - bool: True if both key and cert files exist, False otherwise. - """ - return os.path.exists(os.path.join(root_dir, "key.pem")) and os.path.exists(os.path.join(root_dir, "cert.pem")) - - -def get_cert(root_dir: str) -> str: - """Get the path to the SSL certificate file. - - Args: - root_dir (str): The root directory where the SSL files are stored. - - Returns: - str: The path to the SSL certificate file. - """ - return os.path.join(root_dir, "cert.pem") - - -def get_pkey(root_dir: str) -> str: - """Get the path to the SSL private key file. - - Args: - root_dir: The root directory where the SSL files are stored. - - Returns: - str: The path to the SSL private key file. - """ - return os.path.join(root_dir, "key.pem") - - -def generate_self_signed_cert(root_dir: str): - """Generate a self-signed SSL certificate and private key. - - Args: - root_dir (str): The root directory where the SSL files will be stored. - """ - # Generate a private key - private_key = rsa.generate_private_key( - public_exponent=65537, - key_size=2048, - ) - - # Generate a self-signed certificate - subject = issuer = x509.Name([ - x509.NameAttribute(NameOID.COUNTRY_NAME, "IT"), - x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Piedmont"), - x509.NameAttribute(NameOID.LOCALITY_NAME, "Turin"), - x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Arduino"), - x509.NameAttribute(NameOID.COMMON_NAME, "0.0.0.0"), - ]) - cert = x509.CertificateBuilder() - cert = cert.subject_name(subject) - cert = cert.issuer_name(issuer) - cert = cert.public_key(private_key.public_key()) - cert = cert.serial_number(x509.random_serial_number()) - cert = cert.not_valid_before(datetime.now(UTC)) - cert = cert.not_valid_after(datetime.now(UTC) + timedelta(days=365)) # Valid for 1 year - cert = cert.add_extension(x509.SubjectAlternativeName([x509.DNSName("0.0.0.0")]), critical=False) - cert = cert.sign(private_key, hashes.SHA256()) - - if not os.path.exists(root_dir): - os.makedirs(root_dir) - - # Write the private key to a PEM file - with open(os.path.join(root_dir, "key.pem"), "wb") as key_file: - key_file.write( - private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption(), - ) - ) - - # Write the certificate to a PEM file - with open(os.path.join(root_dir, "cert.pem"), "wb") as cert_file: - cert_file.write(cert.public_bytes(serialization.Encoding.PEM)) diff --git a/src/arduino/app_bricks/web_ui/web_ui.py b/src/arduino/app_bricks/web_ui/web_ui.py index eee20a69..2796121f 100644 --- a/src/arduino/app_bricks/web_ui/web_ui.py +++ b/src/arduino/app_bricks/web_ui/web_ui.py @@ -2,14 +2,17 @@ # # SPDX-License-Identifier: MPL-2.0 -from collections.abc import Callable -import asyncio import os +import asyncio import threading +from typing import Any +from collections.abc import Callable + import uvicorn from fastapi import FastAPI from fastapi.responses import FileResponse from fastapi_socketio import SocketManager + from arduino.app_utils import brick, Logger logger = Logger("WebUI") @@ -32,7 +35,8 @@ def __init__( api_path_prefix: str = "", assets_dir_path: str = "/app/assets", certs_dir_path: str = "/app/certs", - use_ssl: bool = False, + use_tls: bool = False, + use_ssl: bool | None = None, # Deprecated alias for use_tls ): """Initialize the web server. @@ -42,9 +46,14 @@ def __init__( ui_path_prefix (str, optional): URL prefix for UI routes. Defaults to "" (root). api_path_prefix (str, optional): URL prefix for API routes. Defaults to "" (root). assets_dir_path (str, optional): Path to static assets directory. Defaults to "/app/assets". - certs_dir_path (str, optional): Path to SSL certificates directory. Defaults to "/app/certs". - use_ssl (bool, optional): Enable SSL/HTTPS. Defaults to False. + certs_dir_path (str, optional): Path to TLS certificates directory. Defaults to "/app/certs". + use_tls (bool, optional): Enable TLS/HTTPS. Defaults to False. + use_ssl (bool, optional): Deprecated. Use use_tls instead. Defaults to None. """ + # Handle deprecated use_ssl parameter + if use_ssl is not None: + logger.warning("'use_ssl' parameter is deprecated. Use 'use_tls' instead.") + use_tls = use_ssl self.app = FastAPI(title=__name__, openapi_url=None, on_startup=[self._on_startup]) self.sio = SocketManager(app=self.app, mount_location="/socket.io", socketio_path="", max_http_buffer_size=10 * 1024 * 1024) @@ -54,23 +63,23 @@ def __init__( self._api_path_prefix = api_path_prefix self._assets_dir_path = os.path.abspath(assets_dir_path) self._certs_dir_path = os.path.abspath(certs_dir_path) - self._use_ssl = use_ssl - self._protocol = "https" if self._use_ssl else "http" - self._server: uvicorn.Server = None + self._use_tls = use_tls + self._protocol = "https" if self._use_tls else "http" + self._server: uvicorn.Server | None = None self._server_loop: asyncio.AbstractEventLoop | None = None - self._on_connect_cb: Callable[[str], None] = None - self._on_disconnect_cb: Callable[[str], None] = None + self._on_connect_cb: Callable[[str], None] | None = None + self._on_disconnect_cb: Callable[[str], None] | None = None self._on_message_cbs = {} self._on_message_cbs_lock = threading.Lock() def start(self): """Start the web server asynchronously. - This sets up static file routing and WebSocket event handlers, configures SSL if enabled, and launches the server using Uvicorn. + This sets up static file routing and WebSocket event handlers, configures TLS if enabled, and launches the server using Uvicorn. Raises: RuntimeError: If 'index.html' is missing in the static assets directory. - RuntimeError: If SSL is enabled but certificates are missing or fail to generate. + RuntimeError: If TLS is enabled but certificates fail to generate. RuntimeWarning: If the server is already running. """ # Setup static routes and SocketIO events @@ -82,18 +91,18 @@ def start(self): self._init_socketio() config = uvicorn.Config(self.app, host=self._addr, port=self._port, log_level="warning") - if self._use_ssl: - from . import certs - - if not certs.cert_exists(self._certs_dir_path): - try: - certs.generate_self_signed_cert(self._certs_dir_path) - except Exception as e: - logger.exception(f"Failed to generate SSL certificate: {e}") - raise RuntimeError("Failed to generate SSL certificate. Please check the certs directory.") from e - - config.ssl_keyfile = certs.get_pkey(self._certs_dir_path) - config.ssl_certfile = certs.get_cert(self._certs_dir_path) + if self._use_tls: + from arduino.app_utils.tls_cert_manager import TLSCertificateManager + try: + cert_path, key_path = TLSCertificateManager.get_or_create_certificates( + certs_dir=self._certs_dir_path, + common_name=self._addr + ) + config.ssl_certfile = cert_path + config.ssl_keyfile = key_path + except Exception as e: + logger.exception(f"Failed to configure SSL certificate: {e}") + raise RuntimeError("Failed to configure TLS certificate. Please check the certs directory.") from e self._server = uvicorn.Server(config) @@ -108,8 +117,8 @@ def stop(self): def execute(self): logger.debug(f"Serving static web files from {self._assets_dir_path}") - if self._use_ssl: - logger.debug(f"Serving certificates from {self._certs_dir_path}") + if self._use_tls: + logger.debug(f"Using TLS certificates from {self._certs_dir_path}") logger.debug("Starting server...") @@ -126,7 +135,7 @@ def execute(self): except Exception as e: logger.exception(f"Error running server: {e}") - def expose_api(self, method: str, path: str, function: callable): + def expose_api(self, method: str, path: str, function: Callable): """Register a route with the specified HTTP method and path. The path will be prefixed with the api_path_prefix configured during initialization. @@ -134,7 +143,7 @@ def expose_api(self, method: str, path: str, function: callable): Args: method (str): HTTP method to use (e.g., "GET", "POST"). path (str): URL path for the API endpoint (without the prefix). - function (callable): Function to execute when the route is accessed. + function (Callable): Function to execute when the route is accessed. """ self.app.add_api_route(self._api_path_prefix + path, function, methods=[method]) @@ -160,7 +169,7 @@ def on_disconnect(self, callback: Callable[[str], None]): """ self._on_disconnect_cb = callback - def on_message(self, message_type: str, callback: Callable[[str, any], any]): + def on_message(self, message_type: str, callback: Callable[[str, Any], Any]): """Register a callback function for a specific WebSocket message type received by clients. The client should send messages named as message_type for this callback to be triggered. @@ -170,7 +179,7 @@ def on_message(self, message_type: str, callback: Callable[[str, any], any]): Args: message_type (str): The message type name to listen for. - callback (Callable[[str, any], any]): Function to handle the message. Receives two arguments: + callback (Callable[[str, Any], Any]): Function to handle the message. Receives two arguments: the session ID (sid) and the incoming message data. """ @@ -180,7 +189,7 @@ def on_message(self, message_type: str, callback: Callable[[str, any], any]): self._on_message_cbs[message_type] = callback logger.debug(f"Registered listener for message '{message_type}'") - def send_message(self, message_type: str, message: dict | str, room: str = None): + def send_message(self, message_type: str, message: dict | str, room: str | None = None): """Send a message to connected WebSocket clients. Args: From 4516781330b0ddbd3a25eee468293369c5581b88 Mon Sep 17 00:00:00 2001 From: Roberto Gazia Date: Fri, 28 Nov 2025 13:32:26 +0100 Subject: [PATCH 03/12] refactor: expose url and local_url from WebUI --- src/arduino/app_bricks/web_ui/web_ui.py | 36 ++++++++++++++++++++----- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/src/arduino/app_bricks/web_ui/web_ui.py b/src/arduino/app_bricks/web_ui/web_ui.py index 2796121f..1f04ac38 100644 --- a/src/arduino/app_bricks/web_ui/web_ui.py +++ b/src/arduino/app_bricks/web_ui/web_ui.py @@ -58,7 +58,12 @@ def __init__( self.sio = SocketManager(app=self.app, mount_location="/socket.io", socketio_path="", max_http_buffer_size=10 * 1024 * 1024) self._addr = addr - self._port = port + def pick_free_port(): + import socket + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('', 0)) + return s.getsockname()[1] + self._port = port if port != 0 else pick_free_port() self._ui_path_prefix = ui_path_prefix self._api_path_prefix = api_path_prefix self._assets_dir_path = os.path.abspath(assets_dir_path) @@ -71,6 +76,24 @@ def __init__( self._on_disconnect_cb: Callable[[str], None] | None = None self._on_message_cbs = {} self._on_message_cbs_lock = threading.Lock() + + @property + def local_url(self) -> str: + """Get the locally addressable URL of the web server. + + Returns: + str: The server's URL (including protocol, address, and port). + """ + return f"{self._protocol}://localhost:{self._port}" + + @property + def url(self) -> str: + """Get the externally addressable URL of the web server. + + Returns: + str: The server's URL (including protocol, address, and port). + """ + return f"{self._protocol}://{os.getenv('HOST_IP') or self._addr}:{self._port}" def start(self): """Start the web server asynchronously. @@ -123,11 +146,9 @@ def execute(self): logger.debug("Starting server...") startup_log = "The application interface is available here:\n" - startup_log += f" - Local URL: {self._protocol}://localhost:{self._port}" - host_ip = os.getenv("HOST_IP") - if host_ip: - network_url = f"{self._protocol}://{host_ip}:{self._port}" - startup_log += f"\n - Network URL: {network_url}" + startup_log += f" - Local URL: {self.local_url}" + if os.getenv("HOST_IP"): + startup_log += f"\n - Network URL: {self.url}" logger.info(startup_log) try: @@ -209,7 +230,8 @@ def send_message(self, message_type: str, message: dict | str, room: str | None logger.exception(f"Failed to send WebSocket message '{message_type}': {e}") async def _on_startup(self): - """This function is called by uvicorn when the server starts up, it is necessary to capture the running + """ + This function is called by uvicorn when the server starts up, it is necessary to capture the running asyncio event loop and reuse it later for emitting socket.io events as it requires an asyncio context. """ self._server_loop = asyncio.get_running_loop() From fb2c14631da426654c7b5ec75ec01b717dcbb666 Mon Sep 17 00:00:00 2001 From: Roberto Gazia Date: Fri, 28 Nov 2025 13:33:22 +0100 Subject: [PATCH 04/12] refactor: remove deprecation notice --- src/arduino/app_bricks/web_ui/web_ui.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/arduino/app_bricks/web_ui/web_ui.py b/src/arduino/app_bricks/web_ui/web_ui.py index 1f04ac38..e3025f45 100644 --- a/src/arduino/app_bricks/web_ui/web_ui.py +++ b/src/arduino/app_bricks/web_ui/web_ui.py @@ -5,6 +5,7 @@ import os import asyncio import threading +from contextlib import asynccontextmanager from typing import Any from collections.abc import Callable @@ -54,7 +55,13 @@ def __init__( if use_ssl is not None: logger.warning("'use_ssl' parameter is deprecated. Use 'use_tls' instead.") use_tls = use_ssl - self.app = FastAPI(title=__name__, openapi_url=None, on_startup=[self._on_startup]) + + @asynccontextmanager + async def lifespan(app): + await self._on_startup() + yield + + self.app = FastAPI(title=__name__, openapi_url=None, lifespan=lifespan) self.sio = SocketManager(app=self.app, mount_location="/socket.io", socketio_path="", max_http_buffer_size=10 * 1024 * 1024) self._addr = addr From 404917c6ddef38964fcc3e7700d55b60d1c311c6 Mon Sep 17 00:00:00 2001 From: Roberto Gazia Date: Fri, 28 Nov 2025 16:30:18 +0100 Subject: [PATCH 05/12] test: add basic tests for WebUI --- pyproject.toml | 1 + .../arduino/app_bricks/web_ui/test_web_ui.py | 63 ++++++++++++++ .../web_ui/test_web_ui_integration.py | 84 +++++++++++++++++++ 3 files changed, 148 insertions(+) create mode 100644 tests/arduino/app_bricks/web_ui/test_web_ui.py create mode 100644 tests/arduino/app_bricks/web_ui/test_web_ui_integration.py diff --git a/pyproject.toml b/pyproject.toml index 86765c9b..477f8080 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dev = [ "setuptools", "build", "pytest", + "websocket-client", "ruff", "docstring_parser>=0.16", "arduino_app_bricks[all]", diff --git a/tests/arduino/app_bricks/web_ui/test_web_ui.py b/tests/arduino/app_bricks/web_ui/test_web_ui.py new file mode 100644 index 00000000..06e6ae05 --- /dev/null +++ b/tests/arduino/app_bricks/web_ui/test_web_ui.py @@ -0,0 +1,63 @@ +from fastapi.testclient import TestClient +from arduino.app_bricks.web_ui.web_ui import WebUI + + +def test_webui_init_defaults(): + ui = WebUI() + assert ui._addr == "0.0.0.0" + assert ui._port == 7000 + assert ui._ui_path_prefix == "" + assert ui._api_path_prefix == "" + assert ui._assets_dir_path.endswith("/app/assets") + assert ui._certs_dir_path.endswith("/app/certs") + assert ui._use_tls is False + assert ui._protocol == "http" + assert ui._server is None + assert ui._server_loop is None + +def test_webui_init_use_ssl_deprecated(): + webui = WebUI(use_ssl=True) + assert webui._use_tls is True + +def test_expose_api_route(): + ui = WebUI() + def dummy(): + return {"ok": True} + ui.expose_api("GET", "/dummy", dummy) + client = TestClient(ui.app) + response = client.get("/dummy") + assert response.status_code == 200 + assert response.json() == {"ok": True} + +def test_on_connect_and_disconnect(): + ui = WebUI() + called = {"connect": False, "disconnect": False} + def connect_cb(sid): + called["connect"] = True + def disconnect_cb(sid): + called["disconnect"] = True + ui.on_connect(connect_cb) + ui.on_disconnect(disconnect_cb) + assert ui._on_connect_cb == connect_cb + assert ui._on_disconnect_cb == disconnect_cb + +def test_on_message_registration(): + ui = WebUI() + def msg_cb(sid, data): + return "pong" + ui.on_message("ping", msg_cb) + assert "ping" in ui._on_message_cbs + assert ui._on_message_cbs["ping"] == msg_cb + +def test_send_message_no_loop(): + ui = WebUI() + ui.send_message("test", {"msg": "hi"}) # Should not raise + +def test_stop_sets_should_exit(): + import unittest.mock + ui = WebUI() + dummy_server = unittest.mock.Mock() + dummy_server.should_exit = False + ui._server = dummy_server + ui.stop() + assert dummy_server.should_exit is True diff --git a/tests/arduino/app_bricks/web_ui/test_web_ui_integration.py b/tests/arduino/app_bricks/web_ui/test_web_ui_integration.py new file mode 100644 index 00000000..ea2727c6 --- /dev/null +++ b/tests/arduino/app_bricks/web_ui/test_web_ui_integration.py @@ -0,0 +1,84 @@ +import threading +import time +import pytest +import requests +import socketio +from arduino.app_bricks.web_ui.web_ui import WebUI + +@pytest.fixture(scope="module") +def webui_server(): + import tempfile + with tempfile.TemporaryDirectory() as tmp_assets: + import os + os.makedirs(tmp_assets, exist_ok=True) + with open(os.path.join(tmp_assets, "index.html"), "w") as f: + f.write("Hello") + + ui = WebUI(port=0, assets_dir_path=tmp_assets) + ui.start() + + thread = threading.Thread(target=ui.execute, daemon=True) + thread.start() + + time.sleep(1) # Wait for server to start + + yield ui + + ui.stop() + thread.join(timeout=2) + + +def test_http_index(webui_server): + resp = requests.get(f"{webui_server.url}/") + assert resp.status_code == 200 + assert "Hello" in resp.text + + +def test_expose_api_rest(webui_server): + def get_hello(): + return {"msg": "hello"} + webui_server.expose_api("GET", "/api/hello", get_hello) + def post_echo(data: dict): + return {"echo": data.get("value")} + webui_server.expose_api("POST", "/api/echo", post_echo) + + resp = requests.get(f"{webui_server.url}/api/hello") + assert resp.status_code == 200 + assert resp.json() == {"msg": "hello"} + + resp = requests.post(f"{webui_server.url}/api/echo", json={"value": "test123"}) + assert resp.status_code == 200 + assert resp.json() == {"echo": "test123"} + + +def test_websocket_exchange(webui_server): + sio = socketio.Client() + received = {} + test_done = threading.Event() + + @sio.event + def connect(): + received["connect"] = True + + @sio.event + def disconnect(): + received["disconnect"] = True + + def on_ping_response(data): + received["ping_response"] = data + test_done.set() + sio.on("ping_response", on_ping_response) + + # Register a ping handler on server + def ping_cb(sid, data): + return "pong" + webui_server.on_message("ping", ping_cb) + + sio.connect(f"{webui_server.url}", socketio_path="/socket.io") + sio.emit("ping", {"msg": "hi"}) + test_done.wait(timeout=2) + sio.disconnect() + + assert received.get("connect") is True + assert received.get("ping_response") == "pong" + assert received.get("disconnect") is True From b7ddd2bdd0c068b13edbea8d9a2ea4ff840238a3 Mon Sep 17 00:00:00 2001 From: Roberto Gazia Date: Fri, 28 Nov 2025 16:30:47 +0100 Subject: [PATCH 06/12] fix: fmt --- src/arduino/app_bricks/web_ui/web_ui.py | 15 +- src/arduino/app_utils/tls_cert_manager.py | 83 +++--- .../arduino/app_bricks/web_ui/test_web_ui.py | 14 + .../web_ui/test_web_ui_integration.py | 14 +- .../tls_cert_manager/test_tls_cert_manager.py | 247 ++++++++---------- 5 files changed, 176 insertions(+), 197 deletions(-) diff --git a/src/arduino/app_bricks/web_ui/web_ui.py b/src/arduino/app_bricks/web_ui/web_ui.py index e3025f45..ce1a8f01 100644 --- a/src/arduino/app_bricks/web_ui/web_ui.py +++ b/src/arduino/app_bricks/web_ui/web_ui.py @@ -60,16 +60,19 @@ def __init__( async def lifespan(app): await self._on_startup() yield - + self.app = FastAPI(title=__name__, openapi_url=None, lifespan=lifespan) self.sio = SocketManager(app=self.app, mount_location="/socket.io", socketio_path="", max_http_buffer_size=10 * 1024 * 1024) self._addr = addr + def pick_free_port(): import socket + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(('', 0)) + s.bind(("", 0)) return s.getsockname()[1] + self._port = port if port != 0 else pick_free_port() self._ui_path_prefix = ui_path_prefix self._api_path_prefix = api_path_prefix @@ -83,7 +86,7 @@ def pick_free_port(): self._on_disconnect_cb: Callable[[str], None] | None = None self._on_message_cbs = {} self._on_message_cbs_lock = threading.Lock() - + @property def local_url(self) -> str: """Get the locally addressable URL of the web server. @@ -123,11 +126,9 @@ def start(self): config = uvicorn.Config(self.app, host=self._addr, port=self._port, log_level="warning") if self._use_tls: from arduino.app_utils.tls_cert_manager import TLSCertificateManager + try: - cert_path, key_path = TLSCertificateManager.get_or_create_certificates( - certs_dir=self._certs_dir_path, - common_name=self._addr - ) + cert_path, key_path = TLSCertificateManager.get_or_create_certificates(certs_dir=self._certs_dir_path, common_name=self._addr) config.ssl_certfile = cert_path config.ssl_keyfile = key_path except Exception as e: diff --git a/src/arduino/app_utils/tls_cert_manager.py b/src/arduino/app_utils/tls_cert_manager.py index ef387723..76e93ac0 100644 --- a/src/arduino/app_utils/tls_cert_manager.py +++ b/src/arduino/app_utils/tls_cert_manager.py @@ -20,21 +20,21 @@ "locality_name": "Turin", "organization_name": "Arduino", "common_name": "0.0.0.0", - "validity_days": 365 + "validity_days": 365, } + class TLSCertificateManager: """Certificate manager for TLS certificates. - + This class handles certificate generation and retrieval on a Brick basis. By default, all bricks share certificates from the default directory (/app/certs). Components can use their own certificates by providing a different certs_dir path. """ - + _locks = {} _locks_lock = threading.Lock() - - + @classmethod def get_or_create_certificates( cls, @@ -44,16 +44,16 @@ def get_or_create_certificates( locality_name: str = DEFAULT_CERTS_PARAMS["locality_name"], organization_name: str = DEFAULT_CERTS_PARAMS["organization_name"], common_name: str = DEFAULT_CERTS_PARAMS["common_name"], - validity_days: int = DEFAULT_CERTS_PARAMS["validity_days"] + validity_days: int = DEFAULT_CERTS_PARAMS["validity_days"], ) -> tuple[str, str]: """Get or create TLS certificates at the specified path. - + By default, uses shared certificates in /app/certs. If a different certs_dir is provided, uses certificates specific to that directory (useful for brick-specific certificates). - + Concurrent access is managed to prevent race conditions when multiple bricks attempt to access certificates simultaneously. - + Args: certs_dir (str, optional): Directory for certificates. Defaults to /app/certs (shared by all bricks). Provide a different path for brick-specific certificates. @@ -64,17 +64,17 @@ def get_or_create_certificates( organization_name (str, optional): Organization name for the certificate. Defaults to "Arduino". common_name (str, optional): Common name for the certificate. Defaults to "0.0.0.0". validity_days (int, optional): Certificate validity period in days. Defaults to 365. - + Returns: tuple[str, str]: Paths to (certificate_file, private_key_file) - + Raises: RuntimeError: If certificate generation fails. """ target_dir = certs_dir or DEFAULT_CERTS_DIR cert_path = os.path.join(target_dir, "cert.pem") key_path = os.path.join(target_dir, "key.pem") - + if cls.certificates_exist(target_dir): return cert_path, key_path @@ -82,29 +82,23 @@ def get_or_create_certificates( with dir_lock: if cls.certificates_exist(target_dir): return cert_path, key_path - + try: cls._generate_self_signed_cert( - target_dir, - country_name, - state_or_province_name, - locality_name, - organization_name, - common_name, - validity_days + target_dir, country_name, state_or_province_name, locality_name, organization_name, common_name, validity_days ) return cert_path, key_path except Exception as e: raise RuntimeError(f"Failed to generate TLS certificates in {target_dir}: {e}") from e - + @classmethod def certificates_exist(cls, certs_dir: str = DEFAULT_CERTS_DIR) -> bool: """Check if TLS certificates exist in the given directory. - + Args: certs_dir (str, optional): Directory for certificates. Defaults to /app/certs. - + Returns: bool: True if both certificate and key files exist, False otherwise. """ @@ -112,52 +106,52 @@ def certificates_exist(cls, certs_dir: str = DEFAULT_CERTS_DIR) -> bool: cert_path = os.path.join(target_dir, "cert.pem") key_path = os.path.join(target_dir, "key.pem") return os.path.exists(cert_path) and os.path.exists(key_path) - + @classmethod def get_certificates_paths(cls, certs_dir: str = DEFAULT_CERTS_DIR) -> tuple[str, str]: """Get the paths to the TLS certificate and private key files. - + Args: certs_dir (str, optional): Directory for certificates. Defaults to /app/certs. Returns: tuple[str, str]: Paths to certificate_file and private_key_file """ return cls.get_certificate_path(certs_dir), cls.get_private_key_path(certs_dir) - + @classmethod def get_certificate_path(cls, certs_dir: str = DEFAULT_CERTS_DIR) -> str: """Get the path to the TLS certificate file. - + Args: certs_dir (str, optional): Directory for certificates. Defaults to /app/certs. - + Returns: str: Path to the certificate file. """ return os.path.join(certs_dir or DEFAULT_CERTS_DIR, "cert.pem") - + @classmethod def get_private_key_path(cls, certs_dir: str = DEFAULT_CERTS_DIR) -> str: """Get the path to the TLS private key file. - + Args: certs_dir (str, optional): Directory for certificates. Defaults to /app/certs. - + Returns: str: Path to the private key file. """ return os.path.join(certs_dir or DEFAULT_CERTS_DIR, "key.pem") - + @classmethod def _get_dir_lock(cls, target_dir: str) -> threading.Lock: """Get or create a lock for a specific directory. - + This ensures that only operations on the same directory block each other, while operations on different directories can proceed concurrently. - + Args: target_dir (str): The normalized absolute path to the directory. - + Returns: threading.Lock: A lock specific to this directory. """ @@ -165,7 +159,7 @@ def _get_dir_lock(cls, target_dir: str) -> threading.Lock: if target_dir not in cls._locks: cls._locks[target_dir] = threading.Lock() return cls._locks[target_dir] - + @staticmethod def _generate_self_signed_cert( target_dir: str, @@ -174,14 +168,14 @@ def _generate_self_signed_cert( locality_name: str, organization_name: str, common_name: str, - validity_days: int + validity_days: int, ): # Generate a private key private_key = rsa.generate_private_key( public_exponent=65537, key_size=2048, ) - + # Generate a self-signed certificate subject = issuer = x509.Name([ x509.NameAttribute(NameOID.COUNTRY_NAME, country_name), @@ -190,7 +184,7 @@ def _generate_self_signed_cert( x509.NameAttribute(NameOID.ORGANIZATION_NAME, organization_name), x509.NameAttribute(NameOID.COMMON_NAME, common_name), ]) - + cert = x509.CertificateBuilder() cert = cert.subject_name(subject) cert = cert.issuer_name(issuer) @@ -198,19 +192,16 @@ def _generate_self_signed_cert( cert = cert.serial_number(x509.random_serial_number()) cert = cert.not_valid_before(datetime.now(UTC)) cert = cert.not_valid_after(datetime.now(UTC) + timedelta(days=validity_days)) - cert = cert.add_extension( - x509.SubjectAlternativeName([x509.DNSName(common_name)]), - critical=False - ) + cert = cert.add_extension(x509.SubjectAlternativeName([x509.DNSName(common_name)]), critical=False) cert = cert.sign(private_key, hashes.SHA256()) - + Path(target_dir).mkdir(parents=True, exist_ok=True) - + # Write the certificate to a PEM file cert_path = os.path.join(target_dir, "cert.pem") with open(cert_path, "wb") as cert_file: cert_file.write(cert.public_bytes(serialization.Encoding.PEM)) - + # Write the private key to a PEM file key_path = os.path.join(target_dir, "key.pem") with open(key_path, "wb") as key_file: diff --git a/tests/arduino/app_bricks/web_ui/test_web_ui.py b/tests/arduino/app_bricks/web_ui/test_web_ui.py index 06e6ae05..ccdbcd18 100644 --- a/tests/arduino/app_bricks/web_ui/test_web_ui.py +++ b/tests/arduino/app_bricks/web_ui/test_web_ui.py @@ -15,46 +15,60 @@ def test_webui_init_defaults(): assert ui._server is None assert ui._server_loop is None + def test_webui_init_use_ssl_deprecated(): webui = WebUI(use_ssl=True) assert webui._use_tls is True + def test_expose_api_route(): ui = WebUI() + def dummy(): return {"ok": True} + ui.expose_api("GET", "/dummy", dummy) client = TestClient(ui.app) response = client.get("/dummy") assert response.status_code == 200 assert response.json() == {"ok": True} + def test_on_connect_and_disconnect(): ui = WebUI() called = {"connect": False, "disconnect": False} + def connect_cb(sid): called["connect"] = True + def disconnect_cb(sid): called["disconnect"] = True + ui.on_connect(connect_cb) ui.on_disconnect(disconnect_cb) assert ui._on_connect_cb == connect_cb assert ui._on_disconnect_cb == disconnect_cb + def test_on_message_registration(): ui = WebUI() + def msg_cb(sid, data): return "pong" + ui.on_message("ping", msg_cb) assert "ping" in ui._on_message_cbs assert ui._on_message_cbs["ping"] == msg_cb + def test_send_message_no_loop(): ui = WebUI() ui.send_message("test", {"msg": "hi"}) # Should not raise + def test_stop_sets_should_exit(): import unittest.mock + ui = WebUI() dummy_server = unittest.mock.Mock() dummy_server.should_exit = False diff --git a/tests/arduino/app_bricks/web_ui/test_web_ui_integration.py b/tests/arduino/app_bricks/web_ui/test_web_ui_integration.py index ea2727c6..e221efe6 100644 --- a/tests/arduino/app_bricks/web_ui/test_web_ui_integration.py +++ b/tests/arduino/app_bricks/web_ui/test_web_ui_integration.py @@ -5,25 +5,28 @@ import socketio from arduino.app_bricks.web_ui.web_ui import WebUI + @pytest.fixture(scope="module") def webui_server(): import tempfile + with tempfile.TemporaryDirectory() as tmp_assets: import os + os.makedirs(tmp_assets, exist_ok=True) with open(os.path.join(tmp_assets, "index.html"), "w") as f: f.write("Hello") - + ui = WebUI(port=0, assets_dir_path=tmp_assets) ui.start() - + thread = threading.Thread(target=ui.execute, daemon=True) thread.start() time.sleep(1) # Wait for server to start yield ui - + ui.stop() thread.join(timeout=2) @@ -37,9 +40,12 @@ def test_http_index(webui_server): def test_expose_api_rest(webui_server): def get_hello(): return {"msg": "hello"} + webui_server.expose_api("GET", "/api/hello", get_hello) + def post_echo(data: dict): return {"echo": data.get("value")} + webui_server.expose_api("POST", "/api/echo", post_echo) resp = requests.get(f"{webui_server.url}/api/hello") @@ -67,11 +73,13 @@ def disconnect(): def on_ping_response(data): received["ping_response"] = data test_done.set() + sio.on("ping_response", on_ping_response) # Register a ping handler on server def ping_cb(sid, data): return "pong" + webui_server.on_message("ping", ping_cb) sio.connect(f"{webui_server.url}", socketio_path="/socket.io") diff --git a/tests/arduino/app_utils/tls_cert_manager/test_tls_cert_manager.py b/tests/arduino/app_utils/tls_cert_manager/test_tls_cert_manager.py index 55e2e914..ec3472e6 100644 --- a/tests/arduino/app_utils/tls_cert_manager/test_tls_cert_manager.py +++ b/tests/arduino/app_utils/tls_cert_manager/test_tls_cert_manager.py @@ -30,99 +30,87 @@ def temp_certs_dir(): def reset_manager(): """Reset the TLSCertificateManager state between tests.""" yield - + TLSCertificateManager._locks.clear() # Reset state class TestBasicFunctionality: """Test basic certificate creation and retrieval.""" - + def test_create_certificates_in_custom_dir(self, temp_certs_dir, reset_manager): """Test creating certificates in a custom directory.""" - cert_path, key_path = TLSCertificateManager.get_or_create_certificates( - certs_dir=temp_certs_dir - ) - + cert_path, key_path = TLSCertificateManager.get_or_create_certificates(certs_dir=temp_certs_dir) + # Verify paths are correct assert cert_path == os.path.join(temp_certs_dir, "cert.pem") assert key_path == os.path.join(temp_certs_dir, "key.pem") - + # Verify files exist assert os.path.exists(cert_path) assert os.path.exists(key_path) - + def test_certificates_are_valid(self, temp_certs_dir, reset_manager): """Test that generated certificates are valid X.509 certificates.""" - cert_path, key_path = TLSCertificateManager.get_or_create_certificates( - certs_dir=temp_certs_dir, - common_name="test.local" - ) - + cert_path, key_path = TLSCertificateManager.get_or_create_certificates(certs_dir=temp_certs_dir, common_name="test.local") + # Load and verify certificate with open(cert_path, "rb") as f: cert = x509.load_pem_x509_certificate(f.read(), default_backend()) - + # Check common name common_name = cert.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME)[0].value assert common_name == "test.local" - + # Check organization org = cert.subject.get_attributes_for_oid(x509.NameOID.ORGANIZATION_NAME)[0].value assert org == "Arduino" - + def test_reuse_existing_certificates(self, temp_certs_dir, reset_manager): """Test that existing certificates are reused instead of regenerated.""" - cert_path1, key_path1 = TLSCertificateManager.get_or_create_certificates( - certs_dir=temp_certs_dir - ) - + cert_path1, key_path1 = TLSCertificateManager.get_or_create_certificates(certs_dir=temp_certs_dir) + # Get modification time mtime1 = os.path.getmtime(cert_path1) - + # Get certificates again - cert_path2, key_path2 = TLSCertificateManager.get_or_create_certificates( - certs_dir=temp_certs_dir - ) - + cert_path2, key_path2 = TLSCertificateManager.get_or_create_certificates(certs_dir=temp_certs_dir) + assert cert_path1 == cert_path2 assert key_path1 == key_path2 - + # Check modification time is unchanged mtime2 = os.path.getmtime(cert_path2) assert mtime1 == mtime2 - + def test_custom_validity_period(self, temp_certs_dir, reset_manager): """Test creating certificates with custom validity period.""" - cert_path, _ = TLSCertificateManager.get_or_create_certificates( - certs_dir=temp_certs_dir, - validity_days=1 - ) - + cert_path, _ = TLSCertificateManager.get_or_create_certificates(certs_dir=temp_certs_dir, validity_days=1) + with open(cert_path, "rb") as f: cert = x509.load_pem_x509_certificate(f.read(), default_backend()) - + validity_days = (cert.not_valid_after_utc - cert.not_valid_before_utc).days assert validity_days == 1 class TestHelperMethods: """Test helper methods for checking and retrieving certificate paths.""" - + def test_certificates_exist_returns_false_for_missing(self, temp_certs_dir, reset_manager): """Test certificates_exist returns False when certificates don't exist.""" assert not TLSCertificateManager.certificates_exist(certs_dir=temp_certs_dir) - + def test_certificates_exist_returns_true_after_creation(self, temp_certs_dir, reset_manager): """Test certificates_exist returns True after certificates are created.""" TLSCertificateManager.get_or_create_certificates(certs_dir=temp_certs_dir) assert TLSCertificateManager.certificates_exist(certs_dir=temp_certs_dir) - + def test_get_certificate_path(self, temp_certs_dir, reset_manager): """Test get_certificate_path returns correct path.""" expected_path = os.path.join(temp_certs_dir, "cert.pem") actual_path = TLSCertificateManager.get_certificate_path(certs_dir=temp_certs_dir) assert actual_path == expected_path - + def test_get_private_key_path(self, temp_certs_dir, reset_manager): """Test get_private_key_path returns correct path.""" expected_path = os.path.join(temp_certs_dir, "key.pem") @@ -132,121 +120,107 @@ def test_get_private_key_path(self, temp_certs_dir, reset_manager): class TestConcurrentAccess: """Test concurrent access and race condition handling.""" - + def test_concurrent_access_same_directory(self, temp_certs_dir, reset_manager): """Test multiple threads accessing the same directory concurrently.""" results = [] errors = [] - + def create_certs(thread_id): try: start_time = time.time() - cert_path, key_path = TLSCertificateManager.get_or_create_certificates( - certs_dir=temp_certs_dir - ) + cert_path, key_path = TLSCertificateManager.get_or_create_certificates(certs_dir=temp_certs_dir) elapsed = time.time() - start_time - results.append({ - 'thread_id': thread_id, - 'cert_path': cert_path, - 'key_path': key_path, - 'elapsed': elapsed - }) + results.append({"thread_id": thread_id, "cert_path": cert_path, "key_path": key_path, "elapsed": elapsed}) except Exception as e: - errors.append({'thread_id': thread_id, 'error': str(e)}) - + errors.append({"thread_id": thread_id, "error": str(e)}) + # Start 10 threads simultaneously threads = [] for i in range(10): thread = threading.Thread(target=create_certs, args=(i,)) threads.append(thread) - + # Start all threads at once for thread in threads: thread.start() - + # Wait for all to complete for thread in threads: thread.join() - + # Verify no errors occurred assert len(errors) == 0, f"Errors occurred: {errors}" - + # Verify all threads got the same paths assert len(results) == 10 - cert_paths = set(r['cert_path'] for r in results) - key_paths = set(r['key_path'] for r in results) + cert_paths = set(r["cert_path"] for r in results) + key_paths = set(r["key_path"] for r in results) assert len(cert_paths) == 1, "All threads should get the same certificate path" assert len(key_paths) == 1, "All threads should get the same key path" - + # Verify certificates exist and are valid - cert_path = results[0]['cert_path'] + cert_path = results[0]["cert_path"] assert os.path.exists(cert_path) with open(cert_path, "rb") as f: cert = x509.load_pem_x509_certificate(f.read(), default_backend()) assert cert is not None - + def test_concurrent_access_different_directories(self, temp_certs_dir, reset_manager): """Test multiple threads accessing different directories concurrently.""" results = [] errors = [] - + def create_certs(component_name): try: start_time = time.time() component_dir = os.path.join(temp_certs_dir, component_name) - cert_path, key_path = TLSCertificateManager.get_or_create_certificates( - certs_dir=component_dir - ) + cert_path, key_path = TLSCertificateManager.get_or_create_certificates(certs_dir=component_dir) elapsed = time.time() - start_time - results.append({ - 'component': component_name, - 'cert_path': cert_path, - 'key_path': key_path, - 'elapsed': elapsed - }) + results.append({"component": component_name, "cert_path": cert_path, "key_path": key_path, "elapsed": elapsed}) except Exception as e: - errors.append({'component': component_name, 'error': str(e)}) - + errors.append({"component": component_name, "error": str(e)}) + # Simulate multiple components starting simultaneously - components = ['webui', 'api', 'mqtt', 'scanner', 'processor'] + components = ["webui", "api", "mqtt", "scanner", "processor"] threads = [] - + for component in components: thread = threading.Thread(target=create_certs, args=(component,)) threads.append(thread) - + # Start all threads for thread in threads: thread.start() - + # Wait for completion for thread in threads: thread.join() - + # Verify no errors assert len(errors) == 0, f"Errors occurred: {errors}" - + # Verify all components succeeded assert len(results) == len(components) - + # Verify each component has its own certificates - cert_dirs = set(os.path.dirname(r['cert_path']) for r in results) + cert_dirs = set(os.path.dirname(r["cert_path"]) for r in results) assert len(cert_dirs) == len(components), "Each component should have its own directory" - + # Verify all certificates exist and are in correct directories for result in results: - component = result['component'] + component = result["component"] expected_dir = os.path.join(temp_certs_dir, component) - assert expected_dir in result['cert_path'] - assert os.path.exists(result['cert_path']) - assert os.path.exists(result['key_path']) - + assert expected_dir in result["cert_path"] + assert os.path.exists(result["cert_path"]) + assert os.path.exists(result["key_path"]) + def test_concurrent_mixed_access(self, temp_certs_dir, reset_manager): """Test concurrent access with both shared and component-specific directories.""" results = [] errors = [] lock = threading.Lock() - + def create_certs(name, use_custom_dir): try: start_time = time.time() @@ -254,104 +228,93 @@ def create_certs(name, use_custom_dir): certs_dir = os.path.join(temp_certs_dir, name) else: certs_dir = temp_certs_dir - - cert_path, key_path = TLSCertificateManager.get_or_create_certificates( - certs_dir=certs_dir - ) + + cert_path, key_path = TLSCertificateManager.get_or_create_certificates(certs_dir=certs_dir) elapsed = time.time() - start_time - + with lock: - results.append({ - 'name': name, - 'use_custom': use_custom_dir, - 'cert_path': cert_path, - 'elapsed': elapsed - }) + results.append({"name": name, "use_custom": use_custom_dir, "cert_path": cert_path, "elapsed": elapsed}) except Exception as e: with lock: - errors.append({'name': name, 'error': str(e)}) - + errors.append({"name": name, "error": str(e)}) + # Mix of shared and custom directory access configs = [ - ('webui', False), # Shared - ('api', False), # Shared - ('mqtt', True), # Custom - ('scanner', True), # Custom - ('backup', False), # Shared - ('processor', True), # Custom + ("webui", False), # Shared + ("api", False), # Shared + ("mqtt", True), # Custom + ("scanner", True), # Custom + ("backup", False), # Shared + ("processor", True), # Custom ] - + threads = [] for name, use_custom in configs: thread = threading.Thread(target=create_certs, args=(name, use_custom)) threads.append(thread) - + for thread in threads: thread.start() for thread in threads: thread.join() - + # Verify no errors assert len(errors) == 0, f"Errors occurred: {errors}" assert len(results) == len(configs) - + # Verify shared components use the same certificates - shared_certs = [r for r in results if not r['use_custom']] - shared_paths = set(r['cert_path'] for r in shared_certs) + shared_certs = [r for r in results if not r["use_custom"]] + shared_paths = set(r["cert_path"] for r in shared_certs) assert len(shared_paths) == 1, "Shared components should use same certificates" - + # Verify custom components have unique certificates - custom_certs = [r for r in results if r['use_custom']] - custom_paths = set(r['cert_path'] for r in custom_certs) + custom_certs = [r for r in results if r["use_custom"]] + custom_paths = set(r["cert_path"] for r in custom_certs) assert len(custom_paths) == len(custom_certs), "Custom components should have unique certificates" class TestDirectoryCreation: """Test automatic directory creation.""" - + def test_creates_missing_directory(self, temp_certs_dir, reset_manager): """Test that missing directories are created automatically.""" nested_dir = os.path.join(temp_certs_dir, "deeply", "nested", "path") assert not os.path.exists(nested_dir) - - cert_path, key_path = TLSCertificateManager.get_or_create_certificates( - certs_dir=nested_dir - ) - + + cert_path, key_path = TLSCertificateManager.get_or_create_certificates(certs_dir=nested_dir) + assert os.path.exists(nested_dir) assert os.path.exists(cert_path) assert os.path.exists(key_path) - + def test_handles_existing_directory(self, temp_certs_dir, reset_manager): """Test that existing directories are handled correctly.""" # Pre-create the directory os.makedirs(temp_certs_dir, exist_ok=True) - - cert_path, key_path = TLSCertificateManager.get_or_create_certificates( - certs_dir=temp_certs_dir - ) - + + cert_path, key_path = TLSCertificateManager.get_or_create_certificates(certs_dir=temp_certs_dir) + assert os.path.exists(cert_path) assert os.path.exists(key_path) class TestErrorHandling: """Test error handling scenarios.""" - + def test_invalid_directory_permissions(self, reset_manager): """Test handling of directories with invalid permissions.""" # This test is platform-specific and may need adjustment - if os.name != 'posix': + if os.name != "posix": pytest.skip("Permission test only applicable on POSIX systems") - + temp_dir = tempfile.mkdtemp() try: # Make directory read-only os.chmod(temp_dir, 0o444) - + with pytest.raises(RuntimeError) as exc_info: TLSCertificateManager.get_or_create_certificates(certs_dir=temp_dir) - + assert "Failed to generate TLS certificates" in str(exc_info.value) finally: # Restore permissions for cleanup @@ -361,23 +324,23 @@ def test_invalid_directory_permissions(self, reset_manager): class TestPerformance: """Test performance characteristics.""" - + def test_fast_path_no_lock_overhead(self, temp_certs_dir, reset_manager): """Test that retrieving existing certificates is fast (no lock acquisition).""" # Create certificates first TLSCertificateManager.get_or_create_certificates(certs_dir=temp_certs_dir) - + # Measure retrieval time iterations = 100 start = time.time() for _ in range(iterations): TLSCertificateManager.get_or_create_certificates(certs_dir=temp_certs_dir) elapsed = time.time() - start - + # Should be very fast (< 1ms per call on average) avg_time = elapsed / iterations assert avg_time < 0.001, f"Average retrieval time too slow: {avg_time:.6f}s" - + def test_concurrent_different_dirs_no_blocking(self, temp_certs_dir, reset_manager): """Test that different directories don't block each other significantly.""" total_elapsed_lock = threading.Lock() @@ -385,24 +348,26 @@ def test_concurrent_different_dirs_no_blocking(self, temp_certs_dir, reset_manag def create_certs(brick_name): brick_dir = os.path.join(temp_certs_dir, brick_name) - + start = time.time() TLSCertificateManager.get_or_create_certificates(certs_dir=brick_dir) elapsed = time.time() - start with total_elapsed_lock: elapsed_times.append(elapsed) - - bricks = ['brick1', 'brick2', 'brick3', 'brick4'] + + bricks = ["brick1", "brick2", "brick3", "brick4"] threads = [threading.Thread(target=create_certs, args=(c,)) for c in bricks] - + start = time.time() for thread in threads: thread.start() for thread in threads: thread.join() overall_run_time = time.time() - start - + # If bricks truly run in parallel, total time should be lower than # the cumulative total run times by all threads - assert overall_run_time < sum(elapsed_times), f"Bricks blocked each other: {overall_run_time:.3f}s should be lower than {sum(elapsed_times):.3f}s" + assert overall_run_time < sum(elapsed_times), ( + f"Bricks blocked each other: {overall_run_time:.3f}s should be lower than {sum(elapsed_times):.3f}s" + ) From 0c98f4793e38645b3cbb144481d7b67c8c0d4485 Mon Sep 17 00:00:00 2001 From: Roberto Gazia Date: Fri, 28 Nov 2025 16:39:46 +0100 Subject: [PATCH 07/12] Potential fix for code scanning alert no. 4: Binding a socket to all network interfaces Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> --- src/arduino/app_bricks/web_ui/web_ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/arduino/app_bricks/web_ui/web_ui.py b/src/arduino/app_bricks/web_ui/web_ui.py index ce1a8f01..f031d1bc 100644 --- a/src/arduino/app_bricks/web_ui/web_ui.py +++ b/src/arduino/app_bricks/web_ui/web_ui.py @@ -70,7 +70,7 @@ def pick_free_port(): import socket with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) + s.bind(("127.0.0.1", 0)) return s.getsockname()[1] self._port = port if port != 0 else pick_free_port() From 2d9ba7e0c5e51075db8bc8bd8d4b48e69cc40d9f Mon Sep 17 00:00:00 2001 From: Roberto Gazia Date: Fri, 28 Nov 2025 16:54:54 +0100 Subject: [PATCH 08/12] fix: tests --- .../app_bricks/web_ui/test_web_ui_integration.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/tests/arduino/app_bricks/web_ui/test_web_ui_integration.py b/tests/arduino/app_bricks/web_ui/test_web_ui_integration.py index e221efe6..218d4fa6 100644 --- a/tests/arduino/app_bricks/web_ui/test_web_ui_integration.py +++ b/tests/arduino/app_bricks/web_ui/test_web_ui_integration.py @@ -18,6 +18,12 @@ def webui_server(): f.write("Hello") ui = WebUI(port=0, assets_dir_path=tmp_assets) + def get_hello(): + return {"msg": "hello"} + ui.expose_api("GET", "/api/hello", get_hello) + def post_echo(data: dict): + return {"echo": data.get("value")} + ui.expose_api("POST", "/api/echo", post_echo) ui.start() thread = threading.Thread(target=ui.execute, daemon=True) @@ -38,16 +44,6 @@ def test_http_index(webui_server): def test_expose_api_rest(webui_server): - def get_hello(): - return {"msg": "hello"} - - webui_server.expose_api("GET", "/api/hello", get_hello) - - def post_echo(data: dict): - return {"echo": data.get("value")} - - webui_server.expose_api("POST", "/api/echo", post_echo) - resp = requests.get(f"{webui_server.url}/api/hello") assert resp.status_code == 200 assert resp.json() == {"msg": "hello"} From b9c98e5f43aa858b874ca1004111ad77ba643e6c Mon Sep 17 00:00:00 2001 From: Roberto Gazia Date: Fri, 28 Nov 2025 16:59:22 +0100 Subject: [PATCH 09/12] fix: linting --- tests/arduino/app_bricks/web_ui/test_web_ui_integration.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/arduino/app_bricks/web_ui/test_web_ui_integration.py b/tests/arduino/app_bricks/web_ui/test_web_ui_integration.py index 218d4fa6..1c67a2ed 100644 --- a/tests/arduino/app_bricks/web_ui/test_web_ui_integration.py +++ b/tests/arduino/app_bricks/web_ui/test_web_ui_integration.py @@ -18,11 +18,15 @@ def webui_server(): f.write("Hello") ui = WebUI(port=0, assets_dir_path=tmp_assets) + def get_hello(): return {"msg": "hello"} + ui.expose_api("GET", "/api/hello", get_hello) + def post_echo(data: dict): return {"echo": data.get("value")} + ui.expose_api("POST", "/api/echo", post_echo) ui.start() From dbb9e10a4e98a49214076120328454ad1d626d84 Mon Sep 17 00:00:00 2001 From: Roberto Gazia Date: Fri, 28 Nov 2025 17:03:51 +0100 Subject: [PATCH 10/12] fix: add license headers --- tests/arduino/app_bricks/web_ui/test_web_ui.py | 4 ++++ tests/arduino/app_bricks/web_ui/test_web_ui_integration.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/tests/arduino/app_bricks/web_ui/test_web_ui.py b/tests/arduino/app_bricks/web_ui/test_web_ui.py index ccdbcd18..d1f9a9ee 100644 --- a/tests/arduino/app_bricks/web_ui/test_web_ui.py +++ b/tests/arduino/app_bricks/web_ui/test_web_ui.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: Copyright (C) ARDUINO SRL (http://www.arduino.cc) +# +# SPDX-License-Identifier: MPL-2.0 + from fastapi.testclient import TestClient from arduino.app_bricks.web_ui.web_ui import WebUI diff --git a/tests/arduino/app_bricks/web_ui/test_web_ui_integration.py b/tests/arduino/app_bricks/web_ui/test_web_ui_integration.py index 1c67a2ed..5e4c4585 100644 --- a/tests/arduino/app_bricks/web_ui/test_web_ui_integration.py +++ b/tests/arduino/app_bricks/web_ui/test_web_ui_integration.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: Copyright (C) ARDUINO SRL (http://www.arduino.cc) +# +# SPDX-License-Identifier: MPL-2.0 + import threading import time import pytest From 5a069ddbf1e8d9fec6a18e5407544aaf990cd9a3 Mon Sep 17 00:00:00 2001 From: Roberto Gazia Date: Fri, 28 Nov 2025 17:09:20 +0100 Subject: [PATCH 11/12] chore: remove log line --- src/arduino/app_bricks/web_ui/web_ui.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/arduino/app_bricks/web_ui/web_ui.py b/src/arduino/app_bricks/web_ui/web_ui.py index f031d1bc..906c92f6 100644 --- a/src/arduino/app_bricks/web_ui/web_ui.py +++ b/src/arduino/app_bricks/web_ui/web_ui.py @@ -132,7 +132,6 @@ def start(self): config.ssl_certfile = cert_path config.ssl_keyfile = key_path except Exception as e: - logger.exception(f"Failed to configure SSL certificate: {e}") raise RuntimeError("Failed to configure TLS certificate. Please check the certs directory.") from e self._server = uvicorn.Server(config) From 791bc95bddffc92569f983fe9e21d6d37c755e23 Mon Sep 17 00:00:00 2001 From: Roberto Gazia Date: Fri, 28 Nov 2025 18:47:48 +0100 Subject: [PATCH 12/12] refactor: align behavior with other methods --- src/arduino/app_utils/tls_cert_manager.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/arduino/app_utils/tls_cert_manager.py b/src/arduino/app_utils/tls_cert_manager.py index 76e93ac0..a0a939b0 100644 --- a/src/arduino/app_utils/tls_cert_manager.py +++ b/src/arduino/app_utils/tls_cert_manager.py @@ -116,7 +116,8 @@ def get_certificates_paths(cls, certs_dir: str = DEFAULT_CERTS_DIR) -> tuple[str Returns: tuple[str, str]: Paths to certificate_file and private_key_file """ - return cls.get_certificate_path(certs_dir), cls.get_private_key_path(certs_dir) + target_dir = certs_dir or DEFAULT_CERTS_DIR + return cls.get_certificate_path(target_dir), cls.get_private_key_path(target_dir) @classmethod def get_certificate_path(cls, certs_dir: str = DEFAULT_CERTS_DIR) -> str: