Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow mocking preferences #217

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions app/airq/controllers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from flask import request

from airq import commands
from airq.commands.base import MessageResponse
from airq.config import csrf
from airq.lib.client_preferences import ClientPreferencesRegistry, InvalidPrefValue
from airq.models.clients import ClientIdentifierType


Expand Down Expand Up @@ -34,12 +36,29 @@ def sms_reply(locale: str) -> str:
def test_command(locale: str) -> str:
supported_locale = _get_supported_locale(locale)
g.locale = supported_locale
command = request.args.get("command", "").strip()

if request.headers.getlist("X-Forwarded-For"):
ip = request.headers.getlist("X-Forwarded-For")[0]
else:
ip = request.remote_addr
response = commands.handle_command(
command, ip, ClientIdentifierType.IP, supported_locale
)

args = request.args.copy()
command = args.pop("command", "").strip()
overrides = {}
for k, v in args.items():
pref = ClientPreferencesRegistry.get_by_name(k)
if pref:
try:
overrides[pref] = pref.validate(v)
except InvalidPrefValue as e:
msg = str(e)
if not msg:
msg = '{}: Invalid value "{}"'.format(pref.name, v)
return MessageResponse().write(msg).as_html()

with ClientPreferencesRegistry.register_overrides(overrides):
response = commands.handle_command(
command, ip, ClientIdentifierType.IP, supported_locale
)

return response.as_html()
7 changes: 5 additions & 2 deletions app/airq/lib/choices.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@ def display(self) -> str:
...

@classmethod
def from_value(cls: typing.Type[T], value: typing.Any) -> T:
return cls(value)
def from_value(cls: typing.Type[T], value: typing.Any) -> typing.Optional[T]:
for m in cls:
if m.value == value:
return m
return None


class IntChoicesEnum(int, ChoicesEnum):
Expand Down
113 changes: 82 additions & 31 deletions app/airq/lib/client_preferences.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import abc
import collections
import contextlib
import typing

from flask import g
from flask import has_app_context
from flask_babel import gettext
from sqlalchemy.orm.attributes import flag_modified

Expand All @@ -19,6 +22,7 @@ class InvalidPrefValue(Exception):
"""This pref value is invalid."""


TClientPreference = typing.TypeVar("TClientPreference", bound="ClientPreference")
TPreferenceValue = typing.TypeVar(
"TPreferenceValue", bound=typing.Union[int, str, ChoicesEnum]
)
Expand All @@ -41,16 +45,36 @@ def __init__(
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.name}, {self.display_name}, {self.description}, {self.default})"

@typing.overload
def __get__(
self, instance: "Client", owner: typing.Type["Client"]
self: TClientPreference, instance: "Client", owner: typing.Type["Client"]
) -> TPreferenceValue:
if instance is not None:
preferences = instance.preferences or {}
value = preferences.get(self.name)
if value is not None:
return self._cast(value)
...

@typing.overload
def __get__(
self: TClientPreference, instance: None, owner: typing.Type["Client"]
) -> TClientPreference:
...

def __get__(
self: TClientPreference,
instance: typing.Optional["Client"],
owner: typing.Type["Client"],
) -> typing.Union[TPreferenceValue, TClientPreference]:
if instance is None:
return self

# Check for override. This is used for QA.
override = ClientPreferencesRegistry.get_override(self.name)
if override is not None:
return override

preferences = instance.preferences or {}
value = preferences.get(self.name)
if value is None:
return self.default
return self
return self.validate(value)

