In [74]:
import re
from pypinyin import pinyin, Style

# =========================
# 拼音 token
# =========================
INITIALS = [
    "zh", "ch", "sh",
    "b","p","m","f","d","t","n","l",
    "g","k","h","j","q","x",
    "r","z","c","s","y","w"
]

def get_shengmu(py):
    for ini in INITIALS:
        if py.startswith(ini):
            return ini
    return py[0] if py else py

FUZZY_MAP = {
    "zh": "z",
    "ch": "c",
    "sh": "s",
}

def expand_pinyin(py, use_initials=True, use_fuzzy=True):
    res = {py}

    sm = get_shengmu(py)

    # 声母索引（你现在要的“首字母”）
    if use_initials:
        res.add(sm)

    # 模糊音
    if use_fuzzy and sm in FUZZY_MAP:
        fuzzy_sm = FUZZY_MAP[sm]
        res.add(fuzzy_sm + py[len(sm):])  # zong
        res.add(fuzzy_sm)                 # z

    return res



def text_to_tokens(text, use_initials=True, use_fuzzy=True, split_chars=True):
    if split_chars and isinstance(text, str):
        text = list(text)

    pys = pinyin(text, style=Style.NORMAL, heteronym=True)
    tokens = []

    for ch, py_list in zip(text, pys):
        base = set()
        for py in py_list:
            base |= expand_pinyin(py, use_initials, use_fuzzy)

        tokens.append({
            "char": ch,
            "pinyins": base
        })

    return tokens





# =========================
# NFA 状态
# =========================
class State:
    def __init__(self):
        self.eps = set()
        self.trans = {}  # char -> set(State)
        self.accept = False


def epsilon_closure(states):
    stack = list(states)
    res = set(states)
    while stack:
        s = stack.pop()
        for e in s.eps:
            if e not in res:
                res.add(e)
                stack.append(e)
    return res

# =========================
# NFA 匹配工具
# =========================

def match_label(label, ch_org, ch):
    # 任意
    if label == '.':
        return True

    # 转义类
    if isinstance(label, str) and label.startswith('\\'):
        if label == r'\d':
            return ch.isdigit()
        if label == r'\w':
            return ch_org.isascii() and ch_org.isalnum()
        if label == r'\s':
            return ch.isspace()
        if label == r'\p':
            return True

    # 字符集合
    if isinstance(label, frozenset):
        return ch in label

    # 否定集合
    if isinstance(label, tuple) and label[0] == 'NEG':
        return ch not in label[1]

    # 普通字符
    return ch == label



# =========================
# Thompson 构造
# =========================
class Frag:
    def __init__(self, start, end):
        self.start = start
        self.end = end


def literal_frag(label):
    s1, s2 = State(), State()
    s1.trans.setdefault(label, set()).add(s2)
    return Frag(s1, s2)


def concat_frag(a, b):
    a.end.eps.add(b.start)
    return Frag(a.start, b.end)


def alt_frag(a, b):
    s, e = State(), State()
    s.eps |= {a.start, b.start}
    a.end.eps.add(e)
    b.end.eps.add(e)
    return Frag(s, e)


def star_frag(a):
    s, e = State(), State()
    s.eps |= {a.start, e}
    a.end.eps |= {a.start, e}
    return Frag(s, e)


def plus_frag(a):
    s, e = State(), State()
    s.eps.add(a.start)
    a.end.eps |= {a.start, e}
    return Frag(s, e)


def question_frag(a):
    s, e = State(), State()
    s.eps |= {a.start, e}
    a.end.eps.add(e)
    return Frag(s, e)


