In [1]:
import torch

# reduce and combine hidden states

:::{note}
在音素识别时，一个音素可能会跨多个时间帧，因此需要合并连续且相同的音素，构造为一个音素，从而预测整句话的音素，并与ground truth进行比较。
:::

那使用音素特征识别时，那么也许合并连续且相同的音素特征帧。设音素特征为`(B, T, 768)`，那么合并后的特征为`(B, T', 768)`。

In [649]:
B = 64
T = 149
hidden_states = torch.randn(B, T, 768).cuda()
phoneme_ids = torch.randint(0, 3, (B, T)).cuda()
audio_lengths = torch.randint(70, 149, (B,)).cuda()

In [640]:
def segment_means(tensor, segment_sizes):
    assert tensor.size(0) == segment_sizes.sum(), "Sum of segment sizes must equal the tensor's first dimension size."

    # Create an indices tensor that maps each row in the tensor to its corresponding segment
    indices = torch.repeat_interleave(torch.arange(len(segment_sizes), device=tensor.device), segment_sizes)

    # Create a tensor to hold the sum of each segment
    segment_sums = torch.zeros(len(segment_sizes), tensor.size(1), device=tensor.device)

    # Scatter and sum the inputs into the segment_sums tensor
    segment_sums.scatter_add_(0, indices.unsqueeze(1).expand(-1, tensor.size(1)), tensor)

    # Calculate the mean of each segment
    segment_means = segment_sums / segment_sizes.unsqueeze(1)

    return segment_means

In [651]:
def reduce_feat(hidden_states, audio_lengths, phoneme_ids):
    reduced_hidden_states = []
    reduced_audio_lengths = []
    reduced_phoneme_ids = []

    phoneme_counts = []

    for i in range(len(audio_lengths)):
        _phoneme_ids = phoneme_ids[i, : audio_lengths[i]]
        _h = hidden_states[i, : audio_lengths[i]]

        unique_ids, _phoneme_counts = _phoneme_ids.unique_consecutive(return_counts=True)
        # segments = torch.split(_h, _phoneme_counts.tolist())
        # _h_reduced = torch.stack([seg.mean(dim=0) for seg in segments])
        phoneme_counts += _phoneme_counts.tolist()

        reduced_audio_lengths.append(len(unique_ids))
        reduced_phoneme_ids.append(unique_ids)

    reduced_audio_lengths = torch.tensor(reduced_audio_lengths)
    reduced_phoneme_ids = torch.nn.utils.rnn.pad_sequence(reduced_phoneme_ids, batch_first=True)
    h = torch.concat([hidden_states[i, :_len, :] for i, _len in enumerate(audio_lengths)], dim=0)
    reduced_hidden_states = segment_means(h, torch.tensor(phoneme_counts, device=hidden_states.device))
    return reduced_hidden_states, reduced_audio_lengths, reduced_phoneme_ids

In [652]:
reduced_hidden_states, reduced_audio_lengths, reduced_phoneme_ids = reduce_feat(
    hidden_states, audio_lengths, phoneme_ids
)
print(reduced_audio_lengths.shape, reduced_phoneme_ids.shape, reduced_hidden_states.shape)

torch.Size([64]) torch.Size([64, 102]) torch.Size([4633, 768])


# Generate edges

设，一个audio的所有帧（T）对应的phoneme ids是，并给出了audio长度为10：
```python　
predict_ids = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 2, 2])
audio_lengths = [10]
```

那么，建立单项边时，有两个操作：
1. 添加所有邻接边：`0->1, 1->2, 2->3, ...`
2. 对于一个phoneme，对接下来的N个phoneme都建立边：
    - 以第1个node为例（index从0开始），设N=1，那么会新加边：`1->2, 1->3, 1->4`
    - 以第3个node为例（index从0开始），设N=1，那么会新加边：`3->5,6,7,8,9`

### Step 1 

使用arange可以很快地生成所有邻接边。

In [4]:
L = 149
adj_edges = torch.stack([torch.arange(L - 1), torch.arange(1, L)])
adj_edges

