diff --git a/argschema/sources/json_source.py b/argschema/sources/json_source.py index b275436..d87b4ca 100644 --- a/argschema/sources/json_source.py +++ b/argschema/sources/json_source.py @@ -1,4 +1,4 @@ -from .source import FileSource, FileSink +from .source import ArgSource, ArgSink import json import marshmallow as mm import argschema @@ -11,19 +11,17 @@ class JsonOutputConfigSchema(mm.Schema): output_json = argschema.fields.OutputFile(required=True, description = 'filepath to save output_json') -class JsonSource(FileSource): +class JsonSource(ArgSource): ConfigSchema = JsonInputConfigSchema + + def get_dict(self): + with open(self.input_json,'r') as fp: + return json.load(fp) - def __init__(self,input_json=None): - self.filepath = input_json - def read_file(self,fp): - return json.load(fp) - -class JsonSink(FileSink): +class JsonSink(ArgSink): ConfigSchema = JsonOutputConfigSchema - def __init__(self,output_json=None): - self.filepath = output_json - - def write_file(self,fp,d): - json.dump(d,fp) + def put_dict(self,d): + with open(self.output_json,'w') as fp: + json.dump(d,fp) + diff --git a/argschema/sources/source.py b/argschema/sources/source.py index f3cb6d1..42e391b 100644 --- a/argschema/sources/source.py +++ b/argschema/sources/source.py @@ -19,6 +19,8 @@ class MultipleConfiguredSourceError(ConfigurableSourceError): pass def d_contains_any_fields(schema,d): + if len(schema.declared_fields)==0: + return True for field_name, field in schema.declared_fields.items(): if field_name in d.keys(): if d[field_name] is not None: @@ -40,12 +42,8 @@ def __init__(self,**kwargs): which will define the set of fields that are allowed (and their defaults) """ schema = self.ConfigSchema() - result,errors = schema.load(kwargs) - if len(errors)>0: - raise MisconfiguredSourceError('invalid keyword arguments passed {}'.format(kwargs)) - self.__dict__=result - for field_name, field in schema.declared_fields.items(): - self.__dict__[field_name]=result[field_name] + result = self.get_config(self.ConfigSchema,kwargs) + self.__dict__.update(result) @staticmethod def get_config(Schema,d): @@ -58,30 +56,21 @@ def get_config(Schema,d): raise MisconfiguredSourceError("Source incorrectly configured\n" + json.dumps(errors, indent=2)) else: return result + class ArgSource(ConfigurableSource): def get_dict(self): pass -class ArgSink(ConfigurableSource): - def put_dict(self,d): - pass - -class FileSource(ArgSource): - - def get_dict(self): - with open(self.filepath,'r') as fp: - d = self.read_file(fp) - return d - - def read_file(self,fp): - pass - -class FileSink(ArgSink): - - def write_file(self,fp,d): - pass +def get_input_from_config(ArgSource, config_d): + if config_d is not None: + input_config_d = ArgSource.get_config(ArgSource.ConfigSchema, config_d) + input_source = ArgSource(**input_config_d) + input_data = input_source.get_dict() + return input_data + else: + raise NotConfiguredSourceError('No dictionary provided') +class ArgSink(ConfigurableSource): def put_dict(self,d): - with open(self.filepath,'w') as fp: - self.write_file(fp,d) \ No newline at end of file + pass \ No newline at end of file diff --git a/argschema/sources/yaml_source.py b/argschema/sources/yaml_source.py index 480f4d8..1692b9d 100644 --- a/argschema/sources/yaml_source.py +++ b/argschema/sources/yaml_source.py @@ -1,5 +1,5 @@ import yaml -from .source import FileSource,FileSink +from .source import ArgSource,ArgSink import argschema import marshmallow as mm @@ -11,20 +11,17 @@ class YamlOutputConfigSchema(mm.Schema): output_yaml = argschema.fields.OutputFile(required=True, description = 'filepath to save output yaml') -class YamlSource(FileSource): +class YamlSource(ArgSource): ConfigSchema = YamlInputConfigSchema - def __init__(self,input_yaml=None): - self.filepath = input_yaml + def get_dict(self): + with open(self.input_yaml,'r') as fp: + return yaml.load(fp) - def read_file(self,fp): - return yaml.load(fp) - -class YamlSink(FileSink): +class YamlSink(ArgSink): ConfigSchema = YamlOutputConfigSchema - def __init__(self,output_yaml=None): - self.filepath = output_yaml + def put_dict(self,d): + with open(self.output_yaml,'w') as fp: + yaml.dump(d,fp,default_flow_style=False) - def write_file(self,fp,d): - yaml.dump(d,fp,default_flow_style=False) \ No newline at end of file diff --git a/test/sources/test_json.py b/test/sources/test_json.py index 3055cb2..1acc001 100644 --- a/test/sources/test_json.py +++ b/test/sources/test_json.py @@ -22,7 +22,8 @@ def test_input_file(tmpdir_factory): return str(file_in) def test_json_source(test_input_file): - mod = MyParser(input_source= JsonSource(test_input_file), args=[]) + source = JsonSource(input_json=test_input_file) + mod = MyParser(input_source= source, args=[]) def test_json_source_command(test_input_file): mod = MyParser(args = ['--input_json',test_input_file]) \ No newline at end of file diff --git a/test/sources/test_yaml.py b/test/sources/test_yaml.py index 585a8d2..96c0bce 100644 --- a/test/sources/test_yaml.py +++ b/test/sources/test_yaml.py @@ -42,7 +42,8 @@ def test_json_input_file(tmpdir_factory): def test_yaml_source(test_yaml_input_file): - mod = MyParser(input_source=YamlSource(test_yaml_input_file), args=[]) + source = YamlSource(input_yaml=test_yaml_input_file) + mod = MyParser(input_source=source, args=[]) def test_yaml_source_command(test_yaml_input_file): @@ -54,8 +55,10 @@ def test_yaml_sink(test_yaml_input_file, tmpdir): output_data = { 'a': 3 } - mod = MyParser(input_source=YamlSource(test_yaml_input_file), - output_sink=YamlSink(str(outfile))) + source = YamlSource(input_yaml=test_yaml_input_file) + sink = YamlSink(output_yaml = str(outfile)) + mod = MyParser(input_source=source, + output_sink=sink) mod.output(output_data) with open(str(outfile), 'r') as fp: