# 构建语音识别系统 - 解码与评测

## 如何从模型输出到识别文本

模型输出的结果为 (b, len, vocab_size)

### greedy search

每一步选取预测概率最大的词

In [1]:
import torch
def greedy_search(ctc_probs: torch.tensor, encoder_out_lens: torch.tensor):
    batch_size, maxlen = ctc_probs.size()[:2]
    topk_prob, topk_index = ctc_probs.topk(1, dim=2)
    topk_index = topk_index.view(batch_size, maxlen)
    encoder_out_lens = encoder_out_lens.view(-1).tolist()

    hyps = []

    for i in range(len(encoder_out_lens)):
        hyp = topk_index[i, :encoder_out_lens[i]].tolist()
        hyps.append(hyp)

    return hyps

In [2]:
tensordict = torch.load("./example2.pt")

pre = tensordict["pre"].to("cpu")
lens = tensordict["lens"].to("cpu")

print(lens)
res = greedy_search(pre, lens)

tensor([46, 51, 44, 44, 41, 49, 48, 48, 74, 93, 44, 49, 50, 51, 58, 50])


In [3]:
from tokenizer.tokenizer import Tokenizer
tokenizer = Tokenizer("./tokenizer/vocab.txt")

print(res[1])

print(tokenizer.decode(res[1], ignore_special=False))

[2, 2, 5, 323, 5, 296, 5, 5, 75, 5, 243, 278, 278, 5, 394, 5, 5, 5, 51, 5, 5, 247, 5, 5, 360, 5, 364, 5, 5, 57, 5, 5, 5, 238, 122, 5, 65, 5, 5, 167, 271, 5, 5, 142, 5, 68, 5, 5, 5, 3, 3]
['<sos>', '<sos>', '<blk>', 'liao', '<blk>', 'dian', '<blk>', '<blk>', 'pao', '<blk>', 'si', 'ji', 'ji', '<blk>', 'kong', '<blk>', '<blk>', '<blk>', 'cuan', '<blk>', '<blk>', 'hen', '<blk>', '<blk>', 'zong', '<blk>', 'shua', '<blk>', '<blk>', 'qie', '<blk>', '<blk>', '<blk>', 'nian', 'tian', '<blk>', 'wai', '<blk>', '<blk>', 'zhou', 'du', '<blk>', '<blk>', 'jiao', '<blk>', 'die', '<blk>', '<blk>', '<blk>', '<eos>', '<eos>']


**TODO：** 请大家根据CTC的解码思路，将模型的输出进行解码，移除重复字符和blank(上面的版本没有移除重复字符和blank)。

In [4]:
def ctc_decode(hyps, blank_id):
    """
    实现CTC解码，移除重复字符和blank
    
    Args:
        hyps: 模型输出的预测序列列表
        blank_id: blank标记的ID
        
    Returns:
        解码后的序列列表
    """
    decoded_hyps = []
    
    for hyp in hyps:
        decoded = []
        prev = -1
        
        for token_id in hyp:
            if token_id == blank_id:
                continue
            
            if token_id != prev:
                decoded.append(token_id)
            
            prev = token_id
        
        decoded_hyps.append(decoded)
    
    return decoded_hyps

In [5]:
blank_id = tokenizer.blk_id()
decoded_res = ctc_decode(res, blank_id)

print("原始预测：")
print(res[1])
print(tokenizer.decode(res[1], ignore_special=False))

print("\nCTC解码：")
print(decoded_res[1])
print(tokenizer.decode(decoded_res[1], ignore_special=False))