# =========================
# 正则解析（简单递归下降）
# =========================
class Parser:
    def __init__(self, pattern):
        self.p = pattern
        self.i = 0

    def peek(self):
        return self.p[self.i] if self.i < len(self.p) else None

    def get(self):
        c = self.peek()
        self.i += 1
        return c

    def parse(self):
        return self.parse_alt()

    def parse_alt(self):
        left = self.parse_seq()
        while self.peek() == '|':
            self.get()
            right = self.parse_seq()
            left = alt_frag(left, right)
        return left

    def parse_seq(self):
        frags = []
        while True:
            c = self.peek()
            if c is None or c in '|)':
                break
            frags.append(self.parse_repeat())
        if not frags:
            s = State()
            return Frag(s, s)
        res = frags[0]
        for f in frags[1:]:
            res = concat_frag(res, f)
        return res

    def parse_repeat(self):
        atom = self.parse_atom()
    
        while True:
            c = self.peek()
            if c is None or c not in '*+?':
                break
    
            op = self.get()
            if op == '*':
                atom = star_frag(atom)
            elif op == '+':
                atom = plus_frag(atom)
            elif op == '?':
                atom = question_frag(atom)
    
        return atom


    def parse_atom(self):
        c = self.get()
    
        if c == '\\':
            return literal_frag('\\' + self.get())
    
        if c == '[':
            return self.parse_charclass()
    
        if c == '(':
            frag = self.parse_alt()
            if self.peek() == ')':
                self.get()
            return frag
    
        if c == '.':
            return literal_frag('.')
    
        return literal_frag(c)


    def parse_charclass(self):
        negate = False
        chars = set()
    
        if self.peek() == '^':
            negate = True
            self.get()
    
        while self.peek() and self.peek() != ']':
            c = self.get()
    
            # 范围 a-z
            if self.peek() == '-' and self.p[self.i+1] != ']':
                self.get()  # -
                end = self.get()
                for code in range(ord(c), ord(end)+1):
                    chars.add(chr(code))
            else:
                chars.add(c)
    
        self.get()  # ]
    
        if negate:
            return literal_frag(('NEG', frozenset(chars)))
        else:
            return literal_frag(frozenset(chars))



def compile_regex(pattern):
    parser = Parser(pattern)
    frag = parser.parse()
    frag.end.accept = True
    return frag.start


# =========================
# 让 NFA 吃一个字符串
# =========================
def advance_states(states, ch_org, s):
    cur = epsilon_closure(states)

    for ch in s:
        nxt = set()
        for st in cur:
            for label, to_states in st.trans.items():
                if match_label(label, ch_org, ch):
                    nxt |= to_states
        cur = epsilon_closure(nxt)
        if not cur:
            break

    return cur



# =========================
# 拼音驱动 NFA
# =========================
def run_pinyin_regex(start_state, tokens):
    start_closure = epsilon_closure({start_state})
    current = set(start_closure)

    for token in tokens:
        next_states = set()
    
        for py in token["pinyins"]:
            st = advance_states(current, token["char"], py)
            # ⭐ 如果本 token 内已经到 accept，直接成功
            if any(s.accept for s in st):
                return True
            next_states |= st
    
        next_states |= start_closure
        current = next_states


    return any(st.accept for st in current)



# =========================
# 对外接口
# =========================
def pinyin_regex_match(
    pattern,
    text,
    use_initials=True,
    use_fuzzy=True,
    split_chars=True,
):
    start_state = compile_regex(pattern.lower())
    tokens = text_to_tokens(
        text,
        use_initials=use_initials,
        use_fuzzy=use_fuzzy,
        split_chars=split_chars
    )
    return run_pinyin_regex(start_state, tokens)

In [75]:
def run_tests():
    tests = [

        # =====================
        # 基础全拼
        # =====================
        ("yinyue", "音乐", True),
        ("yinle", "音乐", True),
        ("yue", "音乐", True),
        ("le", "音乐", True),

        # =====================
        # 多音字
        # =====================
        ("chongqing", "重庆", True),  # chong qing
        ("zhongqing", "重庆", True),  # zhong qing
        ("cq", "重庆", True),         # 首字母

        # =====================
        # 首字母模式
        # =====================
        ("yy", "音乐", True),
        ("yl", "音乐", True),
        ("yq", "重庆", False),

        # =====================
        # zh ch sh 模糊
        # =====================
        ("zong", "中", True),   # zhong → zong
        ("zhong", "中", True),
        ("shi", "是", True),
        ("si", "是", True),
        ("cheng", "成", True),
        ("ceng", "成", True),

        # =====================
        # 正则：或
        # =====================
        ("yin(yue|le)", "音乐", True),
        ("(yin|zhong)", "音乐", True),
        ("(zhong|yin)", "音乐", True),

        # =====================
        # 正则：*
        # =====================
        ("yin.*le", "音乐", True),
        (".*le", "音乐", True),
        ("yin.*", "音乐", True),

        # =====================
        # 正则：+
        # =====================
        ("y.+e", "音乐", True),
        ("y.+", "音乐", True),

        # =====================
        # 正则：?
        # =====================
        ("yi?n", "音", True),
        ("zh?ong", "中", True),

        # =====================
        # 子串匹配（不是 ^$）
        # =====================
        ("yue", "我的音乐很好听", True),
        ("yin", "纯音乐", True),

        # =====================
        # 不应匹配
        # =====================
        ("bei", "音乐", False),
        ("shanghai", "北京", False),
        ("zzz", "音乐", False),

        # =====================
        # 边界
        # =====================
        ("", "音乐", True),        # 空模式
        (".*", "音乐", True),
    ]

    ok = 0
    for pattern, text, expected in tests:
        result = pinyin_regex_match(pattern, text)
        status = "✅" if result == expected else "❌"
        print(f"{status}  {pattern!r} vs {text!r} → {result} (expected {expected})")
        if result == expected:
            ok += 1

    print(f"\n通过 {ok}/{len(tests)} 项测试")


