# CapsWriter-Offline 独立热词与纠错系统 (Portable Version)

本 Notebook 整合了音素处理 (algo_phoneme)、FastRAG 加速检索 (rag_fast)、拼音纠错 (PhonemeCorrector)、规则纠错 (RuleCorrector) 和纠错历史 RAG (RectificationRAG) 的完整逻辑。

**依赖安装：**
```bash
pip install pypinyin numba numpy
```

In [1]:
import os
import re
import sys
import time
import threading
import logging
from typing import List, Tuple, Dict, Set, Optional, Literal, NamedTuple
from dataclasses import dataclass
from collections import defaultdict
from difflib import SequenceMatcher

# 设置日志
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
logger = logging.getLogger(__name__)

# 尝试导入 Numba 和 numpy
try:
    from numba import njit
    import numpy as np
    HAS_NUMBA = True
except ImportError:
    HAS_NUMBA = False
    np = None

# 尝试导入 pypinyin
try:
    from pypinyin import pinyin, Style
    HAS_PYPINYIN = True
except ImportError:
    HAS_PYPINYIN = False
    pinyin = None
    Style = None
    print("WARNING: pypinyin 未安装，将使用字符级降级处理。")

## 1. 核心算法与数据结构
包含 `Phoneme` 类定义、文本规范化处理和音素转换逻辑。

In [2]:
@dataclass(frozen=True, slots=True)
class Phoneme:
    """带语言属性的音素"""
    value: str
    lang: Literal['zh', 'en', 'num', 'other']
    is_word_start: bool = False
    is_word_end: bool = False
    char_start: int = 0
    char_end: int = 0

    @property
    def is_tone(self) -> bool:
        return self.value.isdigit()

    @property
    def info(self) -> Tuple[str, str, bool, bool, bool, int, int]:
        return (self.value, self.lang, self.is_word_start, self.is_word_end, self.is_tone, self.char_start, self.char_end)

def normalize_text(text: str) -> str:
    """文本预处理：驼峰拆分、数字边界拆分、统一小写"""
    result = []
    prev_char = ''
    for char in text:
        if char.isalnum() or '\u4e00' <= char <= '\u9fff':
            if char.isupper() and prev_char.islower():
                result.append(' ')
            elif char.isdigit() and prev_char.isalpha():
                result.append(' ')
            elif char.isalpha() and prev_char.isdigit():
                result.append(' ')
            result.append(char.lower())
            prev_char = char
        else:
            if result and result[-1] != ' ':
                result.append(' ')
            prev_char = ''
    return ''.join(result).strip()

def split_mixed_label(input_str: str) -> List[str]:
    """将混合字符串切分为 token 列表"""
    tokens = []
    s = input_str.lower()
    while len(s) > 0:
        if s[0] == ' ':
            s = s[1:]; continue
        match = re.match(r'[a-z]+', s)
        if match:
            tokens.append(match.group(0)); s = s[len(match.group(0)):]; continue
        match = re.match(r'[0-9]+', s)
        if match:
            tokens.append(match.group(0)); s = s[len(match.group(0)):]; continue
        tokens.append(s[0]); s = s[1:]
    return tokens

