Skip to content

Commit

Permalink
Merge branch 'master' into ci/mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanNaren committed Dec 8, 2020
2 parents 2921087 + 127454a commit 0e56774
Show file tree
Hide file tree
Showing 27 changed files with 771 additions and 62 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## Unreleased

### Added

- Added `all_gather` method to `LightningModule` which allows gradient based tensor synchronizations for use-cases such as negative sampling. ([#5012](https://github.com/PyTorchLightning/pytorch-lightning/pull/5012))

### Fixed

- Fixed `LoggerConnector` to have logged metrics on root device in DP ([#4138](https://github.com/PyTorchLightning/pytorch-lightning/pull/4138))
Expand Down
1 change: 1 addition & 0 deletions benchmarks/test_sharded_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir):
gpus=2,
accelerator='ddp_spawn',
model_cls=SeedTrainLoaderManualModel,
max_percent_speed_diff=0.25 # Increase speed diff since only 2 GPUs sharding 2 optimizers
)


Expand Down
31 changes: 27 additions & 4 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import contextmanager
from enum import Enum
from typing import Any, Optional, Union
Expand All @@ -21,10 +20,8 @@
from torch.optim import Optimizer

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict

if torch.distributed.is_available():
Expand Down Expand Up @@ -175,6 +172,20 @@ def sync_tensor(self,
"""
raise NotImplementedError()

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
"""
Function to gather a tensor from several distributed processes
Args:
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op
Return:
A tensor of shape (world_size, batch, ...)
"""
raise NotImplementedError()

def optimizer_state(self, optimizer: Optimizer) -> dict:
"""
Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom
Expand Down Expand Up @@ -222,6 +233,18 @@ def __setstate__(self, d):
def on_save(self, checkpoint):
return checkpoint

@property
def rpc_enabled(self):
return self.ddp_plugin is not None and isinstance(self.ddp_plugin, RPCPlugin)

@property
def distributed_sampler_kwargs(self):
raise NotImplementedError

@property
def require_distributed_sampler(self):
raise NotImplementedError

@contextmanager
def block_ddp_plugin_sync_behaviour(self):
"""
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/accelerators/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,7 @@ def sync_tensor(self,
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return tensor

@property
def require_distributed_sampler(self):
return False
53 changes: 49 additions & 4 deletions pytorch_lightning/accelerators/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available, all_gather_ddp_if_available

if HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
Expand Down Expand Up @@ -101,9 +102,11 @@ def set_world_ranks(self, process_idx):
def broadcast(self, obj, src=0):
return self.dist.broadcast(obj)

def model_to_device(self, model, process_idx):
def init_device(self, process_idx):
self.trainer.root_gpu = process_idx
torch.cuda.set_device(self.trainer.root_gpu)

def model_to_device(self, model):
model.cuda(self.trainer.root_gpu)

def get_device_ids(self):
Expand Down Expand Up @@ -133,6 +136,9 @@ def ddp_train(self, process_idx, mp_queue, model):
# set warning rank
rank_zero_only.rank = self.trainer.global_rank

# Initialize cuda device
self.init_device(process_idx)

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
Expand All @@ -143,6 +149,15 @@ def ddp_train(self, process_idx, mp_queue, model):
self.trainer.is_slurm_managing_tasks
)

if isinstance(self.ddp_plugin, RPCPlugin):
if not self.ddp_plugin.is_main_rpc_process:
self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer)
self.ddp_plugin.exit_rpc_process()
if self.ddp_plugin.return_after_exit_rpc_process:
return
else:
self.ddp_plugin.on_main_rpc_connection(self.trainer)

# call setup after the ddp process has connected
self.trainer.call_setup_hook(model)

Expand All @@ -158,12 +173,14 @@ def ddp_train(self, process_idx, mp_queue, model):
model = self.configure_sync_batchnorm(model)

# move the model to the correct device
self.model_to_device(model, process_idx)
self.model_to_device(model)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.setup_optimizers(model)

self.ddp_plugin.on_after_setup_optimizers(self.trainer)

# set model properties before going into wrapper
self.trainer.model_connector.copy_trainer_model_properties(model)

Expand All @@ -189,7 +206,7 @@ def ddp_train(self, process_idx, mp_queue, model):
return results

def configure_ddp(
self, model: LightningModule, device_ids: List[int]
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model
Expand Down Expand Up @@ -217,5 +234,33 @@ def sync_tensor(self,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return sync_ddp_if_available(tensor, group, reduce_op)

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
"""
Function to gather a tensor from several distributed processes
Args:
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op
Return:
A tensor of shape (world_size, batch, ...)
"""
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)

@property
def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(
num_replicas=self.trainer.num_nodes,
rank=self.trainer.global_rank
)
if self.ddp_plugin is not None:
distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs)
return distributed_sampler_kwargs

@property
def require_distributed_sampler(self):
return True
61 changes: 54 additions & 7 deletions pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
from pytorch_lightning.utilities.distributed import find_free_network_port, rank_zero_only, sync_ddp_if_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything
Expand Down Expand Up @@ -162,8 +164,11 @@ def _step(self, args):
return output

def barrier(self, name: Optional[str] = None):
if torch_distrib.is_initialized():
torch_distrib.barrier()
if self.rpc_enabled:
# Allow RPC to handle barrier on main RPC processes
self.ddp_plugin.barrier()
elif torch_distrib.is_initialized():
torch_distrib.barrier(group=self.ddp_plugin.data_parallel_group)

def _check_can_spawn_children(self):
if self._has_spawned_children:
Expand All @@ -177,9 +182,11 @@ def set_world_ranks(self, process_idx):
self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes

def model_to_device(self, model, process_idx):
def init_device(self, process_idx):
self.trainer.root_gpu = self.trainer.data_parallel_device_ids[self.trainer.local_rank]
torch.cuda.set_device(self.trainer.root_gpu)

def model_to_device(self, model):
model.cuda(self.trainer.root_gpu)

def get_device_ids(self):
Expand All @@ -192,12 +199,12 @@ def on_train_end(self):
def early_stopping_should_stop(self, pl_module):
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
torch_distrib.all_reduce(stop, op=torch_distrib.reduce_op.SUM)
torch_distrib.barrier()
self.barrier('early_stopping')
should_stop = stop == self.trainer.world_size
return should_stop

def broadcast(self, obj, src=0):
return self.dist.broadcast(obj)
return self.dist.broadcast(obj, group=self.ddp_plugin.data_parallel_group)

def ddp_train(self, process_idx, model):
"""
Expand Down Expand Up @@ -226,6 +233,9 @@ def ddp_train(self, process_idx, model):
# set warning rank
rank_zero_only.rank = self.trainer.global_rank

# Initialize cuda device
self.init_device(process_idx)

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
Expand All @@ -236,6 +246,15 @@ def ddp_train(self, process_idx, model):
self.trainer.is_slurm_managing_tasks
)

if isinstance(self.ddp_plugin, RPCPlugin):
if not self.ddp_plugin.is_main_rpc_process:
self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer)
self.ddp_plugin.exit_rpc_process()
if self.ddp_plugin.return_after_exit_rpc_process:
return
else:
self.ddp_plugin.on_main_rpc_connection(self.trainer)

# call setup after the ddp process has connected
self.trainer.call_setup_hook(model)

Expand All @@ -251,7 +270,7 @@ def ddp_train(self, process_idx, model):
model = self.configure_sync_batchnorm(model)

# move the model to the correct device
self.model_to_device(model, process_idx)
self.model_to_device(model)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
Expand Down Expand Up @@ -284,7 +303,7 @@ def ddp_train(self, process_idx, model):
return results

def configure_ddp(
self, model: LightningModule, device_ids: List[int]
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model
Expand Down Expand Up @@ -315,5 +334,33 @@ def sync_tensor(self,
"""
return sync_ddp_if_available(tensor, group, reduce_op)

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
"""
Function to gather a tensor from several distributed processes
Args:
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op
Return:
A tensor of shape (world_size, batch, ...)
"""
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)

@property
def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(
num_replicas=self.trainer.num_nodes * self.trainer.num_processes,
rank=self.trainer.global_rank
)
if self.ddp_plugin is not None:
distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs)
return distributed_sampler_kwargs

@property
def require_distributed_sampler(self):
return True
Loading

0 comments on commit 0e56774

Please sign in to comment.