In [1]:
import re
import random
import string

In [36]:
chars = string.ascii_lowercase[:2]

# 朴素算法

In [52]:
def naive_search(p, s):
    for i in range(len(s) - len(p) + 1):
        for j in range(len(p)):
            pos = i + j
            if p[j] != s[pos]:
                break
            elif j == len(p) - 1:
                return i
    return None  

# KMP算法

In [42]:
def prefix(p):
    ret = [0] * len(p)
    for i in range(1, len(p)):
        j = ret[i-1]
        while True:
            if p[i] == p[j]:
                ret[i] = j + 1 #! 要写j+1
                break
            else:
                if j == 0:
                    ret[i] = 0
                    break
                else:
                    j = ret[j-1]
    return ret

In [27]:
def kmp_search(p, s):
    ret = prefix(p)
    i = j = 0
    while j <= len(s) - 1: #! 都要-1，索引和长度差1
        if p[i] == s[j]:
            if i == len(p) - 1: #! 都要-1，索引和长度差1
                return j - i
            i += 1
            j += 1
        else:
            if i == 0:
                j += 1
            else:
                i = ret[i-1]
    else:
        return None

# 正确性检验

In [43]:
times = 10000
in_lst = []
for _ in range(times):
    s = ''.join(random.choices(chars, k=10000))
    p = ''.join(random.choices(chars, k=20))
    kmp_res = kmp_search(p, s)
    re_search = re.search(p, s)
    re_res = re_search.start() if re_search != None else None
    print('\r' + '■'*((_ + 1)//500) + '□'*((times - _ - 1)//500) + '\t' + '{}/{}'.format(_+1, times), end='')
    if kmp_res != None:
        in_lst.append(kmp_res)
    if kmp_res != re_res:
        print()
        print(p)
        print(s)
        break

■■■■■■■■■■■■■■■■■■■■	10000/10000

# 效率检验

In [89]:
%timeit naive_search(''.join(random.choices(chars, k=100)), ''.join(random.choices(chars, k=100000)))

135 ms ± 1.94 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [90]:
%timeit kmp_search(''.join(random.choices(chars, k=100)), ''.join(random.choices(chars, k=100000)))

114 ms ± 1.28 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [82]:
%timeit re.search(''.join(random.choices(chars, k=100)), ''.join(random.choices(chars, k=100000)))

27.7 ms ± 172 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [83]:
%timeit ''.join(random.choices(chars, k=100)) in ''.join(random.choices(chars, k=100000))

27.3 ms ± 276 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
