Skip to content

Commit

Permalink
Merge pull request #7663 from drew2a/refactoring/settings_endpoint
Browse files Browse the repository at this point in the history
Extract `parse_settings` to the `TriblerConfig`
  • Loading branch information
drew2a committed Nov 6, 2023
2 parents e020e21 + e0a8c9f commit 077ace6
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 32 deletions.
37 changes: 9 additions & 28 deletions src/tribler/core/components/restapi/rest/settings_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class SettingsEndpoint(RESTEndpoint):

def __init__(self, tribler_config: TriblerConfig, download_manager: DownloadManager = None):
super().__init__()
self.tribler_config = tribler_config
self.config = tribler_config
self.download_manager = download_manager

def setup_routes(self):
Expand All @@ -40,7 +40,7 @@ def setup_routes(self):
async def get_settings(self, request):
self._logger.info(f'Get settings. Request: {request}')
return RESTResponse({
"settings": self.tribler_config.dict(),
"settings": self.config.dict(),
"ports": list(default_network_utils.ports_in_use)
})

Expand All @@ -55,31 +55,12 @@ async def get_settings(self, request):
)
@json_schema(schema(UpdateTriblerSettingsRequest={}))
async def update_settings(self, request):
settings_dict = await request.json()
await self.parse_settings_dict(settings_dict)
self.tribler_config.write()
return RESTResponse({"modified": True})

async def parse_setting(self, section, option, value):
"""
Set a specific Tribler setting. Throw a ValueError if this setting is not available.
"""
# if section in self.config.config and option in self.config.config[section]:
self.tribler_config.__getattribute__(section).__setattr__(option, value)
# else:
# raise ValueError(f"Section {section} with option {option} does not exist")
settings = await request.json()
self._logger.info(f'Received settings: {settings}')
self.config.update_from_dict(settings)
self.config.write()

# Perform some actions when specific keys are set
if section == "libtorrent" and (option == "max_download_rate" or option == "max_upload_rate"):
if self.download_manager:
self.download_manager.update_max_rates_from_config()
if self.download_manager:
self.download_manager.update_max_rates_from_config()

async def parse_settings_dict(self, settings_dict, depth=1, root_key=None):
"""
Parse the settings dictionary.
"""
for key, value in settings_dict.items():
if isinstance(value, dict):
await self.parse_settings_dict(value, depth=depth + 1, root_key=key)
else:
await self.parse_setting(root_key, key, value)
return RESTResponse({"modified": True})
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from aiohttp import ClientOSError, ServerDisconnectedError
from aiohttp.web_protocol import RequestHandler
from pydantic import ValidationError

from tribler.core.components.restapi.rest.aiohttp_patch import get_transport_is_none_counter, patch_make_request
from tribler.core.components.restapi.rest.base_api_test import do_real_request
Expand Down Expand Up @@ -95,10 +96,10 @@ async def test_unhandled_exception(rest_manager, api_port):
handler.unhandled_error_observer.assert_called_once()
exception_dict = handler.unhandled_error_observer.call_args.args[1]
assert exception_dict['should_stop'] is False
assert isinstance(exception_dict['exception'], TypeError)
assert isinstance(exception_dict['exception'], ValidationError)
assert response_dict
assert not response_dict['error']['handled']
assert response_dict['error']['code'] == "TypeError"
assert response_dict['error']['code'] == "ValidationError"


async def test_patch_make_request():
Expand Down
35 changes: 35 additions & 0 deletions src/tribler/core/config/tests/test_tribler_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,38 @@ def test_invalid_config_recovers(tmpdir):
# should work without the reset flag.
config = TriblerConfig.load(file=default_config_file, state_dir=tmpdir)
assert not config.error


def test_update_from_dict(tmpdir):
""" Test that update_from_dict updates config with correct values"""

config = TriblerConfig(state_dir=tmpdir)
config.api.http_port = 1234

config.update_from_dict(
{
'api':
{
'key': 'key value'
}
}
)

assert config.api.http_port == 1234
assert config.api.key == 'key value'


def test_update_from_dict_wrong_key(tmpdir):
""" Test that update_from_dict raises ValueError when wrong key is passed"""
config = TriblerConfig(state_dir=tmpdir)
with pytest.raises(ValueError):
config.update_from_dict({'wrong key': 'any value'})


def test_validate_config(tmpdir):
""" Test that validate_config raises ValueError when config is invalid"""
config = TriblerConfig(state_dir=tmpdir)
config.general = 'invalid value'

with pytest.raises(ValueError):
config.validate_config()
26 changes: 24 additions & 2 deletions src/tribler/core/config/tribler_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import logging
import traceback
from pathlib import Path
from typing import Optional
from typing import Any, Dict, Optional, Union

import configobj
from configobj import ParseError
from pydantic import BaseSettings, Extra, PrivateAttr
from pydantic import BaseSettings, Extra, PrivateAttr, validate_model

from tribler.core.components.bandwidth_accounting.settings import BandwidthAccountingSettings
from tribler.core.components.gigachannel.community.settings import ChantSettings
Expand Down Expand Up @@ -152,6 +152,28 @@ def write(self, file: Path = None):
conf.filename = str(file)
conf.write()

def update_from_dict(self, config: Dict):
""" Update (patch) current config from dictionary"""

def update_recursively(settings: BaseSettings, attribute_name: str, attribute_value: Union[Any, Dict]):
""" Update setting recursively from dictionary"""
if isinstance(attribute_value, dict):
for k, v in attribute_value.items():
update_recursively(getattr(settings, attribute_name), k, v)
else:
setattr(settings, attribute_name, attribute_value)

for key, value in config.items():
update_recursively(self, key, value)

self.validate_config()

def validate_config(self):
""" Validate config and raise an exception in case of an error"""
*_, error = validate_model(self.__class__, self.__dict__)
if error:
raise error

@property
def error(self) -> Optional[str]:
return self._error
Expand Down

0 comments on commit 077ace6

Please sign in to comment.