tensor([[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
          14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
          28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
          42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
          56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
          70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
          84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
          98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
         112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
         126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
         140, 141, 142, 143, 144, 145, 146, 147],
        [  1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,  14,
          15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,  28,
  

In [6]:
def get_adj_edges(L: int):
    adj_edges = torch.stack([torch.arange(L - 1), torch.arange(1, L)])
    return adj_edges

### Step 2

使用 `unique_consecutive`可以快速地查找所有不同的phoneme id，并定位其index范围。

In [8]:
predict_ids = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 2, 2])
N = 1

output, inverse, counts = predict_ids.unique_consecutive(return_inverse=True, return_counts=True)
cumsum_counts = torch.cumsum(counts, 0)
# output, inverse, conuts, cumsum_counts

edges = []
for i in range(L):
    unique_id = inverse[i]  # 0, 1, 2, 3,
    unique_id_end_index = cumsum_counts[unique_id]
    if unique_id == len(output) - 1:
        break
    next_id = min(len(output) - 1, unique_id + N)
    next_end_index = cumsum_counts[next_id]
    _edges = torch.stack(
        [torch.full((next_end_index - unique_id_end_index,), i), torch.arange(unique_id_end_index, next_end_index)]
    )
    edges.append(_edges)
edges = torch.concat(edges, 1)

In [9]:
torch.zeros((2, 1), dtype=torch.int)

tensor([[0],
        [0]], dtype=torch.int32)

In [10]:
def get_phoneme_edges(predict_ids: torch.Tensor, N=1):
    """
    Args:
        predict_ids: a tensor with shape of (L,) that represents the phoneme id for each audio frame.
        N: the number of looking forward phonemes
    Returns:
        torch.Tensor: the edges with shape of (2, n_edges)
    """
    output, inverse, counts = predict_ids.unique_consecutive(return_inverse=True, return_counts=True)
    cumsum_counts = torch.cumsum(counts, 0)
    # output, inverse, conuts, cumsum_counts

    if len(output) == 1:
        return torch.zeros((2, 0))

    s = time.time()
    edges = []
    for i in range(len(predict_ids)):
        unique_id = inverse[i]  # 0, 1, 2, 3,
        unique_id_end_index = cumsum_counts[unique_id]
        if unique_id == len(output) - 1:
            break
        next_id = min(len(output) - 1, unique_id + N)
        next_end_index = cumsum_counts[next_id]
        _edges = torch.stack(
            [torch.full((next_end_index - unique_id_end_index,), i), torch.arange(unique_id_end_index, next_end_index)]
        )
        edges.append(_edges)
    edges = torch.concat(edges, 1)

    e = time.time()
    print(e - s)

    return edges

In [35]:
def get_phoneme_edges2(predict_ids: torch.Tensor, N=1):
    """
    Args:
        predict_ids: a tensor with shape of (L,) that represents the phoneme id for each audio frame.
        N: the number of looking forward phonemes
    Returns:
        torch.Tensor: the edges with shape of (2, n_edges)
    """

    device = predict_ids.device
    
    output, inverse, counts = predict_ids.unique_consecutive(return_inverse=True, return_counts=True)
    cumsum_counts = torch.cumsum(counts, 0).to(device)
    # print(output, inverse, counts, cumsum_counts)

    # both start and end are length L
    start_indices = torch.cat([torch.tensor([0], device=device), cumsum_counts[:-1]])
    end_indices = cumsum_counts
    # print("start", start_indices, "end", end_indices)

    edge_start_indices = start_indices[1:]
    edge_end_indices = end_indices[torch.clamp(torch.arange(len(output) - 1) + N, max=len(end_indices) - 1)]
    # print(edge_start_indices, edge_end_indices)

    print(edge_end_indices.device)
    x = torch.repeat_interleave(
        torch.arange(cumsum_counts[-2]).to(device), (edge_end_indices - edge_start_indices)[inverse[: cumsum_counts[-2]]], dim=0
    ).to(device)
    y = generate_multple_sequences(ns=counts[:-1], as_=edge_start_indices, bs=edge_end_indices).to(device)
    edges = torch.stack([x, y])
    # print(x.shape, y.shape, edges.shape)
    return edges

In [40]:
predict_ids = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3]).cuda()
get_phoneme_edges2(predict_ids)

cuda:0


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

