In [151]:
import torch
import numpy as np

# reduce and combine hidden states

:::{note}
在音素识别时，一个音素可能会跨多个时间帧，因此需要合并连续且相同的音素，构造为一个音素，从而预测整句话的音素，并与ground truth进行比较。
:::

那使用音素特征识别时，那么也许合并连续且相同的音素特征帧。设音素特征为`(B, T, 768)`，那么合并后的特征为`(B, T', 768)`。

In [95]:
def segment_means(tensor, segment_sizes):
    assert tensor.size(0) == segment_sizes.sum(), "Sum of segment sizes must equal the tensor's first dimension size."

    # Create an indices tensor that maps each row in the tensor to its corresponding segment
    indices = torch.repeat_interleave(torch.arange(len(segment_sizes), device=tensor.device), segment_sizes)

    # Create a tensor to hold the sum of each segment
    segment_sums = torch.zeros(len(segment_sizes), tensor.size(1), device=tensor.device)

    # Scatter and sum the inputs into the segment_sums tensor
    segment_sums.scatter_add_(0, indices.unsqueeze(1).expand(-1, tensor.size(1)), tensor)

    # Calculate the mean of each segment
    segment_means = segment_sums / segment_sizes.unsqueeze(1)

    return segment_means

In [139]:
def reduce_feat(hidden_states, audio_lengths, phoneme_ids):
    reduced_hidden_states = []
    reduced_audio_lengths = []
    reduced_phoneme_ids = []

    phoneme_counts = []
    print(hidden_states.shape, audio_lengths.shape, phoneme_ids.shape)

    for i in range(len(audio_lengths)):
        _phoneme_ids = phoneme_ids[i, : audio_lengths[i]]
        _h = hidden_states[i, : audio_lengths[i]]

        unique_ids, _phoneme_counts = _phoneme_ids.unique_consecutive(return_counts=True)
        # segments = torch.split(_h, _phoneme_counts.tolist())
        # _h_reduced = torch.stack([seg.mean(dim=0) for seg in segments])
        phoneme_counts += _phoneme_counts.tolist()

        reduced_audio_lengths.append(len(unique_ids))
        reduced_phoneme_ids.append(unique_ids)

    reduced_audio_lengths = torch.tensor(reduced_audio_lengths)
    reduced_phoneme_ids = torch.nn.utils.rnn.pad_sequence(reduced_phoneme_ids, batch_first=True)
    h = torch.concat([hidden_states[i, :_len, :] for i, _len in enumerate(audio_lengths)], dim=0)
    reduced_hidden_states = segment_means(h, torch.tensor(phoneme_counts, device=hidden_states.device))

    print(reduced_audio_lengths.shape, reduced_phoneme_ids.shape, reduced_hidden_states.shape)
    
    return reduced_hidden_states, reduced_audio_lengths, reduced_phoneme_ids

In [145]:
B = 64
T = 149
hidden_states = torch.randn(B, T, 768).cuda().requires_grad_(True)
phoneme_ids = torch.randint(0, 2, (B, T)).cuda()
audio_lengths = torch.randint(140, 149, (B,))

In [146]:
reduced_hidden_states, reduced_audio_lengths, reduced_phoneme_ids = reduce_feat(
    hidden_states, audio_lengths, phoneme_ids
)
print(reduced_audio_lengths.shape, reduced_phoneme_ids.shape, reduced_hidden_states.shape)

torch.Size([64, 149, 768]) torch.Size([64]) torch.Size([64, 149])
torch.Size([64]) torch.Size([64, 88]) torch.Size([4694, 768])
torch.Size([64]) torch.Size([64, 88]) torch.Size([4694, 768])


# Generate edges

设，一个audio的所有帧（T）对应的phoneme ids是，并给出了audio长度为10：
```python　
predict_ids = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 2, 2])
audio_lengths = [10]
```

那么，建立单项边时，有两个操作：
1. 添加所有邻接边：`0->1, 1->2, 2->3, ...`
2. 对于一个phoneme，对接下来的N个phoneme都建立边：
    - 以第1个node为例（index从0开始），设N=1，那么会新加边：`1->2, 1->3, 1->4`
    - 以第3个node为例（index从0开始），设N=1，那么会新加边：`3->5,6,7,8,9`

### Step 1 

使用arange可以很快地生成所有邻接边。

In [149]:
def get_adj_edges(L: int, use_np=False):
    if use_np:
        adj_edges = np.stack([np.arange(L - 1), np.arange(1, L)])
    else:
        adj_edges = torch.stack([torch.arange(L - 1), torch.arange(1, L)])
    return adj_edges

