Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion common_utility/configLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from argparse import ArgumentParser, Action, Namespace
from configparser import ConfigParser
from pathlib import Path
from typing import Any, cast
from typing import Any, cast, TextIO

from context_logger import get_logger

Expand All @@ -16,6 +16,9 @@ class IConfigLoader(object):
def load(self, argument_parser: ArgumentParser) -> Namespace:
raise NotImplementedError()

def dump(self, argument_parser: ArgumentParser, config: Namespace, file: TextIO = sys.stdout) -> None:
raise NotImplementedError()


class ConfigLoader(IConfigLoader):

Expand Down Expand Up @@ -57,6 +60,24 @@ def load(self, argument_parser: ArgumentParser) -> Namespace:

return Namespace(**configuration)

def dump(self, argument_parser: ArgumentParser, config: Namespace, file: TextIO = sys.stdout) -> None:
for group in argument_parser._action_groups:
section = group.title if group.title else 'DEFAULT'
values = {}

for action in group._group_actions:
if not action.dest or action.dest == "help":
continue
value = getattr(config, action.dest, None)
if value is None:
continue
values[action.dest] = str(value)

if values:
self._config_parser[section] = values

self._config_parser.write(file)

def _get_cli_overrides(self, parser: ArgumentParser, arguments: Namespace) -> dict[str, Any]:
cli_overrides: dict[str, Any] = {}
argv_tokens = set(sys.argv[1:])
Expand Down
84 changes: 83 additions & 1 deletion tests/configLoaderTest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import sys
import unittest
from argparse import ArgumentParser
from argparse import ArgumentParser, Namespace
from configparser import ConfigParser
from io import StringIO
from pathlib import Path
from unittest import TestCase
from unittest.mock import patch
Expand Down Expand Up @@ -211,6 +213,86 @@ def test_load_config_when_invalid_numeric_values_present_then_keep_original(self
self.assertEqual('invalid', result.retry_count)
self.assertEqual('invalid', result.timeout)

def test_dump_when_values_present_then_write_config_sections(self):
# Given
config_loader = ConfigLoader(Path(DEFAULT_CONFIG_FILE))
argument_parser = ArgumentParser()
network_group = argument_parser.add_argument_group('network')
network_group.add_argument('--host')
network_group.add_argument('--port')
runtime_group = argument_parser.add_argument_group('runtime')
runtime_group.add_argument('--debug')
config = Namespace(host='localhost', port=8080, debug=True)
output = StringIO()

# When
config_loader.dump(argument_parser, config, output)

# Then
parser = ConfigParser(interpolation=None)
parser.read_string(output.getvalue())
self.assertEqual('localhost', parser['network']['host'])
self.assertEqual('8080', parser['network']['port'])
self.assertEqual('True', parser['runtime']['debug'])

def test_dump_when_value_is_none_then_skip_value(self):
# Given
config_loader = ConfigLoader(Path(DEFAULT_CONFIG_FILE))
argument_parser = ArgumentParser()
runtime_group = argument_parser.add_argument_group('runtime')
runtime_group.add_argument('--timeout')
runtime_group.add_argument('--retries')
config = Namespace(timeout=None, retries=3)
output = StringIO()

# When
config_loader.dump(argument_parser, config, output)

# Then
parser = ConfigParser(interpolation=None)
parser.read_string(output.getvalue())
self.assertEqual('3', parser['runtime']['retries'])
self.assertNotIn('timeout', parser['runtime'])
self.assertNotIn('help', output.getvalue())

def test_dump_when_all_values_in_group_are_none_then_omit_section(self):
# Given
config_loader = ConfigLoader(Path(DEFAULT_CONFIG_FILE))
argument_parser = ArgumentParser()
secret_group = argument_parser.add_argument_group('secret')
secret_group.add_argument('--token')
config = Namespace(token=None)
output = StringIO()

# When
config_loader.dump(argument_parser, config, output)

# Then
self.assertNotIn('[secret]', output.getvalue())

def test_dump_when_group_has_no_title_then_uses_default_section(self):
# Given
config_loader = ConfigLoader(Path(DEFAULT_CONFIG_FILE))
argument_parser = ArgumentParser(add_help=False)
region_action = argument_parser.add_argument('--region')

class DummyGroup(object):
def __init__(self):
self.title = None
self._group_actions = [region_action]

argument_parser._action_groups = [DummyGroup()]
config = Namespace(region='eu-central')
output = StringIO()

# When
config_loader.dump(argument_parser, config, output)

# Then
parser = ConfigParser(interpolation=None)
parser.read_string(output.getvalue())
self.assertEqual('eu-central', parser.defaults()['region'])


if __name__ == '__main__':
unittest.main()
Loading