In [1]:
import torch
import numpy as np
import time
from ay2.tools import TimerContextManager
import random

In [2]:
B = 64
T = 149

hidden_states= torch.randn(B, T, 768).cuda()
audio_lengths = torch.randint(T-10, T, (B,)).cuda()
phoneme_ids = torch.randint(0, 10, (B, T)).type(torch.int64).cuda()
languages = torch.randint(0, 3, (B,)).cuda()

In [3]:
## add -1 in each audio, 避免unique_consecutive把不同的audio的phoneme给合并了　
phoneme_ids = torch.concat([ phoneme_ids, torch.ones(B, 1, device=phoneme_ids.device)*-1], dim=-1)
hidden_states = torch.concat([ hidden_states, torch.ones(B, 1, 768, device=hidden_states.device)], dim=1)

### 尽管phoneme_ids是2D，但是unique_consecutive会把它拉伸为1D
### note!!!, inverse is 2D
reduced_phoneme_ids, inverse, counts = phoneme_ids.unique_consecutive(return_inverse=True, return_counts=True)
reduced_phoneme_ids = reduced_phoneme_ids.type(torch.int64)
cumsum_counts = torch.cumsum(counts, 0)

### id_pairs[i, j] denotes whether the phoneme i, j are same
id_pairs = (reduced_phoneme_ids[:, None] == reduced_phoneme_ids[None, :])

#### languages_pairs[i, j] denotes whether the samples i, j are of sample language
languages_pairs = ( languages[:, None] == languages[None, ...]) 

# Step 1

每个audio的frames个数都是$T$，为了避免unique_consecutive把不同的audio的phoneme给合并了，给每个audio padding了一个frame，因此每个audio的frames个数变成了$T+1$。

设对batch的phoneme ids进行unique_consecutive之后，共有N个phonemes，那么`get_phonme_id_mapping`的作用就是计算出，对于reduced的第$i$个phoneme，其对应的：
1. 音频样本的id，$0<=id<B$
2. 在音频样本中，该phoneme对应哪些音频frames，$s:e$
3. 它是音频样本的第几个phoneme

In [4]:
def get_phonme_id_mapping(cumsum_counts, T):
    """
    id_to_index_range is a tensor with (L, 4), where each row (a, b, c, d)
    - a: the audio sample id
    - b, c: the phoneme id range in the audio frames, 1 phoneme may contain multiple frames
    - d: the n-th phoneme in current audio sample
    """
    id_to_index_range = torch.zeros(len(cumsum_counts), 4, dtype=torch.int32, device=cumsum_counts.device)
    id_to_index_range[:, 0] = (cumsum_counts-1) // (T+1)
    id_to_index_range[1:, 1] = cumsum_counts[:-1] % (T+1)
    id_to_index_range[:, 2] = (cumsum_counts-1) % (T+1) + 1

    _, nums = id_to_index_range[:, 0].unique_consecutive(return_counts=True)
    id_to_index_range[:, 2] = torch.concat([torch.arange(x) for x in nums])
    
    # _n = 0
    # for i in range(1, len(cumsum_counts)):
    #     if id_to_index_range[i, 0] != id_to_index_range[i-1, 0]:
    #         _n = 0
    #     id_to_index_range[i, 3] = _n
    #     _n += 1

    # with open('text.txt', 'w') as f:
        # for i in range(len(cumsum_counts)):
            # print(id_to_index_range[i], file=f)
    return id_to_index_range

测试

In [5]:
B = 64
T = 149
phoneme_ids = torch.randint(0, 2, (B, T)).type(torch.int64).cuda()
phoneme_ids = torch.concat([ phoneme_ids, torch.ones(B, 1, device=phoneme_ids.device)*-1], dim=-1)
reduced_phoneme_ids, inverse, counts = phoneme_ids.unique_consecutive(return_inverse=True, return_counts=True)
# print(phoneme_ids)
%time phoneme_id_to_index_range = get_phonme_id_mapping(cumsum_counts=torch.cumsum(counts, 0), T=T)

CPU times: user 48.5 ms, sys: 28.7 ms, total: 77.1 ms
Wall time: 76.5 ms


## Step 2

对于一个batch，计算每个音频样本$i$，在reduced的phoneme ids中，phoneme id的起始和结束的index范围是多少：

