Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
f04ee90
added load on CPU first
williamFalcon Sep 10, 2019
3765986
added load on CPU first
williamFalcon Sep 10, 2019
76a39f6
added load on CPU first
williamFalcon Sep 10, 2019
1dbc700
added load on CPU first
williamFalcon Sep 10, 2019
ab3c97e
added load on CPU first
williamFalcon Sep 10, 2019
f3bd4ef
added load on CPU first
williamFalcon Sep 10, 2019
deaf833
added load on CPU first
williamFalcon Sep 10, 2019
6e28f49
added load on CPU first
williamFalcon Sep 10, 2019
93a601d
added load on CPU first
williamFalcon Sep 10, 2019
d6b9ebe
added load on CPU first
williamFalcon Sep 10, 2019
4ba419a
added load on CPU first
williamFalcon Sep 10, 2019
587f4d7
added load on CPU first
williamFalcon Sep 10, 2019
90971e9
added load on CPU first
williamFalcon Sep 10, 2019
04b2d3b
added load on CPU first
williamFalcon Sep 10, 2019
5f98e58
added load on CPU first
williamFalcon Sep 10, 2019
f341014
added load on CPU first
williamFalcon Sep 10, 2019
b0ea811
added load on CPU first
williamFalcon Sep 10, 2019
addef00
added load on CPU first
williamFalcon Sep 10, 2019
390295a
added load on CPU first
williamFalcon Sep 10, 2019
0f8f266
added load on CPU first
williamFalcon Sep 10, 2019
3d5c6d7
added load on CPU first
williamFalcon Sep 10, 2019
a5d842a
added load on CPU first
williamFalcon Sep 10, 2019
afa42d2
added load on CPU first
williamFalcon Sep 10, 2019
3b2a39b
added load on CPU first
williamFalcon Sep 10, 2019
059247d
added load on CPU first
williamFalcon Sep 10, 2019
b73940b
added load on CPU first
williamFalcon Sep 10, 2019
5c41ed4
added load on CPU first
williamFalcon Sep 10, 2019
2d31ac1
added print logs
williamFalcon Sep 10, 2019
653a88b
added print logs
williamFalcon Sep 10, 2019
7bbea9c
changed close order
williamFalcon Sep 10, 2019
08da62e
changed close order
williamFalcon Sep 10, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def __init__(self,

# allow int, string and gpu list
self.data_parallel_device_ids = self.__parse_gpu_ids(gpus)
self.root_gpu = self.__set_root_gpu(self.data_parallel_device_ids)

# distributed backend choice
self.use_ddp = False
Expand Down Expand Up @@ -270,6 +271,17 @@ def __parse_gpu_ids(self, gpus):

return gpus

def __set_root_gpu(self, gpus):
if gpus is None:
return None

# set root gpu
root_gpu = 0
if type(gpus) is list:
root_gpu = gpus[0]

return root_gpu

@property
def num_gpus(self):
gpus = self.data_parallel_device_ids
Expand Down Expand Up @@ -701,10 +713,7 @@ def __single_gpu_train(self, model):
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())

root_gpu = 0
if type(self.data_parallel_device_ids) is list:
root_gpu = self.data_parallel_device_ids[0]
model.cuda(root_gpu)
model.cuda(self.root_gpu)

if self.use_amp:
# An example
Expand All @@ -721,10 +730,7 @@ def __dp_train(self, model):
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())

root_gpu = 0
if type(self.data_parallel_device_ids) is list:
root_gpu = self.data_parallel_device_ids[0]
model.cuda(root_gpu)
model.cuda(self.root_gpu)

