## symbols

In [None]:
graphemes = ['<pad>', '<s>', '</s>'] + list("abcdefghijklmnopqrstuvwxyz'-.")
phonemes = ['<pad>', '<s>', '</s>', 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
            'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH',
            'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0',
            'IH1', 'IH2', 'IY', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY',
            'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2',
            'V', 'W', 'Y', 'Z', 'ZH']

# phonemes = ['<pad>', '<s>', '</s>']
#
# for line in open("data/cmudict.symbols"):
#     phonemes.append(line.strip('\n'))
#
# print(phonemes)
# print(graphemes)

graphemes_id2char = dict(enumerate(graphemes))
phonemes_id2char = dict(enumerate(phonemes))
graphemes_char2id = dict((v, k) for k, v in enumerate(graphemes))
phonemes_char2id = dict((v, k) for k, v in enumerate(phonemes))


def word2id(word):
    return [graphemes_char2id[c] for c in list(word)]


def id2word(idx_list):
    return ''.join([graphemes_id2char[idx] for idx in idx_list])


def phoneme2id(phoneme_seq):
    return [phonemes_char2id[p] for p in phoneme_seq.split(' ')]


def id2phoneme(idx_list):
    return ' '.join([phonemes_id2char[idx] for idx in idx_list])

## config

In [None]:
# 超参配置
# yaml
class Hyperparameter:
    # ################################################################
    #                             Data
    # ################################################################
    device = 'cuda'
    data_root = './data/'
    origin_dict_path = '../input/cmu-pronouncing-dictionary/cmudict.dict'
    trainset_path = './data/data_train.json'
    testset_path = './data/data_test.json'
    devset_path = './data/data_val.json'

    seed = 1234  # random seed

    # ################################################################
    #                             Model Structure
    # ################################################################

    encoder_layer_num = 6
    encoder_dim = 128
    encoder_drop_prob = 0.1
    graphemes_size = len(graphemes_char2id)
    encoder_max_input = 30

    nhead = 4

    encoder_feed_forward_dim = 1024
    decoder_feed_forward_dim = 1024
    feed_forward_drop_prob = 0.3

    decoder_layer_num = 6
    decoder_dim = 128
    decoder_drop_prob = 0.1
    phoneme_size = len(phonemes_char2id)
    MAX_DECODE_STEP = 50

    ENCODER_SOS_IDX = graphemes_char2id['<s>']
    ENCODER_EOS_IDX = graphemes_char2id['</s>']
    ENCODER_PAD_IDX = graphemes_char2id['<pad>']
    DECODER_SOS_IDX = phonemes_char2id['<s>']
    DECODER_EOS_IDX = phonemes_char2id['</s>']
    DECODER_PAD_IDX = phonemes_char2id['<pad>']

    # ################################################################
    #                             Experiment
    # ################################################################
    batch_size = 128
    init_lr = 1e-4
    epochs = 100
    verbose_step = 100
    save_step = 500
    grad_clip_thresh = 1.


HP = Hyperparameter()


## preprocess

In [None]:
import json
import os
import random

random.seed(HP.seed)

for foldername in ['data', 'log', 'model_save']:
    if not os.path.exists(foldername):
        os.mkdir(foldername)

train_ratio, eval_ratio, test_ratio = 0.8, 0.1, 0.1

lines = []
for line in open(HP.origin_dict_path):
    lines.append(line.strip('\n'))
random.shuffle(lines)

length = len(lines)

def lines2dict(lines):
    the_dict = {}
    for line in lines:
        contents = line.split(' ', 1)
        if contents[0].endswith(')'):
            continue
        if '1' in contents[0]:
            continue
        if '#' in contents[1]:
            continue
        the_dict[contents[0]] = contents[1]
    return the_dict

trainset_dict = lines2dict(lines[:int(length * train_ratio)])
evalset_dict = lines2dict(lines[int(length * train_ratio):int(length * (train_ratio + eval_ratio))])
testset_dict = lines2dict(lines[int(length * (train_ratio + eval_ratio)):])

