Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LightningCLI doesn't parse callback args correctly if more than one args for each callback #15007

Closed
junwang-wish opened this issue Oct 5, 2022 · 4 comments
Labels
3rd party Related to a 3rd-party bug Something isn't working lightningcli pl.cli.LightningCLI

Comments

@junwang-wish
Copy link

Bug description

I'm following Pytorch Lightning 1.7.7 Doc to specify LightningCLI args:

python main_inference.py predict \
        --model=LLM_Inference_Conditional_LM \
        --model.llm_type="conditional_lm" \
        --model.ckpt_path="models/model/version_1/epoch=0-step=29648.ckpt" \
        --model.config_path="models/model/version_1/config.yaml" \
        --model.allowed_gen_sequences="['a', 'b']" \
        --data=LLMData \
        --data.data_source_yaml_path="datasets/data/data.yaml" \
        --data.model_name="t5-base" \
        --trainer.callbacks+=PredictionWriter_Conditional_LM \
        --trainer.callbacks.write_interval="batch" \
        --trainer.callbacks.output_dir="models/model/version_1" \
        --print_config

Ideally, both trainer.callbacks.output_dir and trainer.callbacks.write_interval get passed in to instantiate PredictionWriter_Conditional_LM. However, I get

# pytorch_lightning==1.7.7
seed_everything: true
trainer:
  logger: true
  enable_checkpointing: true
  callbacks:
  - class_path: __main__.PredictionWriter_Conditional_LM
    init_args:
      output_dir: models/model/version_1
      write_interval: null
  default_root_dir: null
  gradient_clip_val: null
  gradient_clip_algorithm: null
  num_nodes: 1
  num_processes: null
  devices: null
  gpus: null
  auto_select_gpus: false
  tpu_cores: null
  ipus: null
  enable_progress_bar: true
  overfit_batches: 0.0
  track_grad_norm: -1
  check_val_every_n_epoch: 1
  fast_dev_run: false
  accumulate_grad_batches: null
  max_epochs: null
  min_epochs: null
  max_steps: -1
  min_steps: null
  max_time: null
  limit_train_batches: null
  limit_val_batches: null
  limit_test_batches: null
  limit_predict_batches: null
  val_check_interval: null
  log_every_n_steps: 50
  accelerator: null
  strategy: null
  sync_batchnorm: false
  precision: 32
  enable_model_summary: true
  weights_save_path: null
  num_sanity_val_steps: 2
  resume_from_checkpoint: null
  profiler: null
  benchmark: null
  deterministic: null
  reload_dataloaders_every_n_epochs: 0
  auto_lr_find: false
  replace_sampler_ddp: true
  detect_anomaly: false
  auto_scale_batch_size: false
  plugins: null
  amp_backend: native
  amp_level: null
  move_metrics_to_cpu: false
  multiple_trainloader_mode: max_size_cycle
return_predictions: null
ckpt_path: null
model:
  class_path: __main__.LLM_Inference_Conditional_LM
  init_args:
    llm_type: conditional_lm
    ckpt_path: models/model/version_1/epoch=0-step=29648.ckpt
    config_path: models/model/version_1/config.yaml
    num_beams: 1
    num_return_sequences: 1
    do_sample: false
    length_penalty: 0.0
    max_new_tokens: 50
    allowed_gen_sequences:
    - a
    - b
data:
  class_path: main_utils.LLMData
  init_args:
    data_source_yaml_path: datasets/data/data.yaml
    model_name: t5-base
    raw_cache_dir: /data/junwang/.cache/general
    batch_size: 16
    overwrite_cache: false
    max_length: 250
    predict_on_test: true
    num_workers: 80
    max_length_out: 100
    cache_dir: null
    force_download: false
    resume_download: false
    proxies: null
    use_auth_token: null
    local_files_only: false
    revision: null
    trust_remote_code: null
    subfolder: ''

Note that only the last used --trainer.callbacks.output_dir="models/model/version_1" gets passed in, and not --trainer.callbacks.write_interval="batch".

How to reproduce the bug

No response

Error messages and logs


# Error messages and logs here please

Environment


