From d048b5bdb6e5a467002699263e45242280ab57ac Mon Sep 17 00:00:00 2001 From: signore662-beep Date: Sat, 14 Mar 2026 08:53:56 +0100 Subject: [PATCH] Create aho_codesick.py. . --- aho_codesick.py. | 280 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 280 insertions(+) create mode 100644 aho_codesick.py. diff --git a/aho_codesick.py. b/aho_codesick.py. new file mode 100644 index 000000000000..3c74c4c5469a --- /dev/null +++ b/aho_codesick.py. @@ -0,0 +1,280 @@ +""" +Aho-Corasick Multi-Pattern String Matching +=========================================== + +Builds a finite-state automaton (trie + failure links) from a set of patterns +and finds **all** occurrences of every pattern in a text in a single O(n + m + z) +pass, where: + + n = length of the text + m = total characters across all patterns + z = total number of matches found + +This beats the naive approach of running a single-pattern algorithm (KMP, +Z-function, Rabin-Karp) once per pattern, which costs O(n * k) for k patterns. + +Typical real-world uses: network intrusion detection (Snort/fgrep), +antivirus signature scanning, spam filtering, DNA motif search. + +References: + - Aho, A.V.; Corasick, M.J. (1975). + "Efficient string matching: an aid to bibliographic search." + Communications of the ACM, 18(6), 333-340. + https://doi.org/10.1145/360825.360855 + - https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm +""" + +from __future__ import annotations + +from collections import deque + + +def _build_trie( + patterns: list[str], +) -> tuple[list[dict[str, int]], list[set[str]]]: + """ + Insert every non-empty, unique pattern into a trie. + + Returns: + goto_table: goto_table[state][char] -> next_state + output_table: output_table[state] -> set of patterns ending at that state + + >>> goto, out = _build_trie(["he", "she", "his", "hers"]) + >>> goto[0]["h"] > 0 + True + >>> "she" in out[_follow(goto, 0, "she")] + True + >>> goto, out = _build_trie([x]) + >>> goto + [{}] + >>> out + [set()] + >>> goto, out = _build_trie(["a", "a", "", "a"]) + >>> sum(len(row) for row in goto) + 1 + """ + goto_table: list[dict[str, int]] = [{}] + output_table: list[set[str]] = [set()] + + def new_state() -> int: + goto_table.append({}) + output_table.append(set()) + return len(goto_table) - 1 + + seen: set[str] = set() + for pattern in patterns: + if not pattern or pattern in seen: + continue + seen.add(pattern) + state = 0 + for char in pattern: + if char not in goto_table[state]: + goto_table[state][char] = new_state() + state = goto_table[state][char] + output_table[state].add(pattern) + + return goto_table, output_table + + +def _follow(goto_table: list[dict[str, int]], state: int, text: str) -> int: + """ + Follow trie transitions from *state* consuming every character in *text*. + + Raises KeyError if any character in *text* has no transition. + Used internally and in tests to locate a specific trie node. + + >>> goto, _ = _build_trie(["abc"]) + >>> _follow(goto, 0, "abc") + 3 + >>> _follow(goto, 0, "ab") + 2 + """ + for char in text: + state = goto_table[state][char] + return state + + +def _build_failure_links( + goto_table: list[dict[str, int]], + output_table: list[set[str]], +) -> list[int]: + """ + Compute failure (suffix) links for every trie state via BFS. + + Also propagates the output sets along the failure chain so every + pattern that is a proper suffix of a longer match is also reported. + + Returns: + fail: fail[state] -> longest proper suffix state reachable by + failure links (0 = root for depth-1 nodes). + + >>> goto, out = _build_trie(["he", "she", "his", "hers"]) + >>> fail = _build_failure_links(goto, out) + >>> fail[0] + 0 + >>> goto, out = _build_trie([]) + >>> _build_failure_links(goto, out) + [0] + """ + num_states = len(goto_table) + fail: list[int] = [0] * num_states + + queue: deque[int] = deque() + + # Depth-1 nodes: failure link points to root. + for child in goto_table[0].values(): + fail[child] = 0 + queue.append(child) + + while queue: + current = queue.popleft() + for char, child in goto_table[current].items(): + # Walk up the failure chain of `current` until we find a state + # that has a transition on `char`, or we reach the root. + fail_state = fail[current] + while fail_state != 0 and char not in goto_table[fail_state]: + fail_state = fail[fail_state] + + fail[child] = goto_table[fail_state].get(char, 0) + + # Prevent self-loop at root. + if fail[child] == child: + fail[child] = 0 + + # Merge outputs from the suffix chain into this node. + output_table[child] |= output_table[fail[child]] + queue.append(child) + + return fail + + +def build_automaton( + patterns: list[str], +) -> tuple[list[dict[str, int]], list[set[str]], list[int]]: + """ + Construct the complete Aho-Corasick automaton from *patterns*. + + Returns a three-tuple ``(goto_table, output_table, fail)`` that can be + passed directly to :func:`search`. Build once; reuse across many texts. + + Time : O(m * |alphabet|) where m = total pattern characters + Space : O(m * |alphabet|) + + >>> goto, out, fail = build_automaton(["foo", "bar"]) + >>> len(goto) > 1 + True + >>> isinstance(fail, list) + True + >>> goto, out, fail = build_automaton([x]) + >>> goto + [{}] + >>> out + [set()] + >>> fail + [0] + """ + goto_table, output_table = _build_trie(patterns) + fail = _build_failure_links(goto_table, output_table) + return goto_table, output_table, fail + + +def search( + text: str, + goto_table: list[dict[str, int]], + output_table: list[set[str]], + fail: list[int], +) -> dict[str, list[int]]: + """ + Find all pattern occurrences in *text* using a pre-built automaton. + + Returns a dict mapping each matched pattern to a sorted list of + 0-based start indices in *text*. + + Time : O(n + z) where n = len(text), z = total matches + + >>> goto, out, fail = build_automaton(["he", "she", "his", "hers"]) + >>> result = search("ushers", goto, out, fail) + >>> result["she"] + [1] + >>> result["he"] + [2] + >>> result["hers"] + [2] + >>> "his" in result + False + + >>> goto, out, fail = build_automaton(["a", "aa", "aaa"]) + >>> sorted(search("aaaa", goto, out, fail).items()) + [('a', [0, 1, 2, 3]), ('aa', [0, 1, 2]), ('aaa', [0, 1])] + + >>> goto, out, fail = build_automaton(["ab", "b"]) + >>> sorted(search("abab", goto, out, fail).items()) + [('ab', [0, 2]), ('b', [1, 3])] + + >>> goto, out, fail = build_automaton(["xyz"]) + >>> search("hello world", goto, out, fail) + {} + + >>> goto, out, fail = build_automaton(["test"]) + >>> search("", goto, out, fail) + {} + """ + matches: dict[str, list[int]] = {} + state = 0 + for index, char in enumerate(text): + # Follow failure links until a valid transition or root is reached. + while state != 0 and char not in goto_table[state]: + state = fail[state] + state = goto_table[state].get(char, 0) + for pattern in output_table[state]: + start = index - len(pattern) + 1 + matches.setdefault(pattern, []).append(start) + return matches + + +def search_all(text: str, patterns: list[str]) -> dict[str, list[int]]: + """ + One-shot convenience wrapper: build automaton from *patterns*, search *text*. + + Prefer :func:`build_automaton` + :func:`search` when querying the same + pattern set across multiple texts (avoids rebuilding the automaton). + + Time : O(m * |alphabet| + n + z) + Space : O(m * |alphabet|) + + >>> search_all("mississippi", ["is", "ss", "ippi"]) + {'is': [1, 4], 'ss': [2, 5], 'ippi': [7]} + + >>> search_all("aababcabcd", ["a", "ab", "abc", "abcd"]) + {'a': [0, 1, 3, 6], 'ab': [1, 3, 6], 'abc': [3, 6], 'abcd': [6]} + + >>> search_all("she sells sea shells", ["she", "sea", "shells"]) + {'she': [0, 14], 'sea': [10], 'shells': [14]} + + >>> search_all("hello", []) + {} + + >>> search_all("", ["a", "b"]) + {} + + >>> search_all("aaaa", ["a"]) + {'a': [0, 1, 2, 3]} + """ + goto_table, output_table, fail = build_automaton(patterns) + return search(text, goto_table, output_table, fail) + + +if __name__ == "__main__": + import doctest + + doctest.testmod() + + demo_text = "she sells sea shells by the seashore" + demo_patterns = ["she", "sells", "sea", "shells", "shore"] + print(f"Text : {demo_text!r}") + print(f"Patterns: {demo_patterns}") + print("Matches :") + for pattern, positions in sorted(search_all(demo_text, demo_patterns).items()): + for pos in positions: + snippet = demo_text[pos : pos + len(pattern)] + print(f" {pattern!r:10s} at index {pos:2d} -> {snippet!r}")