Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better import error message tinydb observer. #554

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 9 additions & 125 deletions sacred/observers/tinydb_hashfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,71 +5,13 @@
absolute_import)

import os
import datetime as dt
import json
import uuid
import textwrap
import uuid
from collections import OrderedDict

from io import BufferedReader, FileIO

from sacred.__about__ import __version__
from sacred.observers import RunObserver
from sacred.commandline_options import CommandLineOption
import sacred.optional as opt

# Set data type values for abstract properties in Serializers
series_type = opt.pandas.Series if opt.has_pandas else None
dataframe_type = opt.pandas.DataFrame if opt.has_pandas else None
ndarray_type = opt.np.ndarray if opt.has_numpy else None


class BufferedReaderWrapper(BufferedReader):
"""Custom wrapper to allow for copying of file handle.

tinydb_serialisation currently does a deepcopy on all the content of the
dictionary before serialisation. By default, file handles are not
copiable so this wrapper is necessary to create a duplicate of the
file handle passes in.

Note that the file passed in will therefor remain open as the copy is the
one that gets closed.
"""

def __init__(self, f_obj):
f_obj = FileIO(f_obj.name)
super(BufferedReaderWrapper, self).__init__(f_obj)

def __copy__(self):
f = open(self.name, self.mode)
return BufferedReaderWrapper(f)

def __deepcopy__(self, memo):
f = open(self.name, self.mode)
return BufferedReaderWrapper(f)


def get_db_file_manager(root_dir):
fs = HashFS(os.path.join(root_dir, 'hashfs'), depth=3,
width=2, algorithm='md5')

# Setup Serialisation object for non list/dict objects
serialization_store = SerializationMiddleware()
serialization_store.register_serializer(DateTimeSerializer(), 'TinyDate')
serialization_store.register_serializer(FileSerializer(fs), 'TinyFile')

if opt.has_numpy:
serialization_store.register_serializer(NdArraySerializer(),
'TinyArray')
if opt.has_pandas:
serialization_store.register_serializer(DataFrameSerializer(),
'TinyDataFrame')
serialization_store.register_serializer(SeriesSerializer(),
'TinySeries')

db = TinyDB(os.path.join(root_dir, 'metadata.json'),
storage=serialization_store)
return db, fs
from sacred.observers import RunObserver


class TinyDbObserver(RunObserver):
Expand All @@ -78,7 +20,7 @@ class TinyDbObserver(RunObserver):

@staticmethod
def create(path='./runs_db', overwrite=None):

from .tinydb_hashfs_bases import get_db_file_manager
root_dir = os.path.abspath(path)
if not os.path.exists(root_dir):
os.makedirs(root_dir)
Expand All @@ -104,7 +46,7 @@ def save(self):
self.db_run_id = db_run_id

def save_sources(self, ex_info):

from .tinydb_hashfs_bases import BufferedReaderWrapper
source_info = []
for source_name, md5 in ex_info['sources']:

Expand Down Expand Up @@ -184,7 +126,7 @@ def failed_event(self, fail_time, fail_trace):
self.save()

def resource_event(self, filename):

from .tinydb_hashfs_bases import BufferedReaderWrapper
id_ = self.fs.put(filename).id
handle = BufferedReaderWrapper(open(filename, 'rb'))
resource = [filename, id_, handle]
Expand All @@ -194,7 +136,7 @@ def resource_event(self, filename):
self.save()

def artifact_event(self, name, filename, metadata=None, content_type=None):

from .tinydb_hashfs_bases import BufferedReaderWrapper
id_ = self.fs.put(filename).id
handle = BufferedReaderWrapper(open(filename, 'rb'))
artifact = [name, filename, id_, handle]
Expand Down Expand Up @@ -231,7 +173,7 @@ def parse_tinydb_arg(cls, args):
class TinyDbReader(object):

def __init__(self, path):

