Skip to content

Commit

Permalink
removed FileSource pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
fcollman committed Jan 3, 2018
1 parent 26751f0 commit e4217a7
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 55 deletions.
24 changes: 11 additions & 13 deletions argschema/sources/json_source.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .source import FileSource, FileSink
from .source import ArgSource, ArgSink
import json
import marshmallow as mm
import argschema
Expand All @@ -11,19 +11,17 @@ class JsonOutputConfigSchema(mm.Schema):
output_json = argschema.fields.OutputFile(required=True,
description = 'filepath to save output_json')

class JsonSource(FileSource):
class JsonSource(ArgSource):
ConfigSchema = JsonInputConfigSchema

def get_dict(self):
with open(self.input_json,'r') as fp:
return json.load(fp)

def __init__(self,input_json=None):
self.filepath = input_json
def read_file(self,fp):
return json.load(fp)

class JsonSink(FileSink):
class JsonSink(ArgSink):
ConfigSchema = JsonOutputConfigSchema

def __init__(self,output_json=None):
self.filepath = output_json

def write_file(self,fp,d):
json.dump(d,fp)
def put_dict(self,d):
with open(self.output_json,'w') as fp:
json.dump(d,fp)

41 changes: 15 additions & 26 deletions argschema/sources/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class MultipleConfiguredSourceError(ConfigurableSourceError):
pass

def d_contains_any_fields(schema,d):
if len(schema.declared_fields)==0:
return True
for field_name, field in schema.declared_fields.items():
if field_name in d.keys():
if d[field_name] is not None:
Expand All @@ -40,12 +42,8 @@ def __init__(self,**kwargs):
which will define the set of fields that are allowed (and their defaults)
"""
schema = self.ConfigSchema()
result,errors = schema.load(kwargs)
if len(errors)>0:
raise MisconfiguredSourceError('invalid keyword arguments passed {}'.format(kwargs))
self.__dict__=result
for field_name, field in schema.declared_fields.items():
self.__dict__[field_name]=result[field_name]
result = self.get_config(self.ConfigSchema,kwargs)
self.__dict__.update(result)

@staticmethod
def get_config(Schema,d):
Expand All @@ -58,30 +56,21 @@ def get_config(Schema,d):
raise MisconfiguredSourceError("Source incorrectly configured\n" + json.dumps(errors, indent=2))
else:
return result


class ArgSource(ConfigurableSource):
def get_dict(self):
pass

class ArgSink(ConfigurableSource):
def put_dict(self,d):
pass

class FileSource(ArgSource):

def get_dict(self):
with open(self.filepath,'r') as fp:
d = self.read_file(fp)
return d

def read_file(self,fp):
pass

class FileSink(ArgSink):

def write_file(self,fp,d):
pass
def get_input_from_config(ArgSource, config_d):
if config_d is not None:
input_config_d = ArgSource.get_config(ArgSource.ConfigSchema, config_d)
input_source = ArgSource(**input_config_d)
input_data = input_source.get_dict()
return input_data
else:
raise NotConfiguredSourceError('No dictionary provided')

class ArgSink(ConfigurableSource):
def put_dict(self,d):
with open(self.filepath,'w') as fp:
self.write_file(fp,d)
pass
21 changes: 9 additions & 12 deletions argschema/sources/yaml_source.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import yaml
from .source import FileSource,FileSink
from .source import ArgSource,ArgSink
import argschema
import marshmallow as mm

Expand All @@ -11,20 +11,17 @@ class YamlOutputConfigSchema(mm.Schema):
output_yaml = argschema.fields.OutputFile(required=True,
description = 'filepath to save output yaml')

class YamlSource(FileSource):
class YamlSource(ArgSource):
ConfigSchema = YamlInputConfigSchema

def __init__(self,input_yaml=None):
self.filepath = input_yaml
def get_dict(self):
with open(self.input_yaml,'r') as fp:
return yaml.load(fp)

def read_file(self,fp):
return yaml.load(fp)

class YamlSink(FileSink):
class YamlSink(ArgSink):
ConfigSchema = YamlOutputConfigSchema

def __init__(self,output_yaml=None):
self.filepath = output_yaml
def put_dict(self,d):
with open(self.output_yaml,'w') as fp:
yaml.dump(d,fp,default_flow_style=False)

def write_file(self,fp,d):
yaml.dump(d,fp,default_flow_style=False)
3 changes: 2 additions & 1 deletion test/sources/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def test_input_file(tmpdir_factory):
return str(file_in)

def test_json_source(test_input_file):
mod = MyParser(input_source= JsonSource(test_input_file), args=[])
source = JsonSource(input_json=test_input_file)
mod = MyParser(input_source= source, args=[])

def test_json_source_command(test_input_file):
mod = MyParser(args = ['--input_json',test_input_file])
9 changes: 6 additions & 3 deletions test/sources/test_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def test_json_input_file(tmpdir_factory):


def test_yaml_source(test_yaml_input_file):
mod = MyParser(input_source=YamlSource(test_yaml_input_file), args=[])
source = YamlSource(input_yaml=test_yaml_input_file)
mod = MyParser(input_source=source, args=[])


def test_yaml_source_command(test_yaml_input_file):
Expand All @@ -54,8 +55,10 @@ def test_yaml_sink(test_yaml_input_file, tmpdir):
output_data = {
'a': 3
}
mod = MyParser(input_source=YamlSource(test_yaml_input_file),
output_sink=YamlSink(str(outfile)))
source = YamlSource(input_yaml=test_yaml_input_file)
sink = YamlSink(output_yaml = str(outfile))
mod = MyParser(input_source=source,
output_sink=sink)
mod.output(output_data)

with open(str(outfile), 'r') as fp:
Expand Down

0 comments on commit e4217a7

Please sign in to comment.