json.dump(trainset_dict, open(HP.trainset_path, 'w'))
json.dump(evalset_dict, open(HP.devset_path, 'w'))
json.dump(testset_dict, open(HP.testset_path, 'w'))


## dataset_g2p

In [None]:
from torch.utils.data import Dataset
import json
import torch


class G2PDataset(Dataset):

    def __init__(self, dataset_path):
        data_dict = json.load(open(dataset_path, 'r'))
        self.data_pairs = list(data_dict.items())

    def __getitem__(self, index):
        word, phone_seq = self.data_pairs[index][0], self.data_pairs[index][1]
        return word2id(word), phoneme2id(phone_seq)

    def __len__(self):
        return len(self.data_pairs)


def collate_fn(iter_batch):
    N = len(iter_batch)
    word_indexes, phoneme_indexs = [list(it) for it in zip(*iter_batch)]

    [it.insert(0, graphemes_char2id['<s>']) for it in word_indexes]
    [it.append(graphemes_char2id['</s>']) for it in word_indexes]

    [it.insert(0, phonemes_char2id['<s>']) for it in phoneme_indexs]
    [it.append(phonemes_char2id['</s>']) for it in phoneme_indexs]

    word_lengths, sort_index = torch.sort(torch.tensor([len(it) for it in word_indexes]).long(), descending=True)
    max_word_len = word_lengths[0]
    word_padded = torch.zeros(size=(N, max_word_len)).long()

    max_phoneme_len = max([len(it) for it in phoneme_indexs])
    phoneme_padded = torch.zeros(size=(N, max_phoneme_len)).long()
    phoneme_lengths = torch.zeros(size=(N,)).long()

    for idx, idx_s in enumerate(sort_index.tolist()):
        word_padded[idx][:word_lengths[idx]] = torch.tensor(word_indexes[idx_s]).long()
        phoneme_padded[idx][:len(phoneme_indexs[idx_s])] = torch.tensor(phoneme_indexs[idx_s]).long()
        phoneme_lengths[idx] = len(phoneme_indexs[idx_s])

    return word_padded, word_lengths, phoneme_padded, phoneme_lengths


## model

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import math


class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=10000):
        super(PositionalEncoding, self).__init__()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.) / d_model))
        position = torch.arange(max_len).unsqueeze(1)
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return x


class Encoder(nn.Module):

    def __init__(self):
        super(Encoder, self).__init__()

        self.token_embedding = nn.Embedding(HP.graphemes_size, HP.encoder_dim)
        self.pe = PositionalEncoding(d_model=HP.encoder_dim, max_len=HP.encoder_max_input)
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(HP.encoder_layer_num)])
        self.drop = nn.Dropout(HP.encoder_drop_prob)
        self.register_buffer('scale', torch.sqrt(torch.tensor(HP.encoder_dim).float()))

    def forward(self, inputs, inputs_mask):
        token_embedded = self.token_embedding(inputs)
        inputs = self.pe(token_embedded * self.scale)
        inputs = self.drop(inputs)

        for idx, layer in enumerate(self.layers):
            inputs = layer(inputs, inputs_mask)

        return inputs


class EncoderLayer(nn.Module):

    def __init__(self):
        super(EncoderLayer, self).__init__()

        self.self_att_layer_norm = nn.LayerNorm(HP.encoder_dim)
        self.pff_layer_norm = nn.LayerNorm(HP.encoder_dim)

        self.self_att = MultiHeadAttentionLayer(HP.encoder_dim, HP.nhead)
        self.pff = PointWiseFeedForwardLayer(HP.encoder_dim, HP.encoder_feed_forward_dim, HP.feed_forward_drop_prob)

        self.dropout = nn.Dropout(HP.encoder_drop_prob)

    def forward(self, inputs, inputs_mask):
        _inputs, att_res = self.self_att(inputs, inputs, inputs, inputs_mask)
        inputs = self.self_att_layer_norm(inputs + self.dropout(_inputs))
        _inputs = self.pff(inputs)
        inputs = self.pff_layer_norm(inputs + self.dropout(_inputs))
        return inputs