def get_phoneme_info(text: str, ascii_split_char: bool = True) -> List[Phoneme]:
    """获取带位置和详细属性的音素序列"""
    if not HAS_PYPINYIN:
        return [Phoneme(c, 'zh', is_word_start=True, is_word_end=True, char_start=i, char_end=i+1) for i, c in enumerate(text)]

    phoneme_seq: List[Phoneme] = []
    pos = 0
    while pos < len(text):
        char = text[pos]
        if '\u4e00' <= char <= '\u9fff':
            start = pos
            scan = pos + 1
            while scan < len(text) and '\u4e00' <= text[scan] <= '\u9fff': scan += 1
            frag = text[start:scan]
            try:
                py_inits = pinyin(frag, style=Style.INITIALS, strict=False, errors='ignore')
                py_fins = pinyin(frag, style=Style.FINALS, strict=False, errors='ignore')
                py_tones = pinyin(frag, style=Style.TONE3, neutral_tone_with_five=True, errors='ignore')
                for i in range(min(len(frag), len(py_inits), len(py_fins), len(py_tones))):
                    idx = start + i
                    init, fin, tone = py_inits[i][0], py_fins[i][0], py_tones[i][0]
                    if init: phoneme_seq.append(Phoneme(init, 'zh', is_word_start=True, char_start=idx, char_end=idx+1))
                    if fin: phoneme_seq.append(Phoneme(fin, 'zh', is_word_start=not init, char_start=idx, char_end=idx+1))
                    if tone and tone[-1].isdigit(): phoneme_seq.append(Phoneme(tone[-1], 'zh', is_word_end=True, char_start=idx, char_end=idx+1))
            except:
                for i, c in enumerate(frag): phoneme_seq.append(Phoneme(c, 'zh', is_word_start=True, is_word_end=True, char_start=start+i, char_end=start+i+1))
            pos = scan
        elif char.isalnum():
            start = pos
            while pos < len(text) and text[pos].isalnum():
                if pos > start:
                    p, c = text[pos-1], text[pos]
                    if (p.islower() and c.isupper()) or (p.isalpha() and c.isdigit()) or (p.isdigit() and c.isalpha()): break
                pos += 1
            token = text[start:pos].lower()
            lang = 'num' if token.isdigit() else 'en'
            if ascii_split_char:
                for i, c in enumerate(token): phoneme_seq.append(Phoneme(c, lang, is_word_start=(i==0), is_word_end=(i==len(token)-1), char_start=start+i, char_end=start+i+1))
            else:
                phoneme_seq.append(Phoneme(token, lang, is_word_start=True, is_word_end=True, char_start=start, char_end=pos))
        else:
            pos += 1
    return phoneme_seq

def get_phoneme_seq(text: str) -> List[Phoneme]:
    """简单音素序列获取，不包含详细位置信息"""
    normalized = normalize_text(text)
    if not HAS_PYPINYIN:
        return [Phoneme(c, 'zh', is_word_start=True, is_word_end=True) for c in normalized.split()]
    
    seq = []
    for token in split_mixed_label(normalized):
        if re.match(r'^[a-z0-9]+$', token):
            seq.append(Phoneme(token, 'num' if token.isdigit() else 'en', is_word_start=True, is_word_end=True))
        elif len(token) == 1:
            try:
                py_t3 = pinyin(token, style=Style.TONE3, strict=False)
                if not py_t3: seq.append(Phoneme(token, 'zh', is_word_start=True, is_word_end=True)); continue
                init = pinyin(token, style=Style.INITIALS, strict=False)[0][0]
                fin = pinyin(token, style=Style.FINALS, strict=False)[0][0]
                tone = py_t3[0][0][-1] if py_t3[0][0][-1].isdigit() else '0'
                if init: seq.append(Phoneme(init, 'zh', is_word_start=True))
                if fin: seq.append(Phoneme(fin, 'zh', is_word_start=not init))
                seq.append(Phoneme(tone, 'zh', is_word_end=True))
            except:
                seq.append(Phoneme(token, 'zh', is_word_start=True, is_word_end=True))
    return seq

## 2. 相似度与检索算法
包含模糊子串匹配得分以及 FastRAG 高性能粗筛模块。

In [3]:
SIMILAR_PHONEMES = [{'an','ang'},{'en','eng'},{'in','ing'},{'ian','iang'},{'uan','uang'},{'z','zh'},{'c','ch'},{'s','sh'},{'l','n'},{'f','h'},{'ai','ei'}]

