# WeNe之aishell数据预处理

In [1]:
import torch
from torch.utils.data import IterableDataset
import torchaudio
import torch.distributed as dist
import random
import yaml
import torchaudio.compliance.kaldi as kaldi
from wenet.utils.file_utils import read_lists
import wenet.dataset.processor as processor
from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols

## 获取数据配置文本

In [2]:
path = r'conf/train_conformer.yaml'

with open(path, 'r') as fin:
    configs = yaml.load(fin, Loader=yaml.FullLoader)

conf = configs["dataset_conf"]

## 获取Dataset

In [3]:
class DistributedSampler:
    def __init__(self, shuffle=True, partition=True):
        self.epoch = -1
        self.update()
        self.shuffle = shuffle
        self.partition = partition

    def update(self):
        assert dist.is_available()
        if dist.is_initialized():
            self.rank = dist.get_rank()
            self.world_size = dist.get_world_size()
        else:
            self.rank = 0
            self.world_size = 1
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            self.worker_id = 0
            self.num_workers = 1
        else:
            self.worker_id = worker_info.id
            self.num_workers = worker_info.num_workers
        return dict(rank=self.rank,
                    world_size=self.world_size,
                    worker_id=self.worker_id,
                    num_workers=self.num_workers)
    
    def sample(self, data):
        """ Sample data according to rank/world_size/num_workers

            Args:
                data(List): input data list

            Returns:
                List: data list after sample
        """
        # 采样后信息
        data = list(range(len(data)))
        # TODO(Binbin Zhang): fix this
        # We can not handle uneven data for CV on DDP, so we don't
        # sample data by rank, that means every GPU gets the same
        # and all the CV data
        if self.partition:
            if self.shuffle:
                random.Random(self.epoch).shuffle(data)
            data = data[self.rank::self.world_size]
        data = data[self.worker_id::self.num_workers]
        return data


In [4]:
class DataList(IterableDataset):
    """完成大内存数据读取
    """
    def __init__(self, lists, shuffle=True, partition=True):
        self.lists = lists
        self.sampler = DistributedSampler(shuffle, partition)

    def set_epoch(self, epoch):
        self.sampler.set_epoch(epoch)

    def __iter__(self):
        sampler_info = self.sampler.update()
        indexes = self.sampler.sample(self.lists)
        for index in indexes:
            # yield dict(src=src)
            data = dict(src=self.lists[index])
            data.update(sampler_info)
            yield data

In [5]:
def read_lists(list_file):
    lists = []
    with open(list_file, 'r', encoding='utf8') as fin:
        for line in fin:
            lists.append(line.strip())
    return lists

In [6]:
class Processor(IterableDataset):
    def __init__(self, source, f, *args, **kw):
        assert callable(f)
        self.source = source
        self.f = f
        self.args = args
        self.kw = kw

    def set_epoch(self, epoch):
        self.source.set_epoch(epoch)

    def __iter__(self):
        """ Return an iterator over the source dataset processed by the
            given processor.
        """
        assert self.source is not None
        assert callable(self.f)
        return self.f(iter(self.source), *self.args, **self.kw)

    def apply(self, f):
        assert callable(f)
        return Processor(self, f, *self.args, **self.kw)

In [7]:
list_data = read_lists("./data/train/data.list")

## 预处理操作

### 读取原始数据

In [8]:
dataset = DataList(list_data)

In [9]:
for i, data in enumerate(dataset):
    print(data) 
    if i == 5: break