In [82]:
def get_sample_index_range_in_phonemes(samples):
    change_indices = torch.nonzero(samples[1:] != samples[:-1], as_tuple=False)[:, 0] + 1
    # Start indices for each unique value
    start_indices = torch.cat((torch.tensor([0], device=change_indices.device), change_indices))
    # End indices for each unique value
    ## 这里减1是为了忽略那个padding的frame　
    end_indices = torch.cat((change_indices - 1, torch.tensor([len(samples) - 1], device=change_indices.device)))
    # Create a dictionary to store index ranges
    index_ranges = {value.item(): (start.item(), end.item()) for value, start, end in zip(samples[start_indices], start_indices, end_indices)}
    return index_ranges

测试

In [95]:
B, T = 64, 149
# B, T = 3, 10

hidden_states= torch.randn(B, T, 768).cuda()
audio_lengths = torch.randint(T-10, T, (B,)).cuda()
phoneme_ids = torch.randint(0, 10, (B, T)).type(torch.int64).cuda()
languages = torch.randint(0, 3, (B,)).cuda()
hidden_states = torch.concat([ hidden_states, torch.ones(B, 1, 768, device=hidden_states.device)], dim=1)
phoneme_ids = torch.concat([ phoneme_ids, torch.ones(B, 1, device=phoneme_ids.device)*-1], dim=-1)
reduced_phoneme_ids, inverse, counts = phoneme_ids.unique_consecutive(return_inverse=True, return_counts=True)
# print(phoneme_ids)
phoneme_id_to_index_range = get_phonme_id_mapping(cumsum_counts=torch.cumsum(counts, 0), T=T)

%time index_ranges = get_sample_index_range_in_phonemes(phoneme_id_to_index_range[:, 0])
print(reduced_phoneme_ids, index_ranges)