In [157]:
L = 149
adj_edges = get_adj_edges(L, use_np=0)
adj_edges_np = get_adj_edges(L, use_np=1)
print(adj_edges, adj_edges.numpy()-adj_edges_np)

tensor([[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
          14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
          28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
          42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
          56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
          70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
          84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
          98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
         112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
         126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
         140, 141, 142, 143, 144, 145, 146, 147],
        [  1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,  14,
          15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,  28,
  

## Step 2

使用 `unique_consecutive`可以快速地查找所有不同的phoneme id，并定位其index范围。

### 遍历　

下面这种遍历长度L的方法，耗时太长了

#### torch

In [179]:
# predict_ids = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3])
predict_ids = torch.randint(0, 3, (149,))

N = 10

output, inverse, counts = predict_ids.unique_consecutive(return_inverse=True, return_counts=True)
cumsum_counts = torch.cumsum(counts, 0)
# output, inverse, conuts, cumsum_counts

edges = []
for i in range(L):
    unique_id = inverse[i]  # 0, 1, 2, 3,
    unique_id_end_index = cumsum_counts[unique_id]
    if unique_id == len(output) - 1:
        break
    next_id = min(len(output) - 1, unique_id + N)
    next_end_index = cumsum_counts[next_id]
    _edges = torch.stack(
        [torch.full((next_end_index - unique_id_end_index,), i), torch.arange(unique_id_end_index, next_end_index)]
    )
    edges.append(_edges)
edges = torch.concat(edges, 1)

In [160]:
edges

tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  2,  2,  2,  2,  2,  2,  2,  3,  3,  3,  3,  3,  3,  3,  4,  4,
          4,  4,  4,  4,  4,  5,  5,  6,  6,  7,  7,  8,  8,  9,  9],
        [ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11,  2,  3,  4,  5,  6,  7,  8,  9,
         10, 11,  5,  6,  7,  8,  9, 10, 11,  5,  6,  7,  8,  9, 10, 11,  5,  6,
          7,  8,  9, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11]])

#### numpy 

首先，由于numpy中没有unique_consecutive这个函数，因此需要先实现他。

In [165]:
predict_ids = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3])
output, inverse, counts = predict_ids.unique_consecutive(return_inverse=True, return_counts=True)
output, inverse, counts

(tensor([0, 1, 2, 3]),
 tensor([0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3]),
 tensor([2, 3, 5, 2]))

In [None]:
def unique_consecutive(x: np.ndarray):

    output = [x[0]]
    inverse = np.zeros_like(x)
    counts = [1]

    for i in range(1, len(x)):
        if x[i] == output[-1]:
            counts[-1] += 1
        else:
            output.append(x[i])
            counts.append(1)
        inverse[i] = len(output) -1
    return np.array(output), inverse, np.array(counts)

In [None]:
## 可以看到，和tensor版本的输出结果是一样的
unique_consecutive(predict_ids.numpy())

(array([0, 1, 2, 3]),
 array([0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3]),
 array([2, 3, 5, 2]))

In [205]:
L = 149 * 64
predict_ids = torch.randint(0, 3, (L,)).numpy()
N = 10

output, inverse, counts = unique_consecutive(predict_ids)
cumsum_counts = np.cumsum(counts, 0)
# output, inverse, conuts, cumsum_counts

edges = []
for i in range(L):
    unique_id = inverse[i]  # 0, 1, 2, 3,
    unique_id_end_index = cumsum_counts[unique_id]
    if unique_id == len(output) - 1:
        break
    next_id = min(len(output) - 1, unique_id + N)
    next_end_index = cumsum_counts[next_id]
    _edges = np.stack(
        [np.full((next_end_index - unique_id_end_index,), i), np.arange(unique_id_end_index, next_end_index)]
    )
    edges.append(_edges)
edges = np.concatenate(edges, 1)

### tensor方法　

