diff --git a/setup.cfg b/setup.cfg index 015f3d2..e511511 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,15 +1,15 @@ [metadata] name = odm2datamodels -description = "Collection of object-relational mapping (ORM) data models for ODM2" +description = Collection of object-relational mapping (ORM) data models for ODM2 long_description = file: README.md -long_description_content = text/markdown -version = 0.0.1 -author = "ODM2 Team" -author_email = "" +long_description_content_type = text/markdown +version = 0.0.2 +author = ODM2 Team +author_email = ptomasula@limno.com url = https://github.com/ODM2/ODM2DataModels project_urls = - bugtracker = https://github.com/ODM2/ODM2DataModels/issueshttps://github.com/pypa/sampleproject/issues + bugtracker = https://github.com/ODM2/ODM2DataModels/issues keywords='Observations Data Model ODM2' [options] diff --git a/src/odm2datamodels/base.py b/src/odm2datamodels/base.py index d3f0dd9..7c93b10 100644 --- a/src/odm2datamodels/base.py +++ b/src/odm2datamodels/base.py @@ -28,11 +28,12 @@ from .models import samplingfeatures from .models import simulation +import warnings +warnings.simplefilter("ignore", category=sqlalchemy.exc.SAWarning) + +OUTPUT_FORMATS = ('json', 'dataframe', 'dict') + -class OutputFormats(Enum): - JSON ='JSON' - DATAFRAME = 'DATAFRAME' - DICT = 'DICT' class Base(): @@ -82,19 +83,25 @@ def __init__(self, session_maker:sqlalchemy.orm.sessionmaker) -> None: def read_query(self, query: Union[Query, Select], - output_format:OutputFormats=OutputFormats.JSON, + output_format:str='json', orient:str='records') -> Union[str, pd.DataFrame]: + + # guard against invalid output_format strings + if output_format not in OUTPUT_FORMATS: + raise ValueError(f':argument output_format={output_format}, is not a valid output_format strings: {OUTPUT_FORMATS}') + + # use SQLAlchemy session to read_query and return response in the designated output_format with self.session_maker() as session: if isinstance(query, Select): df = pd.read_sql(query, session.bind) else: df = pd.read_sql(query.statement, session.bind) - if output_format == OutputFormats.JSON: + if output_format == 'json': return df.to_json(orient=orient) - elif output_format == OutputFormats.DATAFRAME: + elif output_format == 'dataframe': return df - elif output_format == OutputFormats.DICT: + elif output_format == 'dict': return df.to_dict() raise TypeError("Unknown output format") @@ -117,7 +124,7 @@ def create_object(self, obj:object) -> Union[int, str]: return pkey_value def read_object(self, model:Type[Base], pkey:Union[int, str], - output_format: OutputFormats=OutputFormats.DICT, + output_format:str='dict', orient:str='records') -> Dict[str, Any]: with self.session_maker() as session: @@ -126,8 +133,14 @@ def read_object(self, model:Type[Base], pkey:Union[int, str], if obj is None: raise ObjectNotFound(f"No '{model.__name__}' object found with {pkey_name} = {pkey}") session.commit() + # convert obj_dict to a dictionary if it isn't one already obj_dict = obj.to_dict() - if output_format == OutputFormats.DICT: + + # guard against invalid output_format strings + if output_format not in OUTPUT_FORMATS: + raise ValueError(f':param output_format = {output_format}, which is not one of the following valid output_format strings: {OUTPUT_FORMATS}') + + if output_format == 'dict': return obj_dict else: @@ -139,9 +152,9 @@ def read_object(self, model:Type[Base], pkey:Union[int, str], obj_dict[key] = new_value obj_df = pd.DataFrame.from_dict(obj_dict) - if output_format == OutputFormats.DATAFRAME: + if output_format == 'dataframe': return obj_df - elif output_format == OutputFormats.JSON: + elif output_format == 'json': return obj_df.to_json(orient=orient) raise TypeError("Unknown output format")