In [65]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import numpy as np

from pathlib import Path
from kaldiio import ReadHelper
import json
import os 
import time
from tqdm import notebook
import copy
from datetime import datetime

In [66]:
espnet_path = Path('espnet')

In [67]:
egs_path = Path('egs/aishell/asr1')
exp_dir = espnet_path / egs_path / 'exp'

In [68]:
dump_path = 'dump'
train_path = 'train_sp'
dev_path = 'dev'
test_path = 'test'

In [69]:
vocab_file = espnet_path / egs_path / 'data/lang_1char/train_sp_units.txt'
vocab = {}
with open(vocab_file) as fp:
    for line in fp:
        word, idx = line.strip().split()
        vocab[word] = idx
        
print(len(vocab))

4231


In [70]:
new_id = 4
token2id = {'<blank>':0, '<pad>':1, '<sos>':2, '<eos>':3}
for word, _ in vocab.items():
    if word not in token2id:
        token2id[word] = new_id
        new_id += 1

id2token = {v:k for k, v in token2id.items()}

In [71]:
token2id

{'<blank>': 0,
 '<pad>': 1,
 '<sos>': 2,
 '<eos>': 3,
 '<unk>': 4,
 '一': 5,
 '丁': 6,
 '七': 7,
 '万': 8,
 '丈': 9,
 '三': 10,
 '上': 11,
 '下': 12,
 '不': 13,
 '与': 14,
 '丐': 15,
 '丑': 16,
 '专': 17,
 '且': 18,
 '世': 19,
 '丘': 20,
 '丙': 21,
 '业': 22,
 '丛': 23,
 '东': 24,
 '丝': 25,
 '丞': 26,
 '丢': 27,
 '两': 28,
 '严': 29,
 '丧': 30,
 '个': 31,
 '丫': 32,
 '中': 33,
 '丰': 34,
 '串': 35,
 '临': 36,
 '丸': 37,
 '丹': 38,
 '为': 39,
 '主': 40,
 '丽': 41,
 '举': 42,
 '乃': 43,
 '久': 44,
 '么': 45,
 '义': 46,
 '之': 47,
 '乌': 48,
 '乍': 49,
 '乎': 50,
 '乏': 51,
 '乐': 52,
 '乒': 53,
 '乓': 54,
 '乔': 55,
 '乖': 56,
 '乘': 57,
 '乙': 58,
 '九': 59,
 '乞': 60,
 '也': 61,
 '习': 62,
 '乡': 63,
 '书': 64,
 '买': 65,
 '乱': 66,
 '乳': 67,
 '乾': 68,
 '了': 69,
 '予': 70,
 '争': 71,
 '事': 72,
 '二': 73,
 '于': 74,
 '亏': 75,
 '云': 76,
 '互': 77,
 '五': 78,
 '井': 79,
 '亚': 80,
 '些': 81,
 '亟': 82,
 '亡': 83,
 '亢': 84,
 '交': 85,
 '亥': 86,
 '亦': 87,
 '产': 88,
 '亨': 89,
 '亩': 90,
 '享': 91,
 '京': 92,
 '亭': 93,
 '亮': 94,
 '亲': 95,
 '亳': 96,
 '亵': 97,
 '人': 98

In [72]:
train_data = {}
dev_data = {}
test_data = {}

for (set_path, dataset) in zip([train_path, dev_path, test_path], [train_data, dev_data, test_data]):
    set_dump_path = espnet_path / egs_path / dump_path / set_path / 'deltafalse' 
    with open(set_dump_path / 'data.json') as fp:
        json_data = json.load(fp)

    feats = {}
    pbar = notebook.tqdm(total=len(list(set_dump_path.glob('feats.*.ark'))))
    for feats_file in set_dump_path.glob('feats.*.ark'):
        with ReadHelper('ark:'+str(feats_file)) as reader:
            for key, numpy_array in reader:
                feats[key] = torch.from_numpy(numpy_array)
          
        pbar.update(1)
        
        break
        
        
        
    
    for key, value in json_data['utts'].items():
        if key in feats:
            feature = feats[key]
            text = json_data['utts'][key]['output'][0]['text']
            token = []
            token_id = []
            for char in text:
                if char == ' ':
                    token.append('<space>')
                else:
                    if char in vocab:
                        token.append(char)
                    else:
                        token.append('<unk>')
            dataset[key] = {'input':feature,
                           'text':text,
                           'token':' '.join(token),}
        

    
print(len(train_data), len(dev_data), len(test_data))

HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

11259 1433 718


In [73]:
class AttrDict(dict):
    """ Access dictionary keys like attribute 
        https://stackoverflow.com/questions/4984647/accessing-dict-keys-like-an-attribute
    """
    def __init__(self, *av, **kav):
        dict.__init__(self, *av, **kav)
        self.__dict__ = self

opts = AttrDict()

# Configure models

opts.ntoken = len(token2id)
opts.feature = 83
opts.ninp = 256
opts.nhead = 4
opts.nhid = 2048
opts.nlayers_enc = 12
opts.nlayers_dec = 6

opts.ce_weight = 0.7
opts.ctc_weight = 0.3

opts.beam_size = 10




# Configure optimization
opts.learning_rate = 5e-5

opts.dropout_rate = 0.1

opts.batch_size = 64
opts.num_workers = int(opts.batch_size / 8) if int(opts.batch_size / 8) < 16 else 16
print(opts.num_workers)
# Configure training
opts.max_seq_len = 512
opts.num_epochs = 300
# opts.warmup_steps = 4000
# opts.gradient_accumulation = 20

# opts.load_pretrain = True

8


In [74]:
class Dataset:
    
    def __init__(self, dataset):
        
        self.names = []
        self.features = []
        self.texts = []
        self.tokens = []
        
        for name, data in dataset.items():
            feature = data['input']
            text = data['text']
            token = data['token'].split(' ')
            token = ['<sos>'] + token + ['<eos>']
            
            self.names.append(name)
            self.features.append(feature)
            self.texts.append(text)
            self.tokens.append(token)
            
    def __len__(self):
        
        return len(self.names)
    
    def __getitem__(self, index):
        name = self.names[index]
        feature = self.features[index]
        text = self.texts[index]
        token = self.tokens[index]
        token_id = self.tokens2ids(token, token2id)
        
        return name, feature, text, token, token_id
    
    def tokens2ids(self, tokens, token2id):
        token_id = [token2id[token] if token in token2id else token2id['<unk>'] for token in tokens]
    
        return token_id
    
    
def collate_fn(data):
    
    def _pad_sequences(seqs):
        lens = [len(seq)-1 for seq in seqs]
        input_seqs = torch.zeros(len(seqs), max(lens)).long().fill_(token2id['<pad>'])
        target_seqs = torch.zeros(len(seqs), max(lens)).long().fill_(token2id['<pad>'])
#         input_seqs_mask = input_seqs.float().masked_fill(input_seqs.float()==0, float('-inf'))
#         input_seqs_mask = input_seqs_mask.masked_fill(input_seqs_mask!=float('-inf'), 0)

        for i, seq in enumerate(seqs):
            input_seqs[i, :len(seq)-1] = torch.LongTensor(seq[:-1])
            target_seqs[i, :len(seq)-1] = torch.LongTensor(seq[1:])
            
        input_seqs_mask = input_seqs == token2id['<pad>']
        
        return input_seqs, input_seqs_mask, target_seqs, lens
    
    def _pad_features(features):
        flens = [len(feature) for feature in features]
        input_features = torch.zeros(len(features), max(flens), opts.feature)
        for i, feature in enumerate(features):
            input_features[i, :len(feature)] = feature
            
        return input_features
    
    def _generate_square_subsequent_mask(sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = ~mask
        
        return mask
        
    
    name, feature, text, token, token_id = zip(*data)
    
    input_seqs, input_seqs_pad_mask, target_seqs, lens = _pad_sequences(token_id)
    
    input_features = _pad_features(feature)
    
    input_seqs_mask = _generate_square_subsequent_mask(input_seqs.size(1))
    input_seqs_mask = input_seqs_mask.repeat(input_seqs.size(0), 1, 1) #為了給dataparallel切 要給他一個batch維
    
    lens = torch.LongTensor(lens)
    
    return name, input_features, text, token, input_seqs, input_seqs_mask, input_seqs_pad_mask, target_seqs, lens

In [75]:
train_dataset = Dataset(train_data)
dev_dataset = Dataset(dev_data)
test_dataset = Dataset(test_data)


In [76]:
print(dev_dataset[0])

('BAC009S0728W0126', tensor([[-0.6655, -0.6084, -0.7678,  ...,  1.4794,  2.3732, -0.1098],
        [-0.4260, -1.0659, -0.8746,  ...,  0.3665,  2.3732,  0.0944],
        [ 0.3370,  0.1265, -0.8746,  ...,  0.7317,  2.3732,  0.0718],
        ...,
        [-0.0736, -0.4559, -1.7645,  ...,  1.7496,  1.9711,  0.1323],
        [ 0.1625, -0.1001, -1.1237,  ...,  0.6556,  1.9711, -0.0341],
        [ 0.2036, -0.1001, -0.9458,  ...,  1.3593,  1.9459,  0.0566]]), '必然先行抛售二三线城市的房产', ['<sos>', '必', '然', '先', '行', '抛', '售', '二', '三', '线', '城', '市', '的', '房', '产', '<eos>'], [2, 1249, 2306, 264, 3366, 1437, 628, 73, 10, 2888, 767, 1124, 2556, 1396, 88, 3])


In [77]:
np.random.seed(20200915)
torch.manual_seed(20200915)
torch.cuda.manual_seed_all(20200915)

train_iter = DataLoader(dataset=train_dataset,
                        batch_size=opts.batch_size,
                        shuffle=True,
                        num_workers=opts.num_workers,
                        collate_fn=collate_fn)

dev_iter = DataLoader(dataset=dev_dataset,
                        batch_size=opts.batch_size,
                        shuffle=False,
                        num_workers=opts.num_workers,
                        collate_fn=collate_fn)

test_iter = DataLoader(dataset=test_dataset,
                        batch_size=opts.batch_size,
                        shuffle=False,
                        num_workers=opts.num_workers,
                        collate_fn=collate_fn)

In [107]:
import copy
from typing import Optional, Any

from torch import Tensor
from torch.nn import functional as F
from torch.nn.modules.module import Module
from torch.nn.modules.activation import MultiheadAttention
from torch.nn.modules.container import ModuleList
from torch.nn.init import xavier_uniform_
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.linear import Linear
from torch.nn.modules.normalization import LayerNorm

class TransformerEncoderLayer(Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)
    """

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = Linear(d_model, dim_feedforward)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model)

        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super(TransformerEncoderLayer, self).__setstate__(state)

    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        
        src2 = self.norm1(src)
        src2 = self.self_attn(src2, src2, src2, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        
        src2 = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
        src = src + self.dropout2(src2)

        return src

class TransformerDecoderLayer(Module):
    r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
    This standard decoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    Examples::
        >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
        >>> memory = torch.rand(10, 32, 512)
        >>> tgt = torch.rand(20, 32, 512)
        >>> out = decoder_layer(tgt, memory)
    """

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
        super(TransformerDecoderLayer, self).__init__()
        self.self_attn_m = MultiheadAttention(d_model, nhead, dropout=dropout)
        self.self_attn_t = MultiheadAttention(d_model, nhead, dropout=dropout)
        
        # Implementation of Feedforward model
        self.linear1_m = Linear(d_model, dim_feedforward)
        self.dropout_m = Dropout(dropout)
        self.linear2_m = Linear(dim_feedforward, d_model)
        
        self.linear1_t = Linear(d_model, dim_feedforward)
        self.dropout_t = Dropout(dropout)
        self.linear2_t = Linear(dim_feedforward, d_model)

        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super(TransformerDecoderLayer, self).__setstate__(state)

    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the inputs (and mask) through the decoder layer.

        Args:
            tgt: the sequence to the decoder layer (required).
            memory: the sequence from the last layer of the encoder (required).
            tgt_mask: the mask for the tgt sequence (optional).
            memory_mask: the mask for the memory sequence (optional).
            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
            memory_key_padding_mask: the mask for the memory keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        
        m_cat_t = torch.cat([memory, tgt], dim=0)
        m_cat_t = self.norm1(m_cat_t)
        
        memory = m_cat_t[:memory.size(0), :, :]
        tgt = m_cat_t[memory.size(0):, :, :]
        
        memory2 = self.self_attn_m(memory, memory, memory, attn_mask=None,
                                 key_padding_mask=None)[0]
        
        m_cat_t_mask = torch.BoolTensor(tgt.size(0), m_cat_t.size(0)).fill_(False).to(tgt.device)
        m_cat_t_mask[:, -tgt_mask.size(0):] = tgt_mask
        m_cat_t_pad_mask = torch.BoolTensor(tgt_key_padding_mask.size(0), m_cat_t.size(0)).fill_(False).to(tgt.device)
        m_cat_t_pad_mask[:, -tgt_key_padding_mask.size(1):] = tgt_key_padding_mask
        
        
#         print(tgt.shape)
#         print(tgt_key_padding_mask.shape)
#         print(m_cat_t.shape)
#         print(m_cat_t_mask.shape)
        tgt2 = self.self_attn_t(tgt, m_cat_t, m_cat_t, attn_mask=m_cat_t_mask,
                               key_padding_mask=m_cat_t_pad_mask)[0]
        
        memory = memory + self.dropout1(memory2)
        tgt = tgt + self.dropout1(tgt2)
        
        m_cat_t = torch.cat([memory, tgt], dim=0)
        m_cat_t = self.norm2(m_cat_t)
        
        memory = m_cat_t[:memory.size(0), :, :]
        tgt = m_cat_t[memory.size(0):, :, :]
        
        memory2 = self.linear2_m(self.dropout_m(self.activation(self.linear1_m(memory))))
        tgt2 = self.linear2_t(self.dropout_t(self.activation(self.linear1_t(tgt))))
        
        memory = memory + self.dropout2(memory2)
        tgt = tgt + self.dropout2(tgt2)
        
        return memory, tgt
    
def _get_clones(module, N):
    return ModuleList([copy.deepcopy(module) for i in range(N)])


def _get_activation_fn(activation):
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu

    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
    
class TransformerDecoder(Module):
    r"""TransformerDecoder is a stack of N decoder layers

    Args:
        decoder_layer: an instance of the TransformerDecoderLayer() class (required).
        num_layers: the number of sub-decoder-layers in the decoder (required).
        norm: the layer normalization component (optional).

    Examples::
        >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
        >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
        >>> memory = torch.rand(10, 32, 512)
        >>> tgt = torch.rand(20, 32, 512)
        >>> out = transformer_decoder(tgt, memory)
    """
    __constants__ = ['norm']

    def __init__(self, decoder_layer, num_layers, norm=None):
        super(TransformerDecoder, self).__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the inputs (and mask) through the decoder layer in turn.

        Args:
            tgt: the sequence to the decoder (required).
            memory: the sequence from the last layer of the encoder (required).
            tgt_mask: the mask for the tgt sequence (optional).
            memory_mask: the mask for the memory sequence (optional).
            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
            memory_key_padding_mask: the mask for the memory keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        output = tgt

        for mod in self.layers:
            memory, output = mod(output, memory, tgt_mask=tgt_mask,
                                 memory_mask=memory_mask,
                                 tgt_key_padding_mask=tgt_key_padding_mask,
                                 memory_key_padding_mask=memory_key_padding_mask)

        if self.norm is not None:
            output = self.norm(output)

        return output

In [108]:
class Conv2dSubsampling(torch.nn.Module):
    """Convolutional 2D subsampling (to 1/4 length).
    :param int idim: input dim
    :param int odim: output dim
    :param flaot dropout_rate: dropout rate
    """

    def __init__(self, idim, odim, dropout_rate=0.5):
        """Construct an Conv2dSubsampling object."""
        super(Conv2dSubsampling, self).__init__()
        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(1, odim, 3, 2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(odim, odim, 3, 2),
            torch.nn.ReLU(),
        )
        self.out = torch.nn.Sequential(
            torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim),
#             PositionalEncoding(odim, dropout_rate),
        )

    def forward(self, x, x_mask):
        """Subsample x.
        :param torch.Tensor x: input tensor
        :param torch.Tensor x_mask: input mask
        :return: subsampled x and mask
        :rtype Tuple[torch.Tensor, torch.Tensor]
        """
        x = x.unsqueeze(1)  # (b, c, t, f)
        x = self.conv(x)
        b, c, t, f = x.size()
        x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
        if x_mask is None:
            return x, None
        return x, x_mask[:, :, :-2:2][:, :, :-2:2]

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


class EncoderModel(nn.Module):

    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, idim, dropout=0.5):
        super(EncoderModel, self).__init__()
        from torch.nn import TransformerEncoder
        self.model_type = 'Transformer'
        self.src_mask = None
        
        self.embedding = Conv2dSubsampling(idim, ninp)
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.layer_norm = LayerNorm(ninp)
        
        self.ninp = ninp
