In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import random
import time

import numpy as np
import torch
from ay2.tools import TimerContextManager

每个audio的frames个数都是$T$，为了避免unique_consecutive把不同的audio的phoneme给合并了，给每个audio padding了一个frame，因此每个audio的frames个数变成了$T+1$。

设对batch的phoneme ids进行unique_consecutive之后，共有N个phonemes，那么`get_phonme_id_mapping`的作用就是计算出，对于reduced的第$i$个phoneme，其对应的：
1. 音频样本的id，$0<=id<B$
2. 在音频样本中，该phoneme对应哪些音频frames，$s:e$
3. 它是音频样本的第几个phoneme

In [71]:
def get_phonme_id_mapping(cumsum_counts, T):
    """
    id_to_index_range is a tensor with (L, 4), where each row (a, b, c, d)
    - a: the audio sample id
    - b, c: the phoneme id range in the audio frames, 1 phoneme may contain multiple frames
    - d: the n-th phoneme in current audio sample
    """
    id_to_index_range = torch.zeros(len(cumsum_counts), 4, dtype=torch.int32, device=cumsum_counts.device)
    id_to_index_range[:, 0] = (cumsum_counts - 1) // (T + 1)
    id_to_index_range[1:, 1] = cumsum_counts[:-1] % (T + 1)
    id_to_index_range[:, 2] = (cumsum_counts - 1) % (T + 1) + 1

    _, nums = id_to_index_range[:, 0].unique_consecutive(return_counts=True)
    id_to_index_range[:, 3] = torch.concat([torch.arange(x) for x in nums])


    # _n = 0
    # for i in range(1, len(cumsum_counts)):
    #     if id_to_index_range[i, 0] != id_to_index_range[i-1, 0]:
    #         _n = 0
    #     id_to_index_range[i, 3] = _n
    #     _n += 1

    # with open('text.txt', 'w') as f:
    # for i in range(len(cumsum_counts)):
    # print(id_to_index_range[i], file=f)
    return id_to_index_range

测试

In [72]:
# B, T = 64, 149
B, T = 3, 10
phoneme_ids = torch.randint(0, 2, (B, T)).type(torch.int64).cuda()
phoneme_ids = torch.concat([phoneme_ids, torch.ones(B, 1, device=phoneme_ids.device) * -1], dim=-1)
reduced_phoneme_ids, inverse, counts = phoneme_ids.unique_consecutive(return_inverse=True, return_counts=True)
# print(phoneme_ids)
%time phoneme_id_to_index_range = get_phonme_id_mapping(cumsum_counts=torch.cumsum(counts, 0), T=T)
print(phoneme_id_to_index_range)

CPU times: user 5.03 ms, sys: 266 µs, total: 5.3 ms
Wall time: 5.3 ms
tensor([[ 0,  0,  1,  0],
        [ 0,  1,  2,  1],
        [ 0,  2,  3,  2],
        [ 0,  3,  6,  3],
        [ 0,  6,  8,  4],
        [ 0,  8, 10,  5],
        [ 0, 10, 11,  6],
        [ 1,  0,  2,  0],
        [ 1,  2,  3,  1],
        [ 1,  3, 10,  2],
        [ 1, 10, 11,  3],
        [ 2,  0,  1,  0],
        [ 2,  1,  2,  1],
        [ 2,  2,  5,  2],
        [ 2,  5,  6,  3],
        [ 2,  6,  7,  4],
        [ 2,  7, 10,  5],
        [ 2, 10, 11,  6]], device='cuda:0', dtype=torch.int32)


对于一个batch，计算每个音频样本$i$，在reduced的phoneme ids中，phoneme id的起始和结束的index范围是多少：

In [7]:
def get_sample_index_range_in_phonemes(samples):
    change_indices = torch.nonzero(samples[1:] != samples[:-1], as_tuple=False)[:, 0] + 1
    # Start indices for each unique value
    start_indices = torch.cat((torch.tensor([0], device=change_indices.device), change_indices))
    # End indices for each unique value
    ## 这里减1是为了忽略那个padding的frame
    end_indices = torch.cat((change_indices - 1, torch.tensor([len(samples) - 1], device=change_indices.device)))
    # Create a dictionary to store index ranges
    index_ranges = {
        value.item(): (start.item(), end.item())
        for value, start, end in zip(samples[start_indices], start_indices, end_indices)
    }
    return index_ranges

测试

In [8]:
%time index_ranges = get_sample_index_range_in_phonemes(phoneme_id_to_index_range[:, 0])
print(index_ranges)

CPU times: user 51.1 ms, sys: 26.1 ms, total: 77.2 ms
Wall time: 76 ms
{0: (0, 6), 1: (7, 15), 2: (16, 20)}


In [9]:
# 查找每个phoneme同类的index
def find_same_phoneme_ids(reduced_phoneme_ids):
    same_phoneme_id_index = {}
    for i, x in enumerate(reduced_phoneme_ids.cpu().numpy()):
        same_phoneme_id_index.setdefault(x, []).append(i)
    return same_phoneme_id_index