from .tinydb_hashfs_bases import get_db_file_manager
root_dir = os.path.abspath(path)
if not os.path.exists(root_dir):
raise IOError('Path does not exist: %s' % path)
Expand Down Expand Up @@ -367,6 +309,8 @@ def fetch_report(self, exp_name=None, query=None, indices=None):

def fetch_metadata(self, exp_name=None, query=None, indices=None):
"""Return all metadata for matching experiment name, index or query."""
from tinydb import Query
from tinydb.queries import QueryImpl
if exp_name or query:
if query:
assert type(query), QueryImpl
Expand Down Expand Up @@ -422,63 +366,3 @@ def _indent(self, message, prefix):
formatted_text = '\n'.join(formatted_lines)

return formatted_text


if opt.has_tinydb: # noqa
from tinydb import TinyDB, Query
from tinydb.queries import QueryImpl
from hashfs import HashFS
from tinydb_serialization import Serializer, SerializationMiddleware

class DateTimeSerializer(Serializer):
OBJ_CLASS = dt.datetime # The class this serializer handles

def encode(self, obj):
return obj.strftime('%Y-%m-%dT%H:%M:%S.%f')

def decode(self, s):
return dt.datetime.strptime(s, '%Y-%m-%dT%H:%M:%S.%f')

class NdArraySerializer(Serializer):
OBJ_CLASS = ndarray_type

def encode(self, obj):
return json.dumps(obj.tolist(), check_circular=True)

def decode(self, s):
return opt.np.array(json.loads(s))

class DataFrameSerializer(Serializer):
OBJ_CLASS = dataframe_type

def encode(self, obj):
return obj.to_json()

def decode(self, s):
return opt.pandas.read_json(s)

class SeriesSerializer(Serializer):
OBJ_CLASS = series_type

def encode(self, obj):
return obj.to_json()

def decode(self, s):
return opt.pandas.read_json(s, typ='series')

class FileSerializer(Serializer):
OBJ_CLASS = BufferedReaderWrapper

def __init__(self, fs):
self.fs = fs

def encode(self, obj):
address = self.fs.put(obj)
return json.dumps(address.id)

def decode(self, s):
id_ = json.loads(s)
file_reader = self.fs.open(id_)
file_reader = BufferedReaderWrapper(file_reader)
file_reader.hash = id_
return file_reader
121 changes: 121 additions & 0 deletions sacred/observers/tinydb_hashfs_bases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import datetime as dt
import json
import os
from io import BufferedReader, FileIO

from hashfs import HashFS
from tinydb import TinyDB
from tinydb_serialization import Serializer, SerializationMiddleware

import sacred.optional as opt

# Set data type values for abstract properties in Serializers
series_type = opt.pandas.Series if opt.has_pandas else None
dataframe_type = opt.pandas.DataFrame if opt.has_pandas else None
ndarray_type = opt.np.ndarray if opt.has_numpy else None


class BufferedReaderWrapper(BufferedReader):
"""Custom wrapper to allow for copying of file handle.

tinydb_serialisation currently does a deepcopy on all the content of the
dictionary before serialisation. By default, file handles are not
copiable so this wrapper is necessary to create a duplicate of the
file handle passes in.

Note that the file passed in will therefor remain open as the copy is the
one that gets closed.
"""

def __init__(self, f_obj):
f_obj = FileIO(f_obj.name)
super(BufferedReaderWrapper, self).__init__(f_obj)

def __copy__(self):
f = open(self.name, self.mode)
return BufferedReaderWrapper(f)

def __deepcopy__(self, memo):
f = open(self.name, self.mode)
return BufferedReaderWrapper(f)


class DateTimeSerializer(Serializer):
OBJ_CLASS = dt.datetime # The class this serializer handles

def encode(self, obj):
return obj.strftime('%Y-%m-%dT%H:%M:%S.%f')

def decode(self, s):
return dt.datetime.strptime(s, '%Y-%m-%dT%H:%M:%S.%f')


class NdArraySerializer(Serializer):
OBJ_CLASS = ndarray_type