In [39]:
def generate_multple_sequences(ns, as_, bs):
    """
    quickly generate n1 sequences that ranging from a1 to b1,
            generate n2 sequences that ranging from a2 to b2,
            generate n3 sequences that ranging from a3 to b3,
            ....
    and finally combine these sequences


    ```python
    ns = torch.tensor([5, 3, 4])
    as_ = torch.tensor([1, 11, 16])
    bs = torch.tensor([10, 15, 20])
    ```
    > tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9,
               1,  2,  3,  4,  5,  6,  7,  8,  9,
               1,  2,  3,  4,  5,  6,  7,  8,  9,
               1,  2,  3,  4,  5,  6,  7,  8,  9,
               1,  2,  3,  4,  5,  6,  7,  8,  9,
               11, 12, 13, 14,
               11, 12, 13, 14,
               11, 12, 13, 14,
               16, 17, 18, 19,
               16, 17, 18, 19,
               16, 17, 18, 19,
               16, 17, 18, 19])
    Args:
        ns: the repeat number of sequences
        as_: the start number for each seq
        bs: the end number for each seq

    Returns:
        tensor: a 1D tensor for the combined seq.
    """
    # print(ns, as_, bs)
    device = ns.device
    # The maximum value in bs determines the tensor width for uniformity
    max_length = torch.max(bs - as_ + 1)
    # Generate a tensor where each row is a sequence from 0 to max_length
    seq_tensor = torch.arange(max_length).unsqueeze(0).repeat(ns.sum(), 1).to(device)
    seq_tensor = torch.repeat_interleave(as_, ns)[:, None] + seq_tensor
    nums = torch.repeat_interleave(bs - as_, ns)
    mask = torch.arange(seq_tensor.size(1)).expand_as(seq_tensor) < nums.unsqueeze(1)

    return seq_tensor[mask]

### Step 3

In [27]:
def generate_edges(input_audio_lengths: torch.Tensor, input_phoneme_ids: torch.Tensor, N=2):
    start_id = 0
    edge_index = []

    # audio_lengths = input_audio_lengths.cpu()
    audio_lengths = input_audio_lengths
    # phoneme_ids = input_phoneme_ids.cpu()
    phoneme_ids = input_phoneme_ids

    cumsum_audio_lengths = torch.cumsum(audio_lengths, 0)

    device = audio_lengths.device
    
    total_edges = []
    for i in range(len(audio_lengths)):
        _audio_len = audio_lengths[i]
        _phoneme_ids = phoneme_ids[i, :_audio_len]
        _start_index = cumsum_audio_lengths[i - 1] if i > 0 else 0

        adj_edges = get_adj_edges(_audio_len).to(device)
        phoneme_edges = get_phoneme_edges2(_phoneme_ids, N=N).to(device)
        _edges = torch.concat([adj_edges, phoneme_edges], dim=1) + _start_index
        total_edges.append(_edges)
    total_edges = torch.concat(total_edges, dim=1)
    total_edges = torch.unique(total_edges, dim=1)
    return total_edges

In [445]:
audio_lengths = torch.tensor([5, 8])
phoneme_ids = torch.randint(0, 3, (2, 10))

In [446]:
phoneme_ids

tensor([[1, 1, 1, 2, 0, 1, 0, 1, 0, 2],
        [2, 1, 1, 1, 1, 2, 2, 0, 0, 2]])

In [426]:
B = 64
T = 149
phoneme_ids = torch.randint(0, 199, (B, T))
audio_lengths = torch.tensor([T] * B)

In [447]:
cumsum_audio_lengths = torch.cumsum(audio_lengths, 0)

total_edges = []
for i in range(len(audio_lengths)):
    _audio_len = audio_lengths[i]
    _phoneme_ids = phoneme_ids[i, :_audio_len]
    _start_index = cumsum_audio_lengths[i - 1] if i > 0 else 0

    adj_edges = get_adj_edges(_audio_len)
    phoneme_edges = get_phoneme_edges(_phoneme_ids, N=2)
    _edges = torch.concat([adj_edges, phoneme_edges], dim=1) + _start_index
    total_edges.append(_edges)
total_edges = torch.concat(total_edges, dim=1)
total_edges = torch.unique(total_edges, dim=1)

total_edges

tensor([[ 0,  0,  0,  1,  1,  1,  2,  2,  3,  5,  5,  5,  5,  5,  5,  6,  6,  6,
          6,  7,  7,  7,  7,  8,  8,  8,  8,  9,  9,  9, 10, 10, 11],
        [ 1,  3,  4,  2,  3,  4,  3,  4,  4,  6,  7,  8,  9, 10, 11,  7, 10, 11,
         12,  8, 10, 11, 12,  9, 10, 11, 12, 10, 11, 12, 11, 12, 12]])