原始预测：
[2, 2, 5, 323, 5, 296, 5, 5, 75, 5, 243, 278, 278, 5, 394, 5, 5, 5, 51, 5, 5, 247, 5, 5, 360, 5, 364, 5, 5, 57, 5, 5, 5, 238, 122, 5, 65, 5, 5, 167, 271, 5, 5, 142, 5, 68, 5, 5, 5, 3, 3]
['<sos>', '<sos>', '<blk>', 'liao', '<blk>', 'dian', '<blk>', '<blk>', 'pao', '<blk>', 'si', 'ji', 'ji', '<blk>', 'kong', '<blk>', '<blk>', '<blk>', 'cuan', '<blk>', '<blk>', 'hen', '<blk>', '<blk>', 'zong', '<blk>', 'shua', '<blk>', '<blk>', 'qie', '<blk>', '<blk>', '<blk>', 'nian', 'tian', '<blk>', 'wai', '<blk>', '<blk>', 'zhou', 'du', '<blk>', '<blk>', 'jiao', '<blk>', 'die', '<blk>', '<blk>', '<blk>', '<eos>', '<eos>']

CTC解码后：
[2, 323, 296, 75, 243, 278, 394, 51, 247, 360, 364, 57, 238, 122, 65, 167, 271, 142, 68, 3]
['<sos>', 'liao', 'dian', 'pao', 'si', 'ji', 'kong', 'cuan', 'hen', 'zong', 'shua', 'qie', 'nian', 'tian', 'wai', 'zhou', 'du', 'jiao', 'die', '<eos>']


## 评测识别结果

这里我们采用字错率(CER, character error rate)来评测ASR系统的性能，计算公式如下:

$$CER = \frac{S+D+I}{N}$$

pre 代表模型预测， gt 代表正确识别结果。与最小编辑距离一致，将pre转化成gt，其中，S代表将 pre 转化成 gt 需要替换的数量，D 代表将 pre转化成 gt 需要删除的数量，I 代表将 pre 转化成 gt 需要插入的数量，N 代表gt 的长度。


**TODO：** 根据最小编辑距离求出 S，D，I，N ，完成ASR的CER指标评测

In [10]:
from data.dataloader import get_dataloader
from model.model import CTCModel
import torch
from utils.utils import to_device
from tqdm import tqdm

dev_dataloader = get_dataloader("./dataset/split/dev/wav.scp", "./dataset/split/dev/pinyin", 32, tokenizer, shuffle=False)

In [11]:
def edit_distance(ref, hyp):
    """
    计算两个序列之间的编辑距离，并返回替换、删除、插入的具体数量
    
    Args:
        ref: 参考序列（正确文本）
        hyp: 预测序列（识别结果）
        
    Returns:
        距离值，替换数，删除数，插入数
    """
    n = len(ref)
    m = len(hyp)
    dp = [[0 for _ in range(m+1)] for _ in range(n+1)]
    
    for i in range(n+1):
        dp[i][0] = i
    for j in range(m+1):
        dp[0][j] = j

    for i in range(1, n+1):
        for j in range(1, m+1):
            if ref[i-1] == hyp[j-1]:
                dp[i][j] = dp[i-1][j-1]
            else:
                dp[i][j] = min(dp[i-1][j-1] + 1,   # 替换
                               dp[i-1][j] + 1,     # 删除
                               dp[i][j-1] + 1)     # 插入

    i, j = n, m
    s_count = d_count = i_count = 0
    
    while i > 0 or j > 0:
        if i > 0 and j > 0 and ref[i-1] == hyp[j-1]:
            i -= 1
            j -= 1
        elif i > 0 and j > 0 and dp[i][j] == dp[i-1][j-1] + 1:
            # 替换
            s_count += 1
            i -= 1
            j -= 1
        elif i > 0 and dp[i][j] == dp[i-1][j] + 1:
            # 删除
            d_count += 1
            i -= 1
        else:
            # 插入
            i_count += 1
            j -= 1
    
    return dp[n][m], s_count, d_count, i_count

def calculate_cer(references, hypotheses):
    """
    计算字错率(CER)
    
    Args:
        references: 参考文本列表
        hypotheses: 识别结果列表
        
    Returns:
        CER值，以及替换、删除、插入的总数和参考文本总长度
    """
    total_distance = 0
    total_subs = 0
    total_dels = 0
    total_ins = 0
    total_ref_length = 0
    
    for ref, hyp in zip(references, hypotheses):
        distance, subs, dels, ins = edit_distance(ref, hyp)
        total_distance += distance
        total_subs += subs
        total_dels += dels
        total_ins += ins
        total_ref_length += len(ref)

    cer = (total_subs + total_dels + total_ins) / total_ref_length if total_ref_length > 0 else 1.0
    
    return cer, total_subs, total_dels, total_ins, total_ref_length

