diff --git a/flame/star/star_model_tester.py b/flame/star/star_model_tester.py index 6d15d14..39502de 100644 --- a/flame/star/star_model_tester.py +++ b/flame/star/star_model_tester.py @@ -2,8 +2,10 @@ import threading import uuid from typing import Any, Type, Literal, Optional, Union +import traceback from flame.star import StarModel, StarLocalDPModel, StarAnalyzer, StarAggregator +from flame.utils.mock_flame_core import MockFlameCoreSDK class StarModelTester: @@ -28,6 +30,9 @@ def __init__(self, participant_ids = [str(uuid.uuid4()) for _ in range(len(node_roles) + 1)] threads = [] + thread_errors = {} + results_queue = [] + MockFlameCoreSDK.stop_event = [] # shared stop event for all threads in case of failure in any thread for i, participant_id in enumerate(participant_ids): test_kwargs = { 'analyzer': analyzer, @@ -54,13 +59,28 @@ def __init__(self, test_kwargs['epsilon'] = epsilon test_kwargs['sensitivity'] = sensitivity - results_queue = [] def run_node(kwargs=test_kwargs, use_dp=use_local_dp): - if not use_dp: - flame = StarModel(**kwargs).flame - else: - flame = StarLocalDPModel(**kwargs).flame - results_queue.append(flame.final_results_storage) + try: + if not use_dp: + flame = StarModel(**kwargs).flame + else: + flame = StarLocalDPModel(**kwargs).flame + results_queue.append(flame.final_results_storage) + except Exception: + stop_event = MockFlameCoreSDK.stop_event + if not stop_event: + stack_trace = traceback.format_exc()#.replace('\n', '\\n').replace('\t', '\\t') + thread_errors[(kwargs['test_kwargs']['role'], + kwargs['test_kwargs']['node_id'])] = f"\033[31m{stack_trace}\033[0m" + stop_event.append(kwargs['test_kwargs']['node_id']) + mock = MockFlameCoreSDK(test_kwargs=kwargs['test_kwargs']) + mock.__pop_logs__(failure_message=True) + else: + thread_errors[(kwargs['test_kwargs']['role'], + kwargs['test_kwargs']['node_id'])] = (Exception("Another thread already failed, " + "stopping this thread as well.")) + return + thread = threading.Thread(target=run_node) threads.append(thread) @@ -70,8 +90,14 @@ def run_node(kwargs=test_kwargs, use_dp=use_local_dp): for thread in threads: thread.join() + # write final results - self.write_result(results_queue[0], output_type, result_filepath, multiple_results) + if results_queue: + self.write_result(results_queue[0], output_type, result_filepath, multiple_results) + else: + print("No results to write. All threads failed with errors:") + for (role, node_id), error in thread_errors.items(): + print(f"\t{(role if role != 'default' else 'analyzer').capitalize()} {node_id}: {error}") @staticmethod diff --git a/flame/utils/mock_flame_core.py b/flame/utils/mock_flame_core.py index e14ca0a..8358a8c 100644 --- a/flame/utils/mock_flame_core.py +++ b/flame/utils/mock_flame_core.py @@ -44,11 +44,23 @@ def __init__(self, test_kwargs) -> None: self.finished: bool = False +class IterationTracker: + def __init__(self): + self.iter = 0 + + def increment(self): + self.iter += 1 + + def get_iterations(self): + return self.iter + + class MockFlameCoreSDK: - num_iterations: int = 0 + num_iterations: IterationTracker = IterationTracker() logger: dict[str, list[str]] = {} message_broker: dict[str, list[dict[str, Any]]] = {} final_results_storage: Optional[Any] = None + stop_event: list[tuple[str]] = [] def __init__(self, test_kwargs): self.sanity_check(test_kwargs) @@ -202,6 +214,8 @@ def await_messages(self, break raise KeyError except KeyError: + if self.stop_event: + raise Exception time.sleep(.01) pass @@ -323,12 +337,17 @@ def _node_finished(self) -> bool: self.config.finished = True return self.config.finished - def __pop_logs__(self) -> None: - print(f"--- Starting Iteration {self.num_iterations} ---") + def __pop_logs__(self, failure_message: bool = False) -> None: + print(f"--- Starting Iteration {self.__get_iteration__()} ---") + if failure_message: + self.flame_log("Exception was raised (see Stacktrace)!", log_type='error') for k, v in self.logger.items(): role, log = self.logger[k] print(f"Logs for {'Analyzer' if role == 'default' else role.capitalize()} {k}:") self.logger[k] = [role, ''] print(log, end='') - print(f"--- Ending Iteration {self.num_iterations} ---\n") - self.num_iterations += 1 + print(f"--- Ending Iteration {self.__get_iteration__()} ---\n") + self.num_iterations.increment() + + def __get_iteration__(self): + return self.num_iterations.get_iterations() diff --git a/poetry.lock b/poetry.lock index d31d24b..70092fd 100644 --- a/poetry.lock +++ b/poetry.lock @@ -199,20 +199,20 @@ files = [ [[package]] name = "filelock" -version = "3.25.1" +version = "3.25.2" description = "A platform independent file lock." optional = false python-versions = ">=3.10" groups = ["dev"] markers = "python_version >= \"3.10\"" files = [ - {file = "filelock-3.25.1-py3-none-any.whl", hash = "sha256:18972df45473c4aa2c7921b609ee9ca4925910cc3a0fb226c96b92fc224ef7bf"}, - {file = "filelock-3.25.1.tar.gz", hash = "sha256:b9a2e977f794ef94d77cdf7d27129ac648a61f585bff3ca24630c1629f701aa9"}, + {file = "filelock-3.25.2-py3-none-any.whl", hash = "sha256:ca8afb0da15f229774c9ad1b455ed96e85a81373065fb10446672f64444ddf70"}, + {file = "filelock-3.25.2.tar.gz", hash = "sha256:b64ece2b38f4ca29dd3e810287aa8c48182bbecd1ae6e9ae126c9b35f1382694"}, ] [[package]] name = "flamesdk" -version = "0.4.1.1" +version = "0.4.2" description = "" optional = false python-versions = ">=3.9" @@ -228,8 +228,8 @@ uvicorn = "^0.27.1" [package.source] type = "git" url = "https://github.com/PrivateAIM/python-sdk.git" -reference = "0.4.1" -resolved_reference = "c588cc109a445bc790b26168f24e72506aa71df7" +reference = "0.4.2" +resolved_reference = "59699d4152f27521fb113647738825cd625084b7" [[package]] name = "h11" @@ -309,15 +309,15 @@ license = ["ukkonen"] [[package]] name = "identify" -version = "2.6.17" +version = "2.6.18" description = "File identification library for Python" optional = false python-versions = ">=3.10" groups = ["dev"] markers = "python_version >= \"3.10\"" files = [ - {file = "identify-2.6.17-py2.py3-none-any.whl", hash = "sha256:be5f8412d5ed4b20f2bd41a65f920990bdccaa6a4a18a08f1eefdcd0bdd885f0"}, - {file = "identify-2.6.17.tar.gz", hash = "sha256:f816b0b596b204c9fdf076ded172322f2723cf958d02f9c3587504834c8ff04d"}, + {file = "identify-2.6.18-py2.py3-none-any.whl", hash = "sha256:8db9d3c8ea9079db92cafb0ebf97abdc09d52e97f4dcf773a2e694048b7cd737"}, + {file = "identify-2.6.18.tar.gz", hash = "sha256:873ac56a5e3fd63e7438a7ecbc4d91aca692eb3fefa4534db2b7913f3fc352fd"}, ] [package.extras] @@ -827,30 +827,30 @@ files = [ [[package]] name = "ruff" -version = "0.15.5" +version = "0.15.6" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" groups = ["dev"] files = [ - {file = "ruff-0.15.5-py3-none-linux_armv6l.whl", hash = "sha256:4ae44c42281f42e3b06b988e442d344a5b9b72450ff3c892e30d11b29a96a57c"}, - {file = "ruff-0.15.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:6edd3792d408ebcf61adabc01822da687579a1a023f297618ac27a5b51ef0080"}, - {file = "ruff-0.15.5-py3-none-macosx_11_0_arm64.whl", hash = "sha256:89f463f7c8205a9f8dea9d658d59eff49db05f88f89cc3047fb1a02d9f344010"}, - {file = "ruff-0.15.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba786a8295c6574c1116704cf0b9e6563de3432ac888d8f83685654fe528fd65"}, - {file = "ruff-0.15.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fd4b801e57955fe9f02b31d20375ab3a5c4415f2e5105b79fb94cf2642c91440"}, - {file = "ruff-0.15.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:391f7c73388f3d8c11b794dbbc2959a5b5afe66642c142a6effa90b45f6f5204"}, - {file = "ruff-0.15.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8dc18f30302e379fe1e998548b0f5e9f4dff907f52f73ad6da419ea9c19d66c8"}, - {file = "ruff-0.15.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1cc6e7f90087e2d27f98dc34ed1b3ab7c8f0d273cc5431415454e22c0bd2a681"}, - {file = "ruff-0.15.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1cb7169f53c1ddb06e71a9aebd7e98fc0fea936b39afb36d8e86d36ecc2636a"}, - {file = "ruff-0.15.5-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:9b037924500a31ee17389b5c8c4d88874cc6ea8e42f12e9c61a3d754ff72f1ca"}, - {file = "ruff-0.15.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:65bb414e5b4eadd95a8c1e4804f6772bbe8995889f203a01f77ddf2d790929dd"}, - {file = "ruff-0.15.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:d20aa469ae3b57033519c559e9bc9cd9e782842e39be05b50e852c7c981fa01d"}, - {file = "ruff-0.15.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:15388dd28c9161cdb8eda68993533acc870aa4e646a0a277aa166de9ad5a8752"}, - {file = "ruff-0.15.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b30da330cbd03bed0c21420b6b953158f60c74c54c5f4c1dabbdf3a57bf355d2"}, - {file = "ruff-0.15.5-py3-none-win32.whl", hash = "sha256:732e5ee1f98ba5b3679029989a06ca39a950cced52143a0ea82a2102cb592b74"}, - {file = "ruff-0.15.5-py3-none-win_amd64.whl", hash = "sha256:821d41c5fa9e19117616c35eaa3f4b75046ec76c65e7ae20a333e9a8696bc7fe"}, - {file = "ruff-0.15.5-py3-none-win_arm64.whl", hash = "sha256:b498d1c60d2fe5c10c45ec3f698901065772730b411f164ae270bb6bfcc4740b"}, - {file = "ruff-0.15.5.tar.gz", hash = "sha256:7c3601d3b6d76dce18c5c824fc8d06f4eef33d6df0c21ec7799510cde0f159a2"}, + {file = "ruff-0.15.6-py3-none-linux_armv6l.whl", hash = "sha256:7c98c3b16407b2cf3d0f2b80c80187384bc92c6774d85fefa913ecd941256fff"}, + {file = "ruff-0.15.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:ee7dcfaad8b282a284df4aa6ddc2741b3f4a18b0555d626805555a820ea181c3"}, + {file = "ruff-0.15.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:3bd9967851a25f038fc8b9ae88a7fbd1b609f30349231dffaa37b6804923c4bb"}, + {file = "ruff-0.15.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:13f4594b04e42cd24a41da653886b04d2ff87adbf57497ed4f728b0e8a4866f8"}, + {file = "ruff-0.15.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e2ed8aea2f3fe57886d3f00ea5b8aae5bf68d5e195f487f037a955ff9fbaac9e"}, + {file = "ruff-0.15.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:70789d3e7830b848b548aae96766431c0dc01a6c78c13381f423bf7076c66d15"}, + {file = "ruff-0.15.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:542aaf1de3154cea088ced5a819ce872611256ffe2498e750bbae5247a8114e9"}, + {file = "ruff-0.15.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1c22e6f02c16cfac3888aa636e9eba857254d15bbacc9906c9689fdecb1953ab"}, + {file = "ruff-0.15.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98893c4c0aadc8e448cfa315bd0cc343a5323d740fe5f28ef8a3f9e21b381f7e"}, + {file = "ruff-0.15.6-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:70d263770d234912374493e8cc1e7385c5d49376e41dfa51c5c3453169dc581c"}, + {file = "ruff-0.15.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:55a1ad63c5a6e54b1f21b7514dfadc0c7fb40093fa22e95143cf3f64ebdcd512"}, + {file = "ruff-0.15.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8dc473ba093c5ec238bb1e7429ee676dca24643c471e11fbaa8a857925b061c0"}, + {file = "ruff-0.15.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:85b042377c2a5561131767974617006f99f7e13c63c111b998f29fc1e58a4cfb"}, + {file = "ruff-0.15.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:cef49e30bc5a86a6a92098a7fbf6e467a234d90b63305d6f3ec01225a9d092e0"}, + {file = "ruff-0.15.6-py3-none-win32.whl", hash = "sha256:bbf67d39832404812a2d23020dda68fee7f18ce15654e96fb1d3ad21a5fe436c"}, + {file = "ruff-0.15.6-py3-none-win_amd64.whl", hash = "sha256:aee25bc84c2f1007ecb5037dff75cef00414fdf17c23f07dc13e577883dca406"}, + {file = "ruff-0.15.6-py3-none-win_arm64.whl", hash = "sha256:c34de3dd0b0ba203be50ae70f5910b17188556630e2178fd7d79fc030eb0d837"}, + {file = "ruff-0.15.6.tar.gz", hash = "sha256:8394c7bb153a4e3811a4ecdacd4a8e6a4fa8097028119160dffecdcdf9b56ae4"}, ] [[package]] @@ -1118,4 +1118,4 @@ dev = ["pytest", "setuptools"] [metadata] lock-version = "2.1" python-versions = ">=3.9,<4.0" -content-hash = "2f6f86abc633248dd544352187db71678c6cf5f7df6e3e9bb53a950a97e69818" +content-hash = "36c87da31bfe90017f3cb6e335187ed06347e79c0f471371c29b7824aa6c8a38" diff --git a/pyproject.toml b/pyproject.toml index bb57102..af8c247 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "flame" -version = "0.6.0" +version = "0.6.1" description = "" authors = ["Alexander Röhl ", "David Hieber "] readme = "README.md" @@ -9,7 +9,7 @@ packages = [{ include = "flame" }] [tool.poetry.dependencies] python = ">=3.9,<4.0" -flamesdk = {git = "https://github.com/PrivateAIM/python-sdk.git", tag = "0.4.1"} +flamesdk = {git = "https://github.com/PrivateAIM/python-sdk.git", tag = "0.4.2"} opendp = ">=0.12.1,<0.13.0" [tool.poetry.group.dev.dependencies]