In [18]:
import os
import sys
# 计算
import torch 
import numpy as np
# 画图
import matplotlib.pyplot as plt
# 音频处理
import soundfile as sf
# 字符串处理
import codecs
import re
import string
import zhconv
import sentencepiece as spm
from zhon import hanzi


class audioReader(object):
    """Audio model reader
    """
    def __init__(self, dict_path, spm_model_path):
        """
        """
        self._dict_path = dict_path
        self._sp = spm.SentencePieceProcessor(spm_model_path)
        self._sp.Load(spm_model_path)

        self._dict_word2id = {}
        self._dict_id2word = {}
        
    def _buid_dict(self):
        """ build dict btw word and id
        """
        with codecs.open(self._dict_path, "r", "utf-8") as dict_handle:
            for tmp_line in dict_handle:
                tmp_word, tmp_id = tmp_line.strip().split()
                self._dict_id2word[tmp_id] = tmp_word
                self._dict_word2id[tmp_word] = tmp_id

    @staticmethod
    def read_pcm(file_path, sample_rate = 16000):
        """read audio
        """
        data, sample_rate = sf.read(file_path, samplerate = sample_rate, channels = 1, format="RAW", subtype="PCM_16")

        return data, sample_rate

    def _del_cn_spaces(self, text):
        """del chinese spaces
        """
        pattern =re.compile(r'(?<=[\u4e00-\u9fa5])\s+(?=[\u4e00-\u9fa5])')
        out_text = pattern.sub(r'', text)
        return out_text
    
    def _del_punc(self, text):
        """del punc
        """
        # punctuation = r"!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~＂＃＄％＆＇（）＊＋，－／：；＜＝＞＠［＼］＾＿｀｛｜｝～｟｠｢｣､　、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·！？｡。""
        punctuation = string.punctuation + hanzi.punctuation
        dicts = {i:'' for i in punctuation}
        punc_table = str.maketrans(dicts)
        out_text = text.translate(punc_table)
        return out_text
    
    def _tra2simple(self, text):
        """trans traditional to simplified Chinese
        """
        out_text = zhconv.convert(text, 'zh-cn')
        return out_text      
    
    def _cn_encode(self, text, n_char = 1):
        """split chinese text by n_char
        """
        n = n_char 
        text_split = [text[j : j + n] for j in range(0, len(text), n)]
        text_flat = []
        for tmp_chars in text_split:
            text_flat.append("".join(tmp_chars))
        
        cn_token_str = " ".join(text_flat)
        return cn_token_str
    
    def _en_snp(self, text):
        """split english text by spm model
        """
        token = self._sp.EncodeAsPieces(text)
        en_token_str = "".join(str(i) for i in token)
        en_token_str = " " + en_token_str.strip() + " "

        return en_token_str
   
    def _text2token(self, text):
        """
        """
        result_en = re.finditer(r'[a-z_A-Z-\.!@#\$%\\\^&\*\)\(\+=\{\}\[\]\/",\'<>~\·`\?:;][a-z_A-Z-\.!@#\$%\\\^&\*\)\(\+=\{\}\[\]\/",\'<>~\·`\?:;|\s]*',text)
        add_pos_en = 0
        for i in result_en:
            en_text = i.group().strip()
            en_token = self._en_snp(en_text.upper())
            start_pos = i.start() + add_pos_en
            text = text[:start_pos] + text[start_pos:].replace(en_text, en_token, 1)
            add_pos_en += (len(en_token)-len(en_text))

        result_cn = re.finditer(r'([\u4e00-\u9fa5][\u4e00-\u9fa5\s]*)',text)
        add_pos_cn = 0
        for j in result_cn:
            cn_text = j.group()
            cn_token = self._cn_encode(cn_text)
            print(cn_token + "\n")
            strat_pos = j.start() + add_pos_cn
            text = text[:start_pos] + text[start_pos:].replace(cn_text, cn_token, 1)
            add_pos_cn += (len(cn_token)-len(cn_text))
        return text

    def _sym2id(self, token):
        """trans token 2 id, base on self dict
        """
        token_list = token.strip().split()
        for i in range(len(token_list)):
            char_token = token_list[i]
            try:
                token_list[i] = self._word_map[char_token]
            except Exception as e:
                logging.error(token +  "\t" + token_list[i]+" is replace 1 ")
                token_list[i] = "1"

        tokenid = " ".join(token_list)
        return tokenid

    def trans_char2id(self, text):
        """trans character 2 index
        """
        text = self._del_cn_spaces(text)
        text = self._del_punc(text)
        text = self._tra2simple(text)
        seq_ids = self._text2token(text)
        
        return seq_ids

    def read_pcm_text(input_line):
        """read pcm and text, split by \t or space
        """
        line = input_line.strip()
        pcm_path, text = line.split()

        data, sample_rate = self.read(pcm_path)
        

dict_file = "../Data/cnen_dict_14323_units.txt"
smp_model = "../Data/cnen_spm_unigram5000.model"
test_text = "hello, world. 说什么随叫随到啊就这样的话就不要说了, by the way"

test_reader = audioReader(dict_file, smp_model)

test_text_ids = test_reader.trans_char2id(test_text)

        

hello, world. 说什么随叫随到啊就这样的话就不要说了, by the way
说 什 么 随 叫 随 到 啊 就 这 样 的 话 就 不 要 说 了    

 ▁HELLO▁WORLD  说什么随叫随到啊就这样的话就不要说了  ▁BY▁THE▁WAY 