In [61]:
def generate_multple_sequences(ns, as_, bs):
    """
    quickly generate n1 sequences that ranging from a1 to b1,
            generate n2 sequences that ranging from a2 to b2,
            generate n3 sequences that ranging from a3 to b3,
            ....
    and finally combine these sequences


    ```python
    ns = torch.tensor([5, 3, 4])
    as_ = torch.tensor([1, 11, 16])
    bs = torch.tensor([10, 15, 20])
    ```
    > tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9,
               1,  2,  3,  4,  5,  6,  7,  8,  9,
               1,  2,  3,  4,  5,  6,  7,  8,  9,
               1,  2,  3,  4,  5,  6,  7,  8,  9,
               1,  2,  3,  4,  5,  6,  7,  8,  9,
               11, 12, 13, 14,
               11, 12, 13, 14,
               11, 12, 13, 14,
               16, 17, 18, 19,
               16, 17, 18, 19,
               16, 17, 18, 19,
               16, 17, 18, 19])
    Args:
        ns: the repeat number of sequences
        as_: the start number for each seq
        bs: the end number for each seq

    Returns:
        tensor: a 1D tensor for the combined seq.
    """
    # print(ns, as_, bs)
    device = ns.device
    # The maximum value in bs determines the tensor width for uniformity
    max_length = torch.max(bs - as_ + 1)
    # Generate a tensor where each row is a sequence from 0 to max_length
    seq_tensor = torch.arange(max_length).unsqueeze(0).repeat(ns.sum(), 1).to(device)
    seq_tensor = torch.repeat_interleave(as_, ns)[:, None] + seq_tensor
    nums = torch.repeat_interleave(bs - as_, ns)
    mask = torch.arange(seq_tensor.size(1)).expand_as(seq_tensor).to(device) < nums.unsqueeze(1)

    return seq_tensor[mask]

In [60]:
def get_phoneme_edges2(predict_ids: torch.Tensor, N=1):
    """
    Args:
        predict_ids: a tensor with shape of (L,) that represents the phoneme id for each audio frame.
        N: the number of looking forward phonemes
    Returns:
        torch.Tensor: the edges with shape of (2, n_edges)
    """

    device = predict_ids.device
    
    output, inverse, counts = predict_ids.unique_consecutive(return_inverse=True, return_counts=True)
    cumsum_counts = torch.cumsum(counts, 0).to(device)
    # print(output, inverse, counts, cumsum_counts)
    if len(output) == 1:
        return torch.zeros((2, 0))
    # both start and end are length L
    start_indices = torch.cat([torch.tensor([0], device=device), cumsum_counts[:-1]])
    end_indices = cumsum_counts
    # print("start", start_indices, "end", end_indices)

    edge_start_indices = start_indices[1:]
    edge_end_indices = end_indices[torch.clamp(torch.arange(len(output) - 1) + N, max=len(end_indices) - 1)]
    # print(edge_start_indices, edge_end_indices)

    # print(edge_end_indices.device)
    x = torch.repeat_interleave(
        torch.arange(cumsum_counts[-2]).to(device), (edge_end_indices - edge_start_indices)[inverse[: cumsum_counts[-2]]], dim=0
    ).to(device)
    y = generate_multple_sequences(ns=counts[:-1], as_=edge_start_indices, bs=edge_end_indices).to(device)
    edges = torch.stack([x, y])
    # print(x.shape, y.shape, edges.shape)
    return edges

使用一个短的ids进行验证，和上面的遍历方法产生的结果是一样的：

In [51]:
predict_ids = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3]).cuda()
get_phoneme_edges2(predict_ids)

cuda:0


tensor([[ 0,  0,  0,  1,  1,  1,  2,  2,  2,  2,  2,  3,  3,  3,  3,  3,  4,  4,
          4,  4,  4,  5,  5,  6,  6,  7,  7,  8,  8,  9,  9],
        [ 2,  3,  4,  2,  3,  4,  5,  6,  7,  8,  9,  5,  6,  7,  8,  9,  5,  6,
          7,  8,  9, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11]], device='cuda:0')

#### numpy

## Step 3

In [62]:
def generate_edges(input_audio_lengths: torch.Tensor, input_phoneme_ids: torch.Tensor, N=2):
    start_id = 0
    edge_index = []

    # audio_lengths = input_audio_lengths.cpu()
    audio_lengths = input_audio_lengths
    # phoneme_ids = input_phoneme_ids.cpu()
    phoneme_ids = input_phoneme_ids

    cumsum_audio_lengths = torch.cumsum(audio_lengths, 0)

    device = audio_lengths.device
    
    total_edges = []
    for i in range(len(audio_lengths)):
        _audio_len = audio_lengths[i]
        _phoneme_ids = phoneme_ids[i, :_audio_len]
        _start_index = cumsum_audio_lengths[i - 1] if i > 0 else 0

        adj_edges = get_adj_edges(_audio_len).to(device)
        phoneme_edges = get_phoneme_edges2(_phoneme_ids, N=N).to(device)
        _edges = torch.concat([adj_edges, phoneme_edges], dim=1) + _start_index
        total_edges.append(_edges)
    total_edges = torch.concat(total_edges, dim=1)
    total_edges = torch.unique(total_edges, dim=1)
    return total_edges.type(torch.int64)