In [13]:
def evaluate_model(dataloader, model, tokenizer, device='cpu'):
    model.eval()
    all_refs = []
    all_hyps = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="评估中"):
            batch = to_device(batch, device)
            audios = batch['audios']
            audio_lens = batch['audio_lens']
            texts = batch['texts']
            text_lens = batch['text_lens']
            
            encoder_out, _, encoder_out_lens = model(audios, audio_lens, texts, text_lens)
            hyps = greedy_search(encoder_out, encoder_out_lens)
            decoded_hyps = ctc_decode(hyps, tokenizer.blk_id())
            
            for i in range(len(text_lens)):
                ref = texts[i, :text_lens[i]].tolist()
                all_refs.append(ref)
                all_hyps.append(decoded_hyps[i])
    
    # 计算CER
    cer, subs, dels, ins, ref_len = calculate_cer(all_refs, all_hyps)
    
    print(f"评测结果:")
    print(f"替换(S): {subs}, 删除(D): {dels}, 插入(I): {ins}, 参考长度(N): {ref_len}")
    print(f"CER: {cer:.4f} ({subs+dels+ins}/{ref_len})")
    
    print("\n样本对比:")
    for i in range(min(5, len(all_refs))):
        print(f"参考: {tokenizer.decode(all_refs[i])}")
        print(f"预测: {tokenizer.decode(all_hyps[i])}")
        print()
    
    return cer

device = "cuda" if torch.cuda.is_available() else "cpu"
model = CTCModel(80, 256, tokenizer.size(), tokenizer.blk_id()).to(device)

checkpoint = torch.load("./model.pt", map_location=device)
model.load_state_dict(checkpoint['model'])

cer = evaluate_model(dev_dataloader, model, tokenizer, device)

评估中: 100%|██████████| 32/32 [00:02<00:00, 12.13it/s]


评测结果:
替换(S): 1672, 删除(D): 445, 插入(I): 76, 参考长度(N): 19252
CER: 0.1139 (2193/19252)

样本对比:
参考: ['yi', 'ge', 'nan', 'ren', 'tui', 'ran', 'de', 'zuo', 'zai', 'pang', 'bian', 'mu', 'guang', 'dai', 'zhi']
预测: ['yi', 'gen', 'nan', 'ren', 'tui', 'ran', 'de', 'zuo', 'zai', 'pang', 'bian', 'mu', 'guan', 'dai', 'zhi']

参考: ['xi', 'huan', 'ba', 'li', 'ao', 'de', 'shu', 'cha', 'zai', 'niu', 'zai', 'ku', 'de', 'qian', 'mian']
预测: ['xi', 'huan', 'ba', 'li', 'de', 'shu', 'cha', 'zai', 'niu', 'zai', 'ku', 'de', 'qian', 'mian']

参考: ['zha', 'yan', 'yi', 'kan', 'xiang', 'qi', 'de', 'shi', 'guang', 'zhou', 'de', 'qu', 'hao', 'ling', 'e', 'er', 'ling']
预测: ['zhan', 'yan', 'yi', 'kan', 'xiang', 'qi', 'de', 'shi', 'guang', 'zhou', 'de', 'xu', 'hao', 'liu', 'e', 'er', 'ling']

参考: ['ci', 'qian', 'qing', 'hua', 'zi', 'guang', 'jiu', 'ceng', 'zao', 'yu', 'guo', 'tong', 'yang', 'de', 'wei', 'ji']
预测: ['ci', 'qian', 'qing', 'hua', 'zi', 'guang', 'jiu', 'cun', 'zao', 'yu', 'guo', 'tong', 'yao', 'de', 'wei', 'ji']
