In [1]:
from fast_ctc_decode import beam_search, viterbi_search
import numpy as np

alphabet = "NACGT"
posteriors = np.random.rand(100, len(alphabet)).astype(np.float32)

In [2]:
posteriors

array([[6.51403725e-01, 1.61661133e-01, 7.16122389e-01, 6.76492035e-01,
        7.12646186e-01],
       [2.22872019e-01, 7.99192190e-01, 3.48819584e-01, 4.67524767e-01,
        5.25442898e-01],
       [9.26956636e-05, 8.37850153e-01, 4.44912702e-01, 4.39322054e-01,
        5.44991255e-01],
       [5.08055687e-01, 9.15716112e-01, 8.36381197e-01, 6.45499155e-02,
        1.87459681e-02],
       [6.58738732e-01, 3.08092177e-01, 9.05939281e-01, 2.87789077e-01,
        7.01891243e-01],
       [6.45718258e-03, 1.91067383e-01, 3.46876889e-01, 1.67837799e-01,
        1.50810614e-01],
       [3.03525120e-01, 5.69677413e-01, 1.59781054e-01, 9.50392783e-01,
        3.75818610e-02],
       [4.19730663e-01, 1.86281219e-01, 5.05963445e-01, 4.80880558e-01,
        4.10791814e-01],
       [5.03761768e-01, 4.12014693e-01, 1.53879330e-01, 6.86551988e-01,
        5.07341902e-05],
       [3.04291211e-02, 8.89865234e-02, 1.66581005e-01, 6.24630041e-02,
        7.03654826e-01],
       [4.64872211e-01, 5.5318

In [3]:
seq, path = viterbi_search(posteriors, alphabet)

In [4]:
print (seq)
print (path)

CACGCGTTGCATGTTACTGTACATCGCCGTCAGATACGAGATCTCTCAAACAGCTCTCGTAGAGAC
[0, 1, 4, 6, 7, 8, 9, 13, 14, 15, 17, 18, 19, 21, 23, 24, 26, 27, 28, 29, 31, 33, 35, 36, 38, 39, 40, 42, 44, 45, 47, 49, 50, 51, 52, 54, 55, 56, 58, 60, 61, 62, 63, 65, 67, 68, 70, 71, 73, 75, 76, 77, 79, 80, 82, 83, 85, 86, 87, 88, 89, 90, 92, 94, 95, 96]


In [5]:
seq, path = beam_search(posteriors, alphabet, beam_size=5, beam_cut_threshold=0.1)
seq

'CACGTGTCTAGATAGATCTCTCACTAGAGATCTAGACTATGAGAGATC'

In [6]:
import numpy as np

def viterbi_alignment(ctc_probs, text, blank=0):
    B, T, N = ctc_probs.shape  # Batch size, time steps, vocab size
    U = text.shape[1]  # Target sequence length
    
    alignments = []
    for b in range(B):  # Process each batch individually
        # Step 1: Extend target sequence with blanks
        ext_text = [blank]
        for token in text[b]:
            ext_text.append(token)
            ext_text.append(blank)
        ext_U = len(ext_text)

        # Step 2: Initialize DP table
        dp = np.full((T, ext_U), -np.inf)
        dp[0][0] = np.log(ctc_probs[b, 0, blank])
        dp[0][1] = np.log(ctc_probs[b, 0, text[b, 0]])

        # Step 3: Fill DP table
        for t in range(1, T):
            for s in range(ext_U):
                stay = dp[t - 1][s]
                move = dp[t - 1][s - 1] if s > 0 else -np.inf
                skip = dp[t - 1][s - 2] if s > 1 and ext_text[s] != blank and ext_text[s] != ext_text[s - 2] else -np.inf
                
                dp[t][s] = np.log(ctc_probs[b, t, ext_text[s]]) + max(stay, move, skip)

        # Step 4: Backtrace to find alignment
        alignment = []
        s = np.argmax(dp[T - 1])  # Start from the best end state
        for t in range(T - 1, -1, -1):
            alignment.append(ext_text[s])
            if s > 0 and dp[t][s] == dp[t - 1][s - 1] + np.log(ctc_probs[b, t, ext_text[s]]):
                s -= 1
            elif s > 1 and dp[t][s] == dp[t - 1][s - 2] + np.log(ctc_probs[b, t, ext_text[s]]):
                s -= 2
        alignments.append(alignment[::-1])  # Reverse alignment

    return alignments

In [7]:
import torch
import torch.nn.functional as F

def viterbi_alignment_torch(ctc_probs, text, blank=0):
    """
    Viterbi alignment for CTC probabilities using PyTorch.

    Args:
        ctc_probs (torch.Tensor): Shape [B, T, N], CTC probabilities (already log-softmaxed).
        text (torch.Tensor): Shape [B, U], target sequences.
        blank (int): Index of the blank token in the vocabulary.

    Returns:
        List[List[int]]: A list of alignment paths for each batch.
    """
    B, T, N = ctc_probs.size()  # Batch size, time steps, vocab size
    U = text.size(1)  # Target sequence length

    alignments = []

    for b in range(B):  # Process each batch individually
        # Step 1: Extend target sequence with blanks
        ext_text = [blank]
        for token in text[b]:
            ext_text.append(token.item())
            ext_text.append(blank)
        ext_U = len(ext_text)

        # Step 2: Initialize DP table
        dp = torch.full((T, ext_U), -float('inf'), device=ctc_probs.device)  # DP table
        dp[0, 0] = ctc_probs[b, 0, blank]  # Start with blank
        if ext_U > 1:
            dp[0, 1] = ctc_probs[b, 0, text[b, 0]]  # First token

        # Step 3: Fill DP table
        for t in range(1, T):
            for s in range(ext_U):
                # Calculate scores for stay, move, and skip transitions
                stay = dp[t - 1, s]
                move = dp[t - 1, s - 1] if s > 0 else -float('inf')
                skip = dp[t - 1, s - 2] if s > 1 and ext_text[s] != blank and ext_text[s] != ext_text[s - 2] else -float('inf')

                dp[t, s] = ctc_probs[b, t, ext_text[s]] + torch.logsumexp(torch.tensor([stay, move, skip], device=ctc_probs.device), dim=0)

        # Step 4: Backtrace to find alignment
        alignment = []
        s = torch.argmax(dp[T - 1]).item()  # Start from the best end state
        for t in range(T - 1, -1, -1):
            alignment.append(ext_text[s])
            if s > 0 and dp[t, s] == dp[t - 1, s - 1] + ctc_probs[b, t, ext_text[s]]:
                s -= 1
            elif s > 1 and dp[t, s] == dp[t - 1, s - 2] + ctc_probs[b, t, ext_text[s]]:
                s -= 2
        alignments.append(alignment[::-1])  # Reverse alignment

    return alignments

In [8]:
text = np.array([[1, 2], [3, 4]], dtype=np.int32)
ctc_probs = np.random.rand(2, 10, 10).astype(np.float32)
text_tensor = torch.tensor(text)
ctc_probs_tensor = torch.tensor(ctc_probs)
print(viterbi_alignment_torch(ctc_probs_tensor, text_tensor))
print(viterbi_alignment(ctc_probs, text))

[[1, 2, 0, 0, 0, 0, 0, 0, 0, 0], [3, 4, 0, 0, 0, 0, 0, 0, 0, 0]]
[[np.int32(1), 0, 0, 0, np.int32(2), np.int32(2), np.int32(2), np.int32(2), np.int32(2), 0], [0, 0, 0, 0, np.int32(3), np.int32(3), np.int32(3), 0, np.int32(4), np.int32(4)]]


In [9]:
viterbi_alignment(ctc_probs, text)

[[np.int32(1),
  0,
  0,
  0,
  np.int32(2),
  np.int32(2),
  np.int32(2),
  np.int32(2),
  np.int32(2),
  0],
 [0,
  0,
  0,
  0,
  np.int32(3),
  np.int32(3),
  np.int32(3),
  0,
  np.int32(4),
  np.int32(4)]]

In [10]:
import torch
import torch.nn.functional as F

def ctc_viterbi_align(ctc_probs, texts, blank_id=0, pad_id=-1):
    """
    ctc_probs: Tensor [B, T, N]，CTC 后验概率（对数概率）。
    texts: Tensor [B, U]，经过填充的目标序列。
    返回：
    alignments: List，每个元素是长度为 T 的对齐路径（列表形式）。
    """
    B, T, N = ctc_probs.size()
    _, U = texts.size()
    alignments = []

    # 对于批次中的每个样本
    for b in range(B):
        # 1. 获取实际的目标序列，移除填充
        text = texts[b]
        text = text[text != pad_id]  # 移除填充
        text = text.cpu().numpy().tolist()
        
        # 2. 扩展目标序列，在每个标签之间和首尾插入 blank
        extended_seq = []
        for token in text:
            extended_seq.append(blank_id)
            extended_seq.append(token)
        extended_seq.append(blank_id)
        # extended_seq: [blank, l1, blank, l2, blank, ..., ln, blank]
        L = len(extended_seq)
        
        # 3. 获取该样本的 CTC 概率（取对数以避免下溢）
        log_probs = ctc_probs[b].log()  # [T, N]
        log_probs = log_probs.cpu().numpy()
        
        # 4. 初始化动态规划矩阵
        dp = -float('inf') * np.ones((T, L), dtype=np.float32)  # [T, L]
        ptr = -np.ones((T, L), dtype=int)  # 用于回溯路径

        # 初始化第一个时间步
        dp[0, 0] = log_probs[0, blank_id]  # blank
        if L > 1:
            dp[0, 1] = log_probs[0, extended_seq[1]]  # 第一个标签

        # 5. 动态规划递推
        for t in range(1, T):
            for s in range(L):
                current_label = extended_seq[s]
                prob = log_probs[t, current_label]

                # 来自 s
                candidates = [dp[t-1, s]]
                # 来自 s-1
                if s > 0:
                    candidates.append(dp[t-1, s-1])
                # 来自 s-2（避免重复标签）
                if s > 1 and extended_seq[s] != blank_id and extended_seq[s] != extended_seq[s-2]:
                    candidates.append(dp[t-1, s-2])
                
                max_prev = max(candidates)
                dp[t, s] = max_prev + prob

                # 记录回溯指针
                if max_prev == dp[t-1, s]:
                    ptr[t, s] = s
                elif s > 0 and max_prev == dp[t-1, s-1]:
                    ptr[t, s] = s - 1
                elif s > 1 and max_prev == dp[t-1, s-2]:
                    ptr[t, s] = s - 2

        # 6. 回溯找到最优路径
        # 从最后一个时间步和状态开始
        s = np.argmax(dp[T-1, :])
        alignment = []
        for t in range(T-1, -1, -1):
            alignment.append(extended_seq[s])
            s = ptr[t, s]
            if s == -1:
                break
        alignment = alignment[::-1]  # 逆序
        # 将对齐结果扩展到长度为 T
        if len(alignment) < T:
            alignment = [blank_id] * (T - len(alignment)) + alignment

        alignments.append(alignment)

    return alignments

In [11]:
text = np.array([[3, 4, 5, 6]], dtype=np.int32)
ctc_probs = np.random.rand(1, 10, 10).astype(np.float32)
text_tensor = torch.tensor(text)
ctc_probs_tensor = torch.tensor(ctc_probs)
print(viterbi_alignment_torch(ctc_probs_tensor, text_tensor))
print( [[int(x) for x in sublist] for sublist in viterbi_alignment(ctc_probs, text)])
print (ctc_viterbi_align(ctc_probs_tensor, text_tensor, blank_id=0, pad_id=-1))


[[3, 4, 5, 6, 6, 6, 6, 6, 6, 6]]
[[3, 3, 3, 0, 4, 5, 0, 6, 6, 6]]
[[3, 3, 3, 0, 4, 5, 0, 6, 6, 6]]


In [12]:
def ctc_viterbi_align(ctc_prob, text_b, blank_id=0, pad_id=-1):
    T = ctc_prob.shape[0]  # 时间步数
    N = len(text_b)  # 目标序列的长度
    dp = np.zeros((T, N + 1))  # dp[t][i] -> 第 t 个时间步，第 i 个目标标签的最大概率
    path = np.zeros((T, N + 1), dtype=int)  # 用于回溯路径
    
    # 初始状态
    dp[0, 0] = ctc_prob[0, blank_id]  # 初始为空白标签的概率
    for i in range(1, N + 1):
        dp[0, i] = ctc_prob[0, text_b[i - 1]]  # 初始化第一个目标标签
    
    # 递归动态规划
    for t in range(1, T):
        for i in range(N + 1):
            if i == 0:
                dp[t, i] = dp[t - 1, 0] * ctc_prob[t, blank_id]  # 继续选择空白标签
                path[t, i] = 0
            else:
                # 路径1：选择前一个标签对应的对齐路径
                score1 = dp[t - 1, i - 1] * ctc_prob[t, text_b[i - 1]]
                # 路径2：选择空白标签
                score2 = dp[t - 1, i] * ctc_prob[t, blank_id]
                if score1 > score2:
                    dp[t, i] = score1
                    path[t, i] = i - 1
                else:
                    dp[t, i] = score2
                    path[t, i] = i
    
    # 回溯路径
    best_path = []
    t = T - 1
    i = N
    while t >= 0 and i >= 0:
        best_path.append(text_b[i - 1] if i > 0 else blank_id)
        i = path[t, i]
        t -= 1
    
    best_path = best_path[::-1]  # 反转路径
    return best_path

In [13]:
ctc_viterbi_align(ctc_probs[0], text[0])

[np.int32(3),
 np.int32(3),
 np.int32(4),
 np.int32(4),
 np.int32(4),
 np.int32(5),
 np.int32(5),
 np.int32(5),
 np.int32(6),
 np.int32(6)]

In [14]:
def insert_blank(label, blank_id=0):
    """Insert blank token between every two label token."""
    label = np.expand_dims(label, 1)
    blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id
    label = np.concatenate([blanks, label], axis=1)
    label = label.reshape(-1)
    label = np.append(label, label[0])
    return label


def force_align(ctc_probs: torch.Tensor, y: torch.Tensor, blank_id=0) -> list:
    """ctc forced alignment.

    Args:
        torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D)
        torch.Tensor y: id sequence tensor 1d tensor (L)
        int blank_id: blank symbol index
    Returns:
        torch.Tensor: alignment result
    """
    ctc_probs = ctc_probs.cpu()
    y = y.cpu()
    y_insert_blank = insert_blank(y, blank_id)

    log_alpha = torch.zeros((ctc_probs.size(0), len(y_insert_blank)))
    log_alpha = log_alpha - float('inf')  # log of zero
    state_path = torch.zeros((ctc_probs.size(0), len(y_insert_blank)),
                             dtype=torch.int16) - 1  # state path

    # init start state
    log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]]
    log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]]

    for t in range(1, ctc_probs.size(0)):
        for s in range(len(y_insert_blank)):
            if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[
                    s] == y_insert_blank[s - 2]:
                candidates = torch.tensor(
                    [log_alpha[t - 1, s], log_alpha[t - 1, s - 1]])
                prev_state = [s, s - 1]
            else:
                candidates = torch.tensor([
                    log_alpha[t - 1, s],
                    log_alpha[t - 1, s - 1],
                    log_alpha[t - 1, s - 2],
                ])
                prev_state = [s, s - 1, s - 2]
            log_alpha[
                t, s] = torch.max(candidates) + ctc_probs[t][y_insert_blank[s]]
            state_path[t, s] = prev_state[torch.argmax(candidates)]

    state_seq = -1 * torch.ones((ctc_probs.size(0), 1), dtype=torch.int16)

    candidates = torch.tensor([
        log_alpha[-1, len(y_insert_blank) - 1],
        log_alpha[-1, len(y_insert_blank) - 2]
    ])
    final_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2]
    state_seq[-1] = final_state[torch.argmax(candidates)]
    for t in range(ctc_probs.size(0) - 2, -1, -1):
        state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]]

    output_alignment = []
    for t in range(0, ctc_probs.size(0)):
        output_alignment.append(y_insert_blank[state_seq[t, 0]])

    return output_alignment