#         self.decoder = nn.Linear(ninp, ntoken)

#         self.init_weights()

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
#         self.decoder.bias.data.zero_()
#         self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src):
        
        # src [b, t, f]        
        
        src, self.src_mask = self.embedding(src, self.src_mask)

        src = src.permute(1, 0, 2) # [t, b, f]
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src=src, 
                                          mask=self.src_mask,
                                         )
#         output = self.decoder(output)
        output = self.layer_norm(output)
    
        return output

class DecoderModel(nn.Module):
    
    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(DecoderModel, self).__init__()
#         from torch.nn import TransformerDecoder, TransformerDecoderLayer
        self.model_type = 'Transformer'
        self.tgt_mask = None
        self.memory_mask = None
        
        self.embedding = nn.Embedding(ntoken, ninp)
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        decoder_layers = TransformerDecoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_decoder = TransformerDecoder(decoder_layers, nlayers)
        self.layer_norm = LayerNorm(ninp)
        
        self.ninp = ninp
        
        self.init_weights()
        
    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
    
    def init_weights(self):
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        
    def forward(self, tgt, tgt_mask, tgt_key_padding_mask, memory):
        
        # tgt [b, t], tgt_key_padding_mask [b, t], memory [t, b, f]
        
        tgt = tgt.permute(1, 0) # [t, b]
        