In [21]:
B = 64
T = 149
phoneme_ids = torch.randint(0, 199, (B, T)).cuda()
audio_lengths = torch.tensor([T] * B).cuda()

In [30]:
e2 = generate_edges(audio_lengths, phoneme_ids)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

In [770]:
e1 = generate_edges(audio_lengths, phoneme_ids)
torch.sum(e1 - e2)

0.007739543914794922
0.00764155387878418
0.008351802825927734
0.007665395736694336
0.007583141326904297
0.007695436477661133
0.007594585418701172
0.007628440856933594
0.0076253414154052734
0.00757908821105957
0.007592916488647461
0.007593870162963867
0.007617473602294922
0.007589817047119141
0.007596731185913086
0.0075910091400146484
0.0075571537017822266
0.0075647830963134766
0.00755000114440918
0.007582902908325195
0.007536888122558594
0.00753474235534668
0.007597684860229492
0.0075991153717041016
0.007602214813232422
0.007534980773925781
0.008611440658569336
0.007569313049316406
0.0076029300689697266
0.008214473724365234
0.0077092647552490234
0.007634878158569336
0.007620096206665039
0.0076367855072021484
0.007555723190307617
0.007611751556396484
0.007607936859130859
0.00765228271484375
0.007688999176025391
0.007696866989135742
0.007664680480957031
0.0076525211334228516
0.0075719356536865234
0.007580280303955078
0.00758671760559082
0.007586240768432617
0.0075800418853759766
0.007592

tensor(0)

# Weighted hidden states

设，一个audio的所有帧（T）的特征为(B, T, C), 而T帧对应的phoneme ids是：
```python　
predict_ids = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 2, 2])
```

- 正常预测时，会在T帧上进行平均，得到`(B, T, C) -> (B, C)`，
- 而weighted加权，会为每一帧分配一个权重，相邻的相同phoneme_id越多，那么权重越大。

```python
## 根据长度计算权重：　
tensor([2, 2, 3, 3, 3, 5, 5, 5, 5, 5])

## 归一化权重，使得sum=1.0：　
tensor([0.0526, 0.0526, 0.0789, 0.0789, 0.0789, 0.1316, 0.1316, 0.1316, 0.1316,
        0.1316])
```


In [310]:
def calculate_sequence_weights(predict_ids: torch.Tensor):
    """
    Calculate the normalized weights for each position in a sequence based on
    the sequence lengths of consecutive identical elements.

    ```python
        predict_ids = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 2, 2])

        output: tensor([0.0526, 0.0526, 0.0789, 0.0789, 0.0789, 0.1316, 0.1316, 0.1316, 0.1316,
        0.1316])
    ```

    Args:
        predict_ids (torch.Tensor): the ids with shape (L)

    Returns:
        torch.Tensor: the normalizad weights.

    """

    output, inverse, counts = predict_ids.unique_consecutive(return_inverse=True, return_counts=True)
    sequence_lengths = counts[inverse]
    print(sequence_lengths)
    normalized_weights = sequence_lengths.float() / sequence_lengths.sum().float()
    return normalized_weights.squeeze()


def get_weighted_hidden_state(hidden_states, phoneme_logits):
    """
    Computes weighted hidden states efficiently when predict_ids is 1D.
    """
    B, T, C = hidden_states.size()
    weighted_hidden_states = torch.zeros(B, C, dtype=hidden_states.dtype, device=hidden_states.device)
    for i, (_h, _l) in enumerate(zip(hidden_states, phoneme_logits)):
        predict_ids = torch.argmax(_l, dim=1)  # Get most likely phoneme ID sequence from logits
        weights = calculate_sequence_weights(predict_ids).unsqueeze(1)
        weighted_hidden_states[i, :] = torch.sum(_h * weights, dim=0)

    return weighted_hidden_states

In [305]:
hidden_states = torch.randn(2, 149, 768)
phoneme_logits = torch.randn(2, 149, 200)
get_weighted_hidden_state(hidden_states, phoneme_logits)

tensor([[ 0.0385, -0.0323,  0.0318,  ..., -0.1140,  0.0664,  0.0711],
        [-0.1164, -0.1143,  0.0586,  ..., -0.0409,  0.0053, -0.0723]])