def _get_tuple_cost(t1: Tuple, t2: Tuple) -> float:
    if t1[1] != t2[1]: return 1.0
    if t1[0] == t2[0]: return 0.0
    if t1[1] == 'zh':
        pair = {t1[0], t2[0]}
        for s in SIMILAR_PHONEMES:
            if pair.issubset(s): return 0.5
    if t1[1] == 'en':
        m, n = len(t1[0]), len(t2[0])
        if m == 0 or n == 0: return 1.0
        dp_prev, dp_curr = [0]*(n+1), [0]*(n+1)
        for i in range(1, m+1):
            for j in range(1, n+1):
                dp_curr[j] = dp_prev[j-1]+1 if t1[0][i-1]==t2[0][j-1] else max(dp_prev[j], dp_curr[j-1])
            dp_prev, dp_curr = dp_curr, dp_prev
        return 1.0 - (dp_prev[n] / max(m, n))
    return 1.0

def fuzzy_substring_distance(hw_info: List[Tuple], input_info: List[Tuple]) -> float:
    n, m = len(hw_info), len(input_info)
    if n == 0: return 0.0
    if m == 0: return float(n)
    prev, curr = [0.0]*(m+1), [0.0]*(m+1)
    for i in range(1, n+1):
        curr[0] = float(i)
        for j in range(1, m+1):
            cost = _get_tuple_cost(hw_info[i-1], input_info[j-1])
            curr[j] = min(prev[j]+1.0, curr[j-1]+1.0, prev[j-1]+cost)
        prev, curr = curr, prev
    return min(prev)

def fuzzy_substring_score(hw_info: List[Tuple], input_info: List[Tuple]) -> float:
    n = len(hw_info)
    if n == 0: return 0.0
    dist = fuzzy_substring_distance(hw_info, input_info)
    return max(0.0, 1.0 - (dist / n))

if HAS_NUMBA:
    @njit(cache=True)
    def _fuzzy_substring_distance_numba(main_codes: np.ndarray, sub_codes: np.ndarray) -> float:
        n, m = len(sub_codes), len(main_codes)
        if n == 0 or m == 0: return float(n)
        dp = np.zeros((n + 1, m + 1), dtype=np.float32)
        for i in range(1, n + 1): dp[i, 0] = float(i)
        for i in range(1, n + 1):
            for j in range(1, m + 1):
                cost = 0.0 if sub_codes[i-1] == main_codes[j-1] else 1.0
                dp[i, j] = min(dp[i-1, j] + 1.0, dp[i, j-1] + 1.0, dp[i-1, j-1] + cost)
        min_dist = dp[n, 1]
        for j in range(2, m + 1): 
            if dp[n, j] < min_dist: min_dist = dp[n, j]
        return min_dist

class PhonemeEncoder:
    def __init__(self): self.p2c = {}; self.next = 1
    def encode(self, p: str):
        if p not in self.p2c: self.p2c[p] = self.next; self.next += 1
        return self.p2c[p]
    def encode_seq(self, ps: List[str]):
        if not HAS_NUMBA: return [self.encode(p) for p in ps]
        return np.array([self.encode(p) for p in ps], dtype=np.int32)