#         if self.tgt_mask is None or self.tgt_mask.size(0) != len(tgt):
#             device = tgt.device
#             mask = self._generate_square_subsequent_mask(len(tgt)).to(device)
#             self.tgt_mask = mask

        # The reason we increase the embedding values before the addition 
        # is to make the positional encoding relatively smaller. 
        # This means the original meaning 
        # in the embedding vector won’t be lost when we add them together.
        # maybe use learned position embedding will not need to do this?  
        
        tgt = self.embedding(tgt) * math.sqrt(self.ninp)
        tgt = self.pos_encoder(tgt)
        
        tgt_mask = tgt_mask[0] ##為了dataparallel給的batch維把它去掉
        
        output = self.transformer_decoder(tgt, memory, 
                                          tgt_mask=tgt_mask, 
                                          tgt_key_padding_mask=tgt_key_padding_mask, 
                                          memory_mask=self.memory_mask, 
                                          memory_key_padding_mask=None,
                                          )
        
        output = self.layer_norm(output)
        
        return output
        
class Model(nn.Module):
    
    def __init__(self, ntoken, ninp, nhead, nhid, nlayers_enc, nlayers_dec, idim, dropout=0.5):
        super(Model, self).__init__()
        
        self.encoder = EncoderModel(ntoken, ninp, nhead, nhid, nlayers_enc, idim, dropout)
        self.ctc_classifier = nn.Linear(ninp, ntoken)
        
        self.decoder = DecoderModel(ntoken, ninp, nhead, nhid, nlayers_dec, dropout)
        self.classifier = nn.Linear(ninp, ntoken)
        
    def forward(self, src, tgt, tgt_mask, tgt_key_padding_mask):
        
        memory = self.encoder(src)
        
        ctc_output = self.ctc_classifier(memory.permute(1, 0, 2))
        
        decoder_output = self.decoder(tgt, tgt_mask, tgt_key_padding_mask, memory)
        
        output = self.classifier(decoder_output.permute(1, 0, 2))

        # return 一定要batch first 不然dataparallel會concat錯維
        return ctc_output, output, memory.permute(1, 0, 2), decoder_output.permute(1, 0, 2)