{'src': '{"key": "BAC009S0160W0199", "wav": "/home/gavin/Machine/data/asr/aishell/data_aishell/wav/train/S0160/BAC009S0160W0199.wav", "txt": "最大的问题就是提供了某种隐私的背书或担保"}', 'rank': 0, 'world_size': 1, 'worker_id': 0, 'num_workers': 1}
{'src': '{"key": "BAC009S0079W0181", "wav": "/home/gavin/Machine/data/asr/aishell/data_aishell/wav/train/S0079/BAC009S0079W0181.wav", "txt": "几天前有媒体曝光了这片别墅"}', 'rank': 0, 'world_size': 1, 'worker_id': 0, 'num_workers': 1}
{'src': '{"key": "BAC009S0114W0438", "wav": "/home/gavin/Machine/data/asr/aishell/data_aishell/wav/train/S0114/BAC009S0114W0438.wav", "txt": "通州法院判决驳回了潘老太的诉求"}', 'rank': 0, 'world_size': 1, 'worker_id': 0, 'num_workers': 1}
{'src': '{"key": "BAC009S0096W0254", "wav": "/home/gavin/Machine/data/asr/aishell/data_aishell/wav/train/S0096/BAC009S0096W0254.wav", "txt": "未来科技必须回到以人为中心"}', 'rank': 0, 'world_size': 1, 'worker_id': 0, 'num_workers': 1}
{'src': '{"key": "BAC009S0349W0190", "wav": "/home/gavin/Machine/data/asr/aishell/data_aishell/wav/trai

In [10]:
dataset = Processor(dataset, processor.parse_raw)

In [11]:
for i, data in enumerate(dataset):
    print(data) 
    if i == 5: break

{'key': 'BAC009S0160W0199', 'txt': '最大的问题就是提供了某种隐私的背书或担保', 'wav': tensor([[-0.0023, -0.0039, -0.0035,  ..., -0.0048, -0.0050, -0.0044]]), 'sample_rate': 16000}
{'key': 'BAC009S0079W0181', 'txt': '几天前有媒体曝光了这片别墅', 'wav': tensor([[-0.0020, -0.0033, -0.0029,  ..., -0.0013, -0.0013, -0.0015]]), 'sample_rate': 16000}
{'key': 'BAC009S0114W0438', 'txt': '通州法院判决驳回了潘老太的诉求', 'wav': tensor([[-0.0004, -0.0005, -0.0002,  ...,  0.0002,  0.0003,  0.0003]]), 'sample_rate': 16000}
{'key': 'BAC009S0096W0254', 'txt': '未来科技必须回到以人为中心', 'wav': tensor([[ 0.0004,  0.0007,  0.0008,  ..., -0.0008, -0.0007, -0.0006]]), 'sample_rate': 16000}
{'key': 'BAC009S0349W0190', 'txt': '产权市场纷纷抢食这块蛋糕', 'wav': tensor([[0.0004, 0.0007, 0.0007,  ..., 0.0005, 0.0004, 0.0006]]), 'sample_rate': 16000}
{'key': 'BAC009S0422W0354', 'txt': '由西向东步行约一二十米左右', 'wav': tensor([[-0.0013, -0.0017, -0.0008,  ...,  0.0037,  0.0034,  0.0041]]), 'sample_rate': 16000}


### 分词

In [12]:
non_lang_syms = read_non_lang_symbols(non_lang_sym_path=None)
symbol_table = read_symbol_table("./data/dict/lang_char.txt")

In [13]:
dataset = Processor(dataset, processor.tokenize, symbol_table, None,
                    non_lang_syms, conf.get('split_with_space', False))

In [14]:
for i, data in enumerate(dataset):
    print(data) 
    if i == 5: break

{'key': 'BAC009S0160W0199', 'txt': '最大的问题就是提供了某种隐私的背书或担保', 'wav': tensor([[-0.0023, -0.0039, -0.0035,  ..., -0.0048, -0.0050, -0.0044]]), 'sample_rate': 16000, 'tokens': ['最', '大', '的', '问', '题', '就', '是', '提', '供', '了', '某', '种', '隐', '私', '的', '背', '书', '或', '担', '保'], 'label': [1740, 814, 2553, 3925, 4077, 1030, 1694, 1556, 184, 66, 1806, 2723, 3984, 2718, 2553, 3047, 61, 1382, 1450, 205]}
{'key': 'BAC009S0079W0181', 'txt': '几天前有媒体曝光了这片别墅', 'wav': tensor([[-0.0020, -0.0033, -0.0029,  ..., -0.0013, -0.0013, -0.0015]]), 'sample_rate': 16000, 'tokens': ['几', '天', '前', '有', '媒', '体', '曝', '光', '了', '这', '片', '别', '墅'], 'label': [321, 815, 366, 1742, 916, 162, 1731, 262, 66, 3703, 2339, 353, 787]}
{'key': 'BAC009S0114W0438', 'txt': '通州法院判决驳回了潘老太的诉求', 'wav': tensor([[-0.0004, -0.0005, -0.0002,  ...,  0.0002,  0.0003,  0.0003]]), 'sample_rate': 16000, 'tokens': ['通', '州', '法', '院', '判', '决', '驳', '回', '了', '潘', '老', '太', '的', '诉', '求'], 'label': [3733, 1104, 2049, 3970, 350, 306, 4125, 703

### 过滤与重采样

In [15]:
filter_conf = conf.get('filter_conf', {})
dataset = Processor(dataset, processor.filter, **filter_conf)

In [16]:
for i, data in enumerate(dataset):
    print(data) 
    if i == 5: break

{'key': 'BAC009S0160W0199', 'txt': '最大的问题就是提供了某种隐私的背书或担保', 'wav': tensor([[-0.0023, -0.0039, -0.0035,  ..., -0.0048, -0.0050, -0.0044]]), 'sample_rate': 16000, 'tokens': ['最', '大', '的', '问', '题', '就', '是', '提', '供', '了', '某', '种', '隐', '私', '的', '背', '书', '或', '担', '保'], 'label': [1740, 814, 2553, 3925, 4077, 1030, 1694, 1556, 184, 66, 1806, 2723, 3984, 2718, 2553, 3047, 61, 1382, 1450, 205]}
{'key': 'BAC009S0079W0181', 'txt': '几天前有媒体曝光了这片别墅', 'wav': tensor([[-0.0020, -0.0033, -0.0029,  ..., -0.0013, -0.0013, -0.0015]]), 'sample_rate': 16000, 'tokens': ['几', '天', '前', '有', '媒', '体', '曝', '光', '了', '这', '片', '别', '墅'], 'label': [321, 815, 366, 1742, 916, 162, 1731, 262, 66, 3703, 2339, 353, 787]}
{'key': 'BAC009S0114W0438', 'txt': '通州法院判决驳回了潘老太的诉求', 'wav': tensor([[-0.0004, -0.0005, -0.0002,  ...,  0.0002,  0.0003,  0.0003]]), 'sample_rate': 16000, 'tokens': ['通', '州', '法', '院', '判', '决', '驳', '回', '了', '潘', '老', '太', '的', '诉', '求'], 'label': [3733, 1104, 2049, 3970, 350, 306, 4125, 703

In [17]:
resample_conf = conf.get('resample_conf', {})
dataset = Processor(dataset, processor.resample, **resample_conf)

In [18]:
for i, data in enumerate(dataset):
    print(data) 
    if i == 5: break

{'key': 'BAC009S0160W0199', 'txt': '最大的问题就是提供了某种隐私的背书或担保', 'wav': tensor([[-0.0023, -0.0039, -0.0035,  ..., -0.0048, -0.0050, -0.0044]]), 'sample_rate': 16000, 'tokens': ['最', '大', '的', '问', '题', '就', '是', '提', '供', '了', '某', '种', '隐', '私', '的', '背', '书', '或', '担', '保'], 'label': [1740, 814, 2553, 3925, 4077, 1030, 1694, 1556, 184, 66, 1806, 2723, 3984, 2718, 2553, 3047, 61, 1382, 1450, 205]}
{'key': 'BAC009S0079W0181', 'txt': '几天前有媒体曝光了这片别墅', 'wav': tensor([[-0.0020, -0.0033, -0.0029,  ..., -0.0013, -0.0013, -0.0015]]), 'sample_rate': 16000, 'tokens': ['几', '天', '前', '有', '媒', '体', '曝', '光', '了', '这', '片', '别', '墅'], 'label': [321, 815, 366, 1742, 916, 162, 1731, 262, 66, 3703, 2339, 353, 787]}
{'key': 'BAC009S0114W0438', 'txt': '通州法院判决驳回了潘老太的诉求', 'wav': tensor([[-0.0004, -0.0005, -0.0002,  ...,  0.0002,  0.0003,  0.0003]]), 'sample_rate': 16000, 'tokens': ['通', '州', '法', '院', '判', '决', '驳', '回', '了', '潘', '老', '太', '的', '诉', '求'], 'label': [3733, 1104, 2049, 3970, 350, 306, 4125, 703

In [19]:
conf

{'filter_conf': {'max_length': 40960,
  'min_length': 0,
  'token_max_length': 200,
  'token_min_length': 1},
 'resample_conf': {'resample_rate': 16000},
 'speed_perturb': True,
 'fbank_conf': {'num_mel_bins': 80,
  'frame_shift': 10,
  'frame_length': 25,
  'dither': 0.1},
 'spec_aug': True,
 'spec_aug_conf': {'num_t_mask': 2, 'num_f_mask': 2, 'max_t': 50, 'max_f': 10},
 'shuffle': True,
 'shuffle_conf': {'shuffle_size': 1500},
 'sort': True,
 'sort_conf': {'sort_size': 500},
 'batch_conf': {'batch_type': 'static', 'batch_size': 2}}

In [20]:
speed_perturb = conf.get('speed_perturb', False)
if speed_perturb:
    dataset = Processor(dataset, processor.speed_perturb)

In [21]:
for i, data in enumerate(dataset):
    print(data) 
    if i == 5: break

{'key': 'BAC009S0160W0199', 'txt': '最大的问题就是提供了某种隐私的背书或担保', 'wav': tensor([[-0.0022, -0.0039, -0.0035,  ..., -0.0048, -0.0049, -0.0050]]), 'sample_rate': 16000, 'tokens': ['最', '大', '的', '问', '题', '就', '是', '提', '供', '了', '某', '种', '隐', '私', '的', '背', '书', '或', '担', '保'], 'label': [1740, 814, 2553, 3925, 4077, 1030, 1694, 1556, 184, 66, 1806, 2723, 3984, 2718, 2553, 3047, 61, 1382, 1450, 205]}
{'key': 'BAC009S0079W0181', 'txt': '几天前有媒体曝光了这片别墅', 'wav': tensor([[-0.0019, -0.0033, -0.0029,  ..., -0.0014, -0.0012, -0.0016]]), 'sample_rate': 16000, 'tokens': ['几', '天', '前', '有', '媒', '体', '曝', '光', '了', '这', '片', '别', '墅'], 'label': [321, 815, 366, 1742, 916, 162, 1731, 262, 66, 3703, 2339, 353, 787]}
{'key': 'BAC009S0114W0438', 'txt': '通州法院判决驳回了潘老太的诉求', 'wav': tensor([[-0.0004, -0.0005, -0.0002,  ...,  0.0002,  0.0003,  0.0003]]), 'sample_rate': 16000, 'tokens': ['通', '州', '法', '院', '判', '决', '驳', '回', '了', '潘', '老', '太', '的', '诉', '求'], 'label': [3733, 1104, 2049, 3970, 350, 306, 4125, 703

In [22]:
# 获取音频特征
feats_type = conf.get('feats_type', 'fbank')
assert feats_type in ['fbank', 'mfcc']
if feats_type == 'fbank':
    fbank_conf = conf.get('fbank_conf', {})
    dataset = Processor(dataset, processor.compute_fbank, **fbank_conf)
elif feats_type == 'mfcc':
    mfcc_conf = conf.get('mfcc_conf', {})
    dataset = Processor(dataset, processor.compute_mfcc, **mfcc_conf)

In [23]:
for i, data in enumerate(dataset):
    print(data) 
    if i == 5: break

{'key': 'BAC009S0160W0199', 'label': [1740, 814, 2553, 3925, 4077, 1030, 1694, 1556, 184, 66, 1806, 2723, 3984, 2718, 2553, 3047, 61, 1382, 1450, 205], 'feat': tensor([[10.0034,  9.4124,  8.8504,  ...,  3.1760,  3.1335,  3.3529],
        [11.0576, 10.7693,  8.4619,  ...,  3.4684,  4.1584,  4.2449],
        [ 9.8402,  9.7405,  8.2489,  ...,  3.7678,  3.1555,  3.3836],
        ...,
        [11.9243, 11.3285,  7.8447,  ...,  3.8865,  3.7032,  3.0156],
        [12.3342, 11.8614,  9.1529,  ...,  3.5955,  4.2694,  3.5337],
        [10.3848, 10.0204,  7.7809,  ...,  3.2788,  3.5594,  4.1530]])}
{'key': 'BAC009S0079W0181', 'label': [321, 815, 366, 1742, 916, 162, 1731, 262, 66, 3703, 2339, 353, 787], 'feat': tensor([[10.7570, 10.7426,  8.2759,  ..., 10.0635,  9.1715,  8.3663],
        [10.2884,  9.9882,  8.2937,  ...,  9.4569,  9.8060,  8.3763],
        [ 7.2598,  7.7229,  7.4193,  ...,  9.2461,  8.6847,  7.8356],
        ...,
        [ 8.2552,  8.0596,  7.2004,  ...,  8.8967,  9.4172,  7.6320

### 音频特征增强

In [24]:
spec_aug = conf.get('spec_aug', True)
spec_sub = conf.get('spec_sub', False)
spec_trim = conf.get('spec_trim', False)
shuffle = conf.get('shuffle', True)
if spec_aug:
    spec_aug_conf = conf.get('spec_aug_conf', {})
    dataset = Processor(dataset, processor.spec_aug, **spec_aug_conf)
if spec_sub:
    spec_sub_conf = conf.get('spec_sub_conf', {})
    dataset = Processor(dataset, processor.spec_sub, **spec_sub_conf)
if spec_trim:
    spec_trim_conf = conf.get('spec_trim_conf', {})
    dataset = Processor(dataset, processor.spec_trim, **spec_trim_conf)

if shuffle:
    shuffle_conf = conf.get('shuffle_conf', {})
    dataset = Processor(dataset, processor.shuffle, **shuffle_conf)

sort = conf.get('sort', True)
if sort:
    sort_conf = conf.get('sort_conf', {})
    dataset = Processor(dataset, processor.sort, **sort_conf)

In [25]:
for i, data in enumerate(dataset):
    print(data) 
    if i == 5: break

{'key': 'BAC009S0026W0411', 'label': [25, 95, 3014, 524, 62, 271], 'feat': tensor([[ 8.6380,  9.4849,  8.7528,  ..., 10.0332,  9.4069,  8.3837],
        [ 9.7140,  9.2929,  6.8356,  ...,  9.2110,  9.1941,  7.8953],
        [ 9.6509,  9.4898,  9.3455,  ..., 10.0294,  9.5593,  8.1604],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]])}
{'key': 'BAC009S0128W0267', 'label': [25, 427, 75, 2551, 70, 426, 56], 'feat': tensor([[ 8.1434,  8.3520,  6.8038,  ...,  4.2025,  3.3710,  3.6203],
        [ 9.1759,  9.1142,  7.5651,  ...,  3.1498,  3.5437,  3.6296],
        [10.0043,  9.9018,  7.5152,  ...,  3.3204,  2.4338,  3.4348],
        ...,
        [ 8.0451,  7.1918,  6.5447,  ...,  3.7069,  4.1118,  3.2407],
        [ 9.2800,  9.4476,  6.9969,  ...,  3.1703,  3.7083,  4.9803],
        [ 7.5782,  6.7601,  5.9846,  ...,  3.976

### 数据batch化与padding

In [26]:
batch_conf = conf.get('batch_conf', {})
dataset = Processor(dataset, processor.batch, **batch_conf)

In [27]:
for i, data in enumerate(dataset):
    print(data) 
    if i == 5: break

[{'key': 'BAC009S0108W0199', 'label': [1656, 1621, 2807, 331, 506], 'feat': tensor([[12.2109, 11.8598,  7.3374,  ...,  9.4579,  9.4605,  7.6765],
        [11.4925, 10.6988,  7.6509,  ...,  9.6277,  9.5402,  8.7528],
        [ 9.4170,  8.3072,  7.3509,  ...,  9.6820,  9.4914,  8.6989],
        ...,
        [10.5787, 10.3363,  8.3428,  ...,  9.7901,  9.4125,  8.1679],
        [10.4588, 10.0394,  7.9148,  ...,  9.8501,  9.1444,  7.2262],
        [10.5515, 10.2172,  8.3560,  ...,  9.7540,  9.2594,  7.6902]])}, {'key': 'BAC009S0172W0128', 'label': [3608, 1143, 1409, 814], 'feat': tensor([[ 7.7422,  8.8199,  8.5580,  ...,  0.0000,  0.0000,  0.0000],
        [ 7.8527,  7.5003,  6.7255,  ...,  0.0000,  0.0000,  0.0000],
        [ 9.8539, 10.4513,  9.4371,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0

In [28]:
data

[{'key': 'BAC009S0246W0415',
  'label': [3517, 3517, 1176, 2318, 2553, 1107, 167, 95, 560],
  'feat': tensor([[8.7242, 9.3763, 9.3614,  ..., 9.5309, 9.4385, 7.9913],
          [7.9248, 8.1332, 5.7278,  ..., 9.8060, 9.5136, 8.1049],
          [8.5409, 9.5663, 9.3832,  ..., 9.5949, 8.9061, 7.6636],
          ...,
          [7.4998, 7.3673, 6.6385,  ..., 9.5328, 9.3131, 7.5758],
          [7.7944, 8.4934, 7.2805,  ..., 9.7965, 9.4260, 7.0329],
          [7.5611, 8.3587, 8.6900,  ..., 9.7884, 9.1457, 7.5318]])},
 {'key': 'BAC009S0152W0156',
  'label': [3502, 1812, 1011, 3531, 1694, 814, 949, 1970, 19, 2476],
  'feat': tensor([[ 6.6887,  7.8038,  0.0000,  ...,  9.6997,  9.5078,  7.7306],
          [ 6.2701,  5.8444,  0.0000,  ...,  9.8808, 10.8600,  8.5129],
          [ 8.6399,  9.2983,  0.0000,  ...,  9.2540,  9.8663,  7.8412],
          ...,
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     

In [29]:
feat_length = torch.tensor([x["feat"].size(0) for x in data],dtype=torch.int32)

In [30]:
torch.argsort(feat_length)

tensor([0, 1])

In [31]:
# 按照排序结果 返回其对应的索引结果
order = torch.argsort(feat_length, descending=True)

In [32]:
dataset = Processor(dataset, processor.padding)

In [33]:
for i, train_data in enumerate(dataset):
    print(train_data) 
    if i == 5: break

(['BAC009S0128W0267', 'BAC009S0600W0140'], tensor([[[ 8.4586,  8.4750,  5.4352,  ...,  7.2068,  7.6271,  6.4186],
         [10.1365, 10.4012,  8.7876,  ...,  8.2121,  7.3493,  6.0795],
         [ 9.5767,  9.3461,  3.6846,  ...,  8.4466,  7.1648,  6.5419],
         ...,
         [ 8.3635,  8.3586,  6.5479,  ...,  8.9378,  8.3924,  6.7225],
         [ 9.3292,  9.6250,  8.1633,  ...,  8.6924,  8.3276,  7.1249],
         [ 7.4935,  6.0364,  6.7090,  ...,  8.0413,  8.8462,  7.5329]],

        [[ 9.9886,  9.4007,  7.4236,  ...,  8.6365,  9.1940,  7.9169],
         [10.1148,  9.8261,  8.4347,  ...,  9.3652,  9.6718,  8.4966],
         [ 9.4363,  9.2812,  7.2438,  ...,  9.8238,  9.6255,  7.9771],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]]), tensor([[  25,  427,   75, 2551,   70,  426,   56],
        [1740, 4156,

In [34]:
from torch.nn.utils.rnn import pad_sequence

In [35]:
a = torch.ones(2, 4)
b = torch.ones(4, 4)
c = torch.ones(5, 4)
pd_data = pad_sequence([a, b, c], batch_first=True, padding_value=0)

In [36]:
pd_data

tensor([[[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [0., 0., 0., 0.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]]])

## 加载模型

In [37]:
from wenet.utils.init_model import init_model
# from torch.utils import tensorboard

Failed to import k2 and icefall.         Notice that they are necessary for hlg_onebest and hlg_rescore


In [38]:
path = r'exp/conformer/train.yaml'

with open(path, 'r') as fin:
    configs = yaml.load(fin, Loader=yaml.FullLoader)
fin.close()

In [39]:
model = init_model(configs)

##  训练

In [40]:
data

[{'key': 'BAC009S0246W0415',
  'label': [3517, 3517, 1176, 2318, 2553, 1107, 167, 95, 560],
  'feat': tensor([[8.7242, 9.3763, 9.3614,  ..., 9.5309, 9.4385, 7.9913],
          [7.9248, 8.1332, 5.7278,  ..., 9.8060, 9.5136, 8.1049],
          [8.5409, 9.5663, 9.3832,  ..., 9.5949, 8.9061, 7.6636],
          ...,
          [7.4998, 7.3673, 6.6385,  ..., 9.5328, 9.3131, 7.5758],
          [7.7944, 8.4934, 7.2805,  ..., 9.7965, 9.4260, 7.0329],
          [7.5611, 8.3587, 8.6900,  ..., 9.7884, 9.1457, 7.5318]])},
 {'key': 'BAC009S0152W0156',
  'label': [3502, 1812, 1011, 3531, 1694, 814, 949, 1970, 19, 2476],
  'feat': tensor([[ 6.6887,  7.8038,  0.0000,  ...,  9.6997,  9.5078,  7.7306],
          [ 6.2701,  5.8444,  0.0000,  ...,  9.8808, 10.8600,  8.5129],
          [ 8.6399,  9.2983,  0.0000,  ...,  9.2540,  9.8663,  7.8412],
          ...,
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     

In [41]:
train_data

(['BAC009S0044W0426', 'BAC009S0017W0251'],
 tensor([[[ 4.6670,  5.1192,  0.0000,  ...,  9.6385,  9.8300,  7.4352],
          [ 6.8735,  6.8337,  0.0000,  ...,  9.8808,  9.3433,  8.5645],
          [ 7.2836,  7.6655,  0.0000,  ...,  9.3444,  9.6414,  7.6837],
          ...,
          [ 8.3975,  8.1936,  0.0000,  ...,  9.3961,  9.8383,  8.8075],
          [ 7.2510,  7.4284,  0.0000,  ...,  9.1622, 10.1518,  8.2457],
          [ 3.8655,  5.4604,  0.0000,  ...,  9.2287, 10.0109,  8.8139]],
 
         [[ 8.0368,  8.4134,  6.8614,  ...,  9.7831,  9.4139,  7.7385],
          [ 8.0663,  8.7596,  7.9227,  ..., 10.0486,  9.5369,  7.3860],
          [ 6.8677,  8.1551,  8.6109,  ...,  9.9032,  9.3496,  7.8278],
          ...,
          [ 8.4017,  9.0341,  8.2494,  ..., 11.1876, 10.8796,  9.3109],
          [ 8.3789,  8.6269,  6.3244,  ..., 10.8945, 11.1904,  9.0844],
          [ 9.2008,  9.9915,  9.4013,  ..., 10.5520, 10.2344,  9.0678]]]),
 tensor([[ 384,    8, 4083, 2055, 2156, 2156, 3696,  477]

### encoder

In [42]:
asr_encoder = model.encoder

In [43]:
# 构建输入
speech, text,  speech_length, text_length =  torch.tensor(train_data[1]), torch.tensor(train_data[2]),torch.tensor(train_data[3]),torch.tensor(train_data[4])

  speech, text,  speech_length, text_length =  torch.tensor(train_data[1]), torch.tensor(train_data[2]),torch.tensor(train_data[3]),torch.tensor(train_data[4])


In [44]:
encoder_out, encoder_mask = asr_encoder(speech, speech_length)

In [45]:
speech.shape

torch.Size([2, 239, 80])

In [46]:
speech_length.shape

torch.Size([2])

In [47]:
# 训练输出
encoder_out.shape, encoder_mask.shape

(torch.Size([2, 59, 256]), torch.Size([2, 1, 59]))

In [48]:
encoder_mask

tensor([[[True, True, True, True, True, True, True, True, True, True, True,
          True, True, True, True, True, True, True, True, True, True, True,
          True, True, True, True, True, True, True, True, True, True, True,
          True, True, True, True, True, True, True, True, True, True, True,
          True, True, True, True, True, True, True, True, True, True, True,
          True, True, True, True]],

        [[True, True, True, True, True, True, True, True, True, True, True,
          True, True, True, True, True, True, True, True, True, True, True,
          True, True, True, True, True, True, True, True, True, True, True,
          True, True, True, True, True, True, True, True, True, True, True,
          True, True, True, True, True, True, True, True, True, True, True,
          True, True, True, True]]])

In [49]:
# encode_mask_lens
encoder_out_l = encoder_mask.squeeze(1).sum(1)

In [50]:
encoder_out_l

tensor([59, 59])

### Decoder模块

In [51]:
from wenet.utils.common import (IGNORE_ID, add_sos_eos, log_add,
                                remove_duplicates_and_blank, th_accuracy,
                                reverse_pad_list)

In [52]:
sos, eos = configs["output_dim"] - 1, configs["output_dim"] - 1

In [53]:
ys_in_pad, ys_out_pad = add_sos_eos(text, sos, eos, -1)
ys_in_lens = text_length + 1

In [54]:
ys_in_pad, ys_out_pad, ys_out_pad.shape

(tensor([[4232,  384,    8, 4083, 2055, 2156, 2156, 3696,  477],
         [4232, 1544, 3704, 4004, 2005,  338, 2087, 1618, 3738]]),
 tensor([[ 384,    8, 4083, 2055, 2156, 2156, 3696,  477, 4232],
         [1544, 3704, 4004, 2005,  338, 2087, 1618, 3738, 4232]]),
 torch.Size([2, 9]))

In [55]:
# reverse the seq, used for right to left decoder
r_ys_pad = reverse_pad_list(text, text_length, float(-1.0))
r_ys_in_pad, r_ys_out_pad = add_sos_eos(r_ys_pad, sos, eos, -1)

In [56]:
r_ys_in_pad, r_ys_out_pad

(tensor([[4232,  477, 3696, 2156, 2156, 2055, 4083,    8,  384],
         [4232, 3738, 1618, 2087,  338, 2005, 4004, 3704, 1544]]),
 tensor([[ 477, 3696, 2156, 2156, 2055, 4083,    8,  384, 4232],
         [3738, 1618, 2087,  338, 2005, 4004, 3704, 1544, 4232]]))

In [57]:
asr_decoder = model.decoder

In [58]:
decoder_out, r_decoder_out, _ = asr_decoder(encoder_out, encoder_mask,
                                             ys_in_pad, ys_in_lens,
                                             r_ys_in_pad,
                                             reverse_weight=0.5)

In [59]:
decoder_out, decoder_out.shape

(tensor([[[-0.1381, -0.5130,  0.3466,  ..., -1.8519, -0.7287,  0.2848],
          [ 0.6350, -0.6088, -0.4716,  ..., -0.4503, -0.2493, -0.0035],
          [-0.2735,  0.8892, -0.2949,  ..., -0.4250, -0.1103,  0.0653],
          ...,
          [-0.1359,  0.1031, -0.2078,  ...,  0.3344,  0.5821,  0.8283],
          [-0.1494, -0.2119,  0.0921,  ..., -0.8839,  0.1637, -0.6406],
          [-0.0175, -0.4482,  0.5752,  ...,  0.0692, -0.5830,  0.1351]],
 
         [[-0.2971, -0.4606,  0.0773,  ..., -1.7014, -0.5726, -0.2149],
          [-0.8763, -0.8359, -0.2438,  ...,  0.4683,  0.3095,  0.2194],
          [-0.1976,  0.1418,  0.1229,  ..., -0.2470, -0.5907,  1.0407],
          ...,
          [-0.1481,  0.0304, -0.3653,  ...,  0.3072, -0.1863, -0.4751],
          [-0.9997, -0.0741,  0.6237,  ...,  0.7740,  0.6361, -0.1804],
          [-0.3859,  0.7098, -0.1761,  ..., -1.0280,  1.0833, -0.5277]]],
        grad_fn=<ViewBackward0>),
 torch.Size([2, 9, 4233]))

In [60]:
r_decoder_out

tensor(0.)

### 计算decode损失
[标签平滑Label Smoothing](https://blog.csdn.net/AZ_CHEN/article/details/127658050)  
[机器学习：Kullback-Leibler Divergence （KL 散度）](https://blog.csdn.net/matrix_space/article/details/80550561)

In [61]:
ys_out_pad

tensor([[ 384,    8, 4083, 2055, 2156, 2156, 3696,  477, 4232],
        [1544, 3704, 4004, 2005,  338, 2087, 1618, 3738, 4232]])

In [62]:
decoder_out.shape

torch.Size([2, 9, 4233])

In [63]:
# label smoothing 
torch.randn([3,5]).scatter_(1, torch.tensor([[0]
                                             ,[3]
                                             , [4]])
                            , 0)

tensor([[ 0.0000,  1.6013,  0.0455, -0.5443,  1.8900],
        [-0.3198, -0.8287, -0.5038,  0.0000,  0.7186],
        [ 1.4001,  1.2908,  0.8427, -1.9264,  0.0000]])

## 推理

In [283]:
from typing import Dict, List, Optional, Tuple

### attention_beam_search

#### Encoder输出

In [284]:
# config
batch_size = 2
beam_size = 3

device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [285]:
# 基本配置
maxlen = encoder_out.size(1)
encoder_dim = encoder_out.size(2)
running_size = batch_size * beam_size
running_size

6

In [286]:
encoder_out, encoder_mask = asr_encoder(speech, speech_length)
encoder_out.shape, encoder_mask.shape

(torch.Size([2, 59, 256]), torch.Size([2, 1, 59]))

In [287]:
# B*N, T, H
encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view(running_size, maxlen, encoder_dim)  # (B*N, maxlen, encoder_dim)
encoder_out.shape

torch.Size([6, 59, 256])

In [288]:
# eg
a = torch.rand(2, 2, 3)
b = a.unsqueeze(1)
c = b.repeat(1, 3, 1, 1)
print(a, "\n", b,"\n", c)

tensor([[[0.7055, 0.2328, 0.0598],
         [0.0949, 0.3555, 0.8259]],

        [[0.8999, 0.8942, 0.7527],
         [0.4604, 0.4847, 0.0536]]]) 
 tensor([[[[0.7055, 0.2328, 0.0598],
          [0.0949, 0.3555, 0.8259]]],


        [[[0.8999, 0.8942, 0.7527],
          [0.4604, 0.4847, 0.0536]]]]) 
 tensor([[[[0.7055, 0.2328, 0.0598],
          [0.0949, 0.3555, 0.8259]],

         [[0.7055, 0.2328, 0.0598],
          [0.0949, 0.3555, 0.8259]],

         [[0.7055, 0.2328, 0.0598],
          [0.0949, 0.3555, 0.8259]]],


        [[[0.8999, 0.8942, 0.7527],
          [0.4604, 0.4847, 0.0536]],

         [[0.8999, 0.8942, 0.7527],
          [0.4604, 0.4847, 0.0536]],

         [[0.8999, 0.8942, 0.7527],
          [0.4604, 0.4847, 0.0536]]]])


In [289]:
encoder_mask = encoder_mask.unsqueeze(1).repeat(
    1, beam_size, 1, 1).view(running_size, 1,
                             maxlen)  # (B*N, 1, max_len)

In [290]:
encoder_mask.shape

torch.Size([6, 1, 59])

In [291]:
hyps = torch.ones([running_size, 1], dtype=torch.long,
                  device=device).fill_(-1)  # (B*N, 1)
hyps

tensor([[-1],
        [-1],
        [-1],
        [-1],
        [-1],
        [-1]], device='cuda:0')

In [292]:
scores = torch.tensor([0.0] + [-float('inf')] * (beam_size - 1),
                      dtype=torch.float)
scores

tensor([0., -inf, -inf])

In [293]:
scores = scores.to(device).repeat([batch_size]).unsqueeze(1).to(
    device)  # (B*N, 1)
scores

tensor([[0.],
        [-inf],
        [-inf],
        [0.],
        [-inf],
        [-inf]], device='cuda:0')

In [294]:
# end_flag: [B*N, 1 ]
end_flag = torch.zeros_like(scores, dtype=torch.bool, device=device)
cache: Optional[List[torch.Tensor]] = None

In [295]:
end_flag

tensor([[False],
        [False],
        [False],
        [False],
        [False],
        [False]], device='cuda:0')

In [296]:
hyps

tensor([[-1],
        [-1],
        [-1],
        [-1],
        [-1],
        [-1]], device='cuda:0')

In [297]:
scores

tensor([[0.],
        [-inf],
        [-inf],
        [0.],
        [-inf],
        [-inf]], device='cuda:0')

#### decoder解码

In [298]:
def mask_finished_scores(score: torch.Tensor,
                         flag: torch.Tensor) -> torch.Tensor:
    """
    If a sequence is finished, we only allow one alive branch. This function
    aims to give one branch a zero score and the rest -inf score.

    Args:
        score (torch.Tensor): A real value array with shape
            (batch_size * beam_size, beam_size).
        flag (torch.Tensor): A bool array with shape
            (batch_size * beam_size, 1).

    Returns:
        torch.Tensor: (batch_size * beam_size, beam_size).
    """
    beam_size = score.size(-1)
    zero_mask = torch.zeros_like(flag, dtype=torch.bool)
    if beam_size > 1:
        unfinished = torch.cat((zero_mask, flag.repeat([1, beam_size - 1])),
                               dim=1)
        finished = torch.cat((flag, zero_mask.repeat([1, beam_size - 1])),
                             dim=1)
    else:
        unfinished = zero_mask
        finished = flag
    print(unfinished, "\n", finished)
    score.masked_fill_(unfinished, -float('inf'))
    score.masked_fill_(finished, 0)
    return score

In [299]:
def mask_finished_preds(pred: torch.Tensor, flag: torch.Tensor,
                        eos: int) -> torch.Tensor:
    """
    If a sequence is finished, all of its branch should be <eos>

    Args:
        pred (torch.Tensor): A int array with shape
            (batch_size * beam_size, beam_size).
        flag (torch.Tensor): A bool array with shape
            (batch_size * beam_size, 1).

    Returns:
        torch.Tensor: (batch_size * beam_size).
    """
    beam_size = pred.size(-1)
    finished = flag.repeat([1, beam_size])
    return pred.masked_fill_(finished, eos)

In [348]:
x = torch.rand(batch_size*beam_size, 59 ,10).to(device)
y = x[:, -1]
log_b = torch.log_softmax(y, dim=-1)

In [349]:
top_k_logp, top_k_index = log_b.topk(beam_size)
top_k_index, top_k_logp

(tensor([[5, 7, 8],
         [8, 7, 0],
         [3, 0, 5],
         [3, 4, 0],
         [1, 3, 4],
         [7, 8, 2]], device='cuda:0'),
 tensor([[-1.8687, -1.9017, -1.9936],
         [-2.0534, -2.0807, -2.0844],
         [-1.9574, -1.9693, -2.0190],
         [-1.8002, -1.9011, -1.9621],
         [-2.0147, -2.0761, -2.1046],
         [-2.0436, -2.0444, -2.0675]], device='cuda:0'))

In [350]:
top_k_logp = mask_finished_scores(top_k_logp, end_flag)

tensor([[False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False]], device='cuda:0') 
 tensor([[False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False]], device='cuda:0')


In [351]:
top_k_index = mask_finished_preds(top_k_index, end_flag, 1000)

In [352]:
# 选择top-k的数据进行存储
scores = scores + top_k_logp
scores

tensor([[-5.7171, -5.7501, -5.8420],
        [-5.9236, -5.9509, -5.9546],
        [-5.9298, -5.9418, -5.9915],
        [-5.6771, -5.7780, -5.8390],
        [-5.9491, -6.0104, -6.0389],
        [-5.9801, -5.9809, -6.0040]], device='cuda:0')

In [353]:
scores = scores.view(batch_size, beam_size*beam_size)
scores

tensor([[-5.7171, -5.7501, -5.8420, -5.9236, -5.9509, -5.9546, -5.9298, -5.9418,
         -5.9915],
        [-5.6771, -5.7780, -5.8390, -5.9491, -6.0104, -6.0389, -5.9801, -5.9809,
         -6.0040]], device='cuda:0')

In [354]:
scores, offset_k_index = scores.topk(k=beam_size)  

In [355]:
scores, offset_k_index

(tensor([[-5.7171, -5.7501, -5.8420],
         [-5.6771, -5.7780, -5.8390]], device='cuda:0'),
 tensor([[0, 1, 2],
         [0, 1, 2]], device='cuda:0'))

In [356]:
cache_index = (offset_k_index // beam_size).view(-1)
cache_index

tensor([0, 0, 0, 0, 0, 0], device='cuda:0')

In [357]:
base_cache_index = (torch.arange(batch_size, device=device).view(
                -1, 1).repeat([1, beam_size]) * beam_size).view(-1) 

In [358]:
base_cache_index

tensor([0, 0, 0, 3, 3, 3], device='cuda:0')

In [359]:
cache_index = base_cache_index + cache_index
cache_index

tensor([0, 0, 0, 3, 3, 3], device='cuda:0')

In [360]:
scores = scores.view(-1, 1)
scores

tensor([[-5.7171],
        [-5.7501],
        [-5.8420],
        [-5.6771],
        [-5.7780],
        [-5.8390]], device='cuda:0')

In [361]:
base_k_index = torch.arange(batch_size, device=device).view(
    -1, 1).repeat([1, beam_size])  # (B, N)
base_k_index = base_k_index * beam_size * beam_size
base_k_index

tensor([[0, 0, 0],
        [9, 9, 9]], device='cuda:0')

In [362]:
best_k_index = base_k_index.view(-1) + offset_k_index.view(
                -1)

In [363]:
scores

tensor([[-5.7171],
        [-5.7501],
        [-5.8420],
        [-5.6771],
        [-5.7780],
        [-5.8390]], device='cuda:0')

In [364]:
best_k_index

tensor([ 0,  1,  2,  9, 10, 11], device='cuda:0')

In [365]:
best_k_index, top_k_index

(tensor([ 0,  1,  2,  9, 10, 11], device='cuda:0'),
 tensor([[5, 7, 8],
         [8, 7, 0],
         [3, 0, 5],
         [3, 4, 0],
         [1, 3, 4],
         [7, 8, 2]], device='cuda:0'))

In [366]:
best_k_pred = torch.index_select(top_k_index.view(-1), dim=-1,index=best_k_index)
best_k_pred

tensor([5, 7, 8, 3, 4, 0], device='cuda:0')

In [367]:
best_hyps_index = best_k_index // beam_size
best_hyps_index

tensor([0, 0, 0, 3, 3, 3], device='cuda:0')

In [368]:
last_best_k_hyps = torch.index_select(
                hyps, dim=0, index=best_hyps_index)  
last_best_k_hyps

tensor([[-1,  7,  5],
        [-1,  7,  5],
        [-1,  7,  5],
        [-1,  0,  7],
        [-1,  0,  7],
        [-1,  0,  7]], device='cuda:0')

In [369]:
hyps = torch.cat((last_best_k_hyps, best_k_pred.view(-1, 1)),
                             dim=1) 
hyps

tensor([[-1,  7,  5,  5],
        [-1,  7,  5,  7],
        [-1,  7,  5,  8],
        [-1,  0,  7,  3],
        [-1,  0,  7,  4],
        [-1,  0,  7,  0]], device='cuda:0')

In [370]:
end_flag = torch.eq(hyps[:, -1], 1000).view(-1, 1)
end_flag

tensor([[False],
        [False],
        [False],
        [False],
        [False],
        [False]], device='cuda:0')

In [371]:
scores

tensor([[-5.7171],
        [-5.7501],
        [-5.8420],
        [-5.6771],
        [-5.7780],
        [-5.8390]], device='cuda:0')

In [372]:
## get_所有可能的序列

In [373]:
scores = scores.view(batch_size, beam_size)

In [374]:
best_score, best_index = scores.max(dim=-1)
best_score, best_index

(tensor([-5.7171, -5.6771], device='cuda:0'), tensor([0, 0], device='cuda:0'))

In [277]:
best_hyps_index = best_index + torch.arange(
            batch_size, dtype=torch.long, device=device) * beam_size
best_hyps_index

tensor([0, 3], device='cuda:0')

In [279]:
best_hyps = torch.index_select(hyps, dim=0, index=best_hyps_index)
best_hyps

tensor([[-1,  1,  9,  5],
        [-1,  9,  3,  7]], device='cuda:0')