class FastRAG:
    def __init__(self, threshold: float = 0.6):
        self.threshold = threshold
        self.encoder = PhonemeEncoder()
        self.index = defaultdict(list)
        self.hotword_count = 0
    def add_hotwords(self, hotwords: Dict[str, List[Phoneme]]):
        for hw, phons in hotwords.items():
            if not phons: continue
            codes = self.encoder.encode_seq([p.value for p in phons])
            idx_pos = [0]
            if phons[0].lang == 'en': idx_pos = list(range(min(len(phons), 2)))
            for p in {codes[i] for i in idx_pos if i < len(codes)}:
                self.index[p].append((hw, codes))
            self.hotword_count += 1
    def search(self, input_phons: List[Phoneme], top_k: int = 10) -> List[Tuple[str, float]]:
        if not input_phons: return []
        input_codes = self.encoder.encode_seq([p.value for p in input_phons])
        input_p_set = set(input_codes)
        candidates = []
        seen = set()
        for code in input_p_set:
            if code in self.index:
                for hw, codes in self.index[code]:
                    if hw not in seen: candidates.append((hw, codes)); seen.add(hw)
        results = []
        input_len = len(input_codes)
        for hw, hw_codes in candidates:
            if len(hw_codes) > input_len + 3: continue
            if HAS_NUMBA: dist = _fuzzy_substring_distance_numba(input_codes, hw_codes)
            else:
                ns, mm = len(hw_codes), len(input_codes)
                dp = [[float(i) if j==0 else 0.0 for j in range(mm+1)] for i in range(ns+1)]
                for i in range(1, ns+1):
                    for j in range(1, mm+1):
                        cost = 0.0 if hw_codes[i-1] == input_codes[j-1] else 1.0
                        dp[i][j] = min(dp[i-1][j]+1.0, dp[i][j-1]+1.0, dp[i-1][j-1]+cost)
                dist = min(dp[ns][1:])
            score = 1.0 - (dist / len(hw_codes))
            if score >= self.threshold: results.append((hw, round(score, 3)))
        results.sort(key=lambda x: x[1], reverse=True)
        return results[:top_k]

## 3. 纠错组件
包含 `PhonemeCorrector`, `RuleCorrector` 和 `RectificationRAG` 三个主要纠错类。

In [4]:
class MatchResult(NamedTuple):
    start: int; end: int; score: float; hotword: str

class CorrectionResult(NamedTuple):
    text: str; matchs: List[Tuple[str, float]]; similars: List[Tuple[str, float]]

class PhonemeCorrector:
    def __init__(self, threshold: float = 0.7):
        self.threshold = threshold
        self.similar_threshold = threshold - 0.2
        self.hotwords: Dict[str, List[Phoneme]] = {}
        self.fast_rag = FastRAG(threshold=self.similar_threshold - 0.1)
        self._lock = threading.Lock()

    def update_hotwords(self, hotword_text: str) -> int:
        lines = [l.strip() for l in hotword_text.splitlines() if l.strip() and not l.strip().startswith('#')]
        new_hotwords = {}
        for hw in lines:
            phons = get_phoneme_info(hw)
            if phons: new_hotwords[hw] = phons
        with self._lock:
            self.hotwords = new_hotwords
            self.fast_rag = FastRAG(threshold=self.similar_threshold - 0.1)
            self.fast_rag.add_hotwords(new_hotwords)
        return len(new_hotwords)

    def correct(self, text: str) -> CorrectionResult:
        if not text or not self.hotwords: return CorrectionResult(text, [], [])
        input_phons = get_phoneme_info(text)
        if not input_phons: return CorrectionResult(text, [], [])
        with self._lock:
            fast_results = self.fast_rag.search(input_phons, top_k=100)
            input_processed = [p.info for p in input_phons]
            matches = []; similars = []; input_len = len(input_processed)
            for hw, _ in fast_results:
                hw_phons = self.hotwords[hw]
                hw_compare = [p.info[:5] for p in hw_phons]
                target_len = len(hw_compare)
                if target_len > input_len: continue
                for i in range(input_len - target_len + 1):
                    sub_seg = input_processed[i : i + target_len]
                    if sub_seg[0][1] != 'en' and sub_seg[0][0] != hw_compare[0][0]: continue
                    if not sub_seg[0][2]: continue
                    is_end_ok = sub_seg[-1][3] or (i+target_len < input_len and input_processed[i+target_len][1]=='zh' and input_processed[i+target_len][4] and input_processed[i+target_len][3])
                    if not is_end_ok: continue
                    score = fuzzy_substring_score(hw_compare, sub_seg)
                    char_start, char_end = sub_seg[0][5], sub_seg[-1][6]
                    similars.append(MatchResult(char_start, char_end, score, hw))
                    if score >= self.threshold: matches.append(MatchResult(char_start, char_end, score, hw))
        similars.sort(key=lambda x: x.score, reverse=True)
        seen_sim = set(); top_sim = []
        for s in similars:
            if s.hotword not in seen_sim: top_sim.append((s.hotword, s.score)); seen_sim.add(s.hotword)
        matches.sort(key=lambda x: (x.score, x.end - x.start), reverse=True)
        final_matches = []; occupied = []
        for m in matches:
            if any(not (m.end <= rs or m.start >= re) for rs, re in occupied): continue
            if text[m.start:m.end] != m.hotword: final_matches.append(m)
            occupied.append((m.start, m.end))
        final_matches.sort(key=lambda x: x.start, reverse=True)
        res_list = list(text)
        for m in final_matches: res_list[m.start : m.end] = list(m.hotword)
        return CorrectionResult("".join(res_list), [(m.hotword, m.score) for m in final_matches], top_sim[:5])