#         return memory, ctc_output, decoder_output


class Test_Model(nn.Module):
    
    def __init__(self, ntoken, ninp, nhead, nhid, nlayers_enc, nlayers_dec, idim, dropout=0.5):
        super(Test_Model, self).__init__()
        
        self.encoder = EncoderModel(ntoken, ninp, nhead, nhid, nlayers_enc, idim, dropout)
        self.ctc_classifier = nn.Linear(ninp, ntoken)
        
        self.decoder = DecoderModel(ntoken, ninp, nhead, nhid, nlayers_dec, dropout)
        self.classifier = nn.Linear(ninp, ntoken)
        
        self.encoder.eval()
        self.ctc_classifier.eval()
        self.decoder.eval()
        self.classifier.eval()
        
    def forward(self, src, len_limit):
        
        device = src.device
        memory = self.encoder(src)
        
        nbest = []
        beams = []
        beams.append([[token2id['<sos>']], 0])

        for _ in range(2*len_limit):

            results = []

            for beam in beams:

                input_idxs = beam[0]

                input_seqs = torch.LongTensor([input_idxs])
        #         input_seqs = input_seqs.unsqueeze(0)
                input_seqs = input_seqs.to(device)

                decoder_output = self.decoder(input_seqs, None, memory)
                output = self.classifier(decoder_output.permute(1, 0, 2)[:, -1])
                output = output.log_softmax(dim=1)
                output[:, 0] = float('-inf')
                probs, idxs = output.topk(k=opts.beam_size, dim=1)

                for prob, idx in zip(probs.squeeze(0), idxs.squeeze(0)):

                    generate_idxs = input_idxs + [idx.item()]
                    accumulate_prob = beam[1] + prob.item()

                    results.append([generate_idxs, accumulate_prob])


            results.sort(key=lambda x:x[1])
            results = results[::-1]
            results = results[:3]

            beams = []

            for result in results:
                if result[0][-1] == token2id['<eos>']:
                    nbest.append(result)
                else:
                    beams.append(result)

        return nbest, beams
        

