Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions hca/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,13 @@ def _get_command_arg_settings(self, param_data):
else:
return dict(type=type(param_data.default), default=param_data.default)

def _get_param_argparse_type(self, anno):
if anno in {typing.List, typing.Mapping}:
return json.loads
elif isinstance(getattr(anno, "__args__", None), tuple) and anno == typing.Optional[anno.__args__[0]]:
return anno.__args__[0]
return anno

def build_argparse_subparsers(self, subparsers):
for method_name, method_data in self.methods.items():
subcommand_name = method_name.replace("_", "-")
Expand All @@ -398,9 +405,9 @@ def build_argparse_subparsers(self, subparsers):
continue
logger.debug("Registering %s %s %s", method_name, param_name, param.annotation)
nargs = "+" if param.annotation == typing.List else None
argparse_type = json.loads if param.annotation in {typing.List, typing.Mapping} else param.annotation
subparser.add_argument("--" + param_name.replace("_", "-").replace("/", "-"), dest=param_name,
type=argparse_type, nargs=nargs, help=method_data["args"][param_name]["doc"],
type=self._get_param_argparse_type(param.annotation), nargs=nargs,
help=method_data["args"][param_name]["doc"],
choices=method_data["args"][param_name]["choices"],
required=method_data["args"][param_name]["required"])
subparser.set_defaults(entry_point=method_data["entry_point"])
Expand Down
9 changes: 9 additions & 0 deletions test/test_dss_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@


class TestDssCLI(unittest.TestCase):
def test_post_search_cli(self):
query = json.dumps({})
replica = "aws"
args = ["dss", "post-search", "--es-query", query, "--replica", replica, "--output-format", "raw"]
with CapturingIO('stdout') as stdout:
hca.cli.main(args)
result = json.loads(stdout.captured())
self.assertIn("results", result)

def test_get_files_cli(self):
filename = "SRR2967608_1.fastq.gz"
dirpath = os.path.join(os.path.dirname(os.path.realpath(__file__)), "bundle")
Expand Down