def encode(self, obj):
return json.dumps(obj.tolist(), check_circular=True)

def decode(self, s):
return opt.np.array(json.loads(s))


class DataFrameSerializer(Serializer):
OBJ_CLASS = dataframe_type

def encode(self, obj):
return obj.to_json()

def decode(self, s):
return opt.pandas.read_json(s)


class SeriesSerializer(Serializer):
OBJ_CLASS = series_type

def encode(self, obj):
return obj.to_json()

def decode(self, s):
return opt.pandas.read_json(s, typ='series')


class FileSerializer(Serializer):
OBJ_CLASS = BufferedReaderWrapper

def __init__(self, fs):
self.fs = fs

def encode(self, obj):
address = self.fs.put(obj)
return json.dumps(address.id)

def decode(self, s):
id_ = json.loads(s)
file_reader = self.fs.open(id_)
file_reader = BufferedReaderWrapper(file_reader)
file_reader.hash = id_
return file_reader


def get_db_file_manager(root_dir):
fs = HashFS(os.path.join(root_dir, 'hashfs'), depth=3,
width=2, algorithm='md5')

# Setup Serialisation object for non list/dict objects
serialization_store = SerializationMiddleware()
serialization_store.register_serializer(DateTimeSerializer(), 'TinyDate')
serialization_store.register_serializer(FileSerializer(fs), 'TinyFile')

if opt.has_numpy:
serialization_store.register_serializer(NdArraySerializer(),
'TinyArray')
if opt.has_pandas:
serialization_store.register_serializer(DataFrameSerializer(),
'TinyDataFrame')
serialization_store.register_serializer(SeriesSerializer(),
'TinySeries')

db = TinyDB(os.path.join(root_dir, 'metadata.json'),
storage=serialization_store)
return db, fs
11 changes: 6 additions & 5 deletions tests/test_observers/test_tinydb_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
from hashfs import HashFS

from sacred.dependencies import get_digest
from sacred.observers.tinydb_hashfs import (TinyDbObserver, TinyDbOption,
BufferedReaderWrapper)
from sacred.observers.tinydb_hashfs import TinyDbObserver, TinyDbOption
from sacred.observers.tinydb_hashfs_bases import BufferedReaderWrapper

from sacred import optional as opt
from sacred.experiment import Experiment

Expand Down Expand Up @@ -307,7 +308,7 @@ def test_custom_bufferreaderwrapper(tmpdir):

@pytest.mark.skipif(not opt.has_numpy, reason='needs numpy')
def test_serialisation_of_numpy_ndarray(tmpdir):
from sacred.observers.tinydb_hashfs import NdArraySerializer
from sacred.observers.tinydb_hashfs_bases import NdArraySerializer
from tinydb_serialization import SerializationMiddleware
import numpy as np

Expand Down Expand Up @@ -339,8 +340,8 @@ def test_serialisation_of_numpy_ndarray(tmpdir):

@pytest.mark.skipif(not opt.has_pandas, reason='needs pandas')
def test_serialisation_of_pandas_dataframe(tmpdir):
from sacred.observers.tinydb_hashfs import (DataFrameSerializer,
SeriesSerializer)
from sacred.observers.tinydb_hashfs_bases import DataFrameSerializer
from sacred.observers.tinydb_hashfs_bases import SeriesSerializer
from tinydb_serialization import SerializationMiddleware

import numpy as np
Expand Down
25 changes: 25 additions & 0 deletions tests/test_observers/test_tinydb_observer_not_installed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest
from sacred.optional import has_tinydb
from sacred.observers import TinyDbObserver
from sacred import Experiment


@pytest.fixture
def ex():
return Experiment('ator3000')


@pytest.mark.skipif(has_tinydb, reason='We are testing the import error.')
def test_importerror_sql(ex):
with pytest.raises(ImportError):
ex.observers.append(TinyDbObserver.create('some_uri'))

@ex.config
def cfg():
a = {'b': 1}

@ex.main
def foo(a):
return a['b']

ex.run()