LightningCLI doesn't parse callback args correctly if more than one args for each callback #15007

junwang-wish opened this issue Oct 5, 2022 · 4 comments
3rd party Related to a 3rd-party bug Something isn't working lightningcli pl.cli.LightningCLI


Bug description

I'm following Pytorch Lightning 1.7.7 Doc to specify LightningCLI args:

python predict \
        --model=LLM_Inference_Conditional_LM \
        --model.llm_type="conditional_lm" \
        --model.ckpt_path="models/model/version_1/epoch=0-step=29648.ckpt" \
        --model.config_path="models/model/version_1/config.yaml" \
        --model.allowed_gen_sequences="['a', 'b']" \
        --data=LLMData \
        --data.data_source_yaml_path="datasets/data/data.yaml" \
        --data.model_name="t5-base" \
        --trainer.callbacks+=PredictionWriter_Conditional_LM \
        --trainer.callbacks.write_interval="batch" \
        --trainer.callbacks.output_dir="models/model/version_1" \

Ideally, both trainer.callbacks.output_dir and trainer.callbacks.write_interval get passed in to instantiate PredictionWriter_Conditional_LM. However, I get

# pytorch_lightning==1.7.7
seed_everything: true
  logger: true
  enable_checkpointing: true
  - class_path: __main__.PredictionWriter_Conditional_LM
      output_dir: models/model/version_1
      write_interval: null
  default_root_dir: null
  gradient_clip_val: null
  gradient_clip_algorithm: null
  num_nodes: 1
  num_processes: null
  devices: null
  gpus: null
  auto_select_gpus: false
  tpu_cores: null
  ipus: null
  enable_progress_bar: true
  overfit_batches: 0.0
  track_grad_norm: -1
  check_val_every_n_epoch: 1
  fast_dev_run: false
  accumulate_grad_batches: null
  max_epochs: null
  min_epochs: null
  max_steps: -1
  min_steps: null
  max_time: null
  limit_train_batches: null
  limit_val_batches: null
  limit_test_batches: null
  limit_predict_batches: null
  val_check_interval: null
  log_every_n_steps: 50
  accelerator: null
  strategy: null
  sync_batchnorm: false
  precision: 32
  enable_model_summary: true
  weights_save_path: null
  num_sanity_val_steps: 2
  resume_from_checkpoint: null
  profiler: null
  benchmark: null
  deterministic: null
  reload_dataloaders_every_n_epochs: 0
  auto_lr_find: false
  replace_sampler_ddp: true
  detect_anomaly: false
  auto_scale_batch_size: false
  plugins: null
  amp_backend: native
  amp_level: null
  move_metrics_to_cpu: false
  multiple_trainloader_mode: max_size_cycle
return_predictions: null
ckpt_path: null
  class_path: __main__.LLM_Inference_Conditional_LM
    llm_type: conditional_lm
    ckpt_path: models/model/version_1/epoch=0-step=29648.ckpt
    config_path: models/model/version_1/config.yaml
    num_beams: 1
    num_return_sequences: 1
    do_sample: false
    length_penalty: 0.0
    max_new_tokens: 50
    - a
    - b
  class_path: main_utils.LLMData
    data_source_yaml_path: datasets/data/data.yaml
    model_name: t5-base
    raw_cache_dir: /data/junwang/.cache/general
    batch_size: 16
    overwrite_cache: false
    max_length: 250
    predict_on_test: true
    num_workers: 80
    max_length_out: 100
    cache_dir: null
    force_download: false
    resume_download: false
    proxies: null
    use_auth_token: null
    local_files_only: false
    revision: null
    trust_remote_code: null
    subfolder: ''

Note that only the last used --trainer.callbacks.output_dir="models/model/version_1" gets passed in, and not --trainer.callbacks.write_interval="batch".

How to reproduce the bug

Error messages and logs

More info

@junwang-wish junwang-wish added the needs triage Waiting to be triaged by maintainers label Oct 5, 2022
mauvilsa commented Oct 6, 2022

@junwang-wish is it possible for you to post a minimal python script that reproduces this?

junwang-wish commented Oct 6, 2022

@mauvilsa thx, here u go, say this is named

