Skip to content

Commit

Permalink
Better import error message tinydb observer. (#554)
Browse files Browse the repository at this point in the history
* Better error message tinydb.

* Made the separation for tinyfs.

* Finished fixing imports.

* Fixed imports.
  • Loading branch information
gabrieldemarmiesse authored and JarnoRFB committed Aug 2, 2019
1 parent 7fc2cf6 commit 6b177b3
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 130 deletions.
134 changes: 9 additions & 125 deletions sacred/observers/tinydb_hashfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,71 +5,13 @@
absolute_import)

import os
import datetime as dt
import json
import uuid
import textwrap
import uuid
from collections import OrderedDict

from io import BufferedReader, FileIO

from sacred.__about__ import __version__
from sacred.observers import RunObserver
from sacred.commandline_options import CommandLineOption
import sacred.optional as opt

# Set data type values for abstract properties in Serializers
series_type = opt.pandas.Series if opt.has_pandas else None
dataframe_type = opt.pandas.DataFrame if opt.has_pandas else None
ndarray_type = opt.np.ndarray if opt.has_numpy else None


class BufferedReaderWrapper(BufferedReader):
"""Custom wrapper to allow for copying of file handle.
tinydb_serialisation currently does a deepcopy on all the content of the
dictionary before serialisation. By default, file handles are not
copiable so this wrapper is necessary to create a duplicate of the
file handle passes in.
Note that the file passed in will therefor remain open as the copy is the
one that gets closed.
"""

def __init__(self, f_obj):
f_obj = FileIO(f_obj.name)
super(BufferedReaderWrapper, self).__init__(f_obj)

def __copy__(self):
f = open(self.name, self.mode)
return BufferedReaderWrapper(f)

def __deepcopy__(self, memo):
f = open(self.name, self.mode)
return BufferedReaderWrapper(f)


def get_db_file_manager(root_dir):
fs = HashFS(os.path.join(root_dir, 'hashfs'), depth=3,
width=2, algorithm='md5')

# Setup Serialisation object for non list/dict objects
serialization_store = SerializationMiddleware()
serialization_store.register_serializer(DateTimeSerializer(), 'TinyDate')
serialization_store.register_serializer(FileSerializer(fs), 'TinyFile')

if opt.has_numpy:
serialization_store.register_serializer(NdArraySerializer(),
'TinyArray')
if opt.has_pandas:
serialization_store.register_serializer(DataFrameSerializer(),
'TinyDataFrame')
serialization_store.register_serializer(SeriesSerializer(),
'TinySeries')

db = TinyDB(os.path.join(root_dir, 'metadata.json'),
storage=serialization_store)
return db, fs
from sacred.observers import RunObserver


class TinyDbObserver(RunObserver):
Expand All @@ -78,7 +20,7 @@ class TinyDbObserver(RunObserver):

@staticmethod
def create(path='./runs_db', overwrite=None):

from .tinydb_hashfs_bases import get_db_file_manager
root_dir = os.path.abspath(path)
if not os.path.exists(root_dir):
os.makedirs(root_dir)
Expand All @@ -104,7 +46,7 @@ def save(self):
self.db_run_id = db_run_id

def save_sources(self, ex_info):

from .tinydb_hashfs_bases import BufferedReaderWrapper
source_info = []
for source_name, md5 in ex_info['sources']:

Expand Down Expand Up @@ -184,7 +126,7 @@ def failed_event(self, fail_time, fail_trace):
self.save()

def resource_event(self, filename):

from .tinydb_hashfs_bases import BufferedReaderWrapper
id_ = self.fs.put(filename).id
handle = BufferedReaderWrapper(open(filename, 'rb'))
resource = [filename, id_, handle]
Expand All @@ -194,7 +136,7 @@ def resource_event(self, filename):
self.save()

def artifact_event(self, name, filename, metadata=None, content_type=None):

from .tinydb_hashfs_bases import BufferedReaderWrapper
id_ = self.fs.put(filename).id
handle = BufferedReaderWrapper(open(filename, 'rb'))
artifact = [name, filename, id_, handle]
Expand Down Expand Up @@ -231,7 +173,7 @@ def parse_tinydb_arg(cls, args):
class TinyDbReader(object):

