Skip to content

Commit

Permalink
Merge 988d9f5 into d3fe102
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrieldemarmiesse committed Aug 5, 2019
2 parents d3fe102 + 988d9f5 commit 7d27f79
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 136 deletions.
138 changes: 70 additions & 68 deletions sacred/commandline_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,34 +33,37 @@ class CommandLineOption:
if the packages are not available.
"""

_enabled = True

short_flag = None
""" The (one-letter) short form (defaults to first letter of flag) """

arg = None
""" Name of the argument (optional) """

arg_description = None
""" Description of the argument (optional) """
def __init__(self, enabled=True, short_flag=None,
arg=None, arg_description=None):
""""
Parameters
----------
short_flag : str
The (one-letter) short form (defaults to first letter of flag)
arg : str
Name of the argument
arg_description : str
Description of the argument
"""
self.enabled = enabled
self.short_flag = short_flag
self.arg = arg
self.arg_description = arg_description

@classmethod
def get_flag(cls):
def get_flag(self):
# Get the flag name from the class name
flag = cls.__name__
flag = self.__class__.__name__
if flag.endswith("Option"):
flag = flag[:-6]
return '--' + convert_camel_case_to_snake_case(flag)

@classmethod
def get_short_flag(cls):
if cls.short_flag is None:
return '-' + cls.get_flag()[2]
def get_short_flag(self):
if self.short_flag is None:
return '-' + self.get_flag()[2]
else:
return '-' + cls.short_flag
return '-' + self.short_flag

@classmethod
def get_flags(cls):
def get_flags(self):
"""
Return the short and the long version of this option.
Expand All @@ -78,10 +81,9 @@ def get_flags(cls):
tuple of short-flag, and long-flag
"""
return cls.get_short_flag(), cls.get_flag()
return self.get_short_flag(), self.get_flag()

@classmethod
def apply(cls, args, run):
def apply(self, args, run):
"""
Modify the current Run base on this command-line option.
Expand All @@ -98,16 +100,16 @@ def apply(cls, args, run):
The current run to be modified
"""
pass
raise NotImplementedError


def gather_command_line_options(filter_disabled=None):
"""Get a sorted list of all CommandLineOption subclasses."""
if filter_disabled is None:
filter_disabled = not SETTINGS.COMMAND_LINE.SHOW_DISABLED_OPTIONS
options = [opt for opt in get_inheritors(CommandLineOption)
options = [opt() for opt in get_inheritors(CommandLineOption)
if not filter_disabled or opt._enabled]
return sorted(options, key=lambda opt: opt.__name__)
return sorted(options, key=lambda opt: opt.__class__.__name__)


class HelpOption(CommandLineOption):
Expand All @@ -121,31 +123,32 @@ class DebugOption(CommandLineOption):
Also enables usage with ipython --pdb.
"""

@classmethod
def apply(cls, args, run):
def apply(self, args, run):
"""Set this run to debug mode."""
run.debug = True


class PDBOption(CommandLineOption):
"""Automatically enter post-mortem debugging with pdb on failure."""

short_flag = 'D'
def __init__(self):
super().__init__(short_flag='D')

@classmethod
def apply(cls, args, run):
def apply(self, args, run):
run.pdb = True


class LoglevelOption(CommandLineOption):
"""Adjust the loglevel."""

arg = 'LEVEL'
arg_description = 'Loglevel either as 0 - 50 or as string: DEBUG(10), ' \
'INFO(20), WARNING(30), ERROR(40), CRITICAL(50)'
def __init__(self):
super().__init__(
arg='LEVEL',
arg_description='Loglevel either as 0 - 50 or as string: '
'DEBUG(10), INFO(20), WARNING(30), '
'ERROR(40), CRITICAL(50)')

@classmethod
def apply(cls, args, run):
def apply(self, args, run):
"""Adjust the loglevel of the root-logger of this run."""
# TODO: sacred.initialize.create_run already takes care of this

Expand All @@ -159,63 +162,63 @@ def apply(cls, args, run):
class CommentOption(CommandLineOption):
"""Adds a message to the run."""

arg = 'COMMENT'
arg_description = 'A comment that should be stored along with the run.'
def __init__(self):
super().__init__(arg='COMMENT',
arg_description='A comment that should be stored '
'along with the run.')

