Skip to content

hparams.yaml do not write Enum-style arguments properly #8912

@dasayan05

Description

@dasayan05

🐛 Bug

To Reproduce

When using Enums to simulate choices-like parser argument to LightningModule or LightningDataModule

class Options(str, Enum):
    option1 = "option1"
    option2 = "option2"
    option3 = "option3"

class BoringModel(pl.LightningModule):

    def __init__(self,
                 learning_rate: float = 0.0001,
                 switch: Options = Options.option3, # argument of interest
                 batch_size: int = 32):
        super().__init__()
        ...

the hparams.yaml produced by TensorBoardLogger (and maybe others) looks like

learning_rate: 0.0001
switch: !!python/object/apply:__main__.Options
- option3
batch_size: 32

It seems Lightning's internal hyperparameter saving functionality, i.e. save_hparams_to_yaml(..) (used by TensorBoardLogger and others) cannot handle it properly. The drawback is, we cannot load the hparams.yaml in a different context like:

# different file or notebook
>> model = BoringModel.load_from_checkpoint(..., hparams_file='/path/to/above/yaml/hparams.yaml')
...
return super(UnsafeConstructor, self).find_python_name(name, mark, unsafe=True)
  File "<path>/anaconda3/envs/pltest/lib/python3.8/site-packages/yaml/constructor.py", line 560, in find_python_name
    raise ConstructorError("while constructing a Python object", mark,
yaml.constructor.ConstructorError: while constructing a Python object
cannot find 'Options' in the module '__main__'
  in "<path>/lightning_logs/myexp/version_5/hparams.yaml", line 4, column 9

Expected behavior

Since jsonargparse supports such arguments, the config.yaml written by LightningCLI looks all good, which is expected:

model:
  learning_rate: 0.0001
  switch: option3
  batch_size: 32

Additional context

Since save_hparams_to_yaml(..) optionally uses OmegaConf, it can handle it IF OmegaConf is installed. However, OmegaConf is not a required dependency.

Environment

  • PyTorch Lightning Version: master branch
  • PyTorch Version: 1.9.0
  • Python version: 3.8
  • OS: Linux
  • CUDA version: 11.1

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions