diff --git a/pytorch_lightning/utilities/migration.py b/pytorch_lightning/utilities/migration.py index bc3761e47b835..30cc823210423 100644 --- a/pytorch_lightning/utilities/migration.py +++ b/pytorch_lightning/utilities/migration.py @@ -14,10 +14,14 @@ from __future__ import annotations import sys +import threading from types import ModuleType, TracebackType import pytorch_lightning.utilities.argparse +# Create a global lock to ensure no race condition with deleting sys modules +_lock = threading.Lock() + class pl_legacy_patch: """Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be included for @@ -35,6 +39,7 @@ class pl_legacy_patch: """ def __enter__(self) -> None: + _lock.acquire() # `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse` legacy_argparse_module = ModuleType("pytorch_lightning.utilities.argparse_utils") sys.modules["pytorch_lightning.utilities.argparse_utils"] = legacy_argparse_module @@ -49,3 +54,4 @@ def __exit__( if hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default"): delattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default") del sys.modules["pytorch_lightning.utilities.argparse_utils"] + _lock.release() diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 7e753617a6331..ac2806cbf3811 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -14,6 +14,7 @@ import glob import os import sys +import threading from unittest.mock import patch import pytest @@ -60,6 +61,28 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo trainer.should_stop = True +@pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) +def test_legacy_ckpt_threading(tmpdir, pl_version: str): + def load_model(): + import torch + + from pytorch_lightning.utilities.migration import pl_legacy_patch + + with pl_legacy_patch(): + _ = torch.load(PATH_LEGACY) + + PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version) + with patch("sys.path", [PATH_LEGACY] + sys.path): + t1 = threading.Thread(target=load_model) + t2 = threading.Thread(target=load_model) + + t1.start() + t2.start() + + t1.join() + t2.join() + + @pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) def test_resume_legacy_checkpoints(tmpdir, pl_version: str): PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version)