Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

from_argparse_args on child classes doesn't pass around the correct arguments #12974

Closed
agarwalmanvi opened this issue May 4, 2022 · 7 comments
Labels
argparse (removed) Related to argument parsing (argparse, Hydra, ...) bug Something isn't working
Milestone

Comments

@agarwalmanvi
Copy link

agarwalmanvi commented May 4, 2022

馃悰 Bug

For my research I made a hierarchy of classes where the parent BaseDataModule inherits from pl.LightningDataModule and all other child classes inherit from BaseDataModule. Initializing a child class using from_argparse_args takes the default values for attributes inherited from BaseDataModule and ignores values passed from the command line.

To Reproduce

In main.py:

import pytorch_lightning as pl
from argparse import ArgumentParser

class BaseDataModule(pl.LightningDataModule):
    def __init__(
            self,
            batch_size: int = 4,
    ):
        super(BaseDataModule, self).__init__()
        self.batch_size = batch_size
        # use batch_size later for dataloaders

class SpecialDataModule(BaseDataModule):

    def __init__(
            self,
            some_arg: int = 1,
            **kwargs
    ):
        super(SpecialDataModule, self).__init__(**kwargs)
        self.some_arg = some_arg

if __name__ == "__main__":
    parser = ArgumentParser()

    # for BaseDataModule
    parser.add_argument("--batch_size", type=int)
    # for SpecialDataModule
    parser.add_argument("--some_arg", type=int)

    args = parser.parse_args()

    print(args.batch_size)
    dm = SpecialDataModule.from_argparse_args(args)
    print(dm.batch_size)

Run using: python main.py --batch_size=10 --some_arg=20

Expected behavior

My datamodule should show a batch_size of 10 as passed in the command line arguments but takes the default value 4.

Environment

  • PyTorch Lightning Version (e.g., 1.5.0): 1.5.10
  • PyTorch Version (e.g., 1.10): 1.10.2+cu102
  • Python version (e.g., 3.9): 3.6.9
  • OS (e.g., Linux): Ubuntu 18.04
  • CUDA/cuDNN version: 11.4

Dependencies as given in my project's pyproject.toml:

[tool.poetry.dependencies]
python = "^3.6.2"
pytorch-lightning = "^1.5.9"
jupyterlab = "^3.2.8"
numpy = "<1.20.0"
sklearn = "^0.0"
poetry2setup = "^1.0.0"
nose = "^1.3.7"
Markdown = "3.3.4"
mlxtend = "^0.19.0"
wandb = "^0.12.10"
plotly = "^5.6.0"
differint = "^0.3.2"

[tool.poetry.dev-dependencies]
pytest = "^5.2"

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
@agarwalmanvi agarwalmanvi added the needs triage Waiting to be triaged by maintainers label May 4, 2022
@carmocca carmocca added bug Something isn't working argparse (removed) Related to argument parsing (argparse, Hydra, ...) and removed needs triage Waiting to be triaged by maintainers labels May 4, 2022
@carmocca carmocca added this to the 1.6.x milestone May 4, 2022
@plutasnyy
Copy link

I will have a look at it and attempt to fix it ;)

@plutasnyy
Copy link

plutasnyy commented Jun 3, 2022

I looked at it, and it looks like it filters out arguments that don't belong to the selected class (parents are not considered) - so that only class-specific ones get called.

https://github.com/PyTorchLightning/pytorch-lightning/blob/7938293cd9bb27e5ad97aabc7261fb29027dc6b3/pytorch_lightning/utilities/argparse.py#L67-L70

It seems to me that this filtering is necessary and we cannot do without it, at the same time the use of lightning as described in the issue seems reasonable and should be allowed.

Possible solutions I see

  1. If we want to support more complex cases (e.g. two levels of inheritance) we need to analyse all the 'available' arguments by going up the inheritance tree recursively.
  2. Don't change anything - just give user warning if cls during filtering is not a direct child of LightningDataModule or LightningModule - this solution has a lot of cons (requires updates if something changes in other modules, users are not able to do their architecture flexible itp.)
  3. Move logic of keeping allowed parameters list inside classes. For example, DataModule and the others modules will have some internal method that sets self.allowed_params += inspect.signature(self.__class__.__init__).parameters and we somehow force child class to run this method too, then in argparser we simply use:
 trainer_kwargs = {name: params[name] for name in valid_kwargs if name in cls.allowed_params} 
 trainer_kwargs.update(**kwargs) 
  1. Remove filtering? But this will require from users to add **kwargs, and will be probably painful breaking change...
class BaseDataModule(pl.LightningDataModule):
    def __init__(
            self,
            batch_size: int = 4,
            **kwargs
    ):

I don't know, each of the solutions has downsides, maybe someone knows how to find all parameters that can be passed to function? I did not find something like that. @carmocca what do you think?

@plutasnyy
Copy link

plutasnyy commented Jun 3, 2022

FYI @agarwalmanvi the code will work if you simply get rid of **kwargs, but you have to do that in each SpecialDataModule

class SpecialDataModule(BaseDataModule):
    def __init__(
            self,
            some_arg: int = 1,
            batch_size: int = 4
    ):
        super(SpecialDataModule, self).__init__(batch_size=batch_size)

@carmocca
Copy link
Contributor

carmocca commented Jun 6, 2022

Hi! This is a problem that also impacts the CLI (see #11653).
The ideal solution would be (1) but it's complex to implement where an easy workaround already exists (unrolling **kwargs as already described). cc @mauvilsa in case you have thoughts on this.

I strongly suggest updating to use the LightningCLI over the old argparse logic if you can. Your example will work well with it:

import pytorch_lightning as pl

from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.utilities.cli import LightningCLI


class BaseDataModule(pl.LightningDataModule):
    def __init__(self, batch_size: int = 4):
        super(BaseDataModule, self).__init__()
        self.batch_size = batch_size
        # use batch_size later for dataloaders


class SpecialDataModule(BaseDataModule):
    def __init__(self, some_arg: int = 1, **kwargs):
        super(SpecialDataModule, self).__init__(**kwargs)
        self.some_arg = some_arg


cli = LightningCLI(BoringModel, SpecialDataModule, run=False)
print(cli.datamodule.batch_size)
python main.py --data.batch_size=10 --data.some_arg=20

@plutasnyy
Copy link

Great, so we are waiting for @mauvilsa ;)

@mauvilsa
Copy link
Contributor

mauvilsa commented Jun 7, 2022

I would suggest the same as @carmocca, to use LightningCLI instead of from_argparse_args. The example above already works with LightningCLI because it assumes that **kwargs are forwarded with super().__init__ and thus finds the parameters of all parent classes. And I am working on improving this support so that the search for parameters is not just based on assumptions but based on the code itself.

The same support could be implemented for from_argparse_args, however, is this effort justified? LightningCLI does the same as from_argparse_args requiring less code to implement and providing many more features out of the box. Keeping from_argparse_args makes sense so that old code keeps working. But when implementing new things I don't know any good reason to not use LightningCLI.

@plutasnyy
Copy link

So to sum up, we can mark this as "won't do", can't we?

@carmocca carmocca closed this as not planned Won't fix, can't repro, duplicate, stale Jun 20, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
argparse (removed) Related to argument parsing (argparse, Hydra, ...) bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants