diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f7f7eff35997f..9d462ef8f36aa 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -766,8 +766,16 @@ def use_type(x): use_type = arg_types[0] if arg == 'gpus' or arg == 'tpu_cores': - use_type = Trainer._allowed_type - arg_default = Trainer._arg_default + use_type = Trainer._gpus_allowed_type + arg_default = Trainer._gpus_arg_default + + # hack for types in (int, float) + if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types): + use_type = Trainer._int_or_float_type + + # hack for track_grad_norm + if arg == 'track_grad_norm': + use_type = float parser.add_argument( f'--{arg}', @@ -780,18 +788,24 @@ def use_type(x): return parser - def _allowed_type(x) -> Union[int, str]: + def _gpus_allowed_type(x) -> Union[int, str]: if ',' in x: return str(x) else: return int(x) - def _arg_default(x) -> Union[int, str]: + def _gpus_arg_default(x) -> Union[int, str]: if ',' in x: return str(x) else: return int(x) + def _int_or_float_type(x) -> Union[int, float]: + if '.' in str(x): + return float(x) + else: + return int(x) + @classmethod def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: """Parse CLI arguments, required for custom bool types."""