Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LightningCLI dataclass support broken #9207

Closed
leezu opened this issue Aug 30, 2021 · 4 comments
Closed

LightningCLI dataclass support broken #9207

leezu opened this issue Aug 30, 2021 · 4 comments
Assignees
Labels
3rd party Related to a 3rd-party argparse (removed) Related to argument parsing (argparse, Hydra, ...) bug Something isn't working

Comments

@leezu
Copy link
Contributor

leezu commented Aug 30, 2021

🐛 Bug

Usage of simple dataclasses with LightningCLI broke with jsonargparse 3.14+ which is required in recent version of pytorch-lightning. I'm unable to reproduce the issue with jsonargparse directly (but haven't tried extensively) thus report it here.

@dataclass
class Config:
    name: str


class Module(pl.LightningModule):
    def __init__(self, *, cfg: Config):
        super().__init__()

[...]

if __name__ == "__main__":
     LightningCLI(Module)

Triggers

  File "test.py", line 41, in <module>                                                           
    print(LightningCLI(Module))                               
  File ".local/lib/python3.8/site-packages/pytorch_lightning/utilities/cli.py", line 285, in __init__           
    self.instantiate_classes()                                                                                               
  File ".local/lib/python3.8/site-packages/pytorch_lightning/utilities/cli.py", line 338, in inst antiate_classes
    self.config_init = self.parser.instantiate_classes(self.config)                                                          
  File ".local/lib/python3.8/site-packages/jsonargparse/core.py", line 1126, in instantiate_classes         
    component.instantiate_class(component, cfg)
  File ".local/lib/python3.8/site-packages/jsonargparse/signatures.py", line 550, in group_instantiate_class
    parent[key] = group.group_class(**value)
TypeError: type object argument after ** must be a mapping, not Config

To Reproduce

from dataclasses import dataclass
import torch
from torch.utils.data import DataLoader, Dataset

import pytorch_lightning as pl
from pytorch_lightning.utilities.cli import LightningCLI


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


@dataclass
class Config:
    name: str


class Module(pl.LightningModule):
    def __init__(self, *, cfg: Config):
        super().__init__()

    def training_step(self, *args):
        pass

    def train_dataloader(self, *args):
        return DataLoader(RandomDataset(32, 64), batch_size=2)

    def configure_optimizers(self, *args):
        pass


if __name__ == "__main__":
    print(LightningCLI(Module))

Expected behavior

No crash.

Fix

diff --git a/jsonargparse/signatures.py b/jsonargparse/signatures.py
index ba936c6..d3d0c3e 100644
--- a/jsonargparse/signatures.py
+++ b/jsonargparse/signatures.py
@@ -3,6 +3,7 @@
 import inspect
 import re
 from argparse import Namespace
+import dataclasses
 from functools import wraps
 from typing import Any, Callable, List, Optional, Set, Tuple, Type, Union
 
@@ -547,6 +549,8 @@ def group_instantiate_class(group, cfg):
         parent = cfg
         key = group.dest
         assert '.' not in key
+    if dataclasses.is_dataclass(value):
+        value = dataclasses.asdict(value)
     parent[key] = group.group_class(**value)

@mauvilsa is above patch the correct fix? Or is this issue due to pytorch-lightning incorrectly calling jsonargparse?

@leezu leezu added bug Something isn't working help wanted Open to be worked on labels Aug 30, 2021
@ananthsub ananthsub added the argparse (removed) Related to argument parsing (argparse, Hydra, ...) label Aug 30, 2021
@carmocca
Copy link
Member

carmocca commented Aug 30, 2021

@leezu Can you report your exact jsonargparse and Lightning version? I cannot reproduce the failure locally using latest on both.

jsonargparse==3.19.0
lightning==master

Running your repro code gets me:

error: Configuration check failed :: Key "model.cfg.name" is required but not included in config object or its value is None.

Which is totally expected since it's not passing any default cfg value

@carmocca carmocca added 3rd party Related to a 3rd-party waiting on author Waiting on user action, correction, or update and removed help wanted Open to be worked on labels Aug 30, 2021
@mauvilsa
Copy link
Contributor

I see what the issue is. Before only subclasses were instantiated and then the instantiation of classes was added. However, since a dataclass is also a class and a group in the parser, it is trying to instantiate it. Even though the suggested fix would work, the proper fix would be that dataclasses are not considered for instantiation by instantiate_classes. I will look into it.

@leezu
Copy link
Contributor Author

leezu commented Sep 1, 2021

@carmocca sorry for missing the script invocation instructions:

python3 ~/test.py --model.cfg.name test --print_config > config
python3 ~/test.py --config config

@carmocca carmocca removed the waiting on author Waiting on user action, correction, or update label Sep 1, 2021
@mauvilsa
Copy link
Contributor

mauvilsa commented Sep 3, 2021

This is fixed in jsonargparse v3.19.1.

@carmocca carmocca closed this as completed Sep 3, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3rd party Related to a 3rd-party argparse (removed) Related to argument parsing (argparse, Hydra, ...) bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants