Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving Hydra+DDP support #11617

Merged
merged 40 commits into from Sep 22, 2022
Merged

Conversation

jgbos
Copy link
Contributor

@jgbos jgbos commented Jan 25, 2022

What does this PR do?

There are documented issues for Hydra+DDP in #11300 and #2727. This PR attempts to fix these issues by redefining how a process is spawned when Hydra is used.

Fixes #11300

Problem

Current approach is defined here: https://github.com/PyTorchLightning/pytorch-lightning/blob/fe34bf2a653ebd50e6a3a00be829e3611f820c3c/pytorch_lightning/strategies/ddp.py#L233

This PR addresses the issue of running with Hydra multirun. For example, lets say we have the following Hydra app:

@hydra.main(config_path=None, config_name="myconfig", version_base="1.1")
def task_fn(cfg):
    trainer = Trainer(gpus=2, strategy="ddp")
    model = BoringModel()
    trainer.fit(model)

if __name__ == "__main__":
    task_fn()

We can execute a multirun job with this app using the following command:

python script.py foo=1,2 --multirun

This command will attempt to launch 2 jobs sequentially: one with foo=1 and one with foo=2. For the first job, foo=1, PL launchers a normal job that begins execution while the second job is spawned with a subprocess using the following command derived from sys.argv:

python script.py foo=1,2 --multirun hydra.run.dir=<os_cwd> hydra.job.name=train_ddp_process_<local_rank>

This will spawn a new mutlirun job instead of running a normal job with foo=2. The command should be

python script.py foo=2 hydra.run.dir=<os_cwd>

Solution

Every Hydra process will save the reproducible configuration of the job in a config.yaml file located in the hydra.output_subdir experiment directory. Using Hydra's CLI, we can execute the app with same configuration as the current experiment by defining --config_path, -cp and --config_name, -cn:

python script.py -cp <path to hydra output dir> -cn config.yaml hydra.run.dir=<os_cwd> 

Here config.yaml contains the value for foo appropriate for the current multirun job. I've outlined the support for multirun, but this should support any Hydra application launched from the command line.

Lingering Issue Not Solved

In order to run multirun on a local machine the user must add additional code to their task function before launching the next job. This code, shown below, will destroy all distributed processes and remove PL related environment variables. Without this code the multirun job will hang after the first job.

@hydra.main(config_path=None, config_name="myconfig", version_base="1.1")
def task_fn(cfg):
    trainer = Trainer(gpus=2, strategy="ddp")
    model = BoringModel()
    trainer.fit(model)
    trainer.test(model)

    # Need to do this in addition to Lightning shutting down the
    # distributed processes in order to run a multirun loop with hydra
    if dist.is_initialized():
        dist.destroy_process_group()

    os.environ.pop("LOCAL_RANK", None)

Thoughts