class MultiHeadAttentionLayer(nn.Module):

    def __init__(self, hidden_dim, nhead):
        super(MultiHeadAttentionLayer, self).__init__()
        self.hidden_dim = hidden_dim
        self.nhead = nhead
        assert self.hidden_dim % self.nhead == 0
        self.head_dim = self.hidden_dim // self.nhead

        self.fc_q = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.fc_k = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.fc_v = nn.Linear(self.hidden_dim, self.hidden_dim)

        self.fc_o = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.register_buffer('scale', torch.sqrt(torch.tensor(self.hidden_dim).float()))

    def forward(self, query, key, value, input_mask=None):
        bn = query.size(0)
        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)

        # split into n head
        Q = Q.view(bn, -1, self.nhead, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(bn, -1, self.nhead, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(bn, -1, self.nhead, self.head_dim).permute(0, 2, 1, 3)

        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
        if input_mask is not None:
            energy = energy.masked_fill(input_mask == 0, -1.e10)
        attention = F.softmax(energy, dim=-1)
        out = torch.matmul(attention, V)
        out = out.permute(0, 2, 1, 3).contiguous()
        out = out.view(bn, -1, self.hidden_dim)
        out = self.fc_o(out)
        return out, attention


class PointWiseFeedForwardLayer(nn.Module):

    def __init__(self, hidden_dim, pff_dim, pff_drop_prob):
        super(PointWiseFeedForwardLayer, self).__init__()
        self.hidden_dim = hidden_dim
        self.pff_dim = pff_dim
        self.pff_drop_prob = pff_drop_prob

        self.fc1 = nn.Linear(self.hidden_dim, self.pff_dim)
        self.fc2 = nn.Linear(self.pff_dim, self.hidden_dim)
        self.dropout = nn.Dropout(self.pff_drop_prob)

    def forward(self, x):
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.fc2(x)
        return x


class Decoder(nn.Module):

    def __init__(self):
        super(Decoder, self).__init__()
        self.token_embedding = nn.Embedding(HP.phoneme_size, HP.decoder_dim)
        self.pe = PositionalEncoding(d_model=HP.decoder_dim, max_len=HP.MAX_DECODE_STEP)
        self.layers = nn.ModuleList([DecoderLayer() for _ in range(HP.decoder_layer_num)])
        self.fc_out = nn.Linear(HP.decoder_dim, HP.phoneme_size)
        self.drop = nn.Dropout(HP.decoder_drop_prob)
        self.register_buffer('scale', torch.sqrt(torch.tensor(HP.decoder_dim).float()))

    def forward(self, target, enc_src, target_mask, src_mask):
        global attention
        token_embbed = self.token_embedding(target)
        pos_embbed = self.pe(token_embbed)
        target = self.drop(pos_embbed)

        for idx, layer in enumerate(self.layers):
            target, attention = layer(target, enc_src, target_mask, src_mask)
        out = self.fc_out(target)
        return out, attention


class DecoderLayer(nn.Module):

    def __init__(self):
        super(DecoderLayer, self).__init__()
        self.mask_self_att = MultiHeadAttentionLayer(HP.decoder_dim, HP.nhead)
        self.mask_self_norm = nn.LayerNorm(HP.decoder_dim)

        self.mha = MultiHeadAttentionLayer(HP.decoder_dim, HP.nhead)
        self.mha_norm = nn.LayerNorm(HP.decoder_dim)

        self.pff = PointWiseFeedForwardLayer(HP.decoder_dim, HP.decoder_feed_forward_dim, HP.feed_forward_drop_prob)
        self.pff_norm = nn.LayerNorm(HP.decoder_dim)

        self.dropout = nn.Dropout(HP.decoder_drop_prob)

    def forward(self, target, enc_src, target_mask, src_mask):
        _target, _ = self.mask_self_att(target, target, target, target_mask)
        target = self.mask_self_norm(target + self.dropout(_target))

        _target, attention = self.mha(target, enc_src, enc_src, src_mask)
        target = self.mha_norm(target + self.dropout(_target))

        _target = self.pff(target)
        target = self.pff_norm(target + self.dropout(_target))

        return target, attention


class Transformer(nn.Module):

    def __init__(self):
        super(Transformer, self).__init__()

        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, src, target):
        src_mask = self.create_src_mask(src)
        target_mask = self.create_target_mask(target)

        enc_src = self.encoder(src, src_mask)
        output, attention = self.decoder(target, enc_src, target_mask, src_mask)
        return output, attention

    def infer(self, src):
        pass

    @staticmethod
    def create_src_mask(src):
        mask = (src != HP.ENCODER_PAD_IDX).unsqueeze(1).unsqueeze(2).to(HP.device)
        return mask

    @staticmethod
    def create_target_mask(target):
        target_length = target.size(1)
        pad_mask = (target != HP.DECODER_PAD_IDX).unsqueeze(1).unsqueeze(2).to(HP.device)
        sub_mask = torch.tril(torch.ones(target_length, target_length, dtype=torch.uint8)).bool().to(HP.device)
        target_mask = pad_mask & sub_mask
        return target_mask


## trainer

In [None]:
import os.path
import random
import torch
import numpy as np
from tensorboardX import SummaryWriter
from torch import nn
from torch import optim
from torch.utils.data import DataLoader

logger = SummaryWriter('./log')

# seed init: 保证模型的可复现性
torch.manual_seed(HP.seed)
random.seed(HP.seed)
np.random.seed(HP.seed)
torch.cuda.manual_seed(HP.seed)


def evaluate(model, devloader, crit):
    model.eval()
    sum_loss = 0.
    with torch.no_grad():
        for batch in devloader:
            words_idxs, word_lens, phoneme_seqs_idxs, phoneme_len = batch
            output_post, attention = model(words_idxs.to(HP.device), phoneme_seqs_idxs[:, :-1].to(HP.device))
            out = output_post.view(-1, output_post.size(-1))
            target = phoneme_seqs_idxs[:, 1:]
            target = target.contiguous().view(-1)
            loss = crit(out.to(HP.device), target.to(HP.device))
            sum_loss += loss.item()

    model.train()
    return sum_loss / len(devloader)


def save_checkpoint(model, epoch, opt, save_path):
    save_dict = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': opt.state_dict()
    }
    torch.save(save_dict, save_path)