In [92]:
print(exp_dir)

RESTORE = False

LOAD_RNNLM = False

if RESTORE:
    experiment_dir = Path(exp_dir) / 'rnnlm_1stlayer_BERT_EN_finetune_increment_hidden300_len35_2020-01-01 11:15:21'
    last_epoch = 29
    print(experiment_dir)
    
else:

    if LOAD_RNNLM:
        experiment_name = 'rnnlm_dictALL_talk-sent_len130_2019-11-09 14:53:52'
        experiment_dir = Path(exp_dir) / experiment_name
        model_dir = experiment_dir / 'best_model'
        print(model_dir)

    last_epoch = -1
    model_name = 'myaishell_selfmix_ce{}_ctc{}'.format(\
                                   opts.ce_weight, opts.ctc_weight)
    now = str(datetime.now()).split('.')[0]
    experiment_name = '{}_{}'.format(model_name, now)
    experiment_dir = Path(exp_dir) / experiment_name
    experiment_dir.mkdir(exist_ok=True, parents=True)
    print(experiment_dir)

/mnt/disk3/m10615110/espnet/egs/aishell/asr1/exp
/mnt/disk3/m10615110/espnet/egs/aishell/asr1/exp/myaishell_selfmix_ce0.7_ctc0.3_2020-09-22 16:12:42