* CUDA:
	- GPU:
		- Tesla V100-SXM2-16GB
	- available:         True
	- version:           10.2
* Lightning:
	- pytorch-lightning: 1.7.7
	- torch:             1.12.1
	- torchmetrics:      0.9.3
	- torchtext:         0.13.1
* Packages:
	- absl-py:           1.2.0
	- aiobotocore:       2.4.0
	- aiohttp:           3.8.3
	- aiohttp-retry:     2.8.3
	- aioitertools:      0.11.0
	- aiosignal:         1.2.0
	- amqp:              5.1.1
	- antlr4-python3-runtime: 4.9.3
	- appdirs:           1.4.4
	- asttokens:         2.0.8
	- async-timeout:     4.0.2
	- asyncssh:          2.12.0
	- atpublic:          3.1.1
	- attrs:             22.1.0
	- backcall:          0.2.0
	- billiard:          3.6.4.0
	- boto3:             1.24.59
	- botocore:          1.27.59
	- cachetools:        5.2.0
	- celery:            5.2.7
	- certifi:           2022.9.14
	- cffi:              1.15.1
	- charset-normalizer: 2.1.1
	- click:             8.1.3
	- click-didyoumean:  0.3.0
	- click-plugins:     1.1.1
	- click-repl:        0.2.0
	- colorama:          0.4.5
	- commonmark:        0.9.1
	- configobj:         5.0.6
	- contourpy:         1.0.5
	- cryptography:      38.0.1
	- cycler:            0.11.0
	- datasets:          2.5.1
	- debugpy:           1.6.3
	- decorator:         5.1.1
	- deepspeed:         0.7.3
	- dictdiffer:        0.9.0
	- dill:              0.3.5.1
	- diskcache:         5.4.0
	- distro:            1.7.0
	- docstring-parser:  0.15
	- dpath:             2.0.6
	- dulwich:           0.20.46
	- dvc:               2.25.0
	- dvc-data:          0.7.1
	- dvc-http:          2.19.1
	- dvc-objects:       0.2.2
	- dvc-render:        0.0.10
	- dvc-s3:            2.20.0
	- dvc-task:          0.1.2
	- dvclive:           0.11.0
	- entrypoints:       0.4
	- et-xmlfile:        1.1.0
	- executing:         1.1.0
	- filelock:          3.8.0
	- flatten-dict:      0.4.2
	- flufl.lock:        7.1.1
	- fonttools:         4.37.3
	- frozenlist:        1.3.1
	- fsspec:            2022.8.2
	- funcy:             1.17
	- future:            0.18.2
	- gcsfs:             2022.8.2
	- gitdb:             4.0.9
	- gitpython:         3.1.27
	- google-api-core:   2.8.2
	- google-auth:       2.11.1
	- google-auth-oauthlib: 0.4.6
	- google-cloud-core: 2.3.2
	- google-cloud-storage: 2.5.0
	- google-crc32c:     1.5.0
	- google-resumable-media: 2.3.3
	- googleapis-common-protos: 1.56.4
	- grandalf:          0.6
	- grpcio:            1.49.1
	- hjson:             3.1.0
	- huggingface-hub:   0.9.1
	- hydra-core:        1.2.0
	- idna:              3.4
	- importlib-metadata: 4.12.0
	- importlib-resources: 5.9.0
	- ipykernel:         6.16.0
	- ipython:           8.5.0
	- ipywidgets:        8.0.2
	- jedi:              0.18.1
	- jmespath:          1.0.1
	- jsonargparse:      4.14.1
	- jupyter-client:    7.3.5
	- jupyter-core:      4.11.1
	- jupyterlab-widgets: 3.0.3
	- kiwisolver:        1.4.4
	- kombu:             5.2.4
	- markdown:          3.4.1
	- markupsafe:        2.1.1
	- matplotlib:        3.6.0
	- matplotlib-inline: 0.1.6
	- multidict:         6.0.2
	- multiprocess:      0.70.13
	- nanotime:          0.5.2
	- nest-asyncio:      1.5.5
	- networkx:          2.8.6
	- ninja:             1.10.2.4
	- numpy:             1.23.3
	- oauthlib:          3.2.1
	- omegaconf:         2.2.3
	- openpyxl:          3.0.10
	- packaging:         21.3
	- pandas:            1.5.0
	- parso:             0.8.3
	- pathspec:          0.9.0
	- pexpect:           4.8.0
	- pickleshare:       0.7.5
	- pillow:            9.2.0
	- pip:               22.1.2
	- prompt-toolkit:    3.0.31
	- protobuf:          3.19.5
	- psutil:            5.9.2
	- ptyprocess:        0.7.0
	- pure-eval:         0.2.2
	- py-cpuinfo:        8.0.0
	- pyarrow:           9.0.0
	- pyasn1:            0.4.8
	- pyasn1-modules:    0.2.8
	- pycparser:         2.21
	- pydantic:          1.10.2
	- pydeprecate:       0.3.2
	- pydot:             1.4.2
	- pygit2:            1.10.1
	- pygments:          2.13.0
	- pygtrie:           2.5.0
	- pyparsing:         3.0.9
	- python-dateutil:   2.8.2
	- pytorch-lightning: 1.7.7
	- pytz:              2022.2.1
	- pyyaml:            6.0
	- pyzmq:             24.0.1
	- regex:             2022.9.13
	- requests:          2.28.1
	- requests-oauthlib: 1.3.1
	- responses:         0.18.0
	- rich:              12.5.1
	- rsa:               4.9
	- ruamel.yaml:       0.17.21
	- ruamel.yaml.clib:  0.2.6
	- s3fs:              2022.8.2
	- s3transfer:        0.6.0
	- scmrepo:           0.1.1
	- setuptools:        63.4.1
	- shortuuid:         1.0.9
	- shtab:             1.5.5
	- six:               1.16.0
	- smmap:             5.0.0
	- stack-data:        0.5.1
	- tabulate:          0.8.10
	- tensorboard:       2.10.1
	- tensorboard-data-server: 0.6.1
	- tensorboard-plugin-wit: 1.8.1
	- tokenizers:        0.12.1
	- tomlkit:           0.11.4
	- torch:             1.12.1
	- torchmetrics:      0.9.3
	- torchtext:         0.13.1
	- tornado:           6.2
	- tqdm:              4.64.1
	- traitlets:         5.4.0
	- transformers:      4.22.1
	- typing-extensions: 4.3.0
	- urllib3:           1.26.12
	- vine:              5.0.0
	- voluptuous:        0.13.1
	- wcwidth:           0.2.5
	- werkzeug:          2.2.2
	- wheel:             0.37.1
	- widgetsnbextension: 4.0.3
	- wrapt:             1.14.1
	- xxhash:            3.0.0
	- yarl:              1.8.1
	- zc.lockfile:       2.0
	- zipp:              3.8.1
