From 203ce9189160a70b9a7fa8e6b75b573a1d31b4ae Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Mon, 18 Apr 2022 12:48:11 +0530 Subject: [PATCH 1/5] Add lock to pl_legacy_patch context manager; avoiding race condition with multithreading --- pytorch_lightning/utilities/migration.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pytorch_lightning/utilities/migration.py b/pytorch_lightning/utilities/migration.py index bc3761e47b835..664a047b04ac3 100644 --- a/pytorch_lightning/utilities/migration.py +++ b/pytorch_lightning/utilities/migration.py @@ -15,6 +15,7 @@ import sys from types import ModuleType, TracebackType +import threading import pytorch_lightning.utilities.argparse @@ -34,7 +35,12 @@ class pl_legacy_patch: torch.load("path/to/legacy/checkpoint.ckpt") """ + def __init__(self) -> None: + # Create a lock to ensure no race condition with deleting sys modules + self.lock = threading.Lock() + def __enter__(self) -> None: + self.lock.acquire() # `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse` legacy_argparse_module = ModuleType("pytorch_lightning.utilities.argparse_utils") sys.modules["pytorch_lightning.utilities.argparse_utils"] = legacy_argparse_module @@ -49,3 +55,4 @@ def __exit__( if hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default"): delattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default") del sys.modules["pytorch_lightning.utilities.argparse_utils"] + self.lock.release() From 99acdf2c6290b20471091fca881e7d4b42a9039b Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Thu, 28 Apr 2022 10:14:07 +0530 Subject: [PATCH 2/5] Add tests, make lock global after discussion with Adrian --- pytorch_lightning/utilities/migration.py | 12 ++++++------ tests/checkpointing/test_legacy_checkpoints.py | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/utilities/migration.py b/pytorch_lightning/utilities/migration.py index 664a047b04ac3..f6ec362b86722 100644 --- a/pytorch_lightning/utilities/migration.py +++ b/pytorch_lightning/utilities/migration.py @@ -20,6 +20,10 @@ import pytorch_lightning.utilities.argparse +# Create a global lock to ensure no race condition with deleting sys modules +lock = threading.Lock() + + class pl_legacy_patch: """Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be included for unpickling old checkpoints. The following patches apply. @@ -35,12 +39,8 @@ class pl_legacy_patch: torch.load("path/to/legacy/checkpoint.ckpt") """ - def __init__(self) -> None: - # Create a lock to ensure no race condition with deleting sys modules - self.lock = threading.Lock() - def __enter__(self) -> None: - self.lock.acquire() + lock.acquire() # `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse` legacy_argparse_module = ModuleType("pytorch_lightning.utilities.argparse_utils") sys.modules["pytorch_lightning.utilities.argparse_utils"] = legacy_argparse_module @@ -55,4 +55,4 @@ def __exit__( if hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default"): delattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default") del sys.modules["pytorch_lightning.utilities.argparse_utils"] - self.lock.release() + lock.release() diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 7e753617a6331..e1dad1a6a605b 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -60,6 +60,24 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo trainer.should_stop = True +@pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) +def test_legacy_ckpt_threading(tmpdir, pl_version: str): + def load_model(): + with pl_legacy_patch(): + _ = torch.load(PATH_LEGACY) + + PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version) + with patch("sys.path", [PATH_LEGACY] + sys.path): + t1 = threading.Thread(target=load_model) + t2 = threading.Thread(target=load_model) + + t1.start() + t2.start() + + t1.join() + t2.join() + + @pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) def test_resume_legacy_checkpoints(tmpdir, pl_version: str): PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version) From a8fc8205d18fee780a1a781aac4cb01afc5af388 Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Thu, 28 Apr 2022 10:26:06 +0530 Subject: [PATCH 3/5] Imports --- tests/checkpointing/test_legacy_checkpoints.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index e1dad1a6a605b..4e28cb2240200 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -14,6 +14,7 @@ import glob import os import sys +import threading from unittest.mock import patch import pytest @@ -63,6 +64,8 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo @pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) def test_legacy_ckpt_threading(tmpdir, pl_version: str): def load_model(): + import torch + from pytorch_lightning.utilities.migration import pl_legacy_patch with pl_legacy_patch(): _ = torch.load(PATH_LEGACY) From ee2179358b600254f66a33def93e9867540bca04 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Apr 2022 04:58:00 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/utilities/migration.py | 3 +-- tests/checkpointing/test_legacy_checkpoints.py | 2 ++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/migration.py b/pytorch_lightning/utilities/migration.py index f6ec362b86722..7b7c618bae5af 100644 --- a/pytorch_lightning/utilities/migration.py +++ b/pytorch_lightning/utilities/migration.py @@ -14,12 +14,11 @@ from __future__ import annotations import sys -from types import ModuleType, TracebackType import threading +from types import ModuleType, TracebackType import pytorch_lightning.utilities.argparse - # Create a global lock to ensure no race condition with deleting sys modules lock = threading.Lock() diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 4e28cb2240200..ac2806cbf3811 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -65,7 +65,9 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo def test_legacy_ckpt_threading(tmpdir, pl_version: str): def load_model(): import torch + from pytorch_lightning.utilities.migration import pl_legacy_patch + with pl_legacy_patch(): _ = torch.load(PATH_LEGACY) From a67dc9bcf134efc4d582f6b2d0f6895656553cdd Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Thu, 28 Apr 2022 10:47:50 +0530 Subject: [PATCH 5/5] make lock protected --- pytorch_lightning/utilities/migration.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/migration.py b/pytorch_lightning/utilities/migration.py index 7b7c618bae5af..30cc823210423 100644 --- a/pytorch_lightning/utilities/migration.py +++ b/pytorch_lightning/utilities/migration.py @@ -20,7 +20,7 @@ import pytorch_lightning.utilities.argparse # Create a global lock to ensure no race condition with deleting sys modules -lock = threading.Lock() +_lock = threading.Lock() class pl_legacy_patch: @@ -39,7 +39,7 @@ class pl_legacy_patch: """ def __enter__(self) -> None: - lock.acquire() + _lock.acquire() # `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse` legacy_argparse_module = ModuleType("pytorch_lightning.utilities.argparse_utils") sys.modules["pytorch_lightning.utilities.argparse_utils"] = legacy_argparse_module @@ -54,4 +54,4 @@ def __exit__( if hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default"): delattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default") del sys.modules["pytorch_lightning.utilities.argparse_utils"] - lock.release() + _lock.release()