Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
280 changes: 280 additions & 0 deletions aho_codesick.py.
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
"""
Aho-Corasick Multi-Pattern String Matching
===========================================

Builds a finite-state automaton (trie + failure links) from a set of patterns
and finds **all** occurrences of every pattern in a text in a single O(n + m + z)
pass, where:

n = length of the text
m = total characters across all patterns
z = total number of matches found

This beats the naive approach of running a single-pattern algorithm (KMP,
Z-function, Rabin-Karp) once per pattern, which costs O(n * k) for k patterns.

Typical real-world uses: network intrusion detection (Snort/fgrep),
antivirus signature scanning, spam filtering, DNA motif search.

References:
- Aho, A.V.; Corasick, M.J. (1975).
"Efficient string matching: an aid to bibliographic search."
Communications of the ACM, 18(6), 333-340.
https://doi.org/10.1145/360825.360855
- https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm
"""

from __future__ import annotations

from collections import deque


def _build_trie(
patterns: list[str],
) -> tuple[list[dict[str, int]], list[set[str]]]:
"""
Insert every non-empty, unique pattern into a trie.

Returns:
goto_table: goto_table[state][char] -> next_state
output_table: output_table[state] -> set of patterns ending at that state

>>> goto, out = _build_trie(["he", "she", "his", "hers"])
>>> goto[0]["h"] > 0
True
>>> "she" in out[_follow(goto, 0, "she")]
True
>>> goto, out = _build_trie([x])
>>> goto
[{}]
>>> out
[set()]
>>> goto, out = _build_trie(["a", "a", "", "a"])
>>> sum(len(row) for row in goto)
1
"""
goto_table: list[dict[str, int]] = [{}]
output_table: list[set[str]] = [set()]

def new_state() -> int:
goto_table.append({})
output_table.append(set())
return len(goto_table) - 1

seen: set[str] = set()
for pattern in patterns:
if not pattern or pattern in seen:
continue
seen.add(pattern)
state = 0
for char in pattern:
if char not in goto_table[state]:
goto_table[state][char] = new_state()
state = goto_table[state][char]
output_table[state].add(pattern)

return goto_table, output_table


def _follow(goto_table: list[dict[str, int]], state: int, text: str) -> int:
"""
Follow trie transitions from *state* consuming every character in *text*.

Raises KeyError if any character in *text* has no transition.
Used internally and in tests to locate a specific trie node.

>>> goto, _ = _build_trie(["abc"])
>>> _follow(goto, 0, "abc")
3
>>> _follow(goto, 0, "ab")
2
"""
for char in text:
state = goto_table[state][char]
return state


def _build_failure_links(
goto_table: list[dict[str, int]],
output_table: list[set[str]],
) -> list[int]:
"""
Compute failure (suffix) links for every trie state via BFS.

Also propagates the output sets along the failure chain so every
pattern that is a proper suffix of a longer match is also reported.

Returns:
fail: fail[state] -> longest proper suffix state reachable by
failure links (0 = root for depth-1 nodes).

>>> goto, out = _build_trie(["he", "she", "his", "hers"])
>>> fail = _build_failure_links(goto, out)
>>> fail[0]
0
>>> goto, out = _build_trie([])
>>> _build_failure_links(goto, out)
[0]
"""
num_states = len(goto_table)
fail: list[int] = [0] * num_states

queue: deque[int] = deque()

# Depth-1 nodes: failure link points to root.
for child in goto_table[0].values():
fail[child] = 0
queue.append(child)

while queue:
current = queue.popleft()
for char, child in goto_table[current].items():
# Walk up the failure chain of `current` until we find a state
# that has a transition on `char`, or we reach the root.
fail_state = fail[current]
while fail_state != 0 and char not in goto_table[fail_state]:
fail_state = fail[fail_state]

fail[child] = goto_table[fail_state].get(char, 0)

# Prevent self-loop at root.
if fail[child] == child:
fail[child] = 0

# Merge outputs from the suffix chain into this node.
output_table[child] |= output_table[fail[child]]
queue.append(child)

return fail


