# CapsWriter-Offline 热词与纠错系统演示

本 Notebook 整合了 `PhonemeCorrector` (音素纠错) 和 `RectificationRAG` (纠错历史检索) 的核心逻辑，是一个完全自包含的演示环境。

需要依赖：

```
pypinyin
numba
numpy
```

In [37]:
import re
import time
import numpy as np
import threading
import logging
from typing import List, Tuple, Dict, Set, Optional, Literal
from dataclasses import dataclass
from difflib import SequenceMatcher
from collections import defaultdict
from pathlib import Path

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
logger = logging.getLogger("hotword")

# 尝试导入依赖
try:
    from pypinyin import pinyin, Style
    HAS_PYPINYIN = True
except ImportError:
    HAS_PYPINYIN = False
    print("Warning: pypinyin not found. Use 'pip install pypinyin' for full functionality.")

try:
    from numba import njit
    HAS_NUMBA = True
except ImportError:
    HAS_NUMBA = False
    def njit(func): return func
    print("Note: numba not found. Running in pure Python mode.")

## 1. 基础数据结构与音素处理

In [38]:
@dataclass
class Phoneme:
    value: str
    lang: str  # 'zh', 'en', 'num'
    is_word_start: bool = False
    is_word_end: bool = False
    char_start: int = 0
    char_end: int = 0

    @property
    def is_tone(self):
        return self.lang == 'zh' and self.value.isdigit()

    @property
    def info(self):
        return (self.value, self.lang, self.is_word_start, self.is_word_end, self.is_tone, self.char_start, self.char_end)

def normalize_text(text: str) -> str:
    text = re.sub(r'([a-z])([A-Z])', r'\1 \2', text)
    text = re.sub(r'([0-9])([a-zA-Z])', r'\1 \2', text)
    text = re.sub(r'([a-zA-Z])([0-9])', r'\1 \2', text)
    text = re.sub(r'[^a-zA-Z0-9\u4e00-\u9fff\s]', ' ', text)
    return text.lower().strip()

def split_mixed_label(input_str: str) -> List[str]:
    tokens = []
    for part in input_str.split():
        i, n = 0, len(part)
        while i < n:
            if '\u4e00' <= part[i] <= '\u9fff':
                tokens.append(part[i])
                i += 1
            else:
                start = i
                while i < n and not ('\u4e00' <= part[i] <= '\u9fff'):
                    i += 1
                tokens.append(part[start:i])
    return tokens

def _zh_char_to_phonemes(char: str) -> List[Phoneme]:
    if not HAS_PYPINYIN: return [Phoneme(char, 'zh', True, True)]
    try:
        py_initials = pinyin(char, style=Style.INITIALS, strict=False)
        py_finals = pinyin(char, style=Style.FINALS, strict=False)
        py_tone3 = pinyin(char, style=Style.TONE3, strict=False)
        if not py_tone3 or not py_tone3[0][0]:
             return [Phoneme(char, 'zh', True, True)]
        res = []
        has_init = py_initials and py_initials[0][0]
        if has_init: res.append(Phoneme(py_initials[0][0], 'zh', is_word_start=True))
        if py_finals and py_finals[0][0]: res.append(Phoneme(py_finals[0][0], 'zh', is_word_start=not has_init))
        py_val = py_tone3[0][0]
        tone = py_val[-1] if py_val[-1].isdigit() else '0'
        res.append(Phoneme(tone, 'zh', is_word_end=True))
        return res
    except:
        return [Phoneme(char, 'zh', True, True)]

def get_phoneme_info(text: str) -> List[Phoneme]:
    phoneme_seq = []
    pos = 0
    while pos < len(text):
        char = text[pos]
        if '\u4e00' <= char <= '\u9fff':
            phons = _zh_char_to_phonemes(char)
            for p in phons:
                p.char_start, p.char_end = pos, pos + 1
                phoneme_seq.append(p)
            pos += 1
        elif char.isalnum():
            start = pos
            while pos < len(text) and text[pos].isalnum(): pos += 1
            token = text[start:pos].lower()
            lang = 'num' if token.isdigit() else 'en'
            # 简单处理：英文也切分为字母以便模糊匹配
            for i, c in enumerate(token):
                phoneme_seq.append(Phoneme(c, lang, i==0, i==len(token)-1, start, pos))
        else: pos += 1
    return phoneme_seq

