Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fault Tolerant Manual: Enable the feature #10707

Merged
merged 63 commits into from Nov 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
ca8dbbb
update
tchaton Nov 19, 2021
3f0d28a
update
tchaton Nov 19, 2021
0c47670
update
tchaton Nov 19, 2021
3e5b52e
update
tchaton Nov 19, 2021
24c8245
update
tchaton Nov 19, 2021
bcd5569
update
tchaton Nov 19, 2021
8d3844d
update
tchaton Nov 19, 2021
1829b46
update
tchaton Nov 19, 2021
a1a364a
typo
tchaton Nov 19, 2021
de41675
update on comments
tchaton Nov 22, 2021
8178a32
Update pytorch_lightning/utilities/auto_restart.py
kaushikb11 Nov 22, 2021
00b9355
update
tchaton Nov 22, 2021
96f0517
update
tchaton Nov 22, 2021
297fd67
Merge branch 'fault_tolerant_enum' of https://github.com/PyTorchLight…
tchaton Nov 22, 2021
9800cba
update
tchaton Nov 22, 2021
427ed03
docstring improvement
tchaton Nov 22, 2021
ae712b0
update
tchaton Nov 22, 2021
9a5166d
Rename and simplify
carmocca Nov 22, 2021
b5fa819
Add comment
carmocca Nov 22, 2021
c82b2f2
update
tchaton Nov 22, 2021
2ede205
update
tchaton Nov 22, 2021
b16c4c0
update
tchaton Nov 22, 2021
ce9c23c
update
tchaton Nov 22, 2021
2baddb9
update
tchaton Nov 22, 2021
97548bb
update
tchaton Nov 22, 2021
d953ae9
update
tchaton Nov 22, 2021
41ffbab
use_teardown
tchaton Nov 22, 2021
d04596d
Use `Protocol`
carmocca Nov 22, 2021
ff7b836
Simplify test
carmocca Nov 22, 2021
a5698e6
Update CHANGELOG.md
carmocca Nov 22, 2021
79fdacc
update
tchaton Nov 22, 2021
916b520
update
tchaton Nov 22, 2021
4b67fbf
update
tchaton Nov 22, 2021
c9481e2
update
tchaton Nov 22, 2021
ef29342
update
tchaton Nov 22, 2021
4a1fff7
update
tchaton Nov 22, 2021
cb27e30
update
tchaton Nov 22, 2021
7903d24
resolve tests
tchaton Nov 22, 2021
20d19a1
update
tchaton Nov 22, 2021
1104cbc
update
tchaton Nov 23, 2021
f071f9a
change to 0
tchaton Nov 23, 2021
b777dc3
update
tchaton Nov 23, 2021
2da1674
update
tchaton Nov 23, 2021
647bebd
merge with master
tchaton Nov 23, 2021
dbcfa65
update changelog
tchaton Nov 23, 2021
ae18166
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 23, 2021
a527929
update
tchaton Nov 23, 2021
5827e8b
Merge branch 'add_reloading' of https://github.com/PyTorchLightning/p…
tchaton Nov 23, 2021
97421c3
update
tchaton Nov 23, 2021
51cf75b
update
tchaton Nov 23, 2021
35644b8
update on comments
tchaton Nov 23, 2021
04a5c3d
Merge branch 'add_reloading' into fault_tolerant_cleanup
tchaton Nov 23, 2021
26d46b0
update changelog
tchaton Nov 23, 2021
0529f29
update
tchaton Nov 23, 2021
27fa4ce
update changelog
tchaton Nov 23, 2021
5337e7d
remove deadcode
tchaton Nov 23, 2021
30cc2ec
update
tchaton Nov 23, 2021
ec0bdad
update
tchaton Nov 23, 2021
c7ee8e3
Merge branch 'fault_tolerant_cleanup' into enable_fault_tolerant_manual
tchaton Nov 23, 2021
d736679
update
tchaton Nov 24, 2021
e36cec1
update
tchaton Nov 24, 2021
b413d14
update
tchaton Nov 24, 2021
90b14d3
update
tchaton Nov 24, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Add an utility to collect the states across processes ([#10639](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639))
* Add logic to reload the states across data loading components ([#10699](https://github.com/PyTorchLightning/pytorch-lightning/issues/10699))
* Cleanup some fault tolerant utilities ([#10703](https://github.com/PyTorchLightning/pytorch-lightning/issues/10703))
* Enable Fault Tolerant Manual Training ([#10707](https://github.com/PyTorchLightning/pytorch-lightning/issues/10707))

-

Expand Down
20 changes: 11 additions & 9 deletions pytorch_lightning/utilities/auto_restart.py
Expand Up @@ -251,9 +251,7 @@ def __len__(self) -> int:

def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num_workers: int) -> None:
# as workers aren't available, the ``state_dict``` is cached until workers are made available.
state_dict = deepcopy(state_dict)
state_dict = _rotate_worker_indices(state_dict, latest_worker_id, num_workers)
self._cached_state_dict = state_dict
self._cached_state_dict = _rotate_worker_indices(deepcopy(state_dict), latest_worker_id, num_workers)

def state_dict(self) -> Dict[int, Dict[str, Any]]:
return {self.worker_id: {"rng_states": collect_rng_states()}}
Expand Down Expand Up @@ -513,14 +511,17 @@ def patch_dataloader_iterator(

def _add_capture_metadata_collate(dataloader: DataLoader) -> None:
"""Wrap default collate function to retrive captured dataset state dict when fault tolerant is enabled."""
faut_tolerant_mode = _FaultTolerantMode.detect_current_mode()
if not faut_tolerant_mode.is_enabled:
fault_tolerant_mode = _FaultTolerantMode.detect_current_mode()
collate_fn = dataloader.collate_fn
if not fault_tolerant_mode.is_enabled or (
isinstance(collate_fn, partial) and collate_fn.func is _capture_metadata_collate
):
return
dataloader.collate_fn = partial(
_capture_metadata_collate,
dataset=dataloader.dataset,
collate_fn=dataloader.collate_fn,
fault_tolerant_mode=faut_tolerant_mode,
collate_fn=collate_fn,
fault_tolerant_mode=fault_tolerant_mode,
)


Expand Down Expand Up @@ -658,8 +659,7 @@ def _next_index(self) -> Any:
return indexes

def _prepare_loader(self, loader):
if not isinstance(loader.collate_fn, partial):
loader.collate_fn = partial(_capture_metadata_collate, dataset=loader.dataset, collate_fn=loader.collate_fn)
_add_capture_metadata_collate(loader)
self._loader = loader
self._data_fetcher: "pl.utilities.fetching.AbstractDataFetcher" = loader._lightning_fetcher
self.num_batches_fetched = 0
Expand Down Expand Up @@ -723,6 +723,8 @@ def _get_iterator(self) -> "_BaseDataLoaderIter":

def _patch_dataloader_get_iterators() -> None:
"""This function is used to replace the DataLoader iterator by their stateful version."""
if not _FaultTolerantMode.detect_current_mode().is_manual:
return
if not hasattr(DataLoader, "_ori_get_iterator"):
DataLoader._ori_get_iterator = DataLoader._get_iterator
DataLoader._get_iterator = _get_iterator
Expand Down
12 changes: 6 additions & 6 deletions pytorch_lightning/utilities/fetching.py
Expand Up @@ -16,7 +16,6 @@
from collections.abc import Iterable, Iterator
from contextlib import contextmanager
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Generator, List, Optional, Tuple

import torch
Expand All @@ -27,6 +26,8 @@
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
from pytorch_lightning.utilities.auto_restart import (
_add_capture_metadata_collate,
_patch_dataloader_get_iterators,
_teardown_dataloader_get_iterators,
IteratorState,
MergedIteratorState,
patch_dataloader_iterator,
Expand Down Expand Up @@ -109,11 +110,7 @@ def _add_capture_metadata_collate(dataloader: Iterable) -> None:
if isinstance(dataloader, CombinedLoader):
dataloader = dataloader.loaders

def add_capture_metadata_collate(dataloader: DataLoader):
if not isinstance(dataloader.collate_fn, partial):
_add_capture_metadata_collate(dataloader)

apply_to_collection(dataloader, DataLoader, add_capture_metadata_collate)
apply_to_collection(dataloader, DataLoader, _add_capture_metadata_collate)

def append_batch(self, batch) -> None:
self.batches.append(batch)
Expand Down Expand Up @@ -206,6 +203,8 @@ def __iter__(self) -> Generator[Tuple[Any, bool], None, None]:
if self.dataloader is None:
raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.")
self.reset()
self._attach_data_fetcher()
_patch_dataloader_get_iterators()
self.dataloader_iter = iter(self.dataloader)
self._apply_patch()
self.prefetching(self.prefetch_batches)
Expand All @@ -226,6 +225,7 @@ def teardown(self) -> None:
if isinstance(self.dataloader, DataLoader):
CombinedLoader._shutdown_workers_and_reset_iterator(self.dataloader)
self.dataloader_iter = None
_teardown_dataloader_get_iterators()


class DataFetcher(AbstractDataFetcher):
Expand Down
161 changes: 160 additions & 1 deletion tests/utilities/test_auto_restart.py
Expand Up @@ -20,7 +20,7 @@
from contextlib import suppress
from copy import deepcopy
from dataclasses import asdict
from typing import List, Optional
from typing import Iterator, List, Optional
from unittest import mock
from unittest.mock import ANY

Expand Down Expand Up @@ -1317,3 +1317,162 @@ def test_stateful_workers(num_workers):
_reload_dataloader_state_dict(dataloader, asdict(reloaded_state))
assert dataloader.sampler.counter == dataloader.dataset.counter == 1
data_fetcher.teardown()


class RandomFaultTolerantDataset(RandomGetItemDataset):
def __init__(self, *args, seed: int, **kwargs):
super().__init__(*args, **kwargs)
self.seed = seed
self._cache_state_dict = None
self.generator = None
self.counter_debug = 0

@property
def worker_id(self):
info = get_worker_info()
return info.id if info else 0

def __getitem__(self, index):
if self._cache_state_dict:
state_dict = self._cache_state_dict[self.worker_id]
self.generator = random.Random()
self.generator.setstate(state_dict["random_state"])
self._cache_state_dict = None

if not self.generator:
self.generator = random.Random(self.seed + self.worker_id)
return torch.tensor(index + self.generator.random())

def state_dict(self):
return {self.worker_id: {"random_state": self.generator.getstate()}}

def load_state_dict(self, state_dict):
self._cache_state_dict = state_dict


class RandomFaultTolerantSampler(RandomSampler):
def __init__(self, *args, seed: int = 0, generator=None, **kwargs):
generator = torch.Generator().manual_seed(seed)
super().__init__(*args, generator=generator, **kwargs)
self.counter = 0
self.restarting = False

def state_dict(self):
return {"random_state": self.state, "counter": self.counter}

def load_state_dict(self, state_dict):
self.generator.set_state(state_dict.get("random_state"))
self.counter = state_dict["counter"]
self.restarting = True

def __len__(self):
return len(self.data_source) - self.counter

def __iter__(self) -> Iterator[int]:
n = len(self.data_source)

self.state = self.generator.get_state()
indices = torch.randperm(n, generator=self.generator).tolist()

if not self.restarting:
self.counter = 0
else:
indices = indices[self.counter :]
self.restarting = False

for index in indices:
self.counter += 1
yield index

self.counter = 0


@pytest.mark.parametrize(
["train_dataset_cls", "val_dataset_cls"],
[
([RandomFaultTolerantDataset, RandomFaultTolerantDataset], [RandomFaultTolerantDataset]),
],
)
@pytest.mark.parametrize("val_check_interval", [0.5])
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "2"})
def test_fault_tolerant_manual_mode(val_check_interval, train_dataset_cls, val_dataset_cls, tmpdir):
class TestModel(BoringModel):
def __init__(self, should_fail: bool = False):
super().__init__()
self.layer = torch.nn.Linear(1, 2)
self.should_fail = should_fail
self.batches = []

def training_step(self, batch, batch_idx):
if self.should_fail and batch_idx == 7:
raise CustomException
self.batches.append(batch)
losses = []
for b in batch:
losses.append(super().training_step(b, batch_idx)["loss"])
return torch.stack(losses).mean()

def validation_step(self, batch, batch_idx, dataloader_idx=0):
pass

validation_epoch_end = None

def _create_dataloader_kwargs(self, dataset_class, dataset_len, seed, num_workers):
dl_kwargs = {}
dl_kwargs["dataset"] = dataset_class(dataset_len, 1, seed=seed)
dl_kwargs["sampler"] = RandomFaultTolerantSampler(dl_kwargs["dataset"], seed=seed)
dl_kwargs["num_workers"] = num_workers
dl_kwargs["batch_size"] = 1
return dl_kwargs

def train_dataloader(self):
return [
DataLoader(
**self._create_dataloader_kwargs(
dataset_class, 10, seed, seed + 1 if val_check_interval == 1.0 else 0
)
)
for seed, dataset_class in enumerate(train_dataset_cls)
]

def val_dataloader(self):
return [
DataLoader(**self._create_dataloader_kwargs(dataset_class, 1, seed, 0))
for seed, dataset_class in enumerate(val_dataset_cls)
]

def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.001)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]

seed_everything(42)
model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=val_check_interval)
trainer.fit(model)
total_batches = model.batches
total_weight = deepcopy(model.layer.weight)
trainer.train_dataloader = None

seed_everything(42)
model = TestModel(should_fail=True)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=val_check_interval)
with suppress(CustomException):
trainer.fit(model)
trainer.train_dataloader = None
failed_batches = model.batches
failed_weight = deepcopy(model.layer.weight)

checkpoint_path = str(tmpdir / ".pl_auto_save.ckpt")
assert os.path.exists(checkpoint_path)

seed_everything(42)
model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=val_check_interval)
trainer.fit(model, ckpt_path=checkpoint_path)
trainer.train_dataloader = None
restart_batches = model.batches

torch.testing.assert_allclose(total_batches, failed_batches + restart_batches)
assert not torch.equal(total_weight, failed_weight)
assert torch.equal(total_weight, model.layer.weight)