Skip to content

Commit

Permalink
merging dev into master
Browse files Browse the repository at this point in the history
  • Loading branch information
fcollman committed Apr 26, 2018
2 parents 671801f + 645b887 commit c95700e
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 27 deletions.
19 changes: 6 additions & 13 deletions argschema/argschema_parser.py
Expand Up @@ -6,6 +6,7 @@
import copy
from . import schemas
from . import utils
from . import fields
import marshmallow as mm


Expand Down Expand Up @@ -160,26 +161,20 @@ def __init__(self,
self.logger.debug('argsdict is {}'.format(argsdict))

if argsobj.input_json is not None:
result = self.schema.load(argsdict)
if 'input_json' in result.errors:
raise mm.ValidationError(result.errors['input_json'])
with open(result.data['input_json'], 'r') as j:
fields.files.validate_input_path(argsobj.input_json)
with open(argsobj.input_json, 'r') as j:
jsonargs = json.load(j)
else:
jsonargs = input_data if input_data else {}


# merge the command line dictionary into the input json
args = utils.smart_merge(jsonargs, argsdict)
self.logger.debug('args after merge {}'.format(args))

# validate with load!
result = self.load_schema_with_defaults(self.schema, args)
if len(result.errors) > 0:
raise mm.ValidationError(json.dumps(result.errors, indent=2))

self.schema_args = result
self.args = result.data
self.args = result
self.output_schema_type = output_schema_type
self.logger = self.initialize_logger(
logger_name, self.args.get('log_level'))
Expand All @@ -204,9 +199,7 @@ def get_output_json(self,d):
"""
if self.output_schema_type is not None:
schema = self.output_schema_type()
(output_json,errors)=schema.dump(d)
if len(errors)>0:
raise mm.ValidationError(json.dumps(errors))
output_json = utils.dump(schema,d)
else:
self.logger.warning("output_schema_type is not defined,\
the output won't be validated")
Expand Down Expand Up @@ -278,7 +271,7 @@ def load_schema_with_defaults(self ,schema, args):
'Recursive schemas need to subclass argschema.DefaultSchema else defaults will not work')

# load the dictionary via the schema
result = schema.load(args)
result = utils.load(schema, args)

return result

Expand Down
10 changes: 6 additions & 4 deletions argschema/fields/files.py
Expand Up @@ -113,6 +113,11 @@ def _validate(self, value):
# use outputfile to test that a file in this location is a valid path
validate_outpath(value)

def validate_input_path(value):
if not os.path.isfile(value):
raise mm.ValidationError("%s is not a file" % value)
elif not os.access(value, os.R_OK):
raise mm.ValidationError("%s is not readable" % value)

class InputDir(mm.fields.Str):
"""InputDir is :class:`marshmallow.fields.Str` subclass which is a path to a
Expand All @@ -135,7 +140,4 @@ class InputFile(mm.fields.Str):
"""

def _validate(self, value):
if not os.path.isfile(value):
raise mm.ValidationError("%s is not a file" % value)
elif not os.access(value, os.R_OK):
raise mm.ValidationError("%s is not readable" % value)
validate_input_path(value)
69 changes: 65 additions & 4 deletions argschema/utils.py
Expand Up @@ -371,11 +371,11 @@ def schema_argparser(schema):
"""

#build up a list of argument groups using recursive function
#to traverse the tree, root node gets the description given by doc string
#of the schema
# build up a list of argument groups using recursive function
# to traverse the tree, root node gets the description given by doc string
# of the schema
arguments = build_schema_arguments(schema,description=schema.__doc__)
#make the root schema appeear first rather than last
# make the root schema appeear first rather than last
arguments = [arguments[-1]]+arguments[0:-1]

parser = argparse.ArgumentParser()
Expand All @@ -385,3 +385,64 @@ def schema_argparser(schema):
for arg_name,arg in arg_group['args'].items():
group.add_argument(arg_name, **arg)
return parser

def load(schema, d):
""" function to wrap marshmallow load to smooth
differences from marshmallow 2 to 3
Parameters
----------
schema: marshmallow.Schema
schema that you want to use to validate
d: dict
dictionary to validate and load
Returns
-------
dict
deserialized and validated dictionary
Raises
------
marshmallow.ValidationError
if the dictionary does not conform to the schema
"""

results = schema.load(d)
if isinstance(results, tuple):
(results, errors) = results
if len(errors) > 0:
raise mm.ValidationError(errors)

return results


def dump(schema, d):
""" function to wrap marshmallow dump to smooth
differences from marshmallow 2 to 3
Parameters
----------
schema: marshmallow.Schema
schema that you want to use to validate and dump
d: dict
dictionary to validate and dump
Returns
-------
dict
serialized and validated dictionary
Raises
------
marshmallow.ValidationError
if the dictionary does not conform to the schema
"""

results = schema.dump(d)
if isinstance(results, tuple):
(results, errors) = results
if len(errors) > 0:
raise mm.ValidationError(errors)

return results
2 changes: 1 addition & 1 deletion requirements.txt
@@ -1,2 +1,2 @@
numpy
marshmallow<=3.0.0
marshmallow
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -7,7 +7,7 @@
test_required = f.read().splitlines()

setup(name='argschema',
version='1.16.6',
version='1.17.1',
description=' a wrapper for setting up modules that can have parameters specified by command line arguments,\
json_files, or dictionary objects. Providing a common wrapper for data processing modules.',
author='Forrest Collman,David Feng',
Expand Down
4 changes: 2 additions & 2 deletions test/fields/test_numpyarray.py
@@ -1,6 +1,7 @@
import pytest
from argschema import ArgSchemaParser, ArgSchema
from argschema.fields import NumpyArray
from argschema.utils import load,dump
import marshmallow as mm
import numpy as np

Expand Down Expand Up @@ -46,7 +47,6 @@ def test_serialize():
object_dict = {
'a': np.array([1, 2])
}
(json_dict, errors) = schema.dump(object_dict)
assert(len(errors) == 0)
json_dict = dump(schema, object_dict)
assert(type(json_dict['a']) == list)
assert(json_dict['a'] == object_dict['a'].tolist())
3 changes: 1 addition & 2 deletions test/test_nested_examples.py
Expand Up @@ -18,5 +18,4 @@ def test_nested_example():

def test_nested_marshmallow_example():
schema = MySchema()
(result,errors)=schema.load({})
assert(len(errors)==0)
argschema.utils.load(schema, {})

0 comments on commit c95700e

Please sign in to comment.