def __set__(self, client: "Client", value: TPreferenceValue):
self._set(client, value)
Expand All @@ -72,7 +96,7 @@ def set_from_user_input(
return value

def _set(self, client: "Client", value: TPreferenceValue):
self._validate(value)
value = self.validate(value)
if client.preferences is None:
client.preferences = {}
client.preferences[self.name] = value # type: ignore
Expand All @@ -86,10 +110,6 @@ def _set(self, client: "Client", value: TPreferenceValue):
def __set_name__(self, owner: typing.Type["Client"], name: str) -> None:
ClientPreferencesRegistry.register_pref(name, self)

@abc.abstractmethod
def _cast(self, value: typing.Any) -> TPreferenceValue:
pass

@property
def name(self) -> str:
return ClientPreferencesRegistry.get_name(self)
Expand All @@ -99,7 +119,7 @@ def clean(self, value: str) -> typing.Optional[TPreferenceValue]:
"""Coerce user input to a valid value for this pref, or throw an error."""

@abc.abstractmethod
def _validate(self, value: TPreferenceValue):
def validate(self, value: typing.Any) -> TPreferenceValue:
"""Ensure that the raw value is valid for this pref."""

@abc.abstractmethod
Expand All @@ -125,9 +145,6 @@ def __init__(
def _get_choices(self) -> typing.List[TChoicesEnum]:
return list(self._choices)

def _cast(self, value: typing.Any) -> TChoicesEnum:
return self._choices.from_value(value)

def format_value(self, value: TChoicesEnum) -> str:
return value.display

Expand All @@ -141,8 +158,11 @@ def clean(self, user_input: str) -> typing.Optional[TChoicesEnum]:
except (IndexError, TypeError, ValueError):
return None

def _validate(self, _value: TChoicesEnum):
pass # Valid by definition
def validate(self, value: typing.Any) -> TChoicesEnum:
value = self._choices.from_value(value)
if value is None:
raise InvalidPrefValue()
return value

def get_prompt(self) -> str:
prompt = [gettext("Select one of")]
Expand Down Expand Up @@ -180,23 +200,22 @@ def __init__(
def format_value(self, value: int) -> str:
return str(value)

def _cast(self, value: typing.Any) -> int:
assert isinstance(value, int)
return value

def clean(self, user_input: str) -> typing.Optional[int]:
try:
value = int(user_input)
self._validate(value)
except (TypeError, ValueError, InvalidPrefValue):
return self.validate(user_input)
except InvalidPrefValue:
return None
return value

def _validate(self, value: int):
def validate(self, value: typing.Any) -> int:
try:
value = int(value)
except (TypeError, ValueError):
raise InvalidPrefValue()
if self._min_value is not None and value < self._min_value:
raise InvalidPrefValue()
if self._max_value is not None and value > self._max_value:
raise InvalidPrefValue()
return value

def get_prompt(self) -> str:
if self._min_value is not None and self._max_value is not None:
Expand All @@ -220,35 +239,67 @@ def get_prompt(self) -> str:

class ClientPreferencesRegistry:
_prefs: typing.MutableMapping[str, ClientPreference] = collections.OrderedDict()
_overrides: typing.Dict[str, typing.Any] = {}

@classmethod
def register_pref(cls, name: str, pref: ClientPreference) -> None:
"""Register a client pref."""
assert name is not None, "Name unexpectedly None"
if name in cls._prefs:
raise RuntimeError("Can't double-register pref {}".format(pref.name))
cls._prefs[name] = pref

@classmethod
def _get_overrides(cls) -> typing.Dict[str, typing.Any]:
"""Get the overrides in a thread-safe manner."""
if has_app_context():
if not "_pref_overrides" in g:
g._pref_overrides = {}
return g._pref_overrides
else:
return cls._overrides

@classmethod
@contextlib.contextmanager
def register_overrides(
cls,
overrides: typing.Mapping[ClientPreference[TPreferenceValue], TPreferenceValue],
):
"""Override preference values for the duration of the current request."""
current_overrides = cls._get_overrides()
for pref, value in overrides.items():
current_overrides[pref.name] = value
try:
yield
finally:
current_overrides.clear()

@classmethod
def get_override(cls, name: str) -> typing.Any:
"""Get the overriden value for a pref, if any."""
return cls._get_overrides().get(name)

@classmethod
def get_name(cls, pref: ClientPreference) -> str:
"""Get the name of a registered preference."""
for name, p in cls._prefs.items():
if p is pref:
return name
raise RuntimeError("%s is not registered", pref)

@classmethod
def get_by_name(cls, name: str) -> ClientPreference:
"""Get the preference by the given name."""
return cls._prefs[name]

@classmethod
def get_default(cls, name: str) -> typing.Union[str, int]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused 💀

return cls.get_by_name(name).default

@classmethod
def iter_with_index(cls) -> typing.Iterator[typing.Tuple[int, ClientPreference]]:
"""Enumerate all registered preferences along with their index."""
return enumerate(cls._prefs.values(), start=1)

@classmethod
def get_by_index(cls, index: int) -> typing.Optional[ClientPreference]:
"""Get a preference by its index."""
for i, pref in cls.iter_with_index():
if i == index:
return pref
Expand Down
4 changes: 2 additions & 2 deletions app/tests/test_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_maybe_notify(self):

last_pm25 = zipcode.pm25
client = self._make_client(last_pm25=last_pm25)
client.alert_threshold = Pm25.GOOD.value
client.alert_threshold = Pm25.GOOD
self.db.session.commit()

self.assertFalse(client.maybe_notify())
Expand Down Expand Up @@ -146,7 +146,7 @@ def test_maybe_notify(self):

def test_maybe_notify_with_alerting_threshold_set(self):
client = self._make_client()
client.alert_threshold = Pm25.MODERATE.value
client.alert_threshold = Pm25.MODERATE
self.db.session.commit()
zipcode = client.zipcode

Expand Down