Skip to content

Commit

Permalink
filter: add types and improve unit testing coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
lowell80 committed Oct 20, 2023
1 parent acaf77d commit 5847e95
Showing 1 changed file with 57 additions and 34 deletions.
91 changes: 57 additions & 34 deletions ksconf/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys
from collections import Counter
from pathlib import Path
from typing import Callable, Dict, Optional, Sequence, Type, Union

from ksconf.conf.parser import GLOBAL_STANZA
from ksconf.util.file import splglob_to_regex
Expand All @@ -27,16 +28,16 @@ class FilteredList:
INVERT = 2
VERBOSE = 4

def __init__(self, flags=0, default=True):
def __init__(self, flags: int = 0, default: bool = True):
self.data = []
self.rules = None
self.counter = Counter()
self.flags = flags
self.counter: Counter = Counter()
self.flags: int = flags
self._prep = True
# If no patterns defined, return default. (True => match everything)
self.default = default

def _feed_from_file(self, path):
def _feed_from_file(self, path: Union[str, Path]):
items = []
with open(path) as f:
for line in f:
Expand All @@ -48,39 +49,61 @@ def _feed_from_file(self, path):
sys.stderr.write(f"Loaded {len(items)} patterns from {path}\n")
return items

def feed(self, item, filter=None):
def feed(self, item: str, filter: Optional[Callable[[str], str]] = None):
"""
Fee a new rule into the
Use :py:obj:`filter` to enable calling a transformation function on :py:obj:`item` at
the last minute.
This allows item to contain ``file://...`` entries to include a list from the filesystem and
yet still manipulate item is some pre-defined way.
"""
if isinstance(item, Path):
item = os.fspath(item)

if item.startswith("file://"):
# File ingestion mode
filename = item[7:]
# Technically we allow 'file://' inside a file. No recursion depth limits in place. (Stack overflow)
'''
for item in self._feed_from_file(filename):
self.feed(item, filter)
'''
self.feedall(self._feed_from_file(filename), filter)
else:
if filter:
item = filter(item)
self.data.append(item)
# New items added. Mark prep-work as incomplete
self._prep = False

def feedall(self, iterable, filter=None):
def feedall(self, iterable: Sequence[str], filter: Optional[Callable[[str], str]] = None):
if iterable:
for i in iterable:
self.feed(i, filter)
return self

def prep(self):
"""
Prepare for matching activities.
Called automatically by :py:meth:`match`, but it could helpful to call directly to ensure
there are no user input errors (which is accomplished by calling :py:meth:`_pre_match`).
"""
# Kick off any first-time preparatory activities
if self._prep is False:
self._pre_match()
self.counter = self.init_counter()
self._prep = True

def _pre_match(self): # pragma: no cover
pass

def match(self, item):
def match(self, item: str) -> bool:
""" See if given item matches any of the given patterns. If no patterns were provided,
:py:obj:default: will be returned.
"""
if self.data:
# Kick off any first-time preparatory activities
if self._prep is False:
self._pre_match()
self.reset_counters()
self._prep = True

self.prep()
ret = self._match(item)
if ret:
self.counter[ret] += 1
Expand All @@ -96,31 +119,32 @@ def match(self, item):
else:
return result

def match_path(self, path):
def match_path(self, path) -> bool:
""" Same as :py:meth:`match` except with special handling of path normalization.
Patterns must be given with unix-style paths.
"""
if isinstance(path, Path):
path = os.fspath(path)
if os.path.sep != "/":
path = path.replace(os.path.sep, "/")
return self.match(path)

def match_stanza(self, stanza):
""" Same as match(), but handle GLOBAL_STANZA gracefully. """
def match_stanza(self, stanza) -> bool:
""" Same as :py:meth:`match`, but handle GLOBAL_STANZA gracefully. """
if stanza is GLOBAL_STANZA:
stanza = "default"
return self.match(stanza)

def reset_counters(self):
# Set all the counters to 0, so the caller can know which filters had 0 hits
self.counter = Counter()
self.counter.update((n, 0) for n in self.data)
def init_counter(self) -> Counter:
return Counter({n: 0 for n in self.data})

@property
def has_rules(self):
def has_rules(self) -> bool:
return bool(self.data)

def _match(self, item): # pragma: no cover
def _match(self, item) -> Union[str, bool]:
""" Return name of rule, indicating a match or not. """
raise NotImplementedError
raise NotImplementedError # pragma: no cover


class FilteredListString(FilteredList):
Expand All @@ -134,17 +158,17 @@ def _pre_match(self):
self.rules = set(self.data)
return self.rules

def _match(self, item):
def _match(self, item: str) -> Union[str, bool]:
if self.flags & self.IGNORECASE:
item = item.lower()
if item in self.rules:
return item
else:
return False

def reset_counters(self):
self.counter = Counter()
self.counter.update({n: 0 for n in self.rules})
def init_counter(self) -> Counter:
# Set all the counters to 0, so the caller can know which filters had 0 hits
return Counter({n: 0 for n in self.rules})


class FilteredListRegex(FilteredList):
Expand All @@ -162,15 +186,14 @@ def _pre_match(self):
# XXX: Add better error handling here for friendlier user feedback
self.rules = [(pattern, re.compile(pattern, re_flags)) for pattern in self.data]

def _match(self, item):
def _match(self, item: str) -> Union[str, bool]:
for name, pattern_re in self.rules:
if pattern_re.match(item):
return name
return False

def reset_counters(self):
self.counter = Counter()
self.counter.update({i[0]: 0 for i in self.rules})
def init_counter(self) -> Counter:
return Counter({n: 0 for n in self.data})


class FilteredListWildcard(FilteredListRegex):
Expand All @@ -195,17 +218,17 @@ def _pre_match(self):
self.rules = [(wc, splglob_to_regex(wc, re_flags)) for wc in self.data]


class_mapping = {
class_mapping: Dict[str, Type[FilteredList]] = {
"string": FilteredListString,
"wildcard": FilteredListWildcard,
"regex": FilteredListRegex,
"splunk": FilteredListSplunkGlob,
}


def create_filtered_list(match_mode: str, flags=0, default=True) -> FilteredList:
def create_filtered_list(match_mode: str, flags: int = 0, default=True) -> FilteredList:
try:
class_ = class_mapping[match_mode]
except KeyError:
except KeyError: # pragma: no cover
raise NotImplementedError(f"Matching mode {match_mode!r} undefined")
return class_(flags, default)

0 comments on commit 5847e95

Please sign in to comment.