@classmethod
def apply(cls, args, run):
def apply(self, args, run):
"""Add a comment to this run."""
run.meta_info['comment'] = args


class BeatIntervalOption(CommandLineOption):
"""Control the rate of heartbeat events."""

arg = 'BEAT_INTERVAL'
arg_description = "Time between two heartbeat events measured in seconds."
def __init__(self):
super().__init__(arg='BEAT_INTERVAL',
arg_description="Time between two heartbeat "
"events measured in seconds.")

@classmethod
def apply(cls, args, run):
def apply(self, args, run):
"""Set the heart-beat interval for this run."""
run.beat_interval = float(args)


class UnobservedOption(CommandLineOption):
"""Ignore all observers for this run."""

@classmethod
def apply(cls, args, run):
def apply(self, args, run):
"""Set this run to unobserved mode."""
run.unobserved = True


class QueueOption(CommandLineOption):
"""Only queue this run, do not start it."""

@classmethod
def apply(cls, args, run):
def apply(self, args, run):
"""Set this run to queue only mode."""
run.queue_only = True


class ForceOption(CommandLineOption):
"""Disable warnings about suspicious changes for this run."""

@classmethod
def apply(cls, args, run):
def apply(self, args, run):
"""Set this run to not warn about suspicous changes."""
run.force = True


class PriorityOption(CommandLineOption):
"""Sets the priority for a queued up experiment."""

short_flag = 'P'
arg = 'PRIORITY'
arg_description = 'The (numeric) priority for this run.'
def __init__(self):
super().__init__(
short_flag='P',
arg='PRIORITY',
arg_description='The (numeric) priority for this run.')

@classmethod
def apply(cls, args, run):
def apply(self, args, run):
"""Add priority info for this run."""
try:
priority = float(args)
Expand All @@ -228,8 +231,7 @@ def apply(cls, args, run):
class EnforceCleanOption(CommandLineOption):
"""Fail if any version control repository is dirty."""

@classmethod
def apply(cls, args, run):
def apply(self, args, run):
try:
import git # NOQA
except ImportError:
Expand All @@ -252,31 +254,31 @@ def apply(cls, args, run):
class PrintConfigOption(CommandLineOption):
"""Always print the configuration first."""

@classmethod
def apply(cls, args, run):
def apply(self, args, run):
print_config(run)
print('-' * 79)


class NameOption(CommandLineOption):
"""Set the name for this run."""

arg = 'NAME'
arg_description = 'Name for this run.'
def __init__(self):
super().__init__(arg='NAME',
arg_description='Name for this run.')

@classmethod
def apply(cls, args, run):
def apply(self, args, run):
run.experiment_info['name'] = args
run.run_logger = run.root_logger.getChild(args)


class CaptureOption(CommandLineOption):
"""Control the way stdout and stderr are captured."""

short_flag = 'C'
arg = 'CAPTURE_MODE'
arg_description = "stdout/stderr capture mode. One of [no, sys, fd]"
def __init__(self):
super().__init__(short_flag='C',
arg='CAPTURE_MODE',
arg_description="stdout/stderr capture mode. "
"One of [no, sys, fd]")