def build_automaton(
patterns: list[str],
) -> tuple[list[dict[str, int]], list[set[str]], list[int]]:
"""
Construct the complete Aho-Corasick automaton from *patterns*.

Returns a three-tuple ``(goto_table, output_table, fail)`` that can be
passed directly to :func:`search`. Build once; reuse across many texts.

Time : O(m * |alphabet|) where m = total pattern characters
Space : O(m * |alphabet|)

>>> goto, out, fail = build_automaton(["foo", "bar"])
>>> len(goto) > 1
True
>>> isinstance(fail, list)
True
>>> goto, out, fail = build_automaton([x])
>>> goto
[{}]
>>> out
[set()]
>>> fail
[0]
"""
goto_table, output_table = _build_trie(patterns)
fail = _build_failure_links(goto_table, output_table)
return goto_table, output_table, fail


def search(
text: str,
goto_table: list[dict[str, int]],
output_table: list[set[str]],
fail: list[int],
) -> dict[str, list[int]]:
"""
Find all pattern occurrences in *text* using a pre-built automaton.

Returns a dict mapping each matched pattern to a sorted list of
0-based start indices in *text*.

Time : O(n + z) where n = len(text), z = total matches

>>> goto, out, fail = build_automaton(["he", "she", "his", "hers"])
>>> result = search("ushers", goto, out, fail)
>>> result["she"]
[1]
>>> result["he"]
[2]
>>> result["hers"]
[2]
>>> "his" in result
False

>>> goto, out, fail = build_automaton(["a", "aa", "aaa"])
>>> sorted(search("aaaa", goto, out, fail).items())
[('a', [0, 1, 2, 3]), ('aa', [0, 1, 2]), ('aaa', [0, 1])]

>>> goto, out, fail = build_automaton(["ab", "b"])
>>> sorted(search("abab", goto, out, fail).items())
[('ab', [0, 2]), ('b', [1, 3])]

>>> goto, out, fail = build_automaton(["xyz"])
>>> search("hello world", goto, out, fail)
{}

>>> goto, out, fail = build_automaton(["test"])
>>> search("", goto, out, fail)
{}
"""
matches: dict[str, list[int]] = {}
state = 0
for index, char in enumerate(text):
# Follow failure links until a valid transition or root is reached.
while state != 0 and char not in goto_table[state]:
state = fail[state]
state = goto_table[state].get(char, 0)
for pattern in output_table[state]:
start = index - len(pattern) + 1
matches.setdefault(pattern, []).append(start)
return matches


def search_all(text: str, patterns: list[str]) -> dict[str, list[int]]:
"""
One-shot convenience wrapper: build automaton from *patterns*, search *text*.

Prefer :func:`build_automaton` + :func:`search` when querying the same
pattern set across multiple texts (avoids rebuilding the automaton).

Time : O(m * |alphabet| + n + z)
Space : O(m * |alphabet|)

>>> search_all("mississippi", ["is", "ss", "ippi"])
{'is': [1, 4], 'ss': [2, 5], 'ippi': [7]}

>>> search_all("aababcabcd", ["a", "ab", "abc", "abcd"])
{'a': [0, 1, 3, 6], 'ab': [1, 3, 6], 'abc': [3, 6], 'abcd': [6]}

>>> search_all("she sells sea shells", ["she", "sea", "shells"])
{'she': [0, 14], 'sea': [10], 'shells': [14]}

>>> search_all("hello", [])
{}

>>> search_all("", ["a", "b"])
{}

>>> search_all("aaaa", ["a"])
{'a': [0, 1, 2, 3]}
"""
goto_table, output_table, fail = build_automaton(patterns)
return search(text, goto_table, output_table, fail)


if __name__ == "__main__":
import doctest

doctest.testmod()

demo_text = "she sells sea shells by the seashore"
demo_patterns = ["she", "sells", "sea", "shells", "shore"]
print(f"Text : {demo_text!r}")
print(f"Patterns: {demo_patterns}")
print("Matches :")
for pattern, positions in sorted(search_all(demo_text, demo_patterns).items()):
for pos in positions:
snippet = demo_text[pos : pos + len(pattern)]
print(f" {pattern!r:10s} at index {pos:2d} -> {snippet!r}")