Skip to content

Commit bd9a394

Browse files
author
Chris Elion
authored
Fix timers when using multithreading. (#3901)
1 parent 9ab9203 commit bd9a394

File tree

5 files changed

+101
-20
lines changed

5 files changed

+101
-20
lines changed

.pylintrc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,6 @@ disable =
4646

4747
# Using the global statement
4848
W0603,
49+
50+
# "Access to a protected member _foo of a client class (protected-access)"
51+
W0212

docs/Profiling-Python.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,14 @@ By default, at the end of training, timers are collected and written in json for
4343
is optional and defaults to false.
4444

4545
### Parallel execution
46+
#### Subprocesses
4647
For code that executes in multiple processes (for example, SubprocessEnvManager), we periodically send the timer
4748
information back to the "main" process, aggregate the timers there, and flush them in the subprocess. Note that
4849
(depending on the number of processes) this can result in timers where the total time may exceed the parent's total
4950
time. This is analogous to the difference between "real" and "user" values reported from the unix `time` command. In the
5051
timer output, blocks that were run in parallel are indicated by the `is_parallel` flag.
5152

53+
#### Threads
54+
Timers currently use `time.perf_counter()` to track time spent, which may not give accurate results for multiple
55+
threads. If this is problematic, set `threaded: false` in your trainer configuration.
56+

ml-agents-envs/mlagents_envs/tests/test_timers.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@ def decorated_func(x: int = 0, y: float = 1.0) -> str:
1010

1111

1212
def test_timers() -> None:
13-
with mock.patch(
14-
"mlagents_envs.timers._global_timer_stack", new_callable=timers.TimerStack
15-
) as test_timer:
13+
test_timer = timers.TimerStack()
14+
with mock.patch("mlagents_envs.timers._get_thread_timer", return_value=test_timer):
1615
# First, run some simple code
1716
with timers.hierarchical_timer("top_level"):
1817
for i in range(3):

ml-agents-envs/mlagents_envs/timers.py

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@ def main():
3131
import math
3232
import sys
3333
import time
34+
import threading
3435

3536
from contextlib import contextmanager
36-
from typing import Any, Callable, Dict, Generator, TypeVar
37+
from typing import Any, Callable, Dict, Generator, Optional, TypeVar
3738

3839
TIMER_FORMAT_VERSION = "0.1.0"
3940

@@ -97,19 +98,31 @@ class GaugeNode:
9798
Tracks the most recent value of a metric. This is analogous to gauges in statsd.
9899
"""
99100

100-
__slots__ = ["value", "min_value", "max_value", "count"]
101+
__slots__ = ["value", "min_value", "max_value", "count", "_timestamp"]
101102

102103
def __init__(self, value: float):
103104
self.value = value
104105
self.min_value = value
105106
self.max_value = value
106107
self.count = 1
108+
# Internal timestamp so we can determine priority.
109+
self._timestamp = time.time()
107110

108111
def update(self, new_value: float) -> None:
109112
self.min_value = min(self.min_value, new_value)
110113
self.max_value = max(self.max_value, new_value)
111114
self.value = new_value
112115
self.count += 1
116+
self._timestamp = time.time()
117+
118+
def merge(self, other: "GaugeNode") -> None:
119+
if self._timestamp < other._timestamp:
120+
# Keep the "later" value
121+
self.value = other.value
122+
self._timestamp = other._timestamp
123+
self.min_value = min(self.min_value, other.min_value)
124+
self.max_value = max(self.max_value, other.max_value)
125+
self.count += other.count
113126

114127
def as_dict(self) -> Dict[str, float]:
115128
return {
@@ -232,9 +245,23 @@ def _add_default_metadata(self):
232245
self.metadata["command_line_arguments"] = " ".join(sys.argv)
233246

234247

235-
# Global instance of a TimerStack. This is generally all that we need for profiling, but you can potentially
236-
# create multiple instances and pass them to the contextmanager
237-
_global_timer_stack = TimerStack()
248+
# Maintain a separate "global" timer per thread, so that they don't accidentally conflict with each other.
249+
_thread_timer_stacks: Dict[int, TimerStack] = {}
250+
251+
252+
def _get_thread_timer() -> TimerStack:
253+
ident = threading.get_ident()
254+
if ident not in _thread_timer_stacks:
255+
timer_stack = TimerStack()
256+
_thread_timer_stacks[ident] = timer_stack
257+
return _thread_timer_stacks[ident]
258+
259+
260+
def get_timer_stack_for_thread(t: threading.Thread) -> Optional[TimerStack]:
261+
if t.ident is None:
262+
# Thread hasn't started, shouldn't ever happen
263+
return None
264+
return _thread_timer_stacks.get(t.ident)
238265

239266

240267
@contextmanager
@@ -243,7 +270,7 @@ def hierarchical_timer(name: str, timer_stack: TimerStack = None) -> Generator:
243270
Creates a scoped timer around a block of code. This time spent will automatically be incremented when
244271
the context manager exits.
245272
"""
246-
timer_stack = timer_stack or _global_timer_stack
273+
timer_stack = timer_stack or _get_thread_timer()
247274
timer_node = timer_stack.push(name)
248275
start_time = time.perf_counter()
249276

@@ -284,34 +311,52 @@ def set_gauge(name: str, value: float, timer_stack: TimerStack = None) -> None:
284311
"""
285312
Updates the value of the gauge (or creates it if it hasn't been set before).
286313
"""
287-
timer_stack = timer_stack or _global_timer_stack
314+
timer_stack = timer_stack or _get_thread_timer()
288315
timer_stack.set_gauge(name, value)
289316

290317

318+
def merge_gauges(gauges: Dict[str, GaugeNode], timer_stack: TimerStack = None) -> None:
319+
"""
320+
Merge the gauges from another TimerStack with the provided one (or the
321+
current thread's stack if none is provided).
322+
:param gauges:
323+
:param timer_stack:
324+
:return:
325+
"""
326+
timer_stack = timer_stack or _get_thread_timer()
327+
for n, g in gauges.items():
328+
if n in timer_stack.gauges:
329+
timer_stack.gauges[n].merge(g)
330+
else:
331+
timer_stack.gauges[n] = g
332+
333+
291334
def add_metadata(key: str, value: str, timer_stack: TimerStack = None) -> None:
292-
timer_stack = timer_stack or _global_timer_stack
335+
timer_stack = timer_stack or _get_thread_timer()
293336
timer_stack.add_metadata(key, value)
294337

295338

296339
def get_timer_tree(timer_stack: TimerStack = None) -> Dict[str, Any]:
297340
"""
298-
Return the tree of timings from the TimerStack as a dictionary (or the global stack if none is provided)
341+
Return the tree of timings from the TimerStack as a dictionary (or the
342+
current thread's stack if none is provided)
299343
"""
300-
timer_stack = timer_stack or _global_timer_stack
344+
timer_stack = timer_stack or _get_thread_timer()
301345
return timer_stack.get_timing_tree()
302346

303347

304348
def get_timer_root(timer_stack: TimerStack = None) -> TimerNode:
305349
"""
306-
Get the root TimerNode of the timer_stack (or the global TimerStack if not specified)
350+
Get the root TimerNode of the timer_stack (or the current thread's
351+
TimerStack if not specified)
307352
"""
308-
timer_stack = timer_stack or _global_timer_stack
353+
timer_stack = timer_stack or _get_thread_timer()
309354
return timer_stack.get_root()
310355

311356

312357
def reset_timers(timer_stack: TimerStack = None) -> None:
313358
"""
314-
Reset the timer_stack (or the global TimerStack if not specified)
359+
Reset the timer_stack (or the current thread's TimerStack if not specified)
315360
"""
316-
timer_stack = timer_stack or _global_timer_stack
361+
timer_stack = timer_stack or _get_thread_timer()
317362
timer_stack.reset()

ml-agents/mlagents/trainers/trainer_controller.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,12 @@
1919
UnityCommunicatorStoppedException,
2020
)
2121
from mlagents.trainers.sampler_class import SamplerManager
22-
from mlagents_envs.timers import hierarchical_timer, timed
22+
from mlagents_envs.timers import (
23+
hierarchical_timer,
24+
timed,
25+
get_timer_stack_for_thread,
26+
merge_gauges,
27+
)
2328
from mlagents.trainers.trainer import Trainer
2429
from mlagents.trainers.meta_curriculum import MetaCurriculum
2530
from mlagents.trainers.trainer_util import TrainerFactory
@@ -228,7 +233,7 @@ def start_learning(self, env_manager: EnvManager) -> None:
228233
if self._should_save_model(global_step):
229234
self._save_model()
230235
# Stop advancing trainers
231-
self.kill_trainers = True
236+
self.join_threads()
232237
# Final save Tensorflow model
233238
if global_step != 0 and self.train_model:
234239
self._save_model()
@@ -238,7 +243,7 @@ def start_learning(self, env_manager: EnvManager) -> None:
238243
UnityEnvironmentException,
239244
UnityCommunicatorStoppedException,
240245
) as ex:
241-
self.kill_trainers = True
246+
self.join_threads()
242247
if self.train_model:
243248
self._save_model_when_interrupted()
244249

@@ -315,6 +320,30 @@ def advance(self, env: EnvManager) -> int:
315320

316321
return num_steps
317322

323+
def join_threads(self, timeout_seconds: float = 1.0) -> None:
324+
"""
325+
Wait for threads to finish, and merge their timer information into the main thread.
326+
:param timeout_seconds:
327+
:return:
328+
"""
329+
self.kill_trainers = True
330+
for t in self.trainer_threads:
331+
try:
332+
t.join(timeout_seconds)
333+
except Exception:
334+
pass
335+
336+
with hierarchical_timer("trainer_threads") as main_timer_node:
337+
for trainer_thread in self.trainer_threads:
338+
thread_timer_stack = get_timer_stack_for_thread(trainer_thread)
339+
if thread_timer_stack:
340+
main_timer_node.merge(
341+
thread_timer_stack.root,
342+
root_name="thread_root",
343+
is_parallel=True,
344+
)
345+
merge_gauges(thread_timer_stack.gauges)
346+
318347
def trainer_update_func(self, trainer: Trainer) -> None:
319348
while not self.kill_trainers:
320349
with hierarchical_timer("trainer_advance"):

0 commit comments

Comments
 (0)