Skip to content

Commit

Permalink
Update bool
Browse files Browse the repository at this point in the history
  • Loading branch information
Sebastien Iooss committed Jul 26, 2019
1 parent 1fce105 commit 2b57c9a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 14 deletions.
26 changes: 12 additions & 14 deletions distribute_config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,12 @@ def define_var(self, name, default, description, type, is_list=False, possible_v
variable = Variable(name, default, description, type, is_list, possible_values)
self.__add_variables(variable)

if type == bool:
if default:
# Add an argument to desactivate the arg
self.parser.add_argument("--no" + variable.name, dest=variable.name, action='store_false', help=variable.description)
else:
# Add an argument to activate the arg
self.parser.add_argument("--" + variable.name, dest=variable.name, action='store_true', help=variable.description)
elif possible_values is not None:
if possible_values is not None:
self.parser.add_argument("--" + variable.name, type=type, help=f"{variable.description}, need to be in {possible_values}")
else:
# Bool is a special case : we want to use str to specify the value.
if type == bool:
type = str
self.parser.add_argument("--" + variable.name, type=type, help=variable.description)

def __add_variables(self, variable: Variable):
Expand Down Expand Up @@ -112,9 +108,6 @@ def define_bool(cls, var_name, default, description):
def define_enum(cls, var_name, default, possible_values, desciption):
cls.__instance.define_var(var_name, default, desciption, str, possible_values=possible_values)




@classmethod
def get_var(cls, name):
path = name.split(".")
Expand Down Expand Up @@ -160,7 +153,7 @@ def load_conf(cls, config_file_name="config.yml", auto_update_yml=True, no_conf_
"""
args = cls.__instance.parser.parse_args()

if args.c :
if args.c:
config_file_name = args.c
if not no_conf_file:
if not os.path.exists(config_file_name):
Expand All @@ -183,16 +176,21 @@ def load_conf(cls, config_file_name="config.yml", auto_update_yml=True, no_conf_
path = ".".join(var.lower().split("__"))
try:
cls.set_var(path, os.environ[var])
logging.info(f"Load env variable {var}")
logging.info(f"Load env variable {var}={os.environ[var]}")
except KeyError:
pass

# 3
for key in vars(args):
if key == "c" or vars(args)[key] is None:
continue

# # Special case for store_true and store_false
# if cls.get_var(key).type == bool:


cls.set_var(key, vars(args)[key])
logging.info(f"Load variable {vars(args)[key]} from command line args")
logging.info(f"Load variable {key} = {vars(args)[key]} from command line args")

@staticmethod
def load_dict(loading_dict, variables):
Expand Down
16 changes: 16 additions & 0 deletions distribute_config/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def test_int(self):
return_value=argparse.Namespace(v1=2, v2=3, c="conf.yml"))
def test_load_conf(self, mock_args):
Config.clear()
os.remove("conf.yml")
Config.define_int("v1", 1, "var")
Config.define_int("v2", 2, "var")
Config.load_conf()
Expand Down Expand Up @@ -128,3 +129,18 @@ def test_load_conf_4(self, mock_args):
self.assertEqual(Config.get_var("n1.v1"), 2)
self.assertEqual(Config.get_var("n1.v2"), 3)
self.assertEqual(Config.get_var("n1.v3"), False)

@mock.patch.dict(os.environ, {"N1__V3": "false"})
@mock.patch('argparse.ArgumentParser.parse_args',
return_value=argparse.Namespace(**{"n1.v1": 2, "n1.v2": 3}, c="conf.yml"))
def test_load_conf_5(self, mock_args):
Config.clear()
os.remove("conf.yml")
with Config.namespace("n1"):
Config.define_int("v1", 1, "var")
Config.define_int("v2", 2, "var")
Config.define_bool("V3", True, "turn me false")
Config.load_conf()
self.assertEqual(Config.get_var("n1.v1"), 2)
self.assertEqual(Config.get_var("n1.v2"), 3)
self.assertEqual(Config.get_var("n1.v3"), False)

0 comments on commit 2b57c9a

Please sign in to comment.