class RuleCorrector:
    def __init__(self): self.patterns: Dict[str, str] = {}; self._lock = threading.Lock()
    def update_rules(self, rule_text: str) -> int:
        new_patterns = {}
        for line in rule_text.splitlines():
            line = line.strip()
            if not line or line.startswith('#'): continue
            parts = line.split('=')
            if len(parts) == 2: new_patterns[parts[0].strip()] = parts[1].strip()
        with self._lock: self.patterns = new_patterns
        return len(new_patterns)
    def substitute(self, text: str) -> str:
        if not text or not self.patterns: return text
        with self._lock: patterns = self.patterns.copy()
        result = text
        for p, r in patterns.items():
            try: result = re.sub(p, r, result)
            except: pass
        return result

@dataclass
class RectifyRecord:
    wrong: str; right: str; fragments: List[str]; frag_phons: Dict[str, List[Phoneme]]

class RectificationRAG:
    def __init__(self, threshold: float = 0.5):
        self.threshold = threshold
        self.records: List[RectifyRecord] = []; self._lock = threading.Lock()
    def load_rectify_text(self, text: str):
        new_records = []
        for block in text.split('---'):
            lines = [l.strip() for l in block.splitlines() if l.strip() and not l.strip().startswith('#')]
            if len(lines) >= 2:
                wrong, right = lines[0], lines[1]
                w_b = self._get_word_bounds(wrong); r_b = self._get_word_bounds(right)
                sm = SequenceMatcher(None, [b[2] for b in w_b], [b[2] for b in r_b])
                frags = []
                for tag, i1, i2, j1, j2 in sm.get_opcodes():
                    if tag in ('replace', 'delete') and i2 > i1: frags.append(wrong[w_b[i1][0] : w_b[i2-1][1]])
                    if tag in ('replace', 'insert') and j2 > j1: frags.append(right[r_b[j1][0] : r_b[j2-1][1]])
                if not frags: frags = [wrong]
                frags = list(dict.fromkeys(frags))
                frag_phons = {f: get_phoneme_seq(f) for f in frags}
                new_records.append(RectifyRecord(wrong, right, frags, frag_phons))
        with self._lock: self.records = new_records
    def _get_word_bounds(self, text: str):
        bounds = []; i = 0; n = len(text)
        while i < n:
            if not (text[i].isalnum() or '\u4e00' <= text[i] <= '\u9fff'): i += 1; continue
            s = i
            if '\u4e00' <= text[i] <= '\u9fff': i += 1
            else:
                while i < n and text[i].isalnum():
                    if i > s and text[i].isupper() and text[i-1].islower(): break
                    i += 1
            bounds.append((s, i, text[s:i]))
        return bounds
    def search(self, text: str, top_k: int = 5) -> List[Tuple[str, str, float]]:
        if not text or not self.records: return []
        input_phons = [p.info for p in get_phoneme_seq(text)]
        if not input_phons: return []
        results = []
        with self._lock:
            for rec in self.records:
                score = 0.0
                for f_phon in rec.frag_phons.values():
                    if not f_phon: continue
                    score = max(score, fuzzy_substring_score([p.info for p in f_phon], input_phons))
                if score >= self.threshold: results.append((rec.wrong, rec.right, round(score, 3)))
        results.sort(key=lambda x: x[2], reverse=True)
        return results[:top_k]