@classmethod
def apply(cls, args, run):
def apply(self, args, run):
run.capture_mode = args
4 changes: 2 additions & 2 deletions sacred/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,8 +450,8 @@ def _create_run(self, command_name=None, config_updates=None,

run = create_run(self, command_name, config_updates,
named_configs=named_configs,
force=options.get(ForceOption.get_flag(), False),
log_level=options.get(LoglevelOption.get_flag(),
force=options.get(ForceOption().get_flag(), False),
log_level=options.get(LoglevelOption().get_flag(),
None))
if info is not None:
run.info.update(info)
Expand Down
81 changes: 42 additions & 39 deletions sacred/observers/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,48 +352,51 @@ def __eq__(self, other):
class MongoDbOption(CommandLineOption):
"""Add a MongoDB Observer to the experiment."""

arg = 'DB'
arg_description = "Database specification. Can be " \
"[host:port:]db_name[.collection[:id]][!priority]"

RUN_ID_PATTERN = r"(?P<overwrite>\d{1,12})"
PORT1_PATTERN = r"(?P<port1>\d{1,5})"
PORT2_PATTERN = r"(?P<port2>\d{1,5})"
PRIORITY_PATTERN = r"(?P<priority>-?\d+)?"
DB_NAME_PATTERN = r"(?P<db_name>[_A-Za-z]" \
r"[0-9A-Za-z#%&'()+\-;=@\[\]^_{}]{0,63})"
COLL_NAME_PATTERN = r"(?P<collection>[_A-Za-z]" \
r"[0-9A-Za-z#%&'()+\-;=@\[\]^_{}]{0,63})"
HOSTNAME1_PATTERN = r"(?P<host1>" \
r"[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?" \
r"(?:\.[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}" \
r"[0-9A-Za-z])?)*)"
HOSTNAME2_PATTERN = r"(?P<host2>" \
r"[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?" \
r"(?:\.[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}" \
r"[0-9A-Za-z])?)*)"

HOST_ONLY = r"^(?:{host}:{port})$".format(host=HOSTNAME1_PATTERN,
port=PORT1_PATTERN)
FULL = r"^(?:{host}:{port}:)?{db}(?:\.{collection}(?::{rid})?)?" \
r"(?:!{priority})?$".format(host=HOSTNAME2_PATTERN,
port=PORT2_PATTERN,
db=DB_NAME_PATTERN,
collection=COLL_NAME_PATTERN,
rid=RUN_ID_PATTERN,
priority=PRIORITY_PATTERN)

PATTERN = r"{host_only}|{full}".format(host_only=HOST_ONLY, full=FULL)

@classmethod
def apply(cls, args, run):
kwargs = cls.parse_mongo_db_arg(args)
def __init__(self):
run_id_pattern = r"(?P<overwrite>\d{1,12})"
port1_pattern = r"(?P<port1>\d{1,5})"
port2_pattern = r"(?P<port2>\d{1,5})"
priority_pattern = r"(?P<priority>-?\d+)?"
db_name_pattern = r"(?P<db_name>[_A-Za-z]" \
r"[0-9A-Za-z#%&'()+\-;=@\[\]^_{}]{0,63})"
coll_name_pattern = r"(?P<collection>[_A-Za-z]" \
r"[0-9A-Za-z#%&'()+\-;=@\[\]^_{}]{0,63})"
hostname1_pattern = (
r"(?P<host1>"
r"[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?"
r"(?:\.[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}"
r"[0-9A-Za-z])?)*)")
hostname2_pattern = (
r"(?P<host2>"
r"[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?"
r"(?:\.[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}"
r"[0-9A-Za-z])?)*)")

host_only = r"^(?:{host}:{port})$".format(host=hostname1_pattern,
port=port1_pattern)
full = r"^(?:{host}:{port}:)?{db}(?:\.{collection}(?::{rid})?)?" \
r"(?:!{priority})?$".format(host=hostname2_pattern,
port=port2_pattern,
db=db_name_pattern,
collection=coll_name_pattern,
rid=run_id_pattern,
priority=priority_pattern)

self.pattern = r"{host_only}|{full}".format(host_only=host_only,
full=full)
super().__init__(
arg='DB',
arg_description="Database specification. Can be "
"[host:port:]db_name"
"[.collection[:id]][!priority]")

def apply(self, args, run):
kwargs = self.parse_mongo_db_arg(args)
mongo = MongoObserver.create(**kwargs)
run.observers.append(mongo)

@classmethod
def parse_mongo_db_arg(cls, mongo_db):
g = re.match(cls.PATTERN, mongo_db).groupdict()
def parse_mongo_db_arg(self, mongo_db):
g = re.match(self.pattern, mongo_db).groupdict()
if g is None:
raise ValueError('mongo_db argument must have the form "db_name" '
'or "host:port[:db_name]" but was {}'
Expand Down
11 changes: 6 additions & 5 deletions sacred/observers/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,11 @@ def __eq__(self, other):
class SqlOption(CommandLineOption):
"""Add a SQL Observer to the experiment."""

arg = 'DB_URL'
arg_description = \
"The typical form is: dialect://username:password@host:port/database"
def __init__(self):
super().__init__(arg='DB_URL',
arg_description="The typical form is: "
"dialect://username:password@host"
":port/database")

@classmethod
def apply(cls, args, run):
def apply(self, args, run):
run.observers.append(SqlObserver.create(args))
Loading

0 comments on commit 7d27f79

Please sign in to comment.