diff --git a/README.md b/README.md index f0cd30db5d40c..0db819f9178e9 100644 --- a/README.md +++ b/README.md @@ -294,6 +294,7 @@ Lightning also adds a text column with all the hyperparameters for this experime #### Distributed training +- [Implement Your Own Distributed (DDP) training](https://williamfalcon.github.io/pytorch-lightning/Trainer/hooks/#init_ddp_connection) - [16-bit mixed precision](https://williamfalcon.github.io/pytorch-lightning/Trainer/Distributed%20training/#16-bit-mixed-precision) - [Multi-GPU](https://williamfalcon.github.io/pytorch-lightning/Trainer/Distributed%20training/#Multi-GPU) - [Multi-node](https://williamfalcon.github.io/pytorch-lightning/Trainer/Distributed%20training/#Multi-node) diff --git a/docs/LightningModule/RequiredTrainerInterface.md b/docs/LightningModule/RequiredTrainerInterface.md index 53bfc821910d4..7881dde4b057a 100644 --- a/docs/LightningModule/RequiredTrainerInterface.md +++ b/docs/LightningModule/RequiredTrainerInterface.md @@ -15,6 +15,7 @@ Otherwise, to Define a Lightning Module, implement the following methods: **Optional**: +- [training_end](RequiredTrainerInterface.md#training_end) - [validation_step](RequiredTrainerInterface.md#validation_step) - [validation_end](RequiredTrainerInterface.md#validation_end) - [test_step](RequiredTrainerInterface.md#test_step) @@ -178,6 +179,89 @@ def training_step(self, batch, batch_nb, hiddens): You can also return a -1 instead of a dict to stop the current loop. This is useful if you want to break out of the current training epoch early. +--- +### training_end + +``` {.python} +def training_end(self, train_step_outputs) +``` +In certain cases (dp, ddp2), you might want to use all outputs of every process to do something. +For instance, if using negative samples, you could run a batch via dp and use ALL the outputs +for a single softmax across the full batch (ie: the denominator would use the full batch). + +In this case you should define training_end to perform those calculations. + + +**Params** + +| Param | description | +|---|---| +| outputs | What you return in training_step. + +**Return** + +Dictionary or OrderedDict + +| key | value | is required | +|---|---|---| +| loss | tensor scalar | Y | +| progress_bar | Dict for progress bar display. Must have only tensors | N | +| log | Dict of metrics to add to logger. Must have only tensors (no images, etc) | N | + + +**Example** + +``` {.python} +# WITHOUT training_end +# if used in DP or DDP2, this batch is 1/nb_gpus large +def training_step(self, batch, batch_nb): + # batch is 1/nb_gpus big + x, y = batch + + out = self.forward(x) + loss = self.softmax(out) + loss = nce_loss(loss) + return {'loss': loss} + +# -------------- +# with training_end to do softmax over the full batch +def training_step(self, batch, batch_nb): + # batch is 1/nb_gpus big + x, y = batch + + out = self.forward(x) + return {'out': out} + +def training_end(self, outputs): + # this out is now the full size of the batch + out = outputs['out'] + + # this softmax now uses the full batch size + loss = self.softmax(out) + loss = nce_loss(loss) + return {'loss': loss} +``` + +If you define multiple optimizers, this step will also be called with an additional ```optimizer_idx``` param. +``` {.python} +# Multiple optimizers (ie: GANs) +def training_step(self, batch, batch_nb, optimizer_idx): + if optimizer_idx == 0: + # do training_step with encoder + if optimizer_idx == 1: + # do training_step with decoder +``` + +If you add truncated back propagation through time you will also get an additional argument with the hidden states of the previous step. +``` {.python} +# Truncated back-propagation through time +def training_step(self, batch, batch_nb, hiddens): + # hiddens are the hiddens from the previous truncated backprop step +``` + +You can also return a -1 instead of a dict to stop the current loop. This is useful if you want to +break out of the current training epoch early. + --- ### train_dataloader diff --git a/docs/Trainer/hooks.md b/docs/Trainer/hooks.md index 96c41e2ba991b..5f1baf84b6ebf 100644 --- a/docs/Trainer/hooks.md +++ b/docs/Trainer/hooks.md @@ -175,3 +175,92 @@ def tbptt_split_batch(self, batch, split_size): return splits ``` + +--- +#### configure_apex +Overwrite to define your own Apex implementation init. + +```python +def configure_apex(self, amp, model, optimizers, amp_level): + """ + Override to init AMP your own way + Must return a model and list of optimizers + :param amp: + :param model: + :param optimizers: + :param amp_level: + :return: Apex wrapped model and optimizers + """ + model, optimizers = amp.initialize( + model, optimizers, opt_level=amp_level, + ) + + return model, optimizers +``` + +--- +#### configure_ddp +Overwrite to define your own DDP implementation init. +The only requirement is that: +1. On a validation batch the call goes to model.validation_step. +2. On a training batch the call goes to model.training_step. +3. On a testing batch, the call goes to model.test_step + +```python +def configure_ddp(self, model, device_ids): + """ + Override to init DDP in a different way or use your own wrapper. + Must return model. + :param model: + :param device_ids: + :return: DDP wrapped model + """ + # Lightning DDP simply routes to test_step, val_step, etc... + model = LightningDistributedDataParallel( + model, + device_ids=device_ids, + find_unused_parameters=True + ) + return model +``` + +--- +#### init_ddp_connection +Override to init DDP in your own way. + +```python +def init_ddp_connection(self): + """ + Connect all procs in the world using the env:// init + Use the first node as the root address + """ + + # use slurm job id for the port number + # guarantees unique ports across jobs from same grid search + try: + # use the last 4 numbers in the job id as the id + default_port = os.environ['SLURM_JOB_ID'] + default_port = default_port[-4:] + + # all ports should be in the 10k+ range + default_port = int(default_port) + 15000 + + except Exception as e: + default_port = 12910 + + # if user gave a port number, use that one instead + try: + default_port = os.environ['MASTER_PORT'] + except Exception: + os.environ['MASTER_PORT'] = str(default_port) + + # figure out the root node addr + try: + root_node = os.environ['SLURM_NODELIST'].split(' ')[0] + except Exception: + root_node = '127.0.0.2' + + root_node = self.trainer.resolve_root_node_address(root_node) + os.environ['MASTER_ADDR'] = root_node + dist.init_process_group('nccl', rank=self.proc_rank, world_size=self.world_size) +``` diff --git a/docs/Trainer/index.md b/docs/Trainer/index.md index 13df0539549ee..1c60786cfe2c0 100644 --- a/docs/Trainer/index.md +++ b/docs/Trainer/index.md @@ -42,6 +42,7 @@ But of course the fun is in all the advanced things it can do: **Distributed training** +- [Implement Your Own Distributed (DDP) training](https://williamfalcon.github.io/pytorch-lightning/Trainer/hooks/#init_ddp_connection) - [16-bit mixed precision](https://williamfalcon.github.io/pytorch-lightning/Trainer/Distributed%20training/#16-bit-mixed-precision) - [Multi-GPU](https://williamfalcon.github.io/pytorch-lightning/Trainer/Distributed%20training/#Multi-GPU) - [Multi-node](https://williamfalcon.github.io/pytorch-lightning/Trainer/Distributed%20training/#Multi-node) diff --git a/docs/index.md b/docs/index.md index 06a76fec231f8..b48b0d533f046 100644 --- a/docs/index.md +++ b/docs/index.md @@ -99,6 +99,7 @@ Notice a few things about this flow: ###### Distributed training +- [Implement Your Own Distributed (DDP) training](https://williamfalcon.github.io/pytorch-lightning/Trainer/hooks/#init_ddp_connection) - [16-bit mixed precision](https://williamfalcon.github.io/pytorch-lightning/Trainer/Distributed%20training/#16-bit-mixed-precision) - [Multi-GPU](https://williamfalcon.github.io/pytorch-lightning/Trainer/Distributed%20training/#Multi-GPU) - [Multi-node](https://williamfalcon.github.io/pytorch-lightning/Trainer/Distributed%20training/#Multi-node) diff --git a/pytorch_lightning/root_module/root_module.py b/pytorch_lightning/root_module/root_module.py index b25ab9e49443f..3c2f1efd0a4c8 100644 --- a/pytorch_lightning/root_module/root_module.py +++ b/pytorch_lightning/root_module/root_module.py @@ -1,8 +1,10 @@ +import os import warnings import collections from argparse import Namespace import torch +import torch.distributed as dist from pytorch_lightning.root_module.decorators import data_loader from pytorch_lightning.root_module.grads import GradInformation @@ -11,6 +13,7 @@ from pytorch_lightning.root_module.model_saving import ModelIO from pytorch_lightning.trainer.trainer_io import load_hparams_from_tags_csv import logging +from pytorch_lightning.pt_overrides.override_data_parallel import LightningDistributedDataParallel class LightningModule(GradInformation, ModelIO, ModelHooks): @@ -48,10 +51,19 @@ def training_step(self, *args, **kwargs): return loss, dict with metrics for tqdm :param called with batch, batch_nb additional: optimizer_i if multiple optimizers used - :return: + :return: dict with loss key and optional log, progress keys + if implementing training_step, return whatever you need in that step """ raise NotImplementedError + def training_end(self, *args, **kwargs): + """ + return loss, dict with metrics for tqdm + :param called with outputs of training_step + :return: dict with loss key and optional log, progress keys + """ + pass + def validation_step(self, *args, **kwargs): """ return whatever outputs will need to be aggregated in validation_end @@ -90,6 +102,72 @@ def test_end(self, outputs): """ pass + def configure_ddp(self, model, device_ids): + """ + Override to init DDP in a different way or use your own wrapper. + Must return model. + :param model: + :param device_ids: + :return: DDP wrapped model + """ + model = LightningDistributedDataParallel( + model, + device_ids=device_ids, + find_unused_parameters=True + ) + return model + + def init_ddp_connection(self, proc_rank, world_size): + """ + Connect all procs in the world using the env:// init + Use the first node as the root address + """ + + # use slurm job id for the port number + # guarantees unique ports across jobs from same grid search + try: + # use the last 4 numbers in the job id as the id + default_port = os.environ['SLURM_JOB_ID'] + default_port = default_port[-4:] + + # all ports should be in the 10k+ range + default_port = int(default_port) + 15000 + + except Exception as e: + default_port = 12910 + + # if user gave a port number, use that one instead + try: + default_port = os.environ['MASTER_PORT'] + except Exception: + os.environ['MASTER_PORT'] = str(default_port) + + # figure out the root node addr + try: + root_node = os.environ['SLURM_NODELIST'].split(' ')[0] + except Exception: + root_node = '127.0.0.2' + + root_node = self.trainer.resolve_root_node_address(root_node) + os.environ['MASTER_ADDR'] = root_node + dist.init_process_group('nccl', rank=proc_rank, world_size=world_size) + + def configure_apex(self, amp, model, optimizers, amp_level): + """ + Override to init AMP your own way + Must return a model and list of optimizers + :param amp: + :param model: + :param optimizers: + :param amp_level: + :return: Apex wrapped model and optimizers + """ + model, optimizers = amp.initialize( + model, optimizers, opt_level=amp_level, + ) + + return model, optimizers + def configure_optimizers(self): """ Return a list of optimizers and a list of schedulers (could be empty) diff --git a/pytorch_lightning/trainer/ddp_mixin.py b/pytorch_lightning/trainer/ddp_mixin.py index 9f6c1ad4199b5..0653fa58d2651 100644 --- a/pytorch_lightning/trainer/ddp_mixin.py +++ b/pytorch_lightning/trainer/ddp_mixin.py @@ -4,9 +4,7 @@ import logging import torch -import torch.distributed as dist -from pytorch_lightning.pt_overrides.override_data_parallel import LightningDistributedDataParallel from pytorch_lightning.utilities.debugging import MisconfigurationException try: @@ -145,7 +143,8 @@ def ddp_train(self, gpu_nb, model): # set up server using proc 0's ip address # try to init for 20 times at max in case ports are taken # where to store ip_table - self.__init_tcp_connection() + model.trainer = self + model.init_ddp_connection(self.proc_rank, self.world_size) # CHOOSE OPTIMIZER # allow for lr schedulers as well @@ -167,9 +166,7 @@ def ddp_train(self, gpu_nb, model): # run through amp wrapper before going to distributed DP if self.use_amp: # An example - model, optimizers = amp.initialize( - model, self.optimizers, opt_level=self.amp_level, - ) + model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level) self.optimizers = optimizers # DDP2 uses all GPUs on the machine @@ -178,53 +175,12 @@ def ddp_train(self, gpu_nb, model): elif self.use_ddp2: device_ids = None - model = LightningDistributedDataParallel( - model, - device_ids=device_ids, - find_unused_parameters=True - ) + # allow user to configure ddp + model = model.configure_ddp(model, device_ids) # continue training routine self.run_pretrain_routine(model) - def __init_tcp_connection(self): - """ - Connect all procs in the world using the env:// init - Use the first node as the root address - :param port: - :param tries: - :return: - """ - - # use slurm job id for the port number - # guarantees unique ports across jobs from same grid search - try: - # use the last 4 numbers in the job id as the id - default_port = os.environ['SLURM_JOB_ID'] - default_port = default_port[-4:] - - # all ports should be in the 10k+ range - default_port = int(default_port) + 15000 - - except Exception as e: - default_port = 12910 - - # if user gave a port number, use that one instead - try: - default_port = os.environ['MASTER_PORT'] - except Exception: - os.environ['MASTER_PORT'] = str(default_port) - - # figure out the root node addr - try: - root_node = os.environ['SLURM_NODELIST'].split(' ')[0] - except Exception: - root_node = '127.0.0.2' - - root_node = self.resolve_root_node_address(root_node) - os.environ['MASTER_ADDR'] = root_node - dist.init_process_group("nccl", rank=self.proc_rank, world_size=self.world_size) - def resolve_root_node_address(self, root_node): if '[' in root_node: name = root_node.split('[')[0] diff --git a/pytorch_lightning/trainer/dp_mixin.py b/pytorch_lightning/trainer/dp_mixin.py index 0bde8a7b1315a..684ff15c6989b 100644 --- a/pytorch_lightning/trainer/dp_mixin.py +++ b/pytorch_lightning/trainer/dp_mixin.py @@ -71,9 +71,7 @@ def single_gpu_train(self, model): if self.use_amp: # An example - model, optimizers = amp.initialize( - model, self.optimizers, opt_level=self.amp_level, - ) + model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level) self.optimizers = optimizers self.run_pretrain_routine(model) diff --git a/pytorch_lightning/trainer/logging_mixin.py b/pytorch_lightning/trainer/logging_mixin.py index 236e512d781ca..13e5afb6f6f07 100644 --- a/pytorch_lightning/trainer/logging_mixin.py +++ b/pytorch_lightning/trainer/logging_mixin.py @@ -156,6 +156,10 @@ def reduce_distributed_output(self, output, nb_gpus): if isinstance(output[k], dict): output[k] = self.reduce_distributed_output(output[k], nb_gpus) + # do nothing when there's a scalar + elif isinstance(output[k], torch.Tensor) and output[k].dim() == 0: + pass + # reduce only metrics that have the same nb of gpus elif output[k].size(0) == nb_gpus: reduced = torch.mean(output[k]) diff --git a/pytorch_lightning/trainer/train_loop_mixin.py b/pytorch_lightning/trainer/train_loop_mixin.py index d41fc7c16fee6..306416db8d7ed 100644 --- a/pytorch_lightning/trainer/train_loop_mixin.py +++ b/pytorch_lightning/trainer/train_loop_mixin.py @@ -189,13 +189,6 @@ def optimizer_closure(): callback_metrics = output[3] self.hiddens = output[4] - # track metrics for callbacks - all_callback_metrics.append(callback_metrics) - - # track progress bar metrics - self.add_tqdm_metrics(progress_bar_metrics) - all_log_metrics.append(log_metrics) - # accumulate loss # (if accumulate_grad_batches = 1 no effect) closure_loss = closure_loss / self.accumulate_grad_batches @@ -204,6 +197,13 @@ def optimizer_closure(): model_ref = self.get_model() model_ref.backward(self.use_amp, closure_loss, optimizer) + # track metrics for callbacks + all_callback_metrics.append(callback_metrics) + + # track progress bar metrics + self.add_tqdm_metrics(progress_bar_metrics) + all_log_metrics.append(log_metrics) + # insert after step hook if self.is_function_implemented('on_after_backward'): model_ref = self.get_model() @@ -277,13 +277,15 @@ def training_forward(self, batch, batch_nb, opt_idx, hiddens): if len(self.optimizers) > 1: args.append(opt_idx) + # pass hiddens if using tbptt if self.truncated_bptt_steps is not None: args.append(hiddens) - if self.use_ddp or self.use_ddp2: - output = self.model(*args) - elif self.use_dp: + # distributed forward + if self.use_ddp or self.use_ddp2 or self.use_dp: output = self.model(*args) + + # single GPU forward elif self.single_gpu: gpu_id = 0 if type(self.data_parallel_device_ids) is list: @@ -292,9 +294,16 @@ def training_forward(self, batch, batch_nb, opt_idx, hiddens): args[0] = batch output = self.model.training_step(*args) + # CPU forward else: output = self.model.training_step(*args) + # allow any mode to define training_end + if self.is_overriden('training_end'): + model_ref = self.get_model() + output = model_ref.training_end(output) + # format and reduce outputs accordingly output = self.process_output(output, train=True) + return output