## 4. 综合演示
准备演示数据并执行纠错测试。

In [5]:
# --- A. 数据准备 ---

hotwords_data = """
    Claude
    Bilibili
    Microsoft
    买当劳
    肯德基
    # 这是一个注释
    VsCode
    VsCodes
"""
    
rectify_data = """
# 纠错历史演示
把那个锯子给我
把那个句子给我
---
cloud code is good
Claude Code is good
---
今天天其不错
今天天气不错
"""


test_cases_text = """
我想去吃买当劳和肯得鸡
Hello klaude
喜欢刷Bili Bili
请把那个锯子发给我一下
今天天及真的很好
I think klaud code is very good
"""
cases = [l.strip() for l in test_cases_text.strip().split('\n') if l.strip()]

In [6]:
# --- B. 系统初始化与数据加载 ---

# 初始化纠错器和检索器
corrector = PhonemeCorrector(threshold=0.8)
rectifier = RectificationRAG(threshold=0.5)

# 从字符串加载热词
corrector.update_hotwords(hotwords_data)
rectifier.load_rectify_text(rectify_data)

# 从文本文件加载热词
# corrector.load_hotwords_file("hot.txt")
# rectifier.load_rectify_file("hot-rectify.txt")

In [7]:
# --- C. 执行综合纠错演示 ---
print("\n" + "="*50)
print("【 CapsWriter-Offline 综合纠错系统演示 】")
print("="*50)

for i, t in enumerate(cases):
    print(f"\nCase {i+1}: '{t}'")
    res, matched, similars = corrector.correct(t)
    print(f"  [纠错结果] {res}")
    if matched: print(f"  [匹配热词] {matched}")
    if similars: print(f"  [相似推荐] {similars}")
    rag_results = rectifier.search(t)
    if rag_results:
        print(f"  [RAG 相似历史]")
        for wrong, right, score in rag_results:
            print(f"    - '{wrong}' => '{right}' (相似度: {score:.3f})")


【 CapsWriter-Offline 综合纠错系统演示 】

Case 1: '我想去吃买当劳和肯得鸡'
  [纠错结果] 我想去吃买当劳和肯德基
  [匹配热词] [('肯德基', 1.0)]
  [相似推荐] [('买当劳', 1.0), ('肯德基', 1.0)]
  [RAG 相似历史]
    - '把那个锯子给我' => '把那个句子给我' (相似度: 0.667)
    - '今天天其不错' => '今天天气不错' (相似度: 0.667)

Case 2: 'Hello klaude'
  [纠错结果] Hello Claude
  [匹配热词] [('Claude', 0.8333333333333334)]
  [相似推荐] [('Claude', 0.8333333333333334)]

Case 3: '喜欢刷Bili Bili'
  [纠错结果] 喜欢刷Bilibili
  [匹配热词] [('Bilibili', 1.0)]
  [相似推荐] [('Bilibili', 1.0)]

Case 4: '请把那个锯子发给我一下'
  [纠错结果] 请把那个锯子发给我一下
  [RAG 相似历史]
    - '把那个锯子给我' => '把那个句子给我' (相似度: 1.000)

Case 5: '今天天及真的很好'
  [纠错结果] 今天天及真的很好
  [RAG 相似历史]
    - '今天天其不错' => '今天天气不错' (相似度: 0.667)

Case 6: 'I think klaud code is very good'
  [纠错结果] I think klaud code is very good
  [相似推荐] [('VsCode', 0.6666666666666667), ('Claude', 0.5)]
  [RAG 相似历史]
    - 'cloud code is good' => 'Claude Code is good' (相似度: 0.833)
