Skip to content

Commit

Permalink
[V3 Config] Require custom group initialization before usage (#2545)
Browse files Browse the repository at this point in the history
* Require custom group initialization before usage and write that data to disk

* Style

* add tests

* remove custom info update method from drivers

* clean up remnant

* Turn config objects into a singleton to deal with custom group identifiers

* Fix dumbassery

* Stupid stupid stupid
  • Loading branch information
tekulvw committed Apr 5, 2019
1 parent fb722c7 commit 0852d1b
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 5 deletions.
2 changes: 2 additions & 0 deletions redbot/cogs/permissions/permissions.py
Expand Up @@ -100,7 +100,9 @@ def __init__(self, bot: Red):
# Note that GLOBAL rules are denoted by an ID of 0.
self.config = config.Config.get_conf(self, identifier=78631113035100160)
self.config.register_global(version="")
self.config.init_custom(COG, 1)
self.config.register_custom(COG)
self.config.init_custom(COMMAND, 1)
self.config.register_custom(COMMAND)

@commands.group()
Expand Down
1 change: 1 addition & 0 deletions redbot/cogs/reports/reports.py
Expand Up @@ -45,6 +45,7 @@ def __init__(self, bot: Red):
self.bot = bot
self.config = Config.get_conf(self, 78631113035100160, force_registration=True)
self.config.register_guild(**self.default_guild_settings)
self.config.init_custom("REPORT", 2)
self.config.register_custom("REPORT", **self.default_report)
self.antispam = {}
self.user_cache = []
Expand Down
43 changes: 42 additions & 1 deletion redbot/core/config.py
Expand Up @@ -2,6 +2,7 @@
import collections
from copy import deepcopy
from typing import Any, Union, Tuple, Dict, Awaitable, AsyncContextManager, TypeVar, TYPE_CHECKING
import weakref

import discord

Expand All @@ -15,6 +16,8 @@

_T = TypeVar("_T")

_config_cache = weakref.WeakValueDictionary()


class _ValueCtxManager(Awaitable[_T], AsyncContextManager[_T]):
"""Context manager implementation of config values.
Expand Down Expand Up @@ -514,6 +517,19 @@ class Config:
USER = "USER"
MEMBER = "MEMBER"

def __new__(cls, cog_name, unique_identifier, *args, **kwargs):
key = (cog_name, unique_identifier)

if key[0] is None:
raise ValueError("You must provide either the cog instance or a cog name.")

if key in _config_cache:
conf = _config_cache[key]
else:
conf = object.__new__(cls)
_config_cache[key] = conf
return conf

def __init__(
self,
cog_name: str,
Expand All @@ -529,6 +545,8 @@ def __init__(
self.force_registration = force_registration
self._defaults = defaults or {}

self.custom_groups = {}

@property
def defaults(self):
return deepcopy(self._defaults)
Expand Down Expand Up @@ -788,13 +806,32 @@ def register_custom(self, group_identifier: str, **kwargs):
"""
self._register_default(group_identifier, **kwargs)

def init_custom(self, group_identifier: str, identifier_count: int):
"""
Initializes a custom group for usage. This method must be called first!
"""
if group_identifier in self.custom_groups:
raise ValueError(f"Group identifier already registered: {group_identifier}")

self.custom_groups[group_identifier] = identifier_count

def _get_base_group(self, category: str, *primary_keys: str) -> Group:
is_custom = category not in (
self.GLOBAL,
self.GUILD,
self.USER,
self.MEMBER,
self.ROLE,
self.CHANNEL,
)
# noinspection PyTypeChecker
identifier_data = IdentifierData(
uuid=self.unique_identifier,
category=category,
primary_key=primary_keys,
identifiers=(),
custom_group_data=self.custom_groups,
is_custom=is_custom,
)
return Group(
identifier_data=identifier_data,
Expand Down Expand Up @@ -902,6 +939,8 @@ def custom(self, group_identifier: str, *identifiers: str):
The custom group's Group object.
"""
if group_identifier not in self.custom_groups:
raise ValueError(f"Group identifier not initialized: {group_identifier}")
return self._get_base_group(str(group_identifier), *map(str, identifiers))

async def _all_from_scope(self, scope: str) -> Dict[int, Dict[Any, Any]]:
Expand Down Expand Up @@ -1072,7 +1111,9 @@ async def _clear_scope(self, *scopes: str):
"""
if not scopes:
# noinspection PyTypeChecker
identifier_data = IdentifierData(self.unique_identifier, "", (), ())
identifier_data = IdentifierData(
self.unique_identifier, "", (), (), self.custom_groups
)
group = Group(identifier_data, defaults={}, driver=self.driver)
else:
group = self._get_base_group(*scopes)
Expand Down
23 changes: 21 additions & 2 deletions redbot/core/drivers/red_base.py
Expand Up @@ -4,11 +4,21 @@


class IdentifierData:
def __init__(self, uuid: str, category: str, primary_key: Tuple[str], identifiers: Tuple[str]):
def __init__(
self,
uuid: str,
category: str,
primary_key: Tuple[str],
identifiers: Tuple[str],
custom_group_data: dict,
is_custom: bool = False,
):
self._uuid = uuid
self._category = category
self._primary_key = primary_key
self._identifiers = identifiers
self.custom_group_data = custom_group_data
self._is_custom = is_custom

@property
def uuid(self):
Expand All @@ -26,6 +36,10 @@ def primary_key(self):
def identifiers(self):
return self._identifiers

@property
def is_custom(self):
return self._is_custom

def __repr__(self):
return (
f"<IdentifierData uuid={self.uuid} category={self.category} primary_key={self.primary_key}"
Expand All @@ -37,7 +51,12 @@ def add_identifier(self, *identifier: str) -> "IdentifierData":
raise ValueError("Identifiers must be strings.")

return IdentifierData(
self.uuid, self.category, self.primary_key, self.identifiers + identifier
self.uuid,
self.category,
self.primary_key,
self.identifiers + identifier,
self.custom_group_data,
is_custom=self.is_custom,
)

def to_tuple(self):
Expand Down
6 changes: 4 additions & 2 deletions redbot/pytest/core.py
@@ -1,11 +1,13 @@
import random
from collections import namedtuple
from pathlib import Path
import weakref

import pytest
from _pytest.monkeypatch import MonkeyPatch
from redbot.core import Config
from redbot.core.bot import Red
from redbot.core import config as config_module

from redbot.core.drivers import red_json

Expand Down Expand Up @@ -65,26 +67,26 @@ def json_driver(tmpdir_factory):

@pytest.fixture()
def config(json_driver):
config_module._config_cache = weakref.WeakValueDictionary()
conf = Config(
cog_name="PyTest", unique_identifier=json_driver.unique_cog_identifier, driver=json_driver
)
yield conf
conf._defaults = {}


@pytest.fixture()
def config_fr(json_driver):
"""
Mocked config object with force_register enabled.
"""
config_module._config_cache = weakref.WeakValueDictionary()
conf = Config(
cog_name="PyTest",
unique_identifier=json_driver.unique_cog_identifier,
driver=json_driver,
force_registration=True,
)
yield conf
conf._defaults = {}


# region Dpy Mocks
Expand Down
16 changes: 16 additions & 0 deletions tests/core/test_config.py
Expand Up @@ -490,3 +490,19 @@ async def test_cast_str_nested(config):
config.register_global(foo={})
await config.foo.set({123: True, 456: {789: False}})
assert await config.foo() == {"123": True, "456": {"789": False}}


def test_config_custom_noinit(config):
with pytest.raises(ValueError):
config.custom("TEST", 1, 2, 3)


def test_config_custom_init(config):
config.init_custom("TEST", 3)
config.custom("TEST", 1, 2, 3)


def test_config_custom_doubleinit(config):
config.init_custom("TEST", 3)
with pytest.raises(ValueError):
config.init_custom("TEST", 2)

0 comments on commit 0852d1b

Please sign in to comment.