def __init__(self, path):

from .tinydb_hashfs_bases import get_db_file_manager
root_dir = os.path.abspath(path)
if not os.path.exists(root_dir):
raise IOError('Path does not exist: %s' % path)
Expand Down Expand Up @@ -367,6 +309,8 @@ def fetch_report(self, exp_name=None, query=None, indices=None):

def fetch_metadata(self, exp_name=None, query=None, indices=None):
"""Return all metadata for matching experiment name, index or query."""
from tinydb import Query
from tinydb.queries import QueryImpl
if exp_name or query:
if query:
assert type(query), QueryImpl
Expand Down Expand Up @@ -422,63 +366,3 @@ def _indent(self, message, prefix):
formatted_text = '\n'.join(formatted_lines)

return formatted_text


if opt.has_tinydb: # noqa
from tinydb import TinyDB, Query
from tinydb.queries import QueryImpl
from hashfs import HashFS
from tinydb_serialization import Serializer, SerializationMiddleware

class DateTimeSerializer(Serializer):
OBJ_CLASS = dt.datetime # The class this serializer handles

def encode(self, obj):
return obj.strftime('%Y-%m-%dT%H:%M:%S.%f')

def decode(self, s):
return dt.datetime.strptime(s, '%Y-%m-%dT%H:%M:%S.%f')

class NdArraySerializer(Serializer):
OBJ_CLASS = ndarray_type

def encode(self, obj):
return json.dumps(obj.tolist(), check_circular=True)

def decode(self, s):
return opt.np.array(json.loads(s))

class DataFrameSerializer(Serializer):
OBJ_CLASS = dataframe_type

def encode(self, obj):
return obj.to_json()

def decode(self, s):
return opt.pandas.read_json(s)

class SeriesSerializer(Serializer):
OBJ_CLASS = series_type

def encode(self, obj):
return obj.to_json()

def decode(self, s):
return opt.pandas.read_json(s, typ='series')

class FileSerializer(Serializer):
OBJ_CLASS = BufferedReaderWrapper

def __init__(self, fs):
self.fs = fs

def encode(self, obj):
address = self.fs.put(obj)
return json.dumps(address.id)

def decode(self, s):
id_ = json.loads(s)
file_reader = self.fs.open(id_)
file_reader = BufferedReaderWrapper(file_reader)
file_reader.hash = id_
return file_reader
121 changes: 121 additions & 0 deletions sacred/observers/tinydb_hashfs_bases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import datetime as dt
import json
import os
from io import BufferedReader, FileIO

from hashfs import HashFS
from tinydb import TinyDB
from tinydb_serialization import Serializer, SerializationMiddleware

import sacred.optional as opt

# Set data type values for abstract properties in Serializers
series_type = opt.pandas.Series if opt.has_pandas else None
dataframe_type = opt.pandas.DataFrame if opt.has_pandas else None
ndarray_type = opt.np.ndarray if opt.has_numpy else None


class BufferedReaderWrapper(BufferedReader):
"""Custom wrapper to allow for copying of file handle.
tinydb_serialisation currently does a deepcopy on all the content of the
dictionary before serialisation. By default, file handles are not
copiable so this wrapper is necessary to create a duplicate of the
file handle passes in.
Note that the file passed in will therefor remain open as the copy is the
one that gets closed.
"""

def __init__(self, f_obj):
f_obj = FileIO(f_obj.name)
super(BufferedReaderWrapper, self).__init__(f_obj)

def __copy__(self):
f = open(self.name, self.mode)
return BufferedReaderWrapper(f)

def __deepcopy__(self, memo):
f = open(self.name, self.mode)
return BufferedReaderWrapper(f)


class DateTimeSerializer(Serializer):
OBJ_CLASS = dt.datetime # The class this serializer handles

def encode(self, obj):
return obj.strftime('%Y-%m-%dT%H:%M:%S.%f')

def decode(self, s):
return dt.datetime.strptime(s, '%Y-%m-%dT%H:%M:%S.%f')


