Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 93 additions & 107 deletions src/string/rabin_karp.rs
Original file line number Diff line number Diff line change
@@ -1,137 +1,123 @@
const MODULUS: u16 = 101;
const BASE: u16 = 256;

pub fn rabin_karp(target: &str, pattern: &str) -> Vec<usize> {
// Quick exit
if target.is_empty() || pattern.is_empty() || pattern.len() > target.len() {
//! This module implements the Rabin-Karp string searching algorithm.
//! It uses a rolling hash technique to find all occurrences of a pattern
//! within a target string efficiently.

const MOD: usize = 101;
const RADIX: usize = 256;

/// Finds all starting indices where the `pattern` appears in the `text`.
///
/// # Arguments
/// * `text` - The string where the search is performed.
/// * `pattern` - The substring pattern to search for.
///
/// # Returns
/// A vector of starting indices where the pattern is found.
pub fn rabin_karp(text: &str, pattern: &str) -> Vec<usize> {
if text.is_empty() || pattern.is_empty() || pattern.len() > text.len() {
return vec![];
}

let pattern_hash = hash(pattern);
let pat_hash = compute_hash(pattern);
let mut radix_pow = 1;

// Pre-calculate BASE^(n-1)
let mut pow_rem: u16 = 1;
// Compute RADIX^(n-1) % MOD
for _ in 0..pattern.len() - 1 {
pow_rem *= BASE;
pow_rem %= MODULUS;
radix_pow = (radix_pow * RADIX) % MOD;
}

let mut rolling_hash = 0;
let mut ret = vec![];
for i in 0..=target.len() - pattern.len() {
let mut result = vec![];
for i in 0..=text.len() - pattern.len() {
rolling_hash = if i == 0 {
hash(&target[0..pattern.len()])
compute_hash(&text[0..pattern.len()])
} else {
recalculate_hash(target, i - 1, i + pattern.len() - 1, rolling_hash, pow_rem)
update_hash(text, i - 1, i + pattern.len() - 1, rolling_hash, radix_pow)
};
if rolling_hash == pattern_hash && pattern[..] == target[i..i + pattern.len()] {
ret.push(i);
if rolling_hash == pat_hash && pattern[..] == text[i..i + pattern.len()] {
result.push(i);
}
}
ret
result
}

// hash(s) is defined as BASE^(n-1) * s_0 + BASE^(n-2) * s_1 + ... + BASE^0 * s_(n-1)
fn hash(s: &str) -> u16 {
let mut res: u16 = 0;
for &c in s.as_bytes().iter() {
res = (res * BASE % MODULUS + c as u16) % MODULUS;
}
res
/// Calculates the hash of a string using the Rabin-Karp formula.
///
/// # Arguments
/// * `s` - The string to calculate the hash for.
///
/// # Returns
/// The hash value of the string modulo `MOD`.
fn compute_hash(s: &str) -> usize {
let mut hash_val = 0;
for &byte in s.as_bytes().iter() {
hash_val = (hash_val * RADIX + byte as usize) % MOD;
}
hash_val
}

// new_hash = (old_hash - BASE^(n-1) * s_(i-n)) * BASE + s_i
fn recalculate_hash(
/// Updates the rolling hash when shifting the search window.
///
/// # Arguments
/// * `s` - The full text where the search is performed.
/// * `old_idx` - The index of the character that is leaving the window.
/// * `new_idx` - The index of the new character entering the window.
/// * `old_hash` - The hash of the previous substring.
/// * `radix_pow` - The precomputed value of RADIX^(n-1) % MOD.
///
/// # Returns
/// The updated hash for the new substring.
fn update_hash(
s: &str,
old_index: usize,
new_index: usize,
old_hash: u16,
pow_rem: u16,
) -> u16 {
old_idx: usize,
new_idx: usize,
old_hash: usize,
radix_pow: usize,
) -> usize {
let mut new_hash = old_hash;
let (old_ch, new_ch) = (
s.as_bytes()[old_index] as u16,
s.as_bytes()[new_index] as u16,
);
new_hash = (new_hash + MODULUS - pow_rem * old_ch % MODULUS) % MODULUS;
new_hash = (new_hash * BASE + new_ch) % MODULUS;
let old_char = s.as_bytes()[old_idx] as usize;
let new_char = s.as_bytes()[new_idx] as usize;
new_hash = (new_hash + MOD - (old_char * radix_pow % MOD)) % MOD;
new_hash = (new_hash * RADIX + new_char) % MOD;
new_hash
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn hi_hash() {
let hash_result = hash("hi");
assert_eq!(hash_result, 65);
}

#[test]
fn abr_hash() {
let hash_result = hash("abr");
assert_eq!(hash_result, 4);
}

#[test]
fn bra_hash() {
let hash_result = hash("bra");
assert_eq!(hash_result, 30);
}

// Attribution to @pgimalac for his tests from Knuth-Morris-Pratt
#[test]
fn each_letter_matches() {
let index = rabin_karp("aaa", "a");
assert_eq!(index, vec![0, 1, 2]);
}

#[test]
fn a_few_separate_matches() {
let index = rabin_karp("abababa", "ab");
assert_eq!(index, vec![0, 2, 4]);
}

#[test]
fn one_match() {
let index = rabin_karp("ABC ABCDAB ABCDABCDABDE", "ABCDABD");
assert_eq!(index, vec![15]);
}

#[test]
fn lots_of_matches() {
let index = rabin_karp("aaabaabaaaaa", "aa");
assert_eq!(index, vec![0, 1, 4, 7, 8, 9, 10]);
}

#[test]
fn lots_of_intricate_matches() {
let index = rabin_karp("ababababa", "aba");
assert_eq!(index, vec![0, 2, 4, 6]);
}

#[test]
fn not_found0() {
let index = rabin_karp("abcde", "f");
assert_eq!(index, vec![]);
}

#[test]
fn not_found1() {
let index = rabin_karp("abcde", "ac");
assert_eq!(index, vec![]);
}

#[test]
fn not_found2() {
let index = rabin_karp("ababab", "bababa");
assert_eq!(index, vec![]);
macro_rules! test_cases {
($($name:ident: $inputs:expr,)*) => {
$(
#[test]
fn $name() {
let (text, pattern, expected) = $inputs;
assert_eq!(rabin_karp(text, pattern), expected);
}
)*
};
}

#[test]
fn empty_string() {
let index = rabin_karp("", "abcdef");
assert_eq!(index, vec![]);
test_cases! {
single_match_at_start: ("hello world", "hello", vec![0]),
single_match_at_end: ("hello world", "world", vec![6]),
single_match_in_middle: ("abc def ghi", "def", vec![4]),
multiple_matches: ("ababcabc", "abc", vec![2, 5]),
overlapping_matches: ("aaaaa", "aaa", vec![0, 1, 2]),
no_match: ("abcdefg", "xyz", vec![]),
pattern_is_entire_string: ("abc", "abc", vec![0]),
target_is_multiple_patterns: ("abcabcabc", "abc", vec![0, 3, 6]),
empty_text: ("", "abc", vec![]),
empty_pattern: ("abc", "", vec![]),
empty_text_and_pattern: ("", "", vec![]),
pattern_larger_than_text: ("abc", "abcd", vec![]),
large_text_small_pattern: (&("a".repeat(1000) + "b"), "b", vec![1000]),
single_char_match: ("a", "a", vec![0]),
single_char_no_match: ("a", "b", vec![]),
large_pattern_no_match: ("abc", "defghi", vec![]),
repeating_chars: ("aaaaaa", "aa", vec![0, 1, 2, 3, 4]),
special_characters: ("abc$def@ghi", "$def@", vec![3]),
numeric_and_alphabetic_mix: ("abc123abc456", "123abc", vec![3]),
case_sensitivity: ("AbcAbc", "abc", vec![]),
}
}