-
Notifications
You must be signed in to change notification settings - Fork 1
Bump flame-sdk version and update project dependencies #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e6f6d6b
2ccf992
78e09fe
9aa9c85
f1d0ede
1743b72
34c613a
e440841
e3161f0
94a84b6
19d73ca
998249f
c9d52f6
3d19717
01f00d4
87d0f44
32d1d1c
fc4fbf4
4b3f86d
bc75289
1e0cd40
9e7905e
0a66d16
38880d6
95e78e5
a6babd9
1650326
a71cde6
57799d7
c07cd24
48cf411
84942ec
827f567
be1697f
188d9d2
fa2fb8b
69819cf
f8fa6fc
141d3cd
54945b3
ba41a88
e3a4dfc
cf22bd4
44fc4a3
cad909b
a02d737
64dc2a6
afdc79d
4e2852f
63a78f3
6aa2a0b
672a0eb
3f5f264
568e199
a7927f3
94a47ef
2734901
ff69c56
01c4fbb
0fbb53d
a972033
b65c69e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,8 +2,10 @@ | |
| import threading | ||
| import uuid | ||
| from typing import Any, Type, Literal, Optional, Union | ||
| import traceback | ||
|
|
||
| from flame.star import StarModel, StarLocalDPModel, StarAnalyzer, StarAggregator | ||
| from flame.utils.mock_flame_core import MockFlameCoreSDK | ||
|
|
||
|
|
||
| class StarModelTester: | ||
|
|
@@ -28,6 +30,9 @@ def __init__(self, | |
| participant_ids = [str(uuid.uuid4()) for _ in range(len(node_roles) + 1)] | ||
|
|
||
| threads = [] | ||
| thread_errors = {} | ||
| results_queue = [] | ||
| MockFlameCoreSDK.stop_event = [] # shared stop event for all threads in case of failure in any thread | ||
| for i, participant_id in enumerate(participant_ids): | ||
| test_kwargs = { | ||
| 'analyzer': analyzer, | ||
|
|
@@ -54,13 +59,28 @@ def __init__(self, | |
| test_kwargs['epsilon'] = epsilon | ||
| test_kwargs['sensitivity'] = sensitivity | ||
|
|
||
| results_queue = [] | ||
| def run_node(kwargs=test_kwargs, use_dp=use_local_dp): | ||
| if not use_dp: | ||
| flame = StarModel(**kwargs).flame | ||
| else: | ||
| flame = StarLocalDPModel(**kwargs).flame | ||
| results_queue.append(flame.final_results_storage) | ||
| try: | ||
| if not use_dp: | ||
| flame = StarModel(**kwargs).flame | ||
| else: | ||
| flame = StarLocalDPModel(**kwargs).flame | ||
| results_queue.append(flame.final_results_storage) | ||
| except Exception: | ||
| stop_event = MockFlameCoreSDK.stop_event | ||
| if not stop_event: | ||
| stack_trace = traceback.format_exc()#.replace('\n', '\\n').replace('\t', '\\t') | ||
| thread_errors[(kwargs['test_kwargs']['role'], | ||
| kwargs['test_kwargs']['node_id'])] = f"\033[31m{stack_trace}\033[0m" | ||
| stop_event.append(kwargs['test_kwargs']['node_id']) | ||
|
Comment on lines
+70
to
+75
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Class-level
Consider resetting 🐛 Proposed fix to reset shared stateAdd at the beginning of # Reset shared mock state for fresh test run
MockFlameCoreSDK.stop_event = []
MockFlameCoreSDK.message_broker = {}
MockFlameCoreSDK.final_results_storage = None
MockFlameCoreSDK.num_iterations = IterationTracker()🤖 Prompt for AI Agents |
||
| mock = MockFlameCoreSDK(test_kwargs=kwargs['test_kwargs']) | ||
| mock.__pop_logs__(failure_message=True) | ||
| else: | ||
| thread_errors[(kwargs['test_kwargs']['role'], | ||
| kwargs['test_kwargs']['node_id'])] = (Exception("Another thread already failed, " | ||
| "stopping this thread as well.")) | ||
| return | ||
|
Comment on lines
+69
to
+82
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TOCTOU race condition on The check Additionally, Line 79-80 stores an 🔒 Suggested fix for atomicity and consistent error format+ error_lock = threading.Lock()
+ first_failure = [False] # Use list for mutable closure
...
except Exception:
stop_event = MockFlameCoreSDK.stop_event
- if not stop_event:
- stack_trace = traceback.format_exc()
- thread_errors[(kwargs['test_kwargs']['role'],
- kwargs['test_kwargs']['node_id'])] = f"\033[31m{stack_trace}\033[0m"
- stop_event.append(kwargs['test_kwargs']['node_id'])
- mock = MockFlameCoreSDK(test_kwargs=kwargs['test_kwargs'])
- mock.__pop_logs__(failure_message=True)
- else:
- thread_errors[(kwargs['test_kwargs']['role'],
- kwargs['test_kwargs']['node_id'])] = (Exception("Another thread already failed, "
- "stopping this thread as well."))
+ with error_lock:
+ is_first = not first_failure[0]
+ if is_first:
+ first_failure[0] = True
+ stop_event.append(kwargs['test_kwargs']['node_id'])
+ if is_first:
+ stack_trace = traceback.format_exc()
+ thread_errors[(kwargs['test_kwargs']['role'],
+ kwargs['test_kwargs']['node_id'])] = f"\033[31m{stack_trace}\033[0m"
+ mock = MockFlameCoreSDK(test_kwargs=kwargs['test_kwargs'])
+ mock.__pop_logs__(failure_message=True)
+ else:
+ thread_errors[(kwargs['test_kwargs']['role'],
+ kwargs['test_kwargs']['node_id'])] = "Another thread already failed, stopping this thread as well."
return🧰 Tools🪛 Ruff (0.15.6)[warning] 68-68: Do not catch blind exception: (BLE001) 🤖 Prompt for AI Agents |
||
|
|
||
| thread = threading.Thread(target=run_node) | ||
| threads.append(thread) | ||
|
|
||
|
|
@@ -70,8 +90,14 @@ def run_node(kwargs=test_kwargs, use_dp=use_local_dp): | |
| for thread in threads: | ||
| thread.join() | ||
|
|
||
|
|
||
| # write final results | ||
| self.write_result(results_queue[0], output_type, result_filepath, multiple_results) | ||
| if results_queue: | ||
| self.write_result(results_queue[0], output_type, result_filepath, multiple_results) | ||
| else: | ||
| print("No results to write. All threads failed with errors:") | ||
| for (role, node_id), error in thread_errors.items(): | ||
| print(f"\t{(role if role != 'default' else 'analyzer').capitalize()} {node_id}: {error}") | ||
|
|
||
|
|
||
| @staticmethod | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,11 +44,23 @@ def __init__(self, test_kwargs) -> None: | |
| self.finished: bool = False | ||
|
|
||
|
|
||
| class IterationTracker: | ||
| def __init__(self): | ||
| self.iter = 0 | ||
|
|
||
| def increment(self): | ||
| self.iter += 1 | ||
|
|
||
| def get_iterations(self): | ||
| return self.iter | ||
|
|
||
|
|
||
| class MockFlameCoreSDK: | ||
| num_iterations: int = 0 | ||
| num_iterations: IterationTracker = IterationTracker() | ||
| logger: dict[str, list[str]] = {} | ||
| message_broker: dict[str, list[dict[str, Any]]] = {} | ||
| final_results_storage: Optional[Any] = None | ||
| stop_event: list[tuple[str]] = [] | ||
|
Comment on lines
+59
to
+63
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Class-level mutable attributes persist across test runs and may cause test pollution. The class-level
Consider either:
🔧 Proposed class reset method`@classmethod`
def reset_state(cls):
"""Reset all shared state for a fresh test run."""
cls.num_iterations = IterationTracker()
cls.logger = {}
cls.message_broker = {}
cls.final_results_storage = None
cls.stop_event = []🧰 Tools🪛 Ruff (0.15.6)[warning] 60-60: Mutable default value for class attribute (RUF012) [warning] 61-61: Mutable default value for class attribute (RUF012) [warning] 63-63: Mutable default value for class attribute (RUF012) 🤖 Prompt for AI Agents |
||
|
|
||
| def __init__(self, test_kwargs): | ||
| self.sanity_check(test_kwargs) | ||
|
|
@@ -202,6 +214,8 @@ def await_messages(self, | |
| break | ||
| raise KeyError | ||
| except KeyError: | ||
| if self.stop_event: | ||
| raise Exception | ||
| time.sleep(.01) | ||
| pass | ||
|
|
||
|
|
@@ -323,12 +337,17 @@ def _node_finished(self) -> bool: | |
| self.config.finished = True | ||
| return self.config.finished | ||
|
|
||
| def __pop_logs__(self) -> None: | ||
| print(f"--- Starting Iteration {self.num_iterations} ---") | ||
| def __pop_logs__(self, failure_message: bool = False) -> None: | ||
| print(f"--- Starting Iteration {self.__get_iteration__()} ---") | ||
| if failure_message: | ||
| self.flame_log("Exception was raised (see Stacktrace)!", log_type='error') | ||
| for k, v in self.logger.items(): | ||
| role, log = self.logger[k] | ||
| print(f"Logs for {'Analyzer' if role == 'default' else role.capitalize()} {k}:") | ||
| self.logger[k] = [role, ''] | ||
| print(log, end='') | ||
| print(f"--- Ending Iteration {self.num_iterations} ---\n") | ||
| self.num_iterations += 1 | ||
| print(f"--- Ending Iteration {self.__get_iteration__()} ---\n") | ||
| self.num_iterations.increment() | ||
|
|
||
| def __get_iteration__(self): | ||
| return self.num_iterations.get_iterations() | ||
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thread-safety issue:
thread_errorsandresults_queueare shared across threads without synchronization.Both
thread_errors(dict) andresults_queue(list) are accessed concurrently by multiple threads. While CPython's GIL makes individual operations likelist.append()anddict.__setitem__()atomic, relying on this is fragile and non-portable. Consider usingthreading.Lockor thread-safe collections likequeue.Queueforresults_queue.🔒 Suggested fix using Queue
📝 Committable suggestion
🤖 Prompt for AI Agents