* System:
	- OS:                Linux
	- architecture:
		- 64bit
		- ELF
	- processor:         x86_64
	- python:            3.8.13
	- version:           #168-Ubuntu SMP Wed Jan 16 21:00:45 UTC 2019


More info

No response

@junwang-wish junwang-wish added the needs triage Waiting to be triaged by maintainers label Oct 5, 2022
@mauvilsa
Copy link
Contributor

mauvilsa commented Oct 6, 2022

@junwang-wish is it possible for you to post a minimal python script that reproduces this?

@junwang-wish
Copy link
Author

junwang-wish commented Oct 6, 2022

@mauvilsa thx, here u go, say this is named tmp.py

from main_utils import LLMData
import pytorch_lightning as pl 
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning import trainer
from pytorch_lightning.callbacks import BasePredictionWriter

class PredictionWriter_Conditional_LM(BasePredictionWriter):
    def __init__(self, output_dir: str, write_interval: str):
        super().__init__(write_interval)


class LLM_Inference_Base(pl.LightningModule):
    def __init__(self, llm_type: str, ckpt_path: str, config_path: str, output_dir: str,
            write_interval: str, **kwargs):
        super().__init__()
        self.save_hyperparameters()

def cli_main():
    cli = LightningCLI()
    
if __name__ == "__main__":
    cli_main()