In [93]:
experiment_trainlog = experiment_dir / 'train_log.txt'

def log2file(log_file, msg):
    with open(log_file, 'a') as fw:
        fw.write(msg)
        fw.write('\n')

In [109]:
model = Model(opts.ntoken, opts.ninp, opts.nhead, opts.nhid, opts.nlayers_enc, opts.nlayers_dec, opts.feature, opts.dropout_rate)

print('total parms : ', sum(p.numel() for p in model.parameters()))
print('trainable parms : ', sum(p.numel() for p in model.parameters() if p.requires_grad))

total parms :  36724246
trainable parms :  36724246


In [110]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
USE_CUDA = torch.cuda.is_available()
USE_CUDA = False

print("Let's use", torch.cuda.device_count(), "GPUs!")


Let's use 2 GPUs!


In [111]:
if USE_CUDA:
    model = nn.DataParallel(model)

    model.cuda()

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=opts.learning_rate)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=opts.learning_rate, steps_per_epoch=len(train_iter), epochs=20)

# ce_criterion = torch.nn.KLDivLoss()
ce_criterion = torch.nn.CrossEntropyLoss(reduction='mean', ignore_index=token2id['<pad>'])
ctc_criterion = torch.nn.CTCLoss(reduction='mean', blank=token2id['<blank>'])

In [None]:
last_epoch = -1

for k,v in opts.items():
    log_msg = '- {}: {}'.format(k, v)
    log2file(str(experiment_trainlog), log_msg)
    print(log_msg)

pbar_train = notebook.tqdm(total=len(train_iter))
pbar_dev = notebook.tqdm(total=len(dev_iter))
pbar_test = notebook.tqdm(total=len(test_iter))

log_msg = '='*50
print(log_msg)
log2file(str(experiment_trainlog), log_msg)
log_msg = 'optim : \n' + str(optimizer)
print(log_msg)   
log2file(str(experiment_trainlog), log_msg)