In [443]:
def generate_edges_by_combine_and_split(input_audio_lengths: torch.Tensor, input_phoneme_ids: torch.Tensor, N=2):
    edge_index = []

    audio_lengths = input_audio_lengths
    padding = torch.arange(1, N+1, dtype=input_phoneme_ids.dtype, device=input_phoneme_ids.device) * -1
    phoneme_ids = torch.concat([torch.concat([input_phoneme_ids[i, :_audio_len], padding]) for i, _audio_len in enumerate(audio_lengths)])
    device = audio_lengths.device
    
    adj_edges = get_adj_edges(len(phoneme_ids)).to(device)
    phoneme_edges = get_phoneme_edges2(phoneme_ids, N=N).to(device)
    _edges = torch.concat([adj_edges, phoneme_edges], dim=1)
    total_edges = torch.unique(_edges, dim=1)

    audio_lengths = audio_lengths.cpu()
    actual_id = torch.ones((torch.sum(audio_lengths + N),))
    total_len = 0
    for i, _len in enumerate(audio_lengths):
        x = torch.arange(_len + N) + torch.sum(audio_lengths[:i])
        actual_id[total_len: total_len + _len] = x[:_len]
        actual_id[total_len +_len : total_len+_len+N] = -1
        total_len += _len + N

    actual_id= actual_id.to(device)
    total_edges = actual_id[total_edges]
    mask = ~(total_edges == -1).any(dim=0)
    total_edges = total_edges[:, mask]
    
    return total_edges.type(torch.int64)

In [457]:
B = 64
T = 70
N = 20
phoneme_ids = torch.randint(0, 199, (B, T)).cuda()
audio_lengths = torch.randint(T-10, T, (B,)).cuda()
print(phoneme_ids, audio_lengths)

tensor([[139,  47,  72,  ...,  36, 179, 168],
        [148, 171,  81,  ...,  32,  21, 197],
        [ 51,  47,  52,  ..., 189, 126,  29],
        ...,
        [147, 173,   3,  ..., 179,  33,  52],
        [105, 126, 147,  ..., 116,  93, 112],
        [ 40, 149, 147,  ..., 174, 123, 174]], device='cuda:0') tensor([66, 60, 68, 65, 69, 63, 62, 66, 60, 67, 69, 66, 64, 64, 60, 65, 62, 62,
        69, 65, 62, 67, 60, 60, 66, 68, 62, 61, 63, 65, 62, 62, 60, 68, 68, 62,
        60, 66, 61, 69, 60, 65, 69, 69, 61, 60, 63, 60, 65, 64, 67, 62, 65, 69,
        61, 64, 62, 66, 62, 66, 62, 69, 63, 60], device='cuda:0')


In [449]:
e2 = generate_edges(audio_lengths, phoneme_ids, N=N)
e2.shape, e2

(torch.Size([2, 69040]),
 tensor([[   0,    0,    0,  ..., 4104, 4104, 4105],
         [   1,    2,    3,  ..., 4105, 4106, 4106]], device='cuda:0'))

In [458]:
with torch.no_grad():
    e3 = generate_edges_by_combine_and_split(audio_lengths, phoneme_ids, N=N)
print(e3.shape, e3)

torch.Size([2, 68872]) tensor([[   0,    0,    0,  ..., 4095, 4095, 4096],
        [   1,    2,    3,  ..., 4096, 4097, 4097]], device='cuda:0')


In [363]:
audio_lengths = audio_lengths.cpu()
actual_id = torch.ones((torch.sum(audio_lengths + N),))
total_len = 0
for i, _len in enumerate(audio_lengths):
    x = torch.arange(_len + N) + torch.sum(audio_lengths[:i])
    actual_id[total_len: total_len + _len] = x[:_len]
    actual_id[total_len +_len : total_len+_len+N] = -1
    total_len += _len + N

In [364]:
e3.shape, actual_id.shape, actual_id

(torch.Size([2, 34]),
 torch.Size([12]),
 tensor([ 0.,  1., -1., -1., -1.,  2.,  3.,  4.,  5., -1., -1., -1.]))

