Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref: separate argparse #3428

Merged
merged 1 commit into from Sep 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
107 changes: 1 addition & 106 deletions pytorch_lightning/trainer/trainer.py
Expand Up @@ -488,112 +488,7 @@ def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namesp

@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
r"""Extends existing argparse by default `Trainer` attributes.

Args:
parent_parser:
The custom cli arguments parser, which will be extended by
the Trainer default arguments.

Only arguments of the allowed types (str, float, int, bool) will
extend the `parent_parser`.

Examples:
>>> import argparse
>>> import pprint
>>> parser = argparse.ArgumentParser()
>>> parser = Trainer.add_argparse_args(parser)
>>> args = parser.parse_args([])
>>> pprint.pprint(vars(args)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
{...
'check_val_every_n_epoch': 1,
'checkpoint_callback': True,
'default_root_dir': None,
'deterministic': False,
'distributed_backend': None,
'early_stop_callback': False,
...
'logger': True,
'max_epochs': 1000,
'max_steps': None,
'min_epochs': 1,
'min_steps': None,
...
'profiler': None,
'progress_bar_refresh_rate': 1,
...}

"""
parser = ArgumentParser(parents=[parent_parser], add_help=False,)

blacklist = ['kwargs']
depr_arg_names = cls.get_deprecated_arg_names() + blacklist

allowed_types = (str, int, float, bool)

# TODO: get "help" from docstring :)
for arg, arg_types, arg_default in (
at for at in argparse_utils.get_init_arguments_and_types(cls) if at[0] not in depr_arg_names
):
arg_types = [at for at in allowed_types if at in arg_types]
if not arg_types:
# skip argument with not supported type
continue
arg_kwargs = {}
if bool in arg_types:
arg_kwargs.update(nargs="?", const=True)
# if the only arg type is bool
if len(arg_types) == 1:
use_type = parsing.str_to_bool
# if only two args (str, bool)
elif len(arg_types) == 2 and set(arg_types) == {str, bool}:
use_type = parsing.str_to_bool_or_str
else:
# filter out the bool as we need to use more general
use_type = [at for at in arg_types if at is not bool][0]
else:
use_type = arg_types[0]

if arg == 'gpus' or arg == 'tpu_cores':
use_type = Trainer._gpus_allowed_type
arg_default = Trainer._gpus_arg_default

# hack for types in (int, float)
if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types):
use_type = Trainer._int_or_float_type

# hack for track_grad_norm
if arg == 'track_grad_norm':
use_type = float

parser.add_argument(
f'--{arg}',
dest=arg,
default=arg_default,
type=use_type,
help='autogenerated by pl.Trainer',
**arg_kwargs,
)

return parser

def _gpus_allowed_type(x) -> Union[int, str]:
if ',' in x:
return str(x)
else:
return int(x)

def _gpus_arg_default(x) -> Union[int, str]:
if ',' in x:
return str(x)
else:
return int(x)

def _int_or_float_type(x) -> Union[int, float]:
if '.' in str(x):
return float(x)
else:
return int(x)
return argparse_utils.add_argparse_args(cls, parent_parser)

@property
def num_gpus(self) -> int:
Expand Down
114 changes: 114 additions & 0 deletions pytorch_lightning/utilities/argparse_utils.py
@@ -1,6 +1,7 @@
import inspect
from argparse import ArgumentParser, Namespace
from typing import Union, List, Tuple, Any
from pytorch_lightning.utilities import parsing


def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
Expand Down Expand Up @@ -107,3 +108,116 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
name_type_default.append((arg, arg_types, arg_default))

return name_type_default


def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
r"""Extends existing argparse by default `Trainer` attributes.

Args:
parent_parser:
The custom cli arguments parser, which will be extended by
the Trainer default arguments.

Only arguments of the allowed types (str, float, int, bool) will
extend the `parent_parser`.

Examples:
>>> import argparse
>>> import pprint
>>> from pytorch_lightning import Trainer
>>> parser = argparse.ArgumentParser()
>>> parser = Trainer.add_argparse_args(parser)
>>> args = parser.parse_args([])
>>> pprint.pprint(vars(args)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
{...
'check_val_every_n_epoch': 1,
'checkpoint_callback': True,
'default_root_dir': None,
'deterministic': False,
'distributed_backend': None,
'early_stop_callback': False,
...
'logger': True,
'max_epochs': 1000,
'max_steps': None,
'min_epochs': 1,
'min_steps': None,
...
'profiler': None,
'progress_bar_refresh_rate': 1,
...}

"""
parser = ArgumentParser(parents=[parent_parser], add_help=False,)

blacklist = ['kwargs']
depr_arg_names = cls.get_deprecated_arg_names() + blacklist

allowed_types = (str, int, float, bool)

# TODO: get "help" from docstring :)
for arg, arg_types, arg_default in (
at for at in get_init_arguments_and_types(cls) if at[0] not in depr_arg_names
):
arg_types = [at for at in allowed_types if at in arg_types]
if not arg_types:
# skip argument with not supported type
continue
arg_kwargs = {}
if bool in arg_types:
arg_kwargs.update(nargs="?", const=True)
# if the only arg type is bool
if len(arg_types) == 1:
use_type = parsing.str_to_bool
# if only two args (str, bool)
elif len(arg_types) == 2 and set(arg_types) == {str, bool}:
use_type = parsing.str_to_bool_or_str
else:
# filter out the bool as we need to use more general
use_type = [at for at in arg_types if at is not bool][0]
else:
use_type = arg_types[0]

if arg == 'gpus' or arg == 'tpu_cores':
use_type = _gpus_allowed_type
arg_default = _gpus_arg_default

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@williamFalcon Should this be a function ? Raising issue in hparams serialization with omegaconf, when --gpus is not explicitly provided in the command line.


# hack for types in (int, float)
if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types):
use_type = _int_or_float_type

# hack for track_grad_norm
if arg == 'track_grad_norm':
use_type = float

parser.add_argument(
f'--{arg}',
dest=arg,
default=arg_default,
type=use_type,
help='autogenerated by pl.Trainer',
**arg_kwargs,
)

return parser


def _gpus_allowed_type(x) -> Union[int, str]:
if ',' in x:
return str(x)
else:
return int(x)


def _gpus_arg_default(x) -> Union[int, str]:
if ',' in x:
return str(x)
else:
return int(x)


def _int_or_float_type(x) -> Union[int, float]:
if '.' in str(x):
return float(x)
else:
return int(x)