This solution should help most people, but IMHO the current method of spawning using sys.argv is not robust. It would be nice to be able to execute a PL application similar to how I implemented the Hydra solution. This would require dynamically creating a configurations for the application. The use of save_hyperparameters is nice but is only done inside the creation of a LightningModule, it would be nice to create a description of the model at a higher level — outside of the module. By defining the description outside of the module, one could use a similar approach as the Hydra solution above for PL applications. I would recommend taking a look at hydra-zen (FYI, I'm a contributor).

Does your PR introduce any breaking changes? If yes, please list them.

None

Before submitting

Personal TODO:

  • Add test checking processes use the correct parameters when spawned

PL TODO:

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

pytorch_lightning/strategies/ddp.py Outdated Show resolved Hide resolved
tests/strategies/test_ddp_hydra_support.py Outdated Show resolved Hide resolved
tests/strategies/test_ddp_hydra_support.py Outdated Show resolved Hide resolved
@akihironitta akihironitta added the argparse (removed) Related to argument parsing (argparse, Hydra, ...) label Jan 30, 2022
Copy link
Contributor

@tchaton tchaton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks really neat! Mind moving the hydra-specific logic to its own utility file?

pytorch_lightning/strategies/ddp.py Outdated Show resolved Hide resolved
pytorch_lightning/strategies/ddp.py Outdated Show resolved Hide resolved
@jgbos
Copy link
Contributor Author

jgbos commented Jan 31, 2022

@tchaton I'm working through some issues with teardown. It appears there are issues with the process hanging on torch.distributed.destroy_process_group(), but if I do it in my task function outside of trainer.fit I do not have any issues. Is there a different spot to put the code currently in teardown?

pytorch_lightning/utilities/hydra.py Outdated Show resolved Hide resolved
pytorch_lightning/utilities/hydra.py Outdated Show resolved Hide resolved
pytorch_lightning/utilities/hydra.py Outdated Show resolved Hide resolved
pytorch_lightning/utilities/hydra.py Outdated Show resolved Hide resolved
pytorch_lightning/utilities/hydra.py Outdated Show resolved Hide resolved
pytorch_lightning/utilities/hydra.py Outdated Show resolved Hide resolved
tests/strategies/test_ddp_hydra_support.py Outdated Show resolved Hide resolved
tests/strategies/test_ddp_hydra_support.py Outdated Show resolved Hide resolved
@carmocca carmocca added the feature Is an improvement or enhancement label Feb 8, 2022
@carmocca carmocca added this to the 1.6 milestone Feb 8, 2022
@awaelchli
Copy link
Member

@rohitgr7 We could potentially have a hydra specific launcher after #11643 is finalized. In that sense, I like the approach of this PR of creating a utility function encapsulating the hydra command.

@carmocca carmocca enabled auto-merge (squash) September 22, 2022 15:59
@carmocca carmocca merged commit 45ca781 into Lightning-AI:master Sep 22, 2022
@carmocca
Copy link
Member

Finally merged! Thank you so much for your time @jgbos

rohitgr7 added a commit that referenced this pull request Sep 24, 2022
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
@carmocca carmocca modified the milestones: pl:future, pl:1.8 Sep 27, 2022
@Liangtaiwan
Copy link

Liangtaiwan commented Oct 22, 2022

@jgbos @awaelchli @carmocca @rohitgr7
It seems this PR cause some issues.

  1. PYTHONPATH issue.
    When I use ddp with 2 devices (gpu), the second process cannot access my own module.
    There is no issue with pl==1.7.7.
    PYTHONPATH=.:$PYTHONPATH #my cmd
    Traceback (most recent call last): File "/home/$username/diskey-scale/diskey/main.py", line 11, in <module> from diskey.conf.config import MainConfig ModuleNotFoundError: No module named 'diskey'

  2. ddp deadlock
    after changing PYTHONPATH to
    PYTHONPATH=$HOME/diskey-scale:$PYTHONPATH
    ModuleNotFoundError is solved.
    However, the following error will be triggered.
    hydra.errors.ConfigCompositionException: 'config.yaml' is validated against ConfigStore schema with the same name. This behavior is deprecated in Hydra 1.1 and will be removed in Hydra 1.2. In addition, the automatically matched schema contains a defaults list. This combination is no longer supported. See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/automatic_schema_matching for migration instructions.

These issues was not happended with hydra-submitit-launcher (with slurm).

@awaelchli
Copy link
Member

@Liangtaiwan Not sure what is going on there. Could you report the exact issue, steps to reproduce, and environment you run in a GitHub issue? Perhaps @jgbos will be able to help too?
Thanks!

@jgbos
Copy link
Contributor Author

jgbos commented Oct 24, 2022

@Liangtaiwan sorry, I would need to know more about the error. Definitely open an issue and I will try to help. Would be good to know Hydra version and how you executed the script. Also, this error should not occur if launching on slurm as each gpu task is launched via the submission script (i.e., submitit launcher) instead of this script.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
argparse (removed) Related to argument parsing (argparse, Hydra, ...) community This PR is from the community feature Is an improvement or enhancement pl Generic label for PyTorch Lightning package ready PRs ready to be merged
Projects
No open projects
Status: Done
Development

Successfully merging this pull request may close these issues.

DDP with Hydra multirun doesn't work when dirpath in checkpoint callback is specified