Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion pytorch_lightning/models/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(self,
:param log_save_interval:
:param add_log_row_interval:
:param distributed_backend:
'np' to use DistributedParallel, 'dp' to use DistributedDataParallel
'do' to use DistributedParallel, 'dp' to use DistributedDataParallel, 'n' to use none
:param use_amp:
:param print_nan_grads:
:param print_weights_summary:
Expand Down Expand Up @@ -147,6 +147,7 @@ def __init__(self,
self.node_rank = 0
self.use_ddp = False
self.use_dp = False
self.single_gpu = False

# training bookeeping
self.total_batch_nb = 0
Expand Down Expand Up @@ -194,6 +195,12 @@ def __init__(self,
'To silence this warning set distributed_backend=ddp'
warnings.warn(w)

# remove dp and ddp when requesting single gpu
if self.data_parallel_device_ids is not None and len(self.data_parallel_device_ids) == 1:
self.use_ddp = False
self.use_dp = False
self.single_gpu = True

# extract SLURM flag vars
# whenever we have the correct number of tasks, we let slurm manage processes
# otherwise we launch the required number of processes
Expand Down Expand Up @@ -385,6 +392,13 @@ def validate(self, model, dataloader, max_batches):
output = model(data_batch, batch_i)
output = reduce_distributed_output(output, len(self.data_parallel_device_ids))

elif self.single_gpu:
gpu_id = self.data_parallel_device_ids[0]
for i, x in enumerate(data_batch):
if isinstance(x, torch.Tensor):
data_batch[i] = x.cuda(gpu_id)
output = model.validation_step(data_batch, batch_i)

else:
output = model.validation_step(data_batch, batch_i)

Expand Down Expand Up @@ -463,6 +477,9 @@ def fit(self, model):
elif self.use_dp:
self.__dp_train(model)

elif self.single_gpu:
self.__single_gpu_train(model)

# ON CPU
else:
# run through amp wrapper
Expand All @@ -482,6 +499,24 @@ def fit(self, model):
# used for testing or when we need to know that training succeeded
return 1

def __single_gpu_train(self, model):
# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers = model.configure_optimizers()
if len(self.optimizers) == 2:
self.optimizers, self.lr_schedulers = self.optimizers

model.cuda(self.data_parallel_device_ids[0])

if self.use_amp:
# An example
model, optimizers = amp.initialize(
model, self.optimizers, opt_level=self.amp_level,
)
self.optimizers = optimizers

self.__run_pretrain_routine(model)

def __dp_train(self, model):

# CHOOSE OPTIMIZER
Expand Down Expand Up @@ -814,6 +849,13 @@ def __run_tng_batch(self, data_batch, batch_nb):
elif self.use_dp:
output = self.model(data_batch, batch_nb)
output = reduce_distributed_output(output, len(self.data_parallel_device_ids))
elif self.single_gpu:
gpu_id = self.data_parallel_device_ids[0]
for i, x in enumerate(data_batch):
if isinstance(x, torch.Tensor):
data_batch[i] = x.cuda(gpu_id)
output = self.model.training_step(data_batch, batch_nb)

else:
output = self.model.training_step(data_batch, batch_nb)

Expand Down
28 changes: 28 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,34 @@
# TESTS
# ------------------------------------------------------------------------

def test_amp_single_gpu():
"""
Make sure DDP + AMP work
:return:
"""
if not torch.cuda.is_available():
warnings.warn('test_amp_gpu_ddp cannot run.'
'Rerun on a GPU node to run this test')
return
if not torch.cuda.device_count() > 1:
warnings.warn('test_amp_gpu_ddp cannot run.'
'Rerun on a node with 2+ GPUs to run this test')
return

hparams = get_hparams()
model = LightningTestModel(hparams)

trainer_options = dict(
progress_bar=True,
max_nb_epochs=1,
gpus=[0],
distributed_backend='dp',
use_amp=True
)

run_gpu_model_test(trainer_options, model, hparams)


def test_cpu_restore_training():
"""
Verify continue training session on CPU
Expand Down