Skip to content

Commit

Permalink
Merge e1d9dd2 into 2b1e756
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrieldemarmiesse authored Aug 7, 2019
2 parents 2b1e756 + e1d9dd2 commit 03a091b
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 13 deletions.
21 changes: 13 additions & 8 deletions sacred/observers/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import json
from threading import Lock
import warnings

from sacred.commandline_options import CommandLineOption
from sacred.observers.base import RunObserver
Expand All @@ -16,17 +17,21 @@
class SqlObserver(RunObserver):
@classmethod
def create(cls, url, echo=False, priority=DEFAULT_SQL_PRIORITY):
warnings.warn('SqlObserver.create() is deprecated in favor of'
'SqlObserver().')
return cls(url, echo=echo, priority=priority)

def __init__(self, url=None, echo=False, priority=DEFAULT_SQL_PRIORITY,
mock_args=None):
from sqlalchemy.orm import sessionmaker, scoped_session
import sqlalchemy as sa
engine = sa.create_engine(url, echo=echo)
session_factory = sessionmaker(bind=engine)
# make session thread-local to avoid problems with sqlite (see #275)
session = scoped_session(session_factory)
return cls(engine, session, priority)

def __init__(self, engine, session, priority=DEFAULT_SQL_PRIORITY):
self.engine = engine
self.session = session
if mock_args is None:
self.engine = sa.create_engine(url, echo=echo)
self.session = scoped_session(sessionmaker(bind=self.engine))
else:
self.engine = mock_args['engine']
self.session = mock_args['session']
self.priority = priority
self.run = None
self.lock = Lock()
Expand Down
11 changes: 7 additions & 4 deletions tests/test_observers/test_sql_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def session(engine):

@pytest.fixture
def sql_obs(session, engine):
return SqlObserver(engine, session)
return SqlObserver(mock_args={'engine': engine, 'session': session})


@pytest.fixture
Expand Down Expand Up @@ -227,7 +227,8 @@ def test_fs_observer_resource_event(sql_obs, sample_run, session, tmpfile):


def test_fs_observer_doesnt_duplicate_sources(sql_obs, sample_run, session, tmpfile):
sql_obs2 = SqlObserver(sql_obs.engine, session)
sql_obs2 = SqlObserver(mock_args={'engine': sql_obs.engine,
'session': session})
sample_run['_id'] = None
sample_run['ex_info']['sources'] = [[tmpfile.name, tmpfile.md5sum]]

Expand All @@ -239,7 +240,8 @@ def test_fs_observer_doesnt_duplicate_sources(sql_obs, sample_run, session, tmpf


def test_fs_observer_doesnt_duplicate_resources(sql_obs, sample_run, session, tmpfile):
sql_obs2 = SqlObserver(sql_obs.engine, session)
sql_obs2 = SqlObserver(mock_args={'engine': sql_obs.engine,
'session': session})
sample_run['_id'] = None
sample_run['ex_info']['sources'] = [[tmpfile.name, tmpfile.md5sum]]

Expand All @@ -254,7 +256,8 @@ def test_fs_observer_doesnt_duplicate_resources(sql_obs, sample_run, session, tm


def test_sql_observer_equality(sql_obs, engine, session):
sql_obs2 = SqlObserver(engine, session)
sql_obs2 = SqlObserver(mock_args={'engine': sql_obs.engine,
'session': session})
assert sql_obs == sql_obs2

assert not sql_obs != sql_obs2
Expand Down
2 changes: 1 addition & 1 deletion tests/test_observers/test_sql_observer_not_installed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def ex():
@pytest.mark.skipif(has_sqlalchemy, reason='We are testing the import error.')
def test_importerror_sql(ex):
with pytest.raises(ImportError):
ex.observers.append(SqlObserver.create('some_uri'))
ex.observers.append(SqlObserver('some_uri'))

@ex.config
def cfg():
Expand Down

0 comments on commit 03a091b

Please sign in to comment.