Skip to content

Commit

Permalink
Merge 2f66fd1 into 2b1e756
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrieldemarmiesse committed Aug 7, 2019
2 parents 2b1e756 + 2f66fd1 commit 1d1bea3
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 13 deletions.
20 changes: 12 additions & 8 deletions sacred/observers/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import json
from threading import Lock
import warnings

from sacred.commandline_options import CommandLineOption
from sacred.observers.base import RunObserver
Expand All @@ -16,17 +17,20 @@
class SqlObserver(RunObserver):
@classmethod
def create(cls, url, echo=False, priority=DEFAULT_SQL_PRIORITY):
warnings.warn('SqlObserver.create() is deprecated in favor of'
'SqlObserver().')
return cls(url, echo=echo, priority=priority)

def __init__(self, url=None, *, echo=False, engine=None, session=None,
priority=DEFAULT_SQL_PRIORITY):
from sqlalchemy.orm import sessionmaker, scoped_session
import sqlalchemy as sa
engine = sa.create_engine(url, echo=echo)
session_factory = sessionmaker(bind=engine)
# make session thread-local to avoid problems with sqlite (see #275)
session = scoped_session(session_factory)
return cls(engine, session, priority)

def __init__(self, engine, session, priority=DEFAULT_SQL_PRIORITY):
self.engine = engine
self.session = session
self.engine = engine or sa.create_engine(url, echo=echo)
if session is None:
self.session = scoped_session(sessionmaker(bind=self.engine))
else:
self.session = session
self.priority = priority
self.run = None
self.lock = Lock()
Expand Down
8 changes: 4 additions & 4 deletions tests/test_observers/test_sql_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def session(engine):

@pytest.fixture
def sql_obs(session, engine):
return SqlObserver(engine, session)
return SqlObserver(engine=engine, session=session)


@pytest.fixture
Expand Down Expand Up @@ -227,7 +227,7 @@ def test_fs_observer_resource_event(sql_obs, sample_run, session, tmpfile):


def test_fs_observer_doesnt_duplicate_sources(sql_obs, sample_run, session, tmpfile):
sql_obs2 = SqlObserver(sql_obs.engine, session)
sql_obs2 = SqlObserver(engine=sql_obs.engine, session=session)
sample_run['_id'] = None
sample_run['ex_info']['sources'] = [[tmpfile.name, tmpfile.md5sum]]

Expand All @@ -239,7 +239,7 @@ def test_fs_observer_doesnt_duplicate_sources(sql_obs, sample_run, session, tmpf


def test_fs_observer_doesnt_duplicate_resources(sql_obs, sample_run, session, tmpfile):
sql_obs2 = SqlObserver(sql_obs.engine, session)
sql_obs2 = SqlObserver(engine=sql_obs.engine, session=session)
sample_run['_id'] = None
sample_run['ex_info']['sources'] = [[tmpfile.name, tmpfile.md5sum]]

Expand All @@ -254,7 +254,7 @@ def test_fs_observer_doesnt_duplicate_resources(sql_obs, sample_run, session, tm


def test_sql_observer_equality(sql_obs, engine, session):
sql_obs2 = SqlObserver(engine, session)
sql_obs2 = SqlObserver(engine=engine, session=session)
assert sql_obs == sql_obs2

assert not sql_obs != sql_obs2
Expand Down
2 changes: 1 addition & 1 deletion tests/test_observers/test_sql_observer_not_installed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def ex():
@pytest.mark.skipif(has_sqlalchemy, reason='We are testing the import error.')
def test_importerror_sql(ex):
with pytest.raises(ImportError):
ex.observers.append(SqlObserver.create('some_uri'))
ex.observers.append(SqlObserver('some_uri'))

@ex.config
def cfg():
Expand Down

0 comments on commit 1d1bea3

Please sign in to comment.