From 05e8d7c7b4172e5add5abfa11c2dfad4e7092aaf Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Tue, 17 Nov 2020 19:07:44 -0500 Subject: [PATCH] Add global lock for torch.onnx.export() (#4659) --- ml-agents/mlagents/trainers/torch/model_serialization.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ml-agents/mlagents/trainers/torch/model_serialization.py b/ml-agents/mlagents/trainers/torch/model_serialization.py index b82e9005d4..e9a4dfc9f4 100644 --- a/ml-agents/mlagents/trainers/torch/model_serialization.py +++ b/ml-agents/mlagents/trainers/torch/model_serialization.py @@ -19,14 +19,20 @@ class exporting_to_onnx: This implementation is thread safe. """ + # local is_exporting flag for each thread _local_data = threading.local() _local_data._is_exporting = False + # global lock shared among all threads, to make sure only one thread is exporting at a time + _lock = threading.Lock() + def __enter__(self): + self._lock.acquire() self._local_data._is_exporting = True def __exit__(self, *args): self._local_data._is_exporting = False + self._lock.release() @staticmethod def is_exporting():