Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Add FastForwardSampler 2/n - Fault Tolerant Training #8307

Merged
merged 38 commits into from
Jul 7, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
c6a774c
wip
tchaton Jul 6, 2021
a946159
update
tchaton Jul 6, 2021
a2b74f0
resolve bug
tchaton Jul 6, 2021
668b02e
wip
tchaton Jul 6, 2021
3827208
wip
tchaton Jul 6, 2021
1f4ef8c
wip
tchaton Jul 6, 2021
770a78b
resolved tests
tchaton Jul 6, 2021
76f1f53
update on comments
tchaton Jul 7, 2021
3aaf0ea
update
tchaton Jul 7, 2021
ed056aa
update
tchaton Jul 7, 2021
f1cdcdc
Merge branch 'master' into add_fast_forward_sampler
tchaton Jul 7, 2021
81bf954
Update pytorch_lightning/utilities/auto_restart.py
tchaton Jul 7, 2021
7a05094
update on comments
tchaton Jul 7, 2021
bff288c
Merge branch 'add_fast_forward_sampler' of https://github.com/PyTorch…
tchaton Jul 7, 2021
98ec265
Update pytorch_lightning/utilities/auto_restart.py
tchaton Jul 7, 2021
82b1cf1
resolve bug
tchaton Jul 7, 2021
8972d82
update
tchaton Jul 7, 2021
1fb8c02
move properties to top
awaelchli Jul 7, 2021
f086edb
update docs for fast forward sampler
awaelchli Jul 7, 2021
7450388
move public attribute to top
awaelchli Jul 7, 2021
5e43757
add missing super call
awaelchli Jul 7, 2021
eae11c3
update docs for state_dict
awaelchli Jul 7, 2021
efcb882
fix merge conflict
awaelchli Jul 7, 2021
c068704
add missing super() call
awaelchli Jul 7, 2021
79ff550
move property to top
awaelchli Jul 7, 2021
50ac617
update on comments
tchaton Jul 7, 2021
733e329
Merge branch 'add_fast_forward_sampler' of https://github.com/PyTorch…
tchaton Jul 7, 2021
67a3691
update
tchaton Jul 7, 2021
4eee70a
resolve bug
tchaton Jul 7, 2021
8b93505
update
tchaton Jul 7, 2021
028d773
update on comments
tchaton Jul 7, 2021
5c3e328
activate coverage for CaptureIterableDataset
tchaton Jul 7, 2021
461bee9
update on comments
tchaton Jul 7, 2021
2de5290
update
tchaton Jul 7, 2021
613ae7d
update
tchaton Jul 7, 2021
9e1aa51
update
tchaton Jul 7, 2021
fd7ea17
resole bug
tchaton Jul 7, 2021
6daff95
update
tchaton Jul 7, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `restore` function and `restarting` attribute to base `Loop` ([#8247](https://github.com/PyTorchLightning/pytorch-lightning/pull/8247))


- Added `FastForwardSampler` and `CaptureIterativeDataset` ([#8307](https://github.com/PyTorchLightning/pytorch-lightning/pull/8307))


### Changed


Expand Down
231 changes: 231 additions & 0 deletions pytorch_lightning/utilities/auto_restart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Generator, Iterator, List, Optional, Union

from torch.utils.data import BatchSampler, get_worker_info, Sampler
from torch.utils.data.dataloader import IterableDataset

from pytorch_lightning.utilities.enums import AutoRestartBatchKeys


class FastForwardSampler(Sampler):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""
This class is used to wrap a :class:`torch.utils.data.Sampler` and record the number
of iteration performed during an epoch.
On reload, if a ``state_dict`` is provided, this will be used to fast forward the wrapped samplers.

"""

def __init__(self, sampler: Union[Sampler, BatchSampler, Generator]) -> None:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self._sampler = sampler
self._current_iteration = 0
self._dataloader_batch_size: Optional[int] = None
self.restarting: bool = False
self._cached_state_dict: Optional[Dict[str, Any]] = None

def setup(self, dataloader_batch_size: Optional[int] = None) -> None:
"""
Setup the ``FastForwardSampler``.
This is required only when the provided dataset subclassed :class:`torch.utils.data.Dataset`.
"""
self._dataloader_batch_size = dataloader_batch_size

@property
def worker_id(self) -> int:
worker_info = get_worker_info()
return worker_info.id if worker_info else 0

def __iter__(self) -> Iterator[List[int]]:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
# setup local counter
restart_counter = 0

# iteration over wrapped sampler
for batch in self._sampler:
tchaton marked this conversation as resolved.
Show resolved Hide resolved

# if we are restarting, we will fastforward the batches
if self.restarting:

# the state dict is cached until workers are made available through the DataLoader
if self._cached_state_dict is not None and self.worker_id in self._cached_state_dict:

# reload the current state dict
self.load_state_dict(self._cached_state_dict, workers_initialized=True)
self._cached_state_dict = None

# increment counter
restart_counter += 1

# if restart counter matches the current iteration, we should stop restarting
if restart_counter == self._current_iteration:
self.restarting = False
else:
self._current_iteration += 1

# yield the batch
yield batch

self.reset()
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def reset(self) -> None:
self._current_iteration = 0

def __len__(self) -> int:
return len(self._sampler)

@property
def drop_last(self) -> bool:
return self._sampler.drop_last

@property
def batch_size(self) -> int:
return self._sampler.batch_size

@property
def sampler(self) -> Sampler:
return self._sampler

@property
def batch_indices(self) -> Optional[List[int]]:
return self._sampler.batch_indices

def _compute_current_iteration(self, number_batch_processed: Optional[int] = None) -> int:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""
This function is used to compute the effective iteration.
As DataLoader can perform ``prefecthing`` or training can fail while processing a batch,
tchaton marked this conversation as resolved.
Show resolved Hide resolved
the current iteration needs to be computed using the ``number_batch_processed`` processed information.
"""
if number_batch_processed is not None:
current_iteration = number_batch_processed
else:
current_iteration = self._current_iteration

if self._dataloader_batch_size:
current_iteration *= self._dataloader_batch_size

return current_iteration

def state_dict(self, number_batch_processed: Optional[int] = None) -> Dict[int, Dict[str, int]]:
return {self.worker_id: {"current_iteration": self._compute_current_iteration(number_batch_processed)}}

def load_state_dict(self, state_dict: Dict[str, Any], workers_initialized: bool = False) -> None:
# if the state dict contains multiple states, it means there were multiple workers
# as workers aren't available, the ``state_dict``` is cached until workers are made available.
if len(state_dict) > 1 and not workers_initialized:
self._cached_state_dict = state_dict
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self.restarting = self._cached_state_dict[self.worker_id]["current_iteration"] > 0
return
self._current_iteration = state_dict[self.worker_id]["current_iteration"]
self.restarting = self._current_iteration > 0


class CaptureIterativeDataset(IterableDataset):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""
The ``CaptureIterativeDataset`` is used to wrap an :class:`torch.utils.data.IterativeDataset`.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
On ``__iter__`` function call the ``CaptureIterativeDataset`` will wrap the wrapped dataset
generators into ``FastForwardSampler`` to keep track of progress.
On ``__next__`` function call, the ``CaptureIterativeDataset`` will return a dictionary containing
user data and metadata containing the ``FastForwardSampler`` samplers state_dict.
"""

def __init__(self, dataset: IterableDataset, initial_seed: Optional[int] = None):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self.dataset = dataset
self.state_dict: Optional[Dict[int, Any]] = None
self.initial_seed = initial_seed
self.samplers: Optional[Dict[str, FastForwardSampler]] = None

def load_state_dict(self, state_dict: Dict[int, Any]) -> None:
self.state_dict = state_dict

def _wrap_generator_samplers(self) -> None:
if self.samplers is None:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self.samplers = {}

# access wrapped dataset attributes
dataset_dict = self.dataset.__dict__

# create a tuple of sampler names
samplers_names = tuple(v.__class__.__name__ for k, v in dataset_dict.items() if isinstance(v, Sampler))

# create a dictionary of generator present within the dataset attributes
dataset_sampler_generators = {k: v for k, v in dataset_dict.items() if isinstance(v, Generator)}

# iterate over the generator. If a generator was created from a ``Sampler```,
# it will be wrapped into a ``FastForwardSampler``.
for (generator_attr_name, generator) in dataset_sampler_generators.items():

# Generator name have the the form `SamplerName.__iter__`
generator_name = generator.__qualname__.split('.')[0]

# validate the base generator name matches a sampler name.
if any(sampler_name == generator_name for sampler_name in samplers_names):

# wrap the generator into a ``FastForwardSampler``
sampler = FastForwardSampler(generator)

# if ``CaptureIterativeDataset`` was available, the sampler should reload its own state.
if self.state_dict is not None:
sampler.load_state_dict(self.state_dict[generator_attr_name])

# store the samplers
self.samplers[generator_attr_name] = sampler

# replace generator with the generator from the ``FastForwardSampler``.
dataset_dict[generator_attr_name] = iter(sampler)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

# reset state dict.
self.state_dict = None

def reset_on_epoch(self) -> None:
self.state_dict = None

@property
def sampler(self) -> Sampler:
return self.dataset.sampler

def __iter__(self) -> Iterator:
# create a generator from the wrapped Iterative Dataset
# if the dataset contained samplers, they will be transformers into generators
self.iter_data = iter(self.dataset)

# wrap any generator associated to a Sampler into a ``FastForwardSampler``.
self._wrap_generator_samplers()
return self

def __next__(self):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
# fetch next data
data = next(self.iter_data)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

# create current samplers state_dict
worker_info = get_worker_info()
state_dicts = {"id": worker_info.id if worker_info is not None else 0}
for k, v in self.samplers.items():
state_dicts.update({k: v.state_dict()})
tchaton marked this conversation as resolved.
Show resolved Hide resolved

# return both current data and samplers ``state_dict``.
return {"data": data, AutoRestartBatchKeys.PL_SAMPLERS: state_dicts}

@staticmethod
def convert_batch_into_state_dict(batch) -> Dict[str, Dict[int, Any]]:
"""
This function is used to convert a batch into a state_dict
"""
return {
k: {
tchaton marked this conversation as resolved.
Show resolved Hide resolved
batch[AutoRestartBatchKeys.PL_SAMPLERS]["id"][-1].item(): {
"current_iteration": v[batch[AutoRestartBatchKeys.PL_SAMPLERS]["id"][-1].item()]
["current_iteration"][-1].item(),
}
}
for k, v in batch[AutoRestartBatchKeys.PL_SAMPLERS].items() if k != "id"
}
8 changes: 8 additions & 0 deletions pytorch_lightning/utilities/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,11 @@ class GradClipAlgorithmType(LightningEnum):
"""
VALUE = 'value'
NORM = 'norm'


class AutoRestartBatchKeys(LightningEnum):
"""
Defines special dictionary keys used to track sampler progress with multiple workers.
"""

PL_SAMPLERS = "__pl_samplers__"
tchaton marked this conversation as resolved.
Show resolved Hide resolved
Loading