for epoch in range(last_epoch+1,  opts.num_epochs, 1):
    
    pbar_train.reset()
    pbar_dev.reset()
    pbar_test.reset()
    
    loss_tracker = []
    celoss_tracker = []
    ctcloss_tracker = []
    time_tracker = []
    time_tracker.append(time.time())
    total = 0
    correct = 0
    
    for iteration, batch in enumerate(train_iter):
        
        model.train()
        
        name, input_features, text, token, input_seqs, input_seqs_mask, input_seqs_pad_mask, target_seqs, lens = batch
        
        batch_size = input_features.size(0)
        
        if USE_CUDA:
            input_features = input_features.cuda()
            input_seqs = input_seqs.cuda()
            input_seqs_mask = input_seqs_mask.cuda()
            input_seqs_pad_mask = input_seqs_pad_mask.cuda()
            target_seqs = target_seqs.cuda()
        
        ctc_output, output, memory, decoder_output = model(input_features, input_seqs, input_seqs_mask, input_seqs_pad_mask)
        
        total += (target_seqs.view(-1) != token2id['<pad>']).sum().item()
        _, predicted = torch.max(output.view(-1, opts.ntoken).data, 1)
        correct += ((predicted == target_seqs.view(-1)) * (target_seqs.view(-1) != token2id['<pad>'])).sum().item()
        
        optimizer.zero_grad()
        
        ce_loss = ce_criterion(output.view(-1, opts.ntoken), target_seqs.view(-1))
        
        ctc_output = ctc_output.permute(1, 0, 2).log_softmax(2)
        
        input_lengths = torch.full(size=(ctc_output.size(1),), fill_value=ctc_output.size(0), dtype=torch.long)
        target_lengths = lens
        
        if USE_CUDA:
            input_lengths = input_lengths.cuda()
            target_lengths = target_lengths.cuda()
        
        ctc_loss = ctc_criterion(ctc_output, target_seqs, input_lengths, target_lengths)
        
        loss = opts.ce_weight*ce_loss + opts.ctc_weight*ctc_loss
        
        celoss_tracker.append(ce_loss.item()*batch_size)
        ctcloss_tracker.append(ctc_loss.item()*batch_size)
        loss_tracker.append(loss.item()*batch_size)
        
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        pbar_train.update(1)
        
    
        
        
    time_tracker.append(time.time())
    log_msg = "{} | Epoch {:d}/{:d} | Mean CE / CTC / ALL Loss {:5.2f} / {:5.2f} / {:5.2f} | acc {:5.5f} % | time cost {:d} s"\
            .format('train'.upper(), epoch, opts.num_epochs, \
                    np.mean(celoss_tracker), np.mean(ctcloss_tracker), np.mean(loss_tracker), \
                    float(correct)/float(total)*100, int(time_tracker[-1] - time_tracker[-2]))
    print(log_msg)
    log2file(str(experiment_trainlog), log_msg)
    
    loss_tracker = []
    celoss_tracker = []
    ctcloss_tracker = []
    time_tracker = []
    time_tracker.append(time.time())
    total = 0
    correct = 0
    
    
    for iteration, batch in enumerate(dev_iter):
        
        model.eval()
        
        name, input_features, text, token, input_seqs, input_seqs_mask, input_seqs_pad_mask, target_seqs, lens = batch
        
        batch_size = input_features.size(0)
        
        if USE_CUDA:
            input_features = input_features.cuda()
            input_seqs = input_seqs.cuda()
            input_seqs_mask = input_seqs_mask.cuda()
            input_seqs_pad_mask = input_seqs_pad_mask.cuda()
            target_seqs = target_seqs.cuda()
        
        ctc_output, output, memory, decoder_output = model(input_features, input_seqs, input_seqs_mask, input_seqs_pad_mask)
        
        total += (target_seqs.view(-1) != token2id['<pad>']).sum().item()
        _, predicted = torch.max(output.view(-1, opts.ntoken).data, 1)
        correct += ((predicted == target_seqs.view(-1)) * (target_seqs.view(-1) != token2id['<pad>'])).sum().item()
        
        optimizer.zero_grad()
        
        ce_loss = ce_criterion(output.view(-1, opts.ntoken), target_seqs.view(-1))
        
        ctc_output = ctc_output.permute(1, 0, 2).log_softmax(2)
        
        input_lengths = torch.full(size=(ctc_output.size(1),), fill_value=ctc_output.size(0), dtype=torch.long)
        target_lengths = lens
        
        if USE_CUDA:
            input_lengths = input_lengths.cuda()
            target_lengths = target_lengths.cuda()
        
        ctc_loss = ctc_criterion(ctc_output, target_seqs, input_lengths, target_lengths)
        
        loss = opts.ce_weight*ce_loss + opts.ctc_weight*ctc_loss
        
        celoss_tracker.append(ce_loss.item()*batch_size)
        ctcloss_tracker.append(ctc_loss.item()*batch_size)
        loss_tracker.append(loss.item()*batch_size)
        
        pbar_dev.update(1)
        
    time_tracker.append(time.time())
    log_msg = "{} | Epoch {:d}/{:d} | Mean CE / CTC / ALL Loss {:5.2f} / {:5.2f} / {:5.2f} | acc {:5.5f} % | time cost {:d} s"\
            .format('dev  '.upper(), epoch, opts.num_epochs, \
                    np.mean(celoss_tracker), np.mean(ctcloss_tracker), np.mean(loss_tracker), \
                    float(correct)/float(total)*100, int(time_tracker[-1] - time_tracker[-2]))
    print(log_msg)
    log2file(str(experiment_trainlog), log_msg)
    
    loss_tracker = []
    celoss_tracker = []
    ctcloss_tracker = []
    time_tracker = []
    time_tracker.append(time.time())
    total = 0
    correct = 0
    
    for iteration, batch in enumerate(test_iter):
        
        model.eval()
        
        name, input_features, text, token, input_seqs, input_seqs_mask, input_seqs_pad_mask, target_seqs, lens = batch
        
        batch_size = input_features.size(0)
        
        if USE_CUDA:
            input_features = input_features.cuda()
            input_seqs = input_seqs.cuda()
            input_seqs_mask = input_seqs_mask.cuda()
            input_seqs_pad_mask = input_seqs_pad_mask.cuda()
            target_seqs = target_seqs.cuda()
        
        ctc_output, output, memory, decoder_output = model(input_features, input_seqs, input_seqs_mask, input_seqs_pad_mask)
        
        total += (target_seqs.view(-1) != token2id['<pad>']).sum().item()
        _, predicted = torch.max(output.view(-1, opts.ntoken).data, 1)
        correct += ((predicted == target_seqs.view(-1)) * (target_seqs.view(-1) != token2id['<pad>'])).sum().item()
        
        optimizer.zero_grad()
        
        ce_loss = ce_criterion(output.view(-1, opts.ntoken), target_seqs.view(-1))
        
        ctc_output = ctc_output.permute(1, 0, 2).log_softmax(2)
        
        input_lengths = torch.full(size=(ctc_output.size(1),), fill_value=ctc_output.size(0), dtype=torch.long)
        target_lengths = lens
        
        if USE_CUDA:
            input_lengths = input_lengths.cuda()
            target_lengths = target_lengths.cuda()
        
        ctc_loss = ctc_criterion(ctc_output, target_seqs, input_lengths, target_lengths)
        
        loss = opts.ce_weight*ce_loss + opts.ctc_weight*ctc_loss
        
        celoss_tracker.append(ce_loss.item()*batch_size)
        ctcloss_tracker.append(ctc_loss.item()*batch_size)
        loss_tracker.append(loss.item()*batch_size)
        
        pbar_test.update(1)
        
    time_tracker.append(time.time())
    log_msg = "{} | Epoch {:d}/{:d} | Mean CE / CTC / ALL Loss {:5.2f} / {:5.2f} / {:5.2f} | acc {:5.5f} % | time cost {:d} s"\
            .format('test '.upper(), epoch, opts.num_epochs, \
                    np.mean(celoss_tracker), np.mean(ctcloss_tracker), np.mean(loss_tracker), \
                    float(correct)/float(total)*100, int(time_tracker[-1] - time_tracker[-2]))
    print(log_msg)
    log2file(str(experiment_trainlog), log_msg)
    
    checkpoint = {
        "net": model.state_dict(),
        'optimizer':optimizer.state_dict(),
        "epoch": epoch
    }
    
    torch.save(checkpoint, experiment_dir / 'epoch_{}.ckpt'.format(epoch))
    
    log_msg = '='*50
    print(log_msg)
    log2file(str(experiment_trainlog), log_msg)
    
    