def get_phoneme_seq(text: str) -> List[Phoneme]:
    # 为 Rectification 使用的简单版本
    normalized = normalize_text(text)
    res = []
    for token in split_mixed_label(normalized):
        if re.match(r'^[a-z0-9]+$', token):
            lang = 'num' if token.isdigit() else 'en'
            res.extend([Phoneme(c, lang, True, True) for c in token])
        else:
            res.extend(_zh_char_to_phonemes(token))
    return res

## 2. 核心算法 (编辑距离与打分)

In [39]:
@njit
def _fuzzy_substring_distance_numba(main_codes: np.ndarray, sub_codes: np.ndarray) -> float:
    n, m = len(sub_codes), len(main_codes)
    if n == 0: return 0.0
    dp = np.zeros((n + 1, m + 1))
    for i in range(n + 1): dp[i, 0] = i
    for j in range(m + 1): dp[0, j] = 0
    for i in range(1, n + 1):
        for j in range(1, m + 1):
            cost = 0.0 if sub_codes[i-1] == main_codes[j-1] else 1.0
            dp[i, j] = min(dp[i-1, j] + 1.0, dp[i, j-1] + 1.0, dp[i-1, j-1] + cost)
    return np.min(dp[n, :])

def fast_substring_score(hw_compare: List[Tuple], sub_seg: List[Tuple]) -> float:
    n, m = len(hw_compare), len(sub_seg)
    if n == 0: return 0.0
    hw_vals = [x[0] for x in hw_compare]
    sub_vals = [x[0] for x in sub_seg]
    dp = [[0.0]*(m+1) for _ in range(n+1)]
    for i in range(n+1): dp[i][0] = i
    for j in range(m+1): dp[0][j] = 0
    for i in range(1, n+1):
        for j in range(1, m+1):
            cost = 0.0 if hw_vals[i-1] == sub_vals[j-1] else 1.0
            dp[i][j] = min(dp[i-1][j]+1, dp[i][j-1]+1, dp[i-1][j-1]+cost)
    min_dist = min(dp[n])
    return 1.0 - (min_dist / n)

def fuzzy_substring_distance(main_seq: List[Phoneme], sub_seq: List[Phoneme]) -> float:
    m_vals = np.array([hash(p.value) % 1000000 for p in main_seq], dtype=np.int32)
    s_vals = np.array([hash(p.value) % 1000000 for p in sub_seq], dtype=np.int32)
    return _fuzzy_substring_distance_numba(m_vals, s_vals)

def extract_diff_fragments(wrong: str, right: str) -> List[str]:
    # 简单版的差异提取实现
    s1 = split_mixed_label(normalize_text(wrong))
    s2 = split_mixed_label(normalize_text(right))
    matcher = SequenceMatcher(None, s1, s2)
    fragments = []
    for tag, i1, i2, j1, j2 in matcher.get_opcodes():
        if tag in ('replace', 'delete'):
            fragments.append(" ".join(s1[i1:i2]))
        if tag in ('replace', 'insert'):
            fragments.append(" ".join(s2[j1:j2]))
    return [f for f in fragments if f.strip()]

## 3. RAG 检索器与纠错器

In [40]:
class PhonemeIndex:
    def __init__(self):
        self.index = defaultdict(list)
    def add(self, hw: str, phonemes: List[Phoneme]):
        if not phonemes: return
        key = phonemes[0].value
        self.index[key].append((hw, [p.value for p in phonemes]))
    def get_candidates(self, input_phonemes: List[Phoneme]):
        seen_keys = {p.value for p in input_phonemes}
        candidates = []
        for k in seen_keys:
            candidates.extend(self.index.get(k, []))
        return candidates