# check for this bug (amp + dp + !01 doesn't work)
# https://github.com/NVIDIA/apex/issues/227
Expand All @@ -736,7 +742,12 @@ def __dp_train(self, model):
"""
raise MisconfigurationException(m)

model = LightningDataParallel(model, device_ids=self.data_parallel_device_ids)
# create list of device ids
device_ids = self.data_parallel_device_ids
if type(device_ids) is int:
device_ids = list(range(device_ids))

model = LightningDataParallel(model, device_ids=device_ids)

self.__run_pretrain_routine(model)

Expand Down Expand Up @@ -787,6 +798,9 @@ def ddp_train(self, gpu_nb, model):
torch.cuda.set_device(gpu_nb)
model.cuda(gpu_nb)

# override root GPU
self.root_gpu = gpu_nb

# AMP
# run through amp wrapper before going to distributed DP
if self.use_amp:
Expand Down
51 changes: 33 additions & 18 deletions pytorch_lightning/trainer/trainer_io.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import re
import signal
import pdb
from subprocess import call

import torch
Expand Down Expand Up @@ -78,7 +79,7 @@ def register_slurm_signal_handlers(self):
except Exception as e:
pass

if on_slurm and self.proc_rank == 0:
if on_slurm:
print('set slurm handle signals')
signal.signal(signal.SIGUSR1, self.sig_handler)
signal.signal(signal.SIGTERM, self.term_handler)
Expand All @@ -103,6 +104,9 @@ def sig_handler(self, signum, frame):
else:
print('requeue failed...')

# close experiment to avoid issues
self.experiment.close()

def term_handler(self, signum, frame):
# save
print("bypassing sigterm")
Expand All @@ -118,19 +122,22 @@ def save_checkpoint(self, filepath):

def restore(self, checkpoint_path, on_gpu):

if on_gpu:
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)

# load training state (affects trainer only)
self.restore_training_state(checkpoint)
# if on_gpu:
# checkpoint = torch.load(checkpoint_path)
# else:
# load on CPU first
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)

# load model state
model = self.__get_model()

# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'])
if on_gpu:
model.cuda(self.root_gpu)

# load training state (affects trainer only)
self.restore_training_state(checkpoint)

def dump_checkpoint(self):

Expand Down Expand Up @@ -210,6 +217,14 @@ def restore_training_state(self, checkpoint):
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
optimizer.load_state_dict(opt_state)

# move optimizer to GPU 1 weight at a time
# avoids OOM
if self.root_gpu is not None:
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.cuda(self.root_gpu)

# restore the lr schedulers
lr_schedulers = checkpoint['lr_schedulers']
for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers):
Expand All @@ -225,9 +240,6 @@ def hpc_save(self, folderpath, experiment):
# save exp to make sure we get all the metrics
experiment.save()

# close experiment to avoid issues
experiment.close()

ckpt_number = self.max_ckpt_in_folder(folderpath) + 1

if not os.path.exists(folderpath):
Expand All @@ -248,23 +260,26 @@ def hpc_save(self, folderpath, experiment):
def hpc_load(self, folderpath, on_gpu):
filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, self.max_ckpt_in_folder(folderpath))

if on_gpu:
checkpoint = torch.load(filepath)
else:
checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage)

# load training state (affects trainer only)
self.restore_training_state(checkpoint)
# load on CPU first
checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage)

# load model state
model = self.__get_model()

# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'])

if self.root_gpu is not None:
model.cuda(self.root_gpu)

# load training state (affects trainer only)
self.restore_training_state(checkpoint)

# call model hook
model.on_hpc_load(checkpoint)

print(f'restored hpc model from: {filepath}')

def max_ckpt_in_folder(self, path, name_key='ckpt_'):
files = os.listdir(path)
files = [x for x in files if name_key in x]
Expand Down
72 changes: 55 additions & 17 deletions tests/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,20 @@ def get_hparams(continue_training=False, hpc_exp_number=0):


def main():
"""Verify test() on fitted model"""
"""
Make sure DDP + AMP continue training correctly
:return:
"""
hparams = get_hparams()
model = LightningTestModel(hparams)

trainer_options = dict(
show_progress_bar=True,
max_nb_epochs=4,
gpus=2,
distributed_backend='dp',
)

save_dir = init_save_dir()

# exp file to get meta
Expand All @@ -228,31 +238,59 @@ def main():
# exp file to get weights
checkpoint = ModelCheckpoint(save_dir)

trainer_options = dict(
show_progress_bar=False,
max_nb_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
checkpoint_callback=checkpoint,
experiment=exp,
gpus=[0, 1],
distributed_backend='ddp'
)
# add these to the trainer options
trainer_options['experiment'] = exp
trainer_options['checkpoint_callback'] = checkpoint

# fit model
trainer = Trainer(**trainer_options)
trainer.is_slurm_managing_tasks = True
result = trainer.fit(model)

# track epoch before saving
real_global_epoch = trainer.current_epoch

# correct result and ok accuracy
assert result == 1, 'training failed to complete'
pretrained_model = load_model(exp, save_dir, on_gpu=True, module_class=LightningTestModel)
assert result == 1, 'amp + dp model failed to complete'

# ---------------------------
# HPC LOAD/SAVE
# ---------------------------
# save
trainer.hpc_save(save_dir, exp)

# init new trainer
new_exp = get_exp(False, version=exp.version)
trainer_options['experiment'] = new_exp
trainer_options['checkpoint_callback'] = ModelCheckpoint(save_dir)
trainer_options['train_percent_check'] = 0.2
trainer_options['val_percent_check'] = 0.2
trainer_options['max_nb_epochs'] = 1
new_trainer = Trainer(**trainer_options)
new_trainer.test(pretrained_model)

# test we have good test accuracy
assert_ok_test_acc(new_trainer)
# clear_save_dir()
# set the epoch start hook so we can predict before the model does the full training
def assert_good_acc():
assert trainer.current_epoch == real_global_epoch and trainer.current_epoch > 0

# if model and state loaded correctly, predictions will be good even though we
# haven't trained with the new loaded model
dp_model = new_trainer.model
dp_model.eval()

_ = [run_prediction(dataloader, dp_model, dp=True) for dataloader in trainer.val_dataloader]

# new model
model = LightningTestModel(hparams)
model.on_sanity_check_start = assert_good_acc

# fit new model which should load hpc weights
new_trainer.fit(model)

# test freeze on gpu
model.freeze()
model.unfreeze()

clear_save_dir()


if __name__ == '__main__':
Expand Down
Loading