In [365]:
e4 = actual_id[e3.cpu()]
mask = ~(e4 == -1).any(dim=0)
e5 = e4[:, mask]
print(e4.shape, e5.shape)

torch.Size([2, 34]) torch.Size([2, 7])


In [366]:
e2, e3, e4, e5

(tensor([[0, 2, 2, 2, 3, 3, 4],
         [1, 3, 4, 5, 4, 5, 5]], device='cuda:0'),
 tensor([[ 0,  0,  0,  1,  1,  1,  2,  2,  2,  3,  3,  3,  3,  4,  4,  4,  4,  5,
           5,  5,  5,  6,  6,  6,  6,  7,  7,  7,  8,  8,  8,  9,  9, 10],
         [ 1,  2,  3,  2,  3,  4,  3,  4,  5,  4,  5,  6,  7,  5,  6,  7,  8,  6,
           7,  8,  9,  7,  8,  9, 10,  8,  9, 10,  9, 10, 11, 10, 11, 11]],
        device='cuda:0'),
 tensor([[ 0.,  0.,  0.,  1.,  1.,  1., -1., -1., -1., -1., -1., -1., -1., -1.,
          -1., -1., -1.,  2.,  2.,  2.,  2.,  3.,  3.,  3.,  3.,  4.,  4.,  4.,
           5.,  5.,  5., -1., -1., -1.],
         [ 1., -1., -1., -1., -1., -1., -1., -1.,  2., -1.,  2.,  3.,  4.,  2.,
           3.,  4.,  5.,  3.,  4.,  5., -1.,  4.,  5., -1., -1.,  5., -1., -1.,
          -1., -1., -1., -1., -1., -1.]]),
 tensor([[0., 2., 2., 2., 3., 3., 4.],
         [1., 3., 4., 5., 4., 5., 5.]]))

# Weighted hidden states

设，一个audio的所有帧（T）的特征为(B, T, C), 而T帧对应的phoneme ids是：
```python　
predict_ids = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 2, 2])
```

- 正常预测时，会在T帧上进行平均，得到`(B, T, C) -> (B, C)`，
- 而weighted加权，会为每一帧分配一个权重，相邻的相同phoneme_id越多，那么权重越大。

```python
## 根据长度计算权重：　
tensor([2, 2, 3, 3, 3, 5, 5, 5, 5, 5])

## 归一化权重，使得sum=1.0：　
tensor([0.0526, 0.0526, 0.0789, 0.0789, 0.0789, 0.1316, 0.1316, 0.1316, 0.1316,
        0.1316])
```


In [310]:
def calculate_sequence_weights(predict_ids: torch.Tensor):
    """
    Calculate the normalized weights for each position in a sequence based on
    the sequence lengths of consecutive identical elements.

    ```python
        predict_ids = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 2, 2])

        output: tensor([0.0526, 0.0526, 0.0789, 0.0789, 0.0789, 0.1316, 0.1316, 0.1316, 0.1316,
        0.1316])
    ```

    Args:
        predict_ids (torch.Tensor): the ids with shape (L)

    Returns:
        torch.Tensor: the normalizad weights.

    """

    output, inverse, counts = predict_ids.unique_consecutive(return_inverse=True, return_counts=True)
    sequence_lengths = counts[inverse]
    print(sequence_lengths)
    normalized_weights = sequence_lengths.float() / sequence_lengths.sum().float()
    return normalized_weights.squeeze()


def get_weighted_hidden_state(hidden_states, phoneme_logits):
    """
    Computes weighted hidden states efficiently when predict_ids is 1D.
    """
    B, T, C = hidden_states.size()
    weighted_hidden_states = torch.zeros(B, C, dtype=hidden_states.dtype, device=hidden_states.device)
    for i, (_h, _l) in enumerate(zip(hidden_states, phoneme_logits)):
        predict_ids = torch.argmax(_l, dim=1)  # Get most likely phoneme ID sequence from logits
        weights = calculate_sequence_weights(predict_ids).unsqueeze(1)
        weighted_hidden_states[i, :] = torch.sum(_h * weights, dim=0)

    return weighted_hidden_states

In [305]:
hidden_states = torch.randn(2, 149, 768)
phoneme_logits = torch.randn(2, 149, 200)
get_weighted_hidden_state(hidden_states, phoneme_logits)

tensor([[ 0.0385, -0.0323,  0.0318,  ..., -0.1140,  0.0664,  0.0711],
        [-0.1164, -0.1143,  0.0586,  ..., -0.0409,  0.0053, -0.0723]])