class FastRAG:
    def __init__(self, threshold=0.6):
        self.threshold = threshold
        self.index = PhonemeIndex()
        self.hotwords_data = {} 
    def add_hotword(self, hw: str):
        phons = get_phoneme_info(hw)
        self.index.add(hw, phons)
        self.hotwords_data[hw] = phons
    def search(self, input_phonemes: List[Phoneme]):
        input_vals = np.array([hash(p.value) % 1000000 for p in input_phonemes], dtype=np.int32)
        candidates = self.index.get_candidates(input_phonemes)
        results = []
        for hw, hw_phone_vals in candidates:
            hw_np = np.array([hash(v) % 1000000 for v in hw_phone_vals], dtype=np.int32)
            dist = _fuzzy_substring_distance_numba(input_vals, hw_np)
            score = 1.0 - (dist / len(hw_phone_vals))
            if score >= self.threshold: results.append((hw, score))
        return sorted(results, key=lambda x: x[1], reverse=True)

class PhonemeCorrector:
    def __init__(self, threshold=0.7):
        self.rag = FastRAG(threshold=threshold - 0.1)
        self.threshold = threshold
    def update_hotwords(self, hotwords_text: str):
        """支持多行字符串格式的热词加载"""
        for line in hotwords_text.split('\n'):
            hw = line.strip()
            if hw and not hw.startswith('#'):
                self.rag.add_hotword(hw)
    def load_hotwords_file(self, file_path: str):
        path = Path(file_path)
        if path.exists():
            self.update_hotwords(path.read_text(encoding='utf-8'))
    def correct(self, text: str):
        input_phons = get_phoneme_info(text)
        candidates = self.rag.search(input_phons)
        input_processed = [p.info for p in input_phons]
        matches = []
        for hw, _ in candidates:
            hw_phons = self.rag.hotwords_data[hw]
            hw_compare = [p.info[:5] for p in hw_phons]
            target_len = len(hw_compare)
            for i in range(len(input_processed) - target_len + 1):
                sub = input_processed[i:i+target_len]
                if sub[0][1] != 'en' and sub[0][0] != hw_compare[0][0]: continue
                score = fast_substring_score(hw_compare, sub)
                if score >= self.threshold:
                    matches.append({'start': sub[0][5], 'end': sub[-1][6], 'hw': hw, 'score': score})
        matches.sort(key=lambda x: x['score'], reverse=True)
        result_text = list(text)
        used = [False] * len(text)
        final_matched = []
        for m in matches:
            if any(used[m['start']:m['end']]): continue
            for i in range(m['start'], m['end']): used[i] = True
            result_text[m['start']:m['end']] = list(m['hw'])
            final_matched.append((m['hw'], m['score']))
        return "".join(result_text), final_matched

## 4. Rectification (纠错历史检索)

In [41]:
class RectificationRAG:
    def __init__(self, threshold=0.5):
        self.records = []
        self.threshold = threshold
    def add_history(self, wrong: str, right: str):
        fragments = extract_diff_fragments(wrong, right)
        if not fragments: fragments = [wrong]
        record = {
            'wrong': wrong, 
            'right': right, 
            'fragment_phonemes': {f: get_phoneme_seq(f) for f in fragments}
        }
        self.records.append(record)
    def load_rectify_text(self, text: str):
        """支持 --- 分隔的项目纠错历史格式"""
        blocks = text.split('---')
        for block in blocks:
            lines = [l.strip() for l in block.strip().split('\n') if l.strip() and not l.strip().startswith('#')]
            if len(lines) >= 2:
                self.add_history(lines[0], lines[1])
    def load_rectify_file(self, file_path: str):
        path = Path(file_path)
        if path.exists():
            self.load_rectify_text(path.read_text(encoding='utf-8'))
    def search(self, text: str):
        input_phons = get_phoneme_seq(text)
        results = []
        for rec in self.records:
            best_score = 0.0
            for f, f_phons in rec['fragment_phonemes'].items():
                if not f_phons: continue
                score = 1.0 - (fuzzy_substring_distance(input_phons, f_phons) / len(f_phons))
                best_score = max(best_score, score)
            if best_score >= self.threshold: results.append((rec['wrong'], rec['right'], best_score))
        return sorted(results, key=lambda x: x[2], reverse=True)