class NdArraySerializer(Serializer):
OBJ_CLASS = ndarray_type

def encode(self, obj):
return json.dumps(obj.tolist(), check_circular=True)

def decode(self, s):
return opt.np.array(json.loads(s))


class DataFrameSerializer(Serializer):
OBJ_CLASS = dataframe_type

def encode(self, obj):
return obj.to_json()

def decode(self, s):
return opt.pandas.read_json(s)


class SeriesSerializer(Serializer):
OBJ_CLASS = series_type

def encode(self, obj):
return obj.to_json()

def decode(self, s):
return opt.pandas.read_json(s, typ='series')


class FileSerializer(Serializer):
OBJ_CLASS = BufferedReaderWrapper

def __init__(self, fs):
self.fs = fs

def encode(self, obj):
address = self.fs.put(obj)
return json.dumps(address.id)

def decode(self, s):
id_ = json.loads(s)
file_reader = self.fs.open(id_)
file_reader = BufferedReaderWrapper(file_reader)
file_reader.hash = id_
return file_reader


def get_db_file_manager(root_dir):
fs = HashFS(os.path.join(root_dir, 'hashfs'), depth=3,
width=2, algorithm='md5')

# Setup Serialisation object for non list/dict objects
serialization_store = SerializationMiddleware()
serialization_store.register_serializer(DateTimeSerializer(), 'TinyDate')
serialization_store.register_serializer(FileSerializer(fs), 'TinyFile')

if opt.has_numpy:
serialization_store.register_serializer(NdArraySerializer(),
'TinyArray')
if opt.has_pandas:
serialization_store.register_serializer(DataFrameSerializer(),
'TinyDataFrame')
serialization_store.register_serializer(SeriesSerializer(),
'TinySeries')

db = TinyDB(os.path.join(root_dir, 'metadata.json'),
storage=serialization_store)
return db, fs
11 changes: 6 additions & 5 deletions tests/test_observers/test_tinydb_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
from hashfs import HashFS

from sacred.dependencies import get_digest
from sacred.observers.tinydb_hashfs import (TinyDbObserver, TinyDbOption,
BufferedReaderWrapper)
from sacred.observers.tinydb_hashfs import TinyDbObserver, TinyDbOption
from sacred.observers.tinydb_hashfs_bases import BufferedReaderWrapper

from sacred import optional as opt
from sacred.experiment import Experiment

Expand Down Expand Up @@ -307,7 +308,7 @@ def test_custom_bufferreaderwrapper(tmpdir):

@pytest.mark.skipif(not opt.has_numpy, reason='needs numpy')
def test_serialisation_of_numpy_ndarray(tmpdir):
from sacred.observers.tinydb_hashfs import NdArraySerializer
from sacred.observers.tinydb_hashfs_bases import NdArraySerializer
from tinydb_serialization import SerializationMiddleware
import numpy as np

Expand Down Expand Up @@ -339,8 +340,8 @@ def test_serialisation_of_numpy_ndarray(tmpdir):

@pytest.mark.skipif(not opt.has_pandas, reason='needs pandas')
def test_serialisation_of_pandas_dataframe(tmpdir):
from sacred.observers.tinydb_hashfs import (DataFrameSerializer,
SeriesSerializer)
from sacred.observers.tinydb_hashfs_bases import DataFrameSerializer
from sacred.observers.tinydb_hashfs_bases import SeriesSerializer
from tinydb_serialization import SerializationMiddleware

import numpy as np
Expand Down
25 changes: 25 additions & 0 deletions tests/test_observers/test_tinydb_observer_not_installed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest
from sacred.optional import has_tinydb
from sacred.observers import TinyDbObserver
from sacred import Experiment


@pytest.fixture
def ex():
return Experiment('ator3000')


@pytest.mark.skipif(has_tinydb, reason='We are testing the import error.')
def test_importerror_sql(ex):
with pytest.raises(ImportError):
ex.observers.append(TinyDbObserver.create('some_uri'))

@ex.config
def cfg():
a = {'b': 1}

@ex.main
def foo(a):
return a['b']

ex.run()

0 comments on commit 6b177b3

Please sign in to comment.