run_tests()


✅  'yinyue' vs '音乐' → True (expected True)
✅  'yinle' vs '音乐' → True (expected True)
✅  'yue' vs '音乐' → True (expected True)
✅  'le' vs '音乐' → True (expected True)
✅  'chongqing' vs '重庆' → True (expected True)
✅  'zhongqing' vs '重庆' → True (expected True)
✅  'cq' vs '重庆' → True (expected True)
✅  'yy' vs '音乐' → True (expected True)
✅  'yl' vs '音乐' → True (expected True)
✅  'yq' vs '重庆' → False (expected False)
✅  'zong' vs '中' → True (expected True)
✅  'zhong' vs '中' → True (expected True)
✅  'shi' vs '是' → True (expected True)
✅  'si' vs '是' → True (expected True)
✅  'cheng' vs '成' → True (expected True)
✅  'ceng' vs '成' → True (expected True)
✅  'yin(yue|le)' vs '音乐' → True (expected True)
✅  '(yin|zhong)' vs '音乐' → True (expected True)
✅  '(zhong|yin)' vs '音乐' → True (expected True)
✅  'yin.*le' vs '音乐' → True (expected True)
✅  '.*le' vs '音乐' → True (expected True)
✅  'yin.*' vs '音乐' → True (expected True)
✅  'y.+e' vs '音乐' → True (expected True)
✅  'y.+' vs '音乐' → True (expected T

In [78]:
def run_regex_feature_tests():
    tests = [

        # =========================
        # \w 字母类
        # =========================
        (r"y\w+e", "音乐", False),      # y + 任意字母 + e
        (r"\w+", "音乐", False),        # 任意拼音
        (r"\w+le", "音乐", False),

        # =========================
        # \d 数字类（模拟带声调）
        # =========================
        (r"yin\d", "音1", True),
        (r"\d+", "123", True),
        (r"\d+", "音乐", False),

        # =========================
        # \p 任意拼音 token
        # =========================
        (r"\p\p", "音乐", True),       # 两个汉字
        (r"\p+", "音乐", True),

        # =========================
        # .
        # =========================
        (r"y.n", "音", True),          # yin
        (r".+", "音乐", True),

        # =========================
        # 字符集合 [abc]
        # =========================
        (r"[yl]in", "音", True),       # yin
        (r"[abc]in", "音", False),

        # =========================
        # 范围 [a-z]
        # =========================
        (r"y[a-z]+e", "音乐", True),
        (r"[a-z]+", "音乐", True),

        # =========================
        # 否定集合 [^abc]
        # =========================
        (r"[^z]hong", "中", False),   # zhong 以 z 开头
        (r"[^a]in", "音", True),

        # =========================
        # 组合测试
        # =========================
        (r"(yin|zhong)\w*", "音乐", True),
        (r"[yz]\w+", "音乐", False),
    ]

    ok = 0
    for pattern, text, expected in tests:
        result = pinyin_regex_match(pattern, text)
        status = "✅" if result == expected else "❌"
        print(f"{status} {pattern!r} vs {text!r} → {result} (expected {expected})")
        if result == expected:
            ok += 1

    print(f"\n通过 {ok}/{len(tests)} 项测试")


run_regex_feature_tests()


✅ 'y\\w+e' vs '音乐' → False (expected False)
✅ '\\w+' vs '音乐' → False (expected False)
✅ '\\w+le' vs '音乐' → False (expected False)
✅ 'yin\\d' vs '音1' → True (expected True)
✅ '\\d+' vs '123' → True (expected True)
✅ '\\d+' vs '音乐' → False (expected False)
✅ '\\p\\p' vs '音乐' → True (expected True)
✅ '\\p+' vs '音乐' → True (expected True)
✅ 'y.n' vs '音' → True (expected True)
✅ '.+' vs '音乐' → True (expected True)
✅ '[yl]in' vs '音' → True (expected True)
✅ '[abc]in' vs '音' → False (expected False)
✅ 'y[a-z]+e' vs '音乐' → True (expected True)
✅ '[a-z]+' vs '音乐' → True (expected True)
✅ '[^z]hong' vs '中' → False (expected False)
✅ '[^a]in' vs '音' → True (expected True)
✅ '(yin|zhong)\\w*' vs '音乐' → True (expected True)
✅ '[yz]\\w+' vs '音乐' → False (expected False)

通过 18/18 项测试
