From afa4548b12ada2d69cd3452d902789efa2e8e165 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 7 Aug 2019 13:39:40 -0400 Subject: [PATCH 1/8] added single gpu train --- pytorch_lightning/models/trainer.py | 30 ++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 2dbf46eb14b12..b0ce3d1a199b6 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -105,7 +105,7 @@ def __init__(self, :param log_save_interval: :param add_log_row_interval: :param distributed_backend: - 'np' to use DistributedParallel, 'dp' to use DistributedDataParallel + 'do' to use DistributedParallel, 'dp' to use DistributedDataParallel, 'n' to use none :param use_amp: :param print_nan_grads: :param print_weights_summary: @@ -147,6 +147,7 @@ def __init__(self, self.node_rank = 0 self.use_ddp = False self.use_dp = False + self.single_gpu = False # training bookeeping self.total_batch_nb = 0 @@ -194,6 +195,12 @@ def __init__(self, 'To silence this warning set distributed_backend=ddp' warnings.warn(w) + # remove dp and ddp when requesting single gpu + if len(self.data_parallel_device_ids) == 1: + self.use_ddp = False + self.use_dp = False + self.single_gpu = True + # extract SLURM flag vars # whenever we have the correct number of tasks, we let slurm manage processes # otherwise we launch the required number of processes @@ -463,6 +470,9 @@ def fit(self, model): elif self.use_dp: self.__dp_train(model) + elif self.single_gpu: + self.__single_gpu_train(model)\ + # ON CPU else: # run through amp wrapper @@ -482,6 +492,24 @@ def fit(self, model): # used for testing or when we need to know that training succeeded return 1 + def __single_gpu_train(self, model): + # CHOOSE OPTIMIZER + # allow for lr schedulers as well + self.optimizers = model.configure_optimizers() + if len(self.optimizers) == 2: + self.optimizers, self.lr_schedulers = self.optimizers + + model.cuda(self.data_parallel_device_ids[0]) + + if self.use_amp: + # An example + model, optimizers = amp.initialize( + model, self.optimizers, opt_level=self.amp_level, + ) + self.optimizers = optimizers + + self.__run_pretrain_routine(model) + def __dp_train(self, model): # CHOOSE OPTIMIZER From c7e843608302f77ce0edd827f2420336cb0f8fcf Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 7 Aug 2019 13:40:51 -0400 Subject: [PATCH 2/8] added single gpu train test --- tests/test_models.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/test_models.py b/tests/test_models.py index ce1697a91e185..eeab97f00d59a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -27,6 +27,34 @@ # TESTS # ------------------------------------------------------------------------ +def test_amp_single_gpu(): + """ + Make sure DDP + AMP work + :return: + """ + if not torch.cuda.is_available(): + warnings.warn('test_amp_gpu_ddp cannot run.' + 'Rerun on a GPU node to run this test') + return + if not torch.cuda.device_count() > 1: + warnings.warn('test_amp_gpu_ddp cannot run.' + 'Rerun on a node with 2+ GPUs to run this test') + return + + hparams = get_hparams() + model = LightningTestModel(hparams) + + trainer_options = dict( + progress_bar=True, + max_nb_epochs=1, + gpus=[0], + distributed_backend='dp', + use_amp=True + ) + + run_gpu_model_test(trainer_options, model, hparams) + + def test_cpu_restore_training(): """ Verify continue training session on CPU From e30514922b5236f9e61985b19cae239b27eb34f1 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 7 Aug 2019 13:43:28 -0400 Subject: [PATCH 3/8] added single gpu train test --- pytorch_lightning/models/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index b0ce3d1a199b6..69c7a3e878b25 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -196,7 +196,7 @@ def __init__(self, warnings.warn(w) # remove dp and ddp when requesting single gpu - if len(self.data_parallel_device_ids) == 1: + if self.data_parallel_device_ids is not None and len(self.data_parallel_device_ids) == 1: self.use_ddp = False self.use_dp = False self.single_gpu = True From 9ecb1f2aee849fa16b958fcf06a827936c6b73b2 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 7 Aug 2019 13:46:06 -0400 Subject: [PATCH 4/8] added single gpu train test --- pytorch_lightning/models/trainer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 69c7a3e878b25..55f06d53d94d6 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -392,6 +392,9 @@ def validate(self, model, dataloader, max_batches): output = model(data_batch, batch_i) output = reduce_distributed_output(output, len(self.data_parallel_device_ids)) + elif self.single_gpu: + output = model(data_batch.cuda(self.data_parallel_device_ids[0]), batch_i) + else: output = model.validation_step(data_batch, batch_i) @@ -842,6 +845,8 @@ def __run_tng_batch(self, data_batch, batch_nb): elif self.use_dp: output = self.model(data_batch, batch_nb) output = reduce_distributed_output(output, len(self.data_parallel_device_ids)) + elif self.single_gpu: + output = self.model(data_batch.cuda(self.data_parallel_device_ids[0]), batch_nb) else: output = self.model.training_step(data_batch, batch_nb) From 2f7a9ad40d2199ea7371bfc53e9a785c50952d14 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 7 Aug 2019 13:49:01 -0400 Subject: [PATCH 5/8] added single gpu train test --- pytorch_lightning/models/trainer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 55f06d53d94d6..91eb7cc89f022 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -393,7 +393,9 @@ def validate(self, model, dataloader, max_batches): output = reduce_distributed_output(output, len(self.data_parallel_device_ids)) elif self.single_gpu: - output = model(data_batch.cuda(self.data_parallel_device_ids[0]), batch_i) + gpu_id = self.data_parallel_device_ids[0] + data_batch = [x.cuda(gpu_id) for x in data_batch if isinstance(x, torch.Tensor)] + output = model(data_batch, batch_i) else: output = model.validation_step(data_batch, batch_i) @@ -474,7 +476,7 @@ def fit(self, model): self.__dp_train(model) elif self.single_gpu: - self.__single_gpu_train(model)\ + self.__single_gpu_train(model) # ON CPU else: @@ -846,7 +848,10 @@ def __run_tng_batch(self, data_batch, batch_nb): output = self.model(data_batch, batch_nb) output = reduce_distributed_output(output, len(self.data_parallel_device_ids)) elif self.single_gpu: + gpu_id = self.data_parallel_device_ids[0] + data_batch = [x.cuda(gpu_id) for x in data_batch if isinstance(x, torch.Tensor)] output = self.model(data_batch.cuda(self.data_parallel_device_ids[0]), batch_nb) + else: output = self.model.training_step(data_batch, batch_nb) From 56f16694c4e85894e5e0d8e87c87dbf1112a86ab Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 7 Aug 2019 13:50:12 -0400 Subject: [PATCH 6/8] added single gpu train test --- pytorch_lightning/models/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 91eb7cc89f022..37fd9d8a47612 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -395,7 +395,7 @@ def validate(self, model, dataloader, max_batches): elif self.single_gpu: gpu_id = self.data_parallel_device_ids[0] data_batch = [x.cuda(gpu_id) for x in data_batch if isinstance(x, torch.Tensor)] - output = model(data_batch, batch_i) + output = model.validation_step(data_batch, batch_i) else: output = model.validation_step(data_batch, batch_i) @@ -850,7 +850,7 @@ def __run_tng_batch(self, data_batch, batch_nb): elif self.single_gpu: gpu_id = self.data_parallel_device_ids[0] data_batch = [x.cuda(gpu_id) for x in data_batch if isinstance(x, torch.Tensor)] - output = self.model(data_batch.cuda(self.data_parallel_device_ids[0]), batch_nb) + output = self.model.training_step(data_batch.cuda(self.data_parallel_device_ids[0]), batch_nb) else: output = self.model.training_step(data_batch, batch_nb) From ab499573a1c96342bb79e693d4abeb65654e4309 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 7 Aug 2019 13:53:44 -0400 Subject: [PATCH 7/8] added single gpu train test --- pytorch_lightning/models/trainer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 37fd9d8a47612..55f013b6e2f88 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -394,7 +394,9 @@ def validate(self, model, dataloader, max_batches): elif self.single_gpu: gpu_id = self.data_parallel_device_ids[0] - data_batch = [x.cuda(gpu_id) for x in data_batch if isinstance(x, torch.Tensor)] + for i, x in enumerate(data_batch): + if isinstance(x, torch.Tensor): + data_batch[i] = x.cuda(gpu_id) output = model.validation_step(data_batch, batch_i) else: @@ -849,7 +851,9 @@ def __run_tng_batch(self, data_batch, batch_nb): output = reduce_distributed_output(output, len(self.data_parallel_device_ids)) elif self.single_gpu: gpu_id = self.data_parallel_device_ids[0] - data_batch = [x.cuda(gpu_id) for x in data_batch if isinstance(x, torch.Tensor)] + for i, x in enumerate(data_batch): + if isinstance(x, torch.Tensor): + data_batch[i] = x.cuda(gpu_id) output = self.model.training_step(data_batch.cuda(self.data_parallel_device_ids[0]), batch_nb) else: From f3dea818f27c764fb2fb6c1f27aeed2ea99b4a79 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 7 Aug 2019 13:54:32 -0400 Subject: [PATCH 8/8] added single gpu train test --- pytorch_lightning/models/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 55f013b6e2f88..4076741008ab5 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -854,7 +854,7 @@ def __run_tng_batch(self, data_batch, batch_nb): for i, x in enumerate(data_batch): if isinstance(x, torch.Tensor): data_batch[i] = x.cuda(gpu_id) - output = self.model.training_step(data_batch.cuda(self.data_parallel_device_ids[0]), batch_nb) + output = self.model.training_step(data_batch, batch_nb) else: output = self.model.training_step(data_batch, batch_nb)