-
Notifications
You must be signed in to change notification settings - Fork 3.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
from_argparse_args on child classes doesn't pass around the correct arguments #12974
Comments
I will have a look at it and attempt to fix it ;) |
I looked at it, and it looks like it filters out arguments that don't belong to the selected class (parents are not considered) - so that only class-specific ones get called. It seems to me that this filtering is necessary and we cannot do without it, at the same time the use of lightning as described in the issue seems reasonable and should be allowed. Possible solutions I see
trainer_kwargs = {name: params[name] for name in valid_kwargs if name in cls.allowed_params}
trainer_kwargs.update(**kwargs)
class BaseDataModule(pl.LightningDataModule):
def __init__(
self,
batch_size: int = 4,
**kwargs
): I don't know, each of the solutions has downsides, maybe someone knows how to find all parameters that can be passed to function? I did not find something like that. @carmocca what do you think? |
FYI @agarwalmanvi the code will work if you simply get rid of class SpecialDataModule(BaseDataModule):
def __init__(
self,
some_arg: int = 1,
batch_size: int = 4
):
super(SpecialDataModule, self).__init__(batch_size=batch_size) |
Hi! This is a problem that also impacts the CLI (see #11653). I strongly suggest updating to use the import pytorch_lightning as pl
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.utilities.cli import LightningCLI
class BaseDataModule(pl.LightningDataModule):
def __init__(self, batch_size: int = 4):
super(BaseDataModule, self).__init__()
self.batch_size = batch_size
# use batch_size later for dataloaders
class SpecialDataModule(BaseDataModule):
def __init__(self, some_arg: int = 1, **kwargs):
super(SpecialDataModule, self).__init__(**kwargs)
self.some_arg = some_arg
cli = LightningCLI(BoringModel, SpecialDataModule, run=False)
print(cli.datamodule.batch_size) python main.py --data.batch_size=10 --data.some_arg=20 |
Great, so we are waiting for @mauvilsa ;) |
I would suggest the same as @carmocca, to use The same support could be implemented for |
So to sum up, we can mark this as "won't do", can't we? |
馃悰 Bug
For my research I made a hierarchy of classes where the parent
BaseDataModule
inherits frompl.LightningDataModule
and all other child classes inherit fromBaseDataModule
. Initializing a child class usingfrom_argparse_args
takes the default values for attributes inherited fromBaseDataModule
and ignores values passed from the command line.To Reproduce
In
main.py
:Run using:
python main.py --batch_size=10 --some_arg=20
Expected behavior
My datamodule should show a
batch_size
of10
as passed in the command line arguments but takes the default value4
.Environment
Dependencies as given in my project's
pyproject.toml
:The text was updated successfully, but these errors were encountered: