From fd4fea3d0eac36ac3d5539714bc9bab9b4c72f36 Mon Sep 17 00:00:00 2001 From: Andrey Kislyuk Date: Mon, 11 Dec 2017 10:30:21 -0800 Subject: [PATCH] Fix typevar passthrough to argparse with typing.Optional kwargs --- hca/util/__init__.py | 11 +++++++++-- test/test_dss_cli.py | 9 +++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/hca/util/__init__.py b/hca/util/__init__.py index 120836a7..6da12d15 100755 --- a/hca/util/__init__.py +++ b/hca/util/__init__.py @@ -388,6 +388,13 @@ def _get_command_arg_settings(self, param_data): else: return dict(type=type(param_data.default), default=param_data.default) + def _get_param_argparse_type(self, anno): + if anno in {typing.List, typing.Mapping}: + return json.loads + elif isinstance(getattr(anno, "__args__", None), tuple) and anno == typing.Optional[anno.__args__[0]]: + return anno.__args__[0] + return anno + def build_argparse_subparsers(self, subparsers): for method_name, method_data in self.methods.items(): subcommand_name = method_name.replace("_", "-") @@ -398,9 +405,9 @@ def build_argparse_subparsers(self, subparsers): continue logger.debug("Registering %s %s %s", method_name, param_name, param.annotation) nargs = "+" if param.annotation == typing.List else None - argparse_type = json.loads if param.annotation in {typing.List, typing.Mapping} else param.annotation subparser.add_argument("--" + param_name.replace("_", "-").replace("/", "-"), dest=param_name, - type=argparse_type, nargs=nargs, help=method_data["args"][param_name]["doc"], + type=self._get_param_argparse_type(param.annotation), nargs=nargs, + help=method_data["args"][param_name]["doc"], choices=method_data["args"][param_name]["choices"], required=method_data["args"][param_name]["required"]) subparser.set_defaults(entry_point=method_data["entry_point"]) diff --git a/test/test_dss_cli.py b/test/test_dss_cli.py index 645f3a01..5cef9332 100755 --- a/test/test_dss_cli.py +++ b/test/test_dss_cli.py @@ -19,6 +19,15 @@ class TestDssCLI(unittest.TestCase): + def test_post_search_cli(self): + query = json.dumps({}) + replica = "aws" + args = ["dss", "post-search", "--es-query", query, "--replica", replica, "--output-format", "raw"] + with CapturingIO('stdout') as stdout: + hca.cli.main(args) + result = json.loads(stdout.captured()) + self.assertIn("results", result) + def test_get_files_cli(self): filename = "SRR2967608_1.fastq.gz" dirpath = os.path.join(os.path.dirname(os.path.realpath(__file__)), "bundle")