测试

In [10]:
%time same_phoneme_id_index = find_same_phoneme_ids(reduced_phoneme_ids)
print(same_phoneme_id_index)

CPU times: user 113 µs, sys: 191 µs, total: 304 µs
Wall time: 320 µs
{1.0: [0, 2, 4, 7, 9, 11, 13, 17, 19], 0.0: [1, 3, 5, 8, 10, 12, 14, 16, 18], -1.0: [6, 15, 20]}


# 整合　

In [76]:
B, T = 64, 149
# B, T = 3, 10

hidden_states = torch.randn(B, T, 768).cuda()
audio_lengths = torch.randint(T - 10, T, (B,)).cuda()
phoneme_ids = torch.randint(0, 10, (B, T)).type(torch.int64).cuda()
languages = torch.randint(0, 3, (B,)).cuda()

In [83]:
def aug_hidden_states(hidden_states, audio_lengths, phoneme_ids, languages=None, p=0.2, *args, **kwargs):
    B, T = hidden_states.shape[0], hidden_states.shape[1]

    hidden_states = torch.concat([hidden_states, torch.ones(B, 1, 768, device=hidden_states.device)], dim=1)
    phoneme_ids = torch.concat([phoneme_ids, torch.ones(B, 1, device=phoneme_ids.device) * -1], dim=-1)

    reduced_phoneme_ids, inverse, counts = phoneme_ids.unique_consecutive(return_inverse=True, return_counts=True)
    phoneme_id_to_index_range = get_phonme_id_mapping(cumsum_counts=torch.cumsum(counts, 0), T=T)
    sample_index_ranges = get_sample_index_range_in_phonemes(phoneme_id_to_index_range[:, 0])
    same_phoneme_id_index = find_same_phoneme_ids(reduced_phoneme_ids)

    labels = torch.zeros_like(reduced_phoneme_ids)
    new_audio_lengths = torch.zeros_like(audio_lengths)  # actually phoneme ids length

    with TimerContextManager(debug=False):
        _phoneme_ids = reduced_phoneme_ids.cpu().numpy()
        L = len(_phoneme_ids)
        replaced_index = np.arange(L)
        probs = np.random.rand(L)
        filtered_index = np.nonzero(_phoneme_ids < 4)[0]

        choices = np.random.randint(100, 1000, L)

        for i in filtered_index:
            if probs[i] <= p:
                # replaced_index[i] = random.choice(filter_list)
                x = _phoneme_ids[i]
                _list = same_phoneme_id_index[x]
                _len = len(_list)
                _replace = choices[i] % _len
                if _list[_replace] == i:
                    replaced_index[i] = _list[(_replace + 1) % _len]
                else:
                    replaced_index[i] = _list[_replace]
                # print(i, x, replaced_index[i])

    labels[replaced_index != np.arange(L)] = 1
    labels = labels[reduced_phoneme_ids != -1]

    split_hidden_states = list(torch.split(hidden_states.view(-1, 768), tuple(counts.cpu().numpy())))
    splits = []
    for i in range(B):
        s, e = sample_index_ranges[i]
        splits.append([])
        splits[i] = torch.concat([split_hidden_states[j] for j in replaced_index[s:e]])
        new_audio_lengths[i] = splits[i].shape[0]
        # print(s, e, e-s, new_audio_lengths[i], splits[i].shape)
    aug_hidden_states = torch.nn.utils.rnn.pad_sequence(splits, batch_first=True)


    new_phoneme_ids = []
    for i in range(B):
        new_phoneme_ids.append([])
        s, e = sample_index_ranges[i]
        for j in replaced_index[s:e]:
            new_phoneme_ids[i] += [_phoneme_ids[j]] * counts[j]
            # print(i, j, len(new_phoneme_ids[i]), s, e)
        new_phoneme_ids[i] = torch.tensor(new_phoneme_ids[i])
    new_phoneme_ids = torch.nn.utils.rnn.pad_sequence(new_phoneme_ids, batch_first=True).to(phoneme_ids.device)

    # print(aug_hidden_states.shape, labels.shape, new_audio_lengths.sum(), new_phoneme_ids.shape)
    return aug_hidden_states, labels, new_audio_lengths, new_phoneme_ids

In [84]:
_hidden_states, labels, new_audio_lengths, new_phoneme_ids = aug_hidden_states(
    hidden_states, audio_lengths, phoneme_ids, p=0.2
)

In [85]:
phoneme_counts = []
for i in range(len(new_audio_lengths)):
    _phoneme_ids = new_phoneme_ids[i, : new_audio_lengths[i]]
    # _h = hidden_states[i, : audio_lengths[i]]
    unique_ids, _phoneme_counts = _phoneme_ids.unique_consecutive(return_counts=True)
    # print(_phoneme_counts)
    phoneme_counts += _phoneme_counts.tolist()

sum(phoneme_counts)

9537