diff --git a/ml-agents/mlagents/trainers/torch/model_serialization.py b/ml-agents/mlagents/trainers/torch/model_serialization.py index b82e9005d4..e9a4dfc9f4 100644 --- a/ml-agents/mlagents/trainers/torch/model_serialization.py +++ b/ml-agents/mlagents/trainers/torch/model_serialization.py @@ -19,14 +19,20 @@ class exporting_to_onnx: This implementation is thread safe. """ + # local is_exporting flag for each thread _local_data = threading.local() _local_data._is_exporting = False + # global lock shared among all threads, to make sure only one thread is exporting at a time + _lock = threading.Lock() + def __enter__(self): + self._lock.acquire() self._local_data._is_exporting = True def __exit__(self, *args): self._local_data._is_exporting = False + self._lock.release() @staticmethod def is_exporting():