def train():

    model = Transformer().to(HP.device)

    criterion = nn.CrossEntropyLoss(ignore_index=HP.DECODER_PAD_IDX)

    opt = optim.Adam(model.parameters(), lr=HP.init_lr)

    trainset = G2PDataset(HP.trainset_path)
    train_loader = DataLoader(trainset, batch_size=HP.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn)

    devset = G2PDataset(HP.testset_path)
    dev_loader = DataLoader(devset, batch_size=HP.batch_size, shuffle=True, drop_last=False, collate_fn=collate_fn)

    start_epoch, step = 0, 0

    model.train()

    for epoch in range(start_epoch, HP.epochs):
        print('Start Epoch: %d, Steps: %d' % (epoch, len(train_loader)))
        for batch in train_loader:
            words_idxs, word_lens, phoneme_seqs_idxs, phoneme_len = batch
            opt.zero_grad()
            output_post, attention = model(words_idxs.to(HP.device), phoneme_seqs_idxs[:, :-1].to(HP.device))
            out = output_post.view(-1,output_post.size(-1))
            target = phoneme_seqs_idxs[:,1:]
            target = target.contiguous().view(-1)
            loss = criterion(out.to(HP.device), target.to(HP.device))

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), HP.grad_clip_thresh)
            opt.step()

            logger.add_scalar('Loss/Train', loss, step)

            if not step % HP.verbose_step:
                eval_loss = evaluate(model, dev_loader, criterion)
                logger.add_scalar('Loss/Dev', eval_loss, step)

            if not step % HP.save_step:
                model_path = 'model_%d_%d.model' % (epoch, step)
                save_checkpoint(model, epoch, opt, os.path.join('model_save', model_path))

            step += 1
            logger.flush()
            print('Epoch:[%d/%d], step:%d, Train Loss:%.5f, Dev Loss:%.5f' % (
                epoch, HP.epochs, step, loss.item(), eval_loss))

    logger.close()


## 训练

In [None]:
train()