CPU times: user 6.18 ms, sys: 594 µs, total: 6.78 ms
Wall time: 6.79 ms
tensor([ 7.,  5.,  3.,  ...,  9.,  1., -1.], device='cuda:0') {0: (0, 127), 1: (128, 265), 2: (266, 403), 3: (404, 539), 4: (540, 674), 5: (675, 809), 6: (810, 942), 7: (943, 1078), 8: (1079, 1223), 9: (1224, 1354), 10: (1355, 1485), 11: (1486, 1617), 12: (1618, 1754), 13: (1755, 1887), 14: (1888, 2018), 15: (2019, 2156), 16: (2157, 2294), 17: (2295, 2427), 18: (2428, 2562), 19: (2563, 2699), 20: (2700, 2835), 21: (2836, 2971), 22: (2972, 3109), 23: (3110, 3244), 24: (3245, 3377), 25: (3378, 3514), 26: (3515, 3648), 27: (3649, 3784), 28: (3785, 3923), 29: (3924, 4050), 30: (4051, 4188), 31: (4189, 4320), 32: (4321, 4453), 33: (4454, 4583), 34: (4584, 4718), 35: (4719, 4850), 36: (4851, 4987), 37: (4988, 5125), 38: (5126, 5264), 39: (5265, 5402), 40: (5403, 5535), 41: (5536, 5666), 42: (5667, 5803), 43: (5804, 5934), 44: (5935, 6070), 45: (6071, 6205), 46: (6206, 6344), 47: (6345, 6480), 48: (6481, 6613), 49: (6614,

In [73]:
hidden_states.shape, torch.sum(counts)

(torch.Size([3, 11, 768]), tensor(33, device='cuda:0'))

In [48]:
reduced_phoneme_ids

tensor([ 1.,  2.,  6.,  8.,  7.,  9.,  5.,  7.,  9., -1.,  0.,  9.,  3.,  5.,
         3.,  9.,  5.,  1., -1.,  1.,  9.,  7.,  8.,  7.,  8.,  3.,  6., -1.],
       device='cuda:0')

## 查找每个phoneme同类的index　

In [49]:
def find_same_phoneme_ids(reduced_phoneme_ids):
    same_phoneme_id_index = {}
    for i, x in enumerate(reduced_phoneme_ids.cpu().numpy()):
        same_phoneme_id_index.setdefault(x, []).append(i)
    return same_phoneme_id_index

In [96]:
%time same_phoneme_id_index = find_same_phoneme_ids(reduced_phoneme_ids)
print(same_phoneme_id_index)

CPU times: user 0 ns, sys: 2.96 ms, total: 2.96 ms
Wall time: 2.68 ms
{7.0: [0, 9, 24, 32, 36, 41, 50, 54, 73, 83, 106, 108, 113, 129, 134, 175, 183, 186, 191, 194, 196, 200, 208, 210, 226, 231, 241, 246, 253, 261, 264, 279, 289, 295, 304, 306, 308, 316, 318, 325, 336, 357, 363, 370, 385, 400, 411, 413, 435, 438, 446, 451, 459, 465, 471, 511, 513, 517, 523, 538, 541, 547, 550, 554, 562, 577, 583, 590, 606, 622, 629, 636, 653, 661, 664, 699, 716, 722, 748, 759, 768, 776, 781, 785, 792, 796, 800, 805, 815, 824, 829, 836, 844, 849, 854, 860, 862, 871, 874, 886, 888, 897, 899, 915, 924, 926, 937, 943, 947, 954, 968, 991, 997, 1001, 1004, 1012, 1014, 1023, 1028, 1033, 1040, 1044, 1065, 1085, 1104, 1116, 1131, 1134, 1167, 1177, 1184, 1203, 1227, 1247, 1251, 1265, 1267, 1291, 1293, 1309, 1320, 1336, 1345, 1367, 1374, 1379, 1385, 1390, 1392, 1399, 1403, 1416, 1418, 1423, 1425, 1440, 1450, 1475, 1481, 1499, 1507, 1509, 1515, 1522, 1530, 1532, 1537, 1547, 1550, 1569, 1594, 1599, 1602, 1611, 1618

In [97]:
np.nonzero(reduced_phoneme_ids.cpu().numpy()<=4)[0]

array([   2,    7,   10, ..., 8651, 8653, 8654])

In [75]:
p = 5
with TimerContextManager():
    _phoneme_ids = reduced_phoneme_ids.cpu().numpy()
    L = len(_phoneme_ids)
    replaced_index = np.arange(L)
    probs = np.random.rand(L)
    filtered_index = np.nonzero(_phoneme_ids<4)[0]

    choices = np.random.randint(100, 1000, L)
    
    for i in filtered_index:
        if probs[i] <= p:
            # replaced_index[i] = random.choice(filter_list)
            x = _phoneme_ids[i]
            _list = same_phoneme_id_index[x]
            _len = len(_list)
            _replace = choices[i] % _len
            if _list[_replace] == i:
                replaced_index[i] = _list[(_replace + 1) % _len]
            else:
                replaced_index[i] = _list[_replace]
            print(i, x, same_phoneme_id_index[x], replaced_index[i])

1 1.0 [1, 15] 15
3 2.0 [3, 7, 29] 7
6 3.0 [6, 11, 13, 16, 18, 27] 18
7 2.0 [3, 7, 29] 29
10 -1.0 [10, 21, 32] 21
11 3.0 [6, 11, 13, 16, 18, 27] 6
13 3.0 [6, 11, 13, 16, 18, 27] 27
15 1.0 [1, 15] 1
16 3.0 [6, 11, 13, 16, 18, 27] 6
18 3.0 [6, 11, 13, 16, 18, 27] 11
21 -1.0 [10, 21, 32] 10
22 0.0 [22, 25] 25
25 0.0 [22, 25] 22
27 3.0 [6, 11, 13, 16, 18, 27] 11
29 2.0 [3, 7, 29] 7
32 -1.0 [10, 21, 32] 21
 代码执行时间: 0.0016355514526367188 秒


In [90]:
replaced_index

array([ 0, 15,  2,  7,  4,  5, 18, 29,  8,  9, 21,  6, 12, 27, 14,  1,  6,
       17, 11, 19, 20, 10, 25, 23, 24, 22, 26, 11, 28,  7, 30, 31, 21])

In [78]:
[[]]*10

[[], [], [], [], [], [], [], [], [], []]

In [93]:
split_hidden_states = list(torch.split(hidden_states.view(-1, 768), tuple(counts.cpu().numpy())))
splits = [[]] * B
for i in range(B):
    s, e = index_ranges[i]
    print(s, e, replaced_index[s:e])
    splits[i] = torch.concat([split_hidden_states[j] for j in replaced_index[s:e]])
    # print(splits[i].shape)
aug_hidden_states = torch.nn.utils.rnn.pad_sequence(splits,batch_first=True)

0 9 [ 0 15  2  7  4  5 18 29  8]
10 20 [21  6 12 27 14  1  6 17 11 19]
21 31 [10 25 23 24 22 26 11 28  7 30]


In [94]:
aug_hidden_states.shape

torch.Size([3, 10, 768])

In [69]:
sum(tuple(counts.cpu().numpy()))

33

In [19]:
np.random.rand(10)

array([0.94259596, 0.04583652, 0.79635097, 0.3748233 , 0.43596316,
       0.72562881, 0.30926741, 0.6197601 , 0.49392744, 0.08592097])

In [190]:
x = [1, 2, 6]
x.remove(1)
x

[2, 6]

In [84]:
# def find_first_larger_element(X, y):
#     """
#     Finds the first element in sorted tensor X that is larger than y.
    
#     Args:
#     - X (torch.Tensor): Sorted 1D tensor.
#     - y (float or int): Value to compare against.
    
#     Returns:
#     - int: Index of the first element in X that is larger than y.
#       Returns -1 if no such element exists.
#     """
#     greater_mask = X > y
#     indices = torch.nonzero(greater_mask).squeeze()
    
#     if indices.numel() > 0:
#         return indices[0].item()
#     else:
#         return len(X)
        
# def find_phoneme_ids_index(sample_id, cumsum_counts, T):

#     # each sample has T+1 unreduced phoneme ids
#     s = sample_id * (T + 1)
#     e = s + (T + 1)
#     s = find_first_larger_element(cumsum_counts, s)
#     e = find_first_larger_element(cumsum_counts, e) - 1
#     return s, e


# def get_sample_phoneme_ids_indices(cumsum_counts, B, T):
#     sample_phoneme_ids_index = {}
#     for i in range(B):
#         s, e = find_phoneme_ids_index(i, cumsum_counts, T)
#         sample_phoneme_ids_index[i] = (s, e)
#         # print(i, s, e, reduced_phoneme_ids[s:e])
#     return sample_phoneme_ids_index

# %time sample_phoneme_ids_index = get_sample_phoneme_ids_indices(cumsum_counts, B, T)

# Step 3　

In [107]:
def aug_hidden_states(hidden_states,audio_lengths ,phoneme_ids, languages=None, N=5):


    with TimerContextManager('init'):
        B,T = hidden_states.shape[0], hidden_states.shape[1]
        ## add -1 in each audio, 避免unique_consecutive把不同的audio的phoneme给合并了　
        phoneme_ids = torch.concat([ phoneme_ids, torch.ones(B, 1, device=phoneme_ids.device)*-1], dim=-1)
        hidden_states = torch.concat([ hidden_states, torch.ones(B, 1, 768, device=hidden_states.device)], dim=1)
        
        ### 尽管phoneme_ids是2D，但是unique_consecutive会把它拉伸为1D
        ### note!!!, inverse is 2D
        reduced_phoneme_ids, inverse, counts = phoneme_ids.unique_consecutive(return_inverse=True, return_counts=True)
        reduced_phoneme_ids = reduced_phoneme_ids.type(torch.int64)
        cumsum_counts = torch.cumsum(counts, 0)
        
        ### id_pairs[i, j] denotes whether the phoneme i, j are same
        id_pairs = (reduced_phoneme_ids[:, None] == reduced_phoneme_ids[None, :])
        
        #### languages_pairs[i, j] denotes whether the samples i, j are of sample language
        # languages_pairs = ( languages[:, None] == languages[None, ...]) 
    
    
        labels = torch.zeros_like(reduced_phoneme_ids)
    
    with TimerContextManager('compute mapping'):
        id_to_index_range = get_phonme_id_mapping(cumsum_counts.cpu(), T)
        # sample_phoneme_ids_index = get_sample_phoneme_ids_indices(cumsum_counts, B,T)
        sample_phoneme_ids_index = get_sample_index_range_in_phonemes(id_to_index_range[:, 0])
        
        
    split_hidden_states = list(torch.split(hidden_states.view(-1, 768), tuple(counts.cpu().numpy())))
    split_hidden_states_org = list(torch.split(hidden_states.view(-1, 768), tuple(counts.cpu().numpy())))
    
    i = 1
    splits = []
    for i in range(B):
        start_index, end_index = sample_phoneme_ids_index[i]
        _split = split_hidden_states[start_index:end_index]
        x_phoneme_ids = reduced_phoneme_ids[start_index:end_index]
    
        # print('sample', i, start_index, end_index)
        
        find_substitution = 0
        with TimerContextManager(f'sample {i}'):
            for j in (np.random.permutation(end_index - start_index) + start_index): # 遍历phoneme　
                if reduced_phoneme_ids[j] >= 0 and reduced_phoneme_ids[j] < 5:
                    continue
                _id = reduced_phoneme_ids[j]
                same_ids = id_pairs[j].nonzero().squeeze()
                # print('phonme_id', _id)
                if same_ids.ndim == 0:
                    continue
                success = 0
                for k in torch.randperm(len(same_ids)):
                    current_phoneme_id = same_ids[k]
                    current_sample, current_sample_s, current_sample_e, current_sample_phoneme  = id_to_index_range[current_phoneme_id]
                    # print(current_sample)
                    if current_sample == i:
                        continue
                    ### replace
                    start_index2, end_index2 = sample_phoneme_ids_index[current_sample.item()]
                    # print(i, start_index, end_index, 'source', j-start_index, _id, 'target', current_sample, start_index2, end_index2,current_sample_phoneme, reduced_phoneme_ids[current_phoneme_id])
                    _split[j-start_index] = split_hidden_states_org[start_index2:end_index2][current_sample_phoneme]
                    success = 1
    
                    labels[j] = 1
                    
                    break
                
                find_substitution += success
                if find_substitution >= N:
                    break
            splits.append(_split)

    with TimerContextManager(f'final process'):
        for i in range(B):
            splits[i] = torch.concat(splits[i])
            # print(splits[i].shape)
        aug_hidden_states = torch.nn.utils.rnn.pad_sequence(splits,batch_first=True)
        labels = labels[reduced_phoneme_ids!=-1]
    return aug_hidden_states, labels

In [95]:
B = 64
T = 149

hidden_states= torch.randn(B, T, 768).cuda()
audio_lengths = torch.randint(T-10, T, (B,)).cuda()
phoneme_ids = torch.randint(0, 10, (B, T)).type(torch.int64).cuda()
languages = torch.randint(0, 3, (B,)).cuda()

In [108]:
aug_hidden_states(hidden_states, audio_lengths, phoneme_ids, languages)

init 代码执行时间: 0.008322715759277344 秒
compute mapping 代码执行时间: 0.00443577766418457 秒
sample 0 代码执行时间: 0.011241912841796875 秒
sample 1 代码执行时间: 0.008610010147094727 秒
sample 2 代码执行时间: 0.007604122161865234 秒
sample 3 代码执行时间: 0.006790637969970703 秒
sample 4 代码执行时间: 0.006298542022705078 秒
sample 5 代码执行时间: 0.0065000057220458984 秒
sample 6 代码执行时间: 0.006036996841430664 秒
sample 7 代码执行时间: 0.0061419010162353516 秒
sample 8 代码执行时间: 0.006064891815185547 秒
sample 9 代码执行时间: 0.006111621856689453 秒
sample 10 代码执行时间: 0.00619959831237793 秒
sample 11 代码执行时间: 0.0061910152435302734 秒
sample 12 代码执行时间: 0.006241559982299805 秒
sample 13 代码执行时间: 0.006371259689331055 秒
sample 14 代码执行时间: 0.006404876708984375 秒
sample 15 代码执行时间: 0.006104946136474609 秒
sample 16 代码执行时间: 0.006050825119018555 秒
sample 17 代码执行时间: 0.006089448928833008 秒
sample 18 代码执行时间: 0.006044149398803711 秒
sample 19 代码执行时间: 0.006255626678466797 秒
sample 20 代码执行时间: 0.00638890266418457 秒
sample 21 代码执行时间: 0.0058956146240234375 秒
sample 22 代码执行时间: 0.0061

(tensor([[[-1.9122e+00,  1.7298e-01,  1.6296e+00,  ..., -1.1782e+00,
           -1.6140e-01, -4.6088e-01],
          [-8.2703e-01,  3.4944e-01, -1.6509e+00,  ..., -1.4307e+00,
           -2.8071e+00, -5.8551e-01],
          [ 2.9599e-01, -4.7831e-01, -7.7205e-01,  ..., -1.0221e+00,
            1.1022e+00,  6.6595e-01],
          ...,
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00]],
 
         [[-1.3124e+00,  1.4327e+00,  8.2652e-01,  ..., -1.0508e+00,
            3.1194e-01, -1.5754e+00],
          [ 1.0310e+00, -1.2389e+00, -1.0692e+00,  ...,  2.5045e-01,
           -4.9238e-01,  6.8962e-01],
          [ 4.2484e-01,  1.3397e+00,  1.5697e+00,  ...,  6.9813e-01,
           -1.2801e+00,  1.6349e+00],
          ...,
    