In [25]:
text = np.array([[3, 4, 5, 6, 9]], dtype=np.int32)
ctc_probs = np.random.rand(1, 20, 10).astype(np.float32)
ctc_probs_tensor = torch.softmax(torch.tensor(ctc_probs), dim=-1)
ctc_probs = ctc_probs_tensor.numpy()
text_tensor = torch.tensor(text)
ctc_probs_tensor = torch.tensor(ctc_probs)
print(viterbi_alignment_torch(ctc_probs_tensor, text_tensor))
print( [[int(x) for x in sublist] for sublist in viterbi_alignment(ctc_probs, text)])
print ([int(x) for x in ctc_viterbi_align(ctc_probs[0], text[0])])
print ([int(x) for x in force_align(ctc_probs_tensor[0], text_tensor[0])])

[[3, 4, 5, 6, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
[[3, 3, 3, 0, 4, 0, 0, 0, 0, 5, 5, 5, 5, 5, 0, 6, 6, 6, 9, 9]]
[0, 0, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 6, 6, 6, 9, 9]
[3, 3, 3, 0, 4, 0, 0, 0, 0, 5, 5, 5, 5, 5, 0, 6, 6, 6, 9, 9]


In [26]:
ctc_prob = torch.load("/ssd/zhuang/code/FunASR/funasr/models/paraformerV2/ctc_prob")
text_b = torch.load("/ssd/zhuang/code/FunASR/funasr/models/paraformerV2/text_b")

  ctc_prob = torch.load("/ssd/zhuang/code/FunASR/funasr/models/paraformerV2/ctc_prob")
  text_b = torch.load("/ssd/zhuang/code/FunASR/funasr/models/paraformerV2/text_b")


In [27]:
ctc_prob

tensor([[-8.0503, -8.2395, -9.9503,  ..., -7.9820, -7.8188, -8.6158],
        [-7.4608, -9.0117, -9.1419,  ..., -7.0696, -7.5999, -8.3954],
        [-7.8047, -8.6365, -9.6139,  ..., -7.0698, -7.8144, -8.3241],
        ...,
        [-8.3635, -9.0758, -9.6072,  ..., -8.0879, -9.0771, -8.6355],
        [-8.2764, -8.7710, -9.4807,  ..., -8.2481, -8.0597, -8.5509],
        [-7.8216, -9.5307, -9.6108,  ..., -8.0596, -8.9597, -8.7208]])

In [28]:
print ([int(x) for x in force_align(ctc_prob, text_b)])

[1127, 0, 0, 1749, 0, 433, 37, 0, 0, 3066, 0, 0, 0, 814, 0, 0, 0, 0, 1380, 37, 37, 4068, 4068, 4068, 4068, 4068, 4068, 4068, 4068, 4068, 4068, 4068, 4068, 4068, 4068, 4068, 4068, 4068, 4068, 4068, 4068, 4068, 4068, 4068, 4068, 4068, 0, 1015, 1276, 2554, 597, 2342, 2342, 2342, 2342, 2342, 2342, 2342, 2342, 2342, 2342, 2342, 2342, 2342, 2342, 2342, 2342, 2342, 0, 0, 1127, 1749, 433, 37, 37, 3066, 3066, 814, 0, 0, 0, 0, 0, 1380, 1380, 1380, 1380, 0, 0, 37, 4068, 4068, 4068, 4068, 4068, 4068, 4068, 4068, 4068, 1015, 1276, 2554, 597, 2342, 2342, 2342, 2342, 2342, 2342, 2342, 2342, 0, 0]


In [35]:
import torch

def average_repeats(ctc_prob, alignment):
    """
    Averages the repeated frames based on alignment without merging distinct occurrences of the same token.

    Args:
        ctc_prob (torch.Tensor): Tensor of shape [T, VocabSize + 1] representing frame-wise CTC posteriors.
        alignment (torch.Tensor): Tensor of shape [T,] representing the target alignment from Viterbi algorithm.

    Returns:
        torch.Tensor: Compressed CTC posterior with repeated frames averaged.
    """
    unique_probs = []
    current_sum = ctc_prob[0]
    current_count = 1

    for t in range(1, alignment.size(0)):
        token = alignment[t].item()
        prev_token = alignment[t - 1].item()
        prob = ctc_prob[t]

        if token == prev_token:
            current_sum += prob
            current_count += 1
        else:
            unique_probs.append(current_sum / current_count)
            current_sum = prob
            current_count = 1

    # Append the last averaged probability
    unique_probs.append(current_sum / current_count)

    return torch.stack(unique_probs)

def remove_blanks(ctc_prob, alignment):
    """
    Removes blank tokens from the alignment and returns the corresponding CTC probabilities.

    Args:
        ctc_prob (torch.Tensor): Tensor of shape [U', VocabSize + 1] representing compressed CTC posteriors.
        alignment (torch.Tensor): Tensor of shape [T,] representing the target alignment from Viterbi algorithm.

    Returns:
        torch.Tensor: Compressed CTC posterior with blanks removed.
    """
    non_blank_probs = []
    non_blank_indices = []
    idx = 0

    for t in range(alignment.size(0)):
        token = alignment[t].item()
        if token != 0:  # 0 is assumed to be the blank token
            non_blank_indices.append(idx)
        if t == alignment.size(0) - 1 or alignment[t] != alignment[t + 1]:
            idx += 1

    for idx in non_blank_indices:
        non_blank_probs.append(ctc_prob[idx])

    return torch.stack(non_blank_probs)

# Example usage
ctc_prob = torch.randn(15, 100)  # Assume 15 frames, 100 vocabulary size + 1 blank
alignment = torch.tensor([3797, 11, 3727, 3143, 72, 71, 4009, 4009, 4009, 1150, 2554, 3015, 1338, 339, 1452])

compressed_ctc_prob = average_repeats(ctc_prob, alignment)
non_blank_ctc_prob = remove_blanks(compressed_ctc_prob, alignment)
print(non_blank_ctc_prob.size())

torch.Size([15, 100])