from main_utils import LLMData
import pytorch_lightning as pl 
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning import trainer
from pytorch_lightning.callbacks import BasePredictionWriter

class PredictionWriter_Conditional_LM(BasePredictionWriter):
    def __init__(self, output_dir: str, write_interval: str):

class LLM_Inference_Base(pl.LightningModule):
    def __init__(self, llm_type: str, ckpt_path: str, config_path: str, output_dir: str,
            write_interval: str, **kwargs):

def cli_main():
    cli = LightningCLI()
if __name__ == "__main__":

If I run

python predict \
        --model=LLM_Inference_Base \
        --model.llm_type="conditional_lm" \
        --model.ckpt_path="models/model/version_1/epoch=0-step=100.ckpt" \
        --model.output_dir="models/model/version_1/config.yaml" \
        --model.write_interval="batch" \
        --data=LLMData \
        --data.data_source_yaml_path="datasets/data/data.yaml" \
        --data.model_name="t5-base" \
        --trainer.callbacks+=PredictionWriter_Conditional_LM \
        --trainer.callbacks.write_interval="batch" \
        --trainer.callbacks.output_dir="models/model/version_1" \

u would get

# pytorch_lightning==1.7.7
seed_everything: true
  logger: true
  enable_checkpointing: true
  - class_path: __main__.PredictionWriter_Conditional_LM
      output_dir: models/model/version_1
      write_interval: null
  default_root_dir: null
  gradient_clip_val: null
  gradient_clip_algorithm: null
  num_nodes: 1
  num_processes: null
  devices: null
  gpus: null
  auto_select_gpus: false
  tpu_cores: null
  ipus: null
  enable_progress_bar: true
  overfit_batches: 0.0
  track_grad_norm: -1
  check_val_every_n_epoch: 1
  fast_dev_run: false
  accumulate_grad_batches: null
  max_epochs: null
  min_epochs: null
  max_steps: -1
  min_steps: null
  max_time: null
  limit_train_batches: null
  limit_val_batches: null
  limit_test_batches: null
  limit_predict_batches: null
  val_check_interval: null
  log_every_n_steps: 50
  accelerator: null
  strategy: null
  sync_batchnorm: false
  precision: 32
  enable_model_summary: true
  weights_save_path: null
  num_sanity_val_steps: 2
  resume_from_checkpoint: null
  profiler: null
  benchmark: null
  deterministic: null
  reload_dataloaders_every_n_epochs: 0
  auto_lr_find: false
  replace_sampler_ddp: true
  detect_anomaly: false
  auto_scale_batch_size: false
  plugins: null
  amp_backend: native
  amp_level: null
  move_metrics_to_cpu: false
  multiple_trainloader_mode: max_size_cycle
return_predictions: null
ckpt_path: null
  class_path: __main__.LLM_Inference_Base
    llm_type: conditional_lm
    ckpt_path: models/model/version_1/epoch=0-step=100.ckpt
    config_path: null
    output_dir: models/model/version_1/config.yaml
    write_interval: batch
  class_path: main_utils.LLMData
    data_source_yaml_path: datasets/data/data.yaml
    model_name: t5-base
    raw_cache_dir: /data/junwang/.cache/general
    batch_size: 16
    overwrite_cache: false
    max_length: 250
    predict_on_test: true
    num_workers: 80
    max_length_out: 100
    cache_dir: null
    force_download: false
    resume_download: false
    proxies: null
    use_auth_token: null
    local_files_only: false
    revision: null
    trust_remote_code: null
    subfolder: ''

Notice that write_interval of __main__.PredictionWriter_Conditional_LM is no set, but it is passed in

Copy link

mauvilsa commented Oct 7, 2022

Great, thank you!

Copy link

mauvilsa commented Oct 7, 2022

@junwang-wish thank you very much for reporting. This was a bug in jsonargparse, fixed in commit 3337a0e and just released as version 4.15.1. Please update the package (e.g. pip3 install -U jsonargparse) and the problem should be fixed.

@carmocca carmocca closed this as completed Oct 8, 2022
@carmocca carmocca added bug Something isn't working 3rd party Related to a 3rd-party lightningcli pl.cli.LightningCLI and removed needs triage Waiting to be triaged by maintainers label Oct 8, 2022