## 5. 综合演示 (Demo)

In [42]:
hotwords_data = """
    Claude
    Bilibili
    Microsoft
    买当劳
    肯德基
    # 这是一个注释
    VsCode
"""
    
rectify_data = """
# 纠错历史演示
把那个锯子给我
把那个句子给我
---
cloud code is good
Claude Code is good
---
今天天其不错
今天天气不错
"""



In [None]:
# 初始化纠错器和检索器
corrector = PhonemeCorrector(threshold=0.8)
rectifier = RectificationRAG(threshold=0.5)

# 从字符串加载热词
corrector.update_hotwords(hotwords_data)
rectifier.load_rectify_text(rectify_data)

# 从文本文件加载热词
# corrector.load_hotwords_file("hot.txt")
# rectifier.load_rectify_file("hot-rectify.txt")

In [44]:
test_cases_text = """
我想去吃买当劳和肯得鸡
Hello klaude
喜欢刷Bili Bili
请把那个锯子发给我一下
今天天及真的很好
I think claud code is very good
使用vs code编写代码
"""

cases = [l.strip() for l in test_cases_text.strip().split('\n') if l.strip()]


In [45]:


print("\n" + "="*40)
print("【综合纠错与检索演示】")
print("="*40)

for i, t in enumerate(cases):
    print(f"\nCase {i+1}: '{t}'")
    
    # 步骤 1: 拼音纠错
    res, matched = corrector.correct(t)
    print(f"  [纠错后] {res}")
    if matched: print(f"  [匹配热词] {matched}")
    
    # 步骤 2: RAG 历史检索
    rag_results = rectifier.search(t)
    if rag_results:
        print(f"  [RAG 历史匹配]")
        for w, r, s in rag_results:
            print(f"    - '{w}' => '{r}' (相似度: {s:.3f})")



【综合纠错与检索演示】

Case 1: '我想去吃买当劳和肯得鸡'
  [纠错后] 我想去吃买当劳和肯德基
  [匹配热词] [('肯德基', 1.0), ('买当劳', 1.0)]
  [RAG 历史匹配]
    - '把那个锯子给我' => '把那个句子给我' (相似度: 0.667)
    - '今天天其不错' => '今天天气不错' (相似度: 0.667)

Case 2: 'Hello klaude'
  [纠错后] Hello klaude
  [RAG 历史匹配]
    - 'cloud code is good' => 'Claude Code is good' (相似度: 0.833)

Case 3: '喜欢刷Bili Bili'
  [纠错后] 喜欢刷Bilibili
  [匹配热词] [('Bilibili', 1.0)]

Case 4: '请把那个锯子发给我一下'
  [纠错后] 请把那个锯子发给我一下
  [RAG 历史匹配]
    - '把那个锯子给我' => '把那个句子给我' (相似度: 1.000)

Case 5: '今天天及真的很好'
  [纠错后] 今天天及真的很好
  [RAG 历史匹配]
    - '今天天其不错' => '今天天气不错' (相似度: 0.667)

Case 6: 'I think claud code is very good'
  [纠错后] I Claude code is very good
  [匹配热词] [('Claude', 0.8333333333333334)]
  [RAG 历史匹配]
    - 'cloud code is good' => 'Claude Code is good' (相似度: 0.833)

Case 7: '使用vs code编写代码'
  [纠错后] 使用VsCode
  [匹配热词] [('VsCode', 1.0)]
  [RAG 历史匹配]
    - 'cloud code is good' => 'Claude Code is good' (相似度: 0.600)
