Skip to content

Commit

Permalink
Merge 6c032c4 into 2b1e756
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrieldemarmiesse committed Aug 7, 2019
2 parents 2b1e756 + 6c032c4 commit 05d5404
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 16 deletions.
32 changes: 21 additions & 11 deletions sacred/observers/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import json
from threading import Lock
import warnings

from sacred.commandline_options import CommandLineOption
from sacred.observers.base import RunObserver
Expand All @@ -13,16 +14,7 @@

# ############################# Observer #################################### #

class SqlObserver(RunObserver):
@classmethod
def create(cls, url, echo=False, priority=DEFAULT_SQL_PRIORITY):
from sqlalchemy.orm import sessionmaker, scoped_session
import sqlalchemy as sa
engine = sa.create_engine(url, echo=echo)
session_factory = sessionmaker(bind=engine)
# make session thread-local to avoid problems with sqlite (see #275)
session = scoped_session(session_factory)
return cls(engine, session, priority)
class PlainSQLObserver(RunObserver):

def __init__(self, engine, session, priority=DEFAULT_SQL_PRIORITY):
self.engine = engine
Expand Down Expand Up @@ -111,13 +103,31 @@ def query(self, _id):
return run.to_json()

def __eq__(self, other):
if isinstance(other, SqlObserver):
if isinstance(other, PlainSQLObserver):
# fixme: this will probably fail to detect two equivalent engines
return (self.engine == other.engine and
self.session == other.session)
return False


class SqlObserver(PlainSQLObserver):
@classmethod
def create(cls, *args, **kwargs):
warnings.warn("Use of the create method is depreciated. Please use"
"SqlObserver(...) instead of SqlObserver.create(...).",
DeprecationWarning)
return cls(*args, **kwargs)

def __init__(self, url, echo=False, priority=DEFAULT_SQL_PRIORITY):
from sqlalchemy.orm import sessionmaker, scoped_session
import sqlalchemy as sa
engine = sa.create_engine(url, echo=echo)
session_factory = sessionmaker(bind=engine)
# make session thread-local to avoid problems with sqlite (see #275)
session = scoped_session(session_factory)
super().__init__(engine, session, priority)


# ######################## Commandline Option ############################### #

class SqlOption(CommandLineOption):
Expand Down
10 changes: 5 additions & 5 deletions tests/test_observers/test_sql_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

sqlalchemy = pytest.importorskip("sqlalchemy")

from sacred.observers.sql import SqlObserver
from sacred.observers.sql import PlainSQLObserver
from sacred.observers.sql_bases import Host, Experiment, Run, Source, Resource


Expand Down Expand Up @@ -46,7 +46,7 @@ def session(engine):

@pytest.fixture
def sql_obs(session, engine):
return SqlObserver(engine, session)
return PlainSQLObserver(engine, session)


@pytest.fixture
Expand Down Expand Up @@ -227,7 +227,7 @@ def test_fs_observer_resource_event(sql_obs, sample_run, session, tmpfile):


def test_fs_observer_doesnt_duplicate_sources(sql_obs, sample_run, session, tmpfile):
sql_obs2 = SqlObserver(sql_obs.engine, session)
sql_obs2 = PlainSQLObserver(sql_obs.engine, session)
sample_run['_id'] = None
sample_run['ex_info']['sources'] = [[tmpfile.name, tmpfile.md5sum]]

Expand All @@ -239,7 +239,7 @@ def test_fs_observer_doesnt_duplicate_sources(sql_obs, sample_run, session, tmpf


def test_fs_observer_doesnt_duplicate_resources(sql_obs, sample_run, session, tmpfile):
sql_obs2 = SqlObserver(sql_obs.engine, session)
sql_obs2 = PlainSQLObserver(sql_obs.engine, session)
sample_run['_id'] = None
sample_run['ex_info']['sources'] = [[tmpfile.name, tmpfile.md5sum]]

Expand All @@ -254,7 +254,7 @@ def test_fs_observer_doesnt_duplicate_resources(sql_obs, sample_run, session, tm


def test_sql_observer_equality(sql_obs, engine, session):
sql_obs2 = SqlObserver(engine, session)
sql_obs2 = PlainSQLObserver(engine, session)
assert sql_obs == sql_obs2

assert not sql_obs != sql_obs2
Expand Down

0 comments on commit 05d5404

Please sign in to comment.