diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 2dbf46eb14b12..4076741008ab5 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -105,7 +105,7 @@ def __init__(self, :param log_save_interval: :param add_log_row_interval: :param distributed_backend: - 'np' to use DistributedParallel, 'dp' to use DistributedDataParallel + 'do' to use DistributedParallel, 'dp' to use DistributedDataParallel, 'n' to use none :param use_amp: :param print_nan_grads: :param print_weights_summary: @@ -147,6 +147,7 @@ def __init__(self, self.node_rank = 0 self.use_ddp = False self.use_dp = False + self.single_gpu = False # training bookeeping self.total_batch_nb = 0 @@ -194,6 +195,12 @@ def __init__(self, 'To silence this warning set distributed_backend=ddp' warnings.warn(w) + # remove dp and ddp when requesting single gpu + if self.data_parallel_device_ids is not None and len(self.data_parallel_device_ids) == 1: + self.use_ddp = False + self.use_dp = False + self.single_gpu = True + # extract SLURM flag vars # whenever we have the correct number of tasks, we let slurm manage processes # otherwise we launch the required number of processes @@ -385,6 +392,13 @@ def validate(self, model, dataloader, max_batches): output = model(data_batch, batch_i) output = reduce_distributed_output(output, len(self.data_parallel_device_ids)) + elif self.single_gpu: + gpu_id = self.data_parallel_device_ids[0] + for i, x in enumerate(data_batch): + if isinstance(x, torch.Tensor): + data_batch[i] = x.cuda(gpu_id) + output = model.validation_step(data_batch, batch_i) + else: output = model.validation_step(data_batch, batch_i) @@ -463,6 +477,9 @@ def fit(self, model): elif self.use_dp: self.__dp_train(model) + elif self.single_gpu: + self.__single_gpu_train(model) + # ON CPU else: # run through amp wrapper @@ -482,6 +499,24 @@ def fit(self, model): # used for testing or when we need to know that training succeeded return 1 + def __single_gpu_train(self, model): + # CHOOSE OPTIMIZER + # allow for lr schedulers as well + self.optimizers = model.configure_optimizers() + if len(self.optimizers) == 2: + self.optimizers, self.lr_schedulers = self.optimizers + + model.cuda(self.data_parallel_device_ids[0]) + + if self.use_amp: + # An example + model, optimizers = amp.initialize( + model, self.optimizers, opt_level=self.amp_level, + ) + self.optimizers = optimizers + + self.__run_pretrain_routine(model) + def __dp_train(self, model): # CHOOSE OPTIMIZER @@ -814,6 +849,13 @@ def __run_tng_batch(self, data_batch, batch_nb): elif self.use_dp: output = self.model(data_batch, batch_nb) output = reduce_distributed_output(output, len(self.data_parallel_device_ids)) + elif self.single_gpu: + gpu_id = self.data_parallel_device_ids[0] + for i, x in enumerate(data_batch): + if isinstance(x, torch.Tensor): + data_batch[i] = x.cuda(gpu_id) + output = self.model.training_step(data_batch, batch_nb) + else: output = self.model.training_step(data_batch, batch_nb) diff --git a/tests/test_models.py b/tests/test_models.py index ce1697a91e185..eeab97f00d59a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -27,6 +27,34 @@ # TESTS # ------------------------------------------------------------------------ +def test_amp_single_gpu(): + """ + Make sure DDP + AMP work + :return: + """ + if not torch.cuda.is_available(): + warnings.warn('test_amp_gpu_ddp cannot run.' + 'Rerun on a GPU node to run this test') + return + if not torch.cuda.device_count() > 1: + warnings.warn('test_amp_gpu_ddp cannot run.' + 'Rerun on a node with 2+ GPUs to run this test') + return + + hparams = get_hparams() + model = LightningTestModel(hparams) + + trainer_options = dict( + progress_bar=True, + max_nb_epochs=1, + gpus=[0], + distributed_backend='dp', + use_amp=True + ) + + run_gpu_model_test(trainer_options, model, hparams) + + def test_cpu_restore_training(): """ Verify continue training session on CPU