If I run

python tmp.py predict \
        --model=LLM_Inference_Base \
        --model.llm_type="conditional_lm" \
        --model.ckpt_path="models/model/version_1/epoch=0-step=100.ckpt" \
        --model.output_dir="models/model/version_1/config.yaml" \
        --model.write_interval="batch" \
        --data=LLMData \
        --data.data_source_yaml_path="datasets/data/data.yaml" \
        --data.model_name="t5-base" \
        --trainer.callbacks+=PredictionWriter_Conditional_LM \
        --trainer.callbacks.write_interval="batch" \
        --trainer.callbacks.output_dir="models/model/version_1" \
        --print_config

u would get

# pytorch_lightning==1.7.7
seed_everything: true
trainer:
  logger: true
  enable_checkpointing: true
  callbacks:
  - class_path: __main__.PredictionWriter_Conditional_LM
    init_args:
      output_dir: models/model/version_1
      write_interval: null
  default_root_dir: null
  gradient_clip_val: null
  gradient_clip_algorithm: null
  num_nodes: 1
  num_processes: null
  devices: null
  gpus: null
  auto_select_gpus: false
  tpu_cores: null
  ipus: null
  enable_progress_bar: true
  overfit_batches: 0.0
  track_grad_norm: -1
  check_val_every_n_epoch: 1
  fast_dev_run: false
  accumulate_grad_batches: null
  max_epochs: null
  min_epochs: null
  max_steps: -1
  min_steps: null
  max_time: null
  limit_train_batches: null
  limit_val_batches: null
  limit_test_batches: null
  limit_predict_batches: null
  val_check_interval: null
  log_every_n_steps: 50
  accelerator: null
  strategy: null
  sync_batchnorm: false
  precision: 32
  enable_model_summary: true
  weights_save_path: null
  num_sanity_val_steps: 2
  resume_from_checkpoint: null
  profiler: null
  benchmark: null
  deterministic: null
  reload_dataloaders_every_n_epochs: 0
  auto_lr_find: false
  replace_sampler_ddp: true
  detect_anomaly: false
  auto_scale_batch_size: false
  plugins: null
  amp_backend: native
  amp_level: null
  move_metrics_to_cpu: false
  multiple_trainloader_mode: max_size_cycle
return_predictions: null
ckpt_path: null
model:
  class_path: __main__.LLM_Inference_Base
  init_args:
    llm_type: conditional_lm
    ckpt_path: models/model/version_1/epoch=0-step=100.ckpt
    config_path: null
    output_dir: models/model/version_1/config.yaml
    write_interval: batch
data:
  class_path: main_utils.LLMData
  init_args:
    data_source_yaml_path: datasets/data/data.yaml
    model_name: t5-base
    raw_cache_dir: /data/junwang/.cache/general
    batch_size: 16
    overwrite_cache: false
    max_length: 250
    predict_on_test: true
    num_workers: 80
    max_length_out: 100
    cache_dir: null
    force_download: false
    resume_download: false
    proxies: null
    use_auth_token: null
    local_files_only: false
    revision: null
    trust_remote_code: null
    subfolder: ''

Notice that write_interval of __main__.PredictionWriter_Conditional_LM is no set, but it is passed in

@mauvilsa
Copy link
Contributor

mauvilsa commented Oct 7, 2022

Great, thank you!

@mauvilsa
Copy link
Contributor

mauvilsa commented Oct 7, 2022

@junwang-wish thank you very much for reporting. This was a bug in jsonargparse, fixed in commit 3337a0e and just released as version 4.15.1. Please update the package (e.g. pip3 install -U jsonargparse) and the problem should be fixed.

@carmocca carmocca closed this as completed Oct 8, 2022
@carmocca carmocca added bug Something isn't working 3rd party Related to a 3rd-party lightningcli pl.cli.LightningCLI and removed needs triage Waiting to be triaged by maintainers labels Oct 8, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3rd party Related to a 3rd-party bug Something isn't working lightningcli pl.cli.LightningCLI
Projects
None yet
Development

No branches or pull requests

3 participants