From 420e4ed2c8c7969b94b42984f4941f9305c22c72 Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Wed, 7 Aug 2019 16:11:33 +0200 Subject: [PATCH 1/2] Trying to be backward compatible. --- sacred/observers/sql.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/sacred/observers/sql.py b/sacred/observers/sql.py index fedde548..3cb6ba9c 100644 --- a/sacred/observers/sql.py +++ b/sacred/observers/sql.py @@ -3,6 +3,7 @@ import json from threading import Lock +import warnings from sacred.commandline_options import CommandLineOption from sacred.observers.base import RunObserver @@ -16,17 +17,20 @@ class SqlObserver(RunObserver): @classmethod def create(cls, url, echo=False, priority=DEFAULT_SQL_PRIORITY): + warnings.warn('SqlObserver.create() is deprecated in favor of' + 'SqlObserver().') + return cls(url, echo=echo, priority=priority) + + def __init__(self, url=None, *, echo=False, engine=None, session=None, + priority=DEFAULT_SQL_PRIORITY): from sqlalchemy.orm import sessionmaker, scoped_session import sqlalchemy as sa - engine = sa.create_engine(url, echo=echo) - session_factory = sessionmaker(bind=engine) # make session thread-local to avoid problems with sqlite (see #275) - session = scoped_session(session_factory) - return cls(engine, session, priority) - - def __init__(self, engine, session, priority=DEFAULT_SQL_PRIORITY): - self.engine = engine - self.session = session + self.engine = engine or sa.create_engine(url, echo=echo) + if session is None: + self.session = scoped_session(sessionmaker(bind=self.engine)) + else: + self.session = session self.priority = priority self.run = None self.lock = Lock() From 2f66fd13d94b2df80ce3c0dbb69b7894a3458a91 Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Wed, 7 Aug 2019 16:13:42 +0200 Subject: [PATCH 2/2] Fixed tests. --- tests/test_observers/test_sql_observer.py | 8 ++++---- tests/test_observers/test_sql_observer_not_installed.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_observers/test_sql_observer.py b/tests/test_observers/test_sql_observer.py index 06de897f..62b0a5e9 100644 --- a/tests/test_observers/test_sql_observer.py +++ b/tests/test_observers/test_sql_observer.py @@ -46,7 +46,7 @@ def session(engine): @pytest.fixture def sql_obs(session, engine): - return SqlObserver(engine, session) + return SqlObserver(engine=engine, session=session) @pytest.fixture @@ -227,7 +227,7 @@ def test_fs_observer_resource_event(sql_obs, sample_run, session, tmpfile): def test_fs_observer_doesnt_duplicate_sources(sql_obs, sample_run, session, tmpfile): - sql_obs2 = SqlObserver(sql_obs.engine, session) + sql_obs2 = SqlObserver(engine=sql_obs.engine, session=session) sample_run['_id'] = None sample_run['ex_info']['sources'] = [[tmpfile.name, tmpfile.md5sum]] @@ -239,7 +239,7 @@ def test_fs_observer_doesnt_duplicate_sources(sql_obs, sample_run, session, tmpf def test_fs_observer_doesnt_duplicate_resources(sql_obs, sample_run, session, tmpfile): - sql_obs2 = SqlObserver(sql_obs.engine, session) + sql_obs2 = SqlObserver(engine=sql_obs.engine, session=session) sample_run['_id'] = None sample_run['ex_info']['sources'] = [[tmpfile.name, tmpfile.md5sum]] @@ -254,7 +254,7 @@ def test_fs_observer_doesnt_duplicate_resources(sql_obs, sample_run, session, tm def test_sql_observer_equality(sql_obs, engine, session): - sql_obs2 = SqlObserver(engine, session) + sql_obs2 = SqlObserver(engine=engine, session=session) assert sql_obs == sql_obs2 assert not sql_obs != sql_obs2 diff --git a/tests/test_observers/test_sql_observer_not_installed.py b/tests/test_observers/test_sql_observer_not_installed.py index a386cdd3..14432a13 100644 --- a/tests/test_observers/test_sql_observer_not_installed.py +++ b/tests/test_observers/test_sql_observer_not_installed.py @@ -12,7 +12,7 @@ def ex(): @pytest.mark.skipif(has_sqlalchemy, reason='We are testing the import error.') def test_importerror_sql(ex): with pytest.raises(ImportError): - ex.observers.append(SqlObserver.create('some_uri')) + ex.observers.append(SqlObserver('some_uri')) @ex.config def cfg():