In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

import numpy as np

from pathlib import Path
from kaldiio import ReadHelper
import json
import os 
import time
from tqdm import tqdm_notebook, notebook
import copy

In [None]:
espnet_path = Path('espnet')

In [None]:
egs_path = Path('egs/aishell/asr1')

In [None]:
dump_path = 'dump'
train_path = 'train_sp'
dev_path = 'dev'
test_path = 'test'

In [None]:
vocab_file = espnet_path / egs_path / 'data/lang_1char/train_sp_units.txt'
vocab = {}
with open(vocab_file) as fp:
    for line in fp:
        word, idx = line.strip().split()
        vocab[word] = idx
        
print(len(vocab))

In [None]:
new_id = 4
token2id = {'<blank>':0, '<pad>':1, '<sos>':2, '<eos>':3}
for word, _ in vocab.items():
    if word not in token2id:
        token2id[word] = new_id
        new_id += 1

id2token = {v:k for k, v in token2id.items()}

In [None]:
token2id

In [None]:
# train_data = {}
dev_data = {}
test_data = {}

for (set_path, dataset) in zip([dev_path, test_path], [dev_data, test_data]):
    set_dump_path = espnet_path / egs_path / dump_path / set_path / 'deltafalse' 
    with open(set_dump_path / 'data.json') as fp:
        json_data = json.load(fp)

    feats = {}
    pbar = notebook.tqdm(total=len(list(set_dump_path.glob('feats.*.ark'))))
    for feats_file in set_dump_path.glob('feats.*.ark'):
        with ReadHelper('ark:'+str(feats_file)) as reader:
            for key, numpy_array in reader:
                feats[key] = torch.from_numpy(numpy_array)
          
        pbar.update(1)
        
    
    for key, value in json_data['utts'].items():
        if key in feats:
            feature = feats[key]
            text = json_data['utts'][key]['output'][0]['text']
            token = []
            token_id = []
            for char in text:
                if char == ' ':
                    token.append('<space>')
                else:
                    if char in vocab:
                        token.append(char)
                    else:
                        token.append('<unk>')
            dataset[key] = {'input':feature,
                           'text':text,
                           'token':' '.join(token),}
        

    
print(len(dev_data), len(test_data))

In [None]:
class AttrDict(dict):
    """ Access dictionary keys like attribute 
        https://stackoverflow.com/questions/4984647/accessing-dict-keys-like-an-attribute
    """
    def __init__(self, *av, **kav):
        dict.__init__(self, *av, **kav)
        self.__dict__ = self

opts = AttrDict()

# Configure models

opts.ntoken = len(token2id)
opts.feature = 83
opts.ninp = 256
opts.nhead = 4
opts.nhid = 2048
opts.nlayers_enc = 12
opts.nlayers_dec = 6

opts.ce_weight = 0.7
opts.ctc_weight = 0.3


opts.emb_size = 650
opts.hidden_size = 650
opts.num_layers = 2
opts.tie_weight = True


opts.beam_size = 10
opts.lm_weight = 0.4
opts.temperature = 1.0




# Configure optimization
opts.learning_rate = 5e-5

opts.dropout_rate = 0.1

opts.batch_size = 64
opts.num_workers = int(opts.batch_size / 8) if int(opts.batch_size / 8) < 16 else 16
print(opts.num_workers)
# Configure training
opts.max_seq_len = 512
opts.num_epochs = 300
# opts.warmup_steps = 4000
# opts.gradient_accumulation = 20

# opts.load_pretrain = True

In [None]:
class Dataset:
    
    def __init__(self, dataset):
        
        self.names = []
        self.features = []
        self.texts = []
        self.tokens = []
        
        for name, data in dataset.items():
            feature = data['input']
            text = data['text']
            token = data['token'].split(' ')
            token = ['<sos>'] + token + ['<eos>']
            
            self.names.append(name)
            self.features.append(feature)
            self.texts.append(text)
            self.tokens.append(token)
            
    def __len__(self):
        
        return len(self.names)
    
    def __getitem__(self, index):
        name = self.names[index]
        feature = self.features[index]
        text = self.texts[index]
        token = self.tokens[index]
        token_id = self.tokens2ids(token, token2id)
        
        return name, feature, text, token, token_id
    
    def tokens2ids(self, tokens, token2id):
        token_id = [token2id[token] if token in token2id else token2id['<unk>'] for token in tokens]
    
        return token_id
    
    
def collate_fn(data):
    
    def _pad_sequences(seqs):
        lens = [len(seq)-1 for seq in seqs]
        input_seqs = torch.zeros(len(seqs), max(lens)).long().fill_(token2id['<pad>'])
        target_seqs = torch.zeros(len(seqs), max(lens)).long().fill_(token2id['<pad>'])
#         input_seqs_mask = input_seqs.float().masked_fill(input_seqs.float()==0, float('-inf'))
#         input_seqs_mask = input_seqs_mask.masked_fill(input_seqs_mask!=float('-inf'), 0)

        for i, seq in enumerate(seqs):
            input_seqs[i, :len(seq)-1] = torch.LongTensor(seq[:-1])
            target_seqs[i, :len(seq)-1] = torch.LongTensor(seq[1:])
            
        input_seqs_mask = ~input_seqs.bool()
        
        return input_seqs, input_seqs_mask, target_seqs, lens
    
    def _pad_features(features):
        flens = [len(feature) for feature in features]
        input_features = torch.zeros(len(features), max(flens), opts.feature)
        for i, feature in enumerate(features):
            input_features[i, :len(feature)] = feature
            
        return input_features
    
    def _generate_square_subsequent_mask(sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
        
    
    name, feature, text, token, token_id = zip(*data)
    
    input_seqs, input_seqs_pad_mask, target_seqs, lens = _pad_sequences(token_id)
    
    input_features = _pad_features(feature)
    
    input_seqs_mask = _generate_square_subsequent_mask(input_seqs.size(1))
    input_seqs_mask = input_seqs_mask.repeat(input_seqs.size(0), 1, 1) #為了給dataparallel切 要給他一個batch維
    
    lens = torch.LongTensor(lens)
    
    return name, input_features, text, token, input_seqs, input_seqs_mask, input_seqs_pad_mask, target_seqs, lens

In [None]:
# train_dataset = Dataset(train_data)
dev_dataset = Dataset(dev_data)
test_dataset = Dataset(test_data)


In [None]:
print(dev_dataset[0])

In [None]:
np.random.seed(20200908)
torch.manual_seed(20200908)
torch.cuda.manual_seed_all(20200908)

# train_iter = DataLoader(dataset=train_dataset,
#                         batch_size=opts.batch_size,
#                         shuffle=True,
#                         num_workers=opts.num_workers,
#                         collate_fn=collate_fn)

dev_iter = DataLoader(dataset=dev_dataset,
                        batch_size=1,
                        shuffle=False,
                        num_workers=1,
                        collate_fn=collate_fn)

test_iter = DataLoader(dataset=test_dataset,
                        batch_size=1,
                        shuffle=False,
                        num_workers=1,
                        collate_fn=collate_fn)

In [None]:
class RNNLM(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_size, num_layers, tie_weight=False, dropout=0.2):
        super(RNNLM, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, emb_size)

        self.lstm = nn.LSTM(emb_size, 
                            hidden_size, 
                            num_layers=num_layers,
                            batch_first=True, 
                            bidirectional=False, 
                            dropout=0.1)

        self.h2w = nn.Linear(emb_size, vocab_size, bias=True)
        
        self.tanh = nn.Tanh()
        
        if tie_weight:
            self.h2w.weight = self.embedding.weight
            
        
    def forward(self, input_seqs, seqs_len):     
        
        batch_size = input_seqs.size(0)
        
        emb = self.embedding(input_seqs)
        
        packed_output = pack_padded_sequence(emb, seqs_len, batch_first=True, enforce_sorted=False)

        packed_output, c = self.lstm(packed_output)

        hidden_outputs, output_lengths = pad_packed_sequence(packed_output, batch_first=True)

        hidden_outputs = self.tanh(hidden_outputs)

        outputs = self.h2w(hidden_outputs)
            
        return outputs

In [None]:
import copy
from typing import Optional, Any

from torch import Tensor
from torch.nn import functional as F
from torch.nn.modules.module import Module
from torch.nn.modules.activation import MultiheadAttention
from torch.nn.modules.container import ModuleList
from torch.nn.init import xavier_uniform_
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.linear import Linear
from torch.nn.modules.normalization import LayerNorm


class TransformerEncoderLayer(Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)
    """

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = Linear(d_model, dim_feedforward)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model)

        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super(TransformerEncoderLayer, self).__setstate__(state)

    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        
        src2 = self.norm1(src)
        src2 = self.self_attn(src2, src2, src2, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        
        src2 = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
        src = src + self.dropout2(src2)

        return src


class TransformerDecoderLayer(Module):
    r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
    This standard decoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    Examples::
        >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
        >>> memory = torch.rand(10, 32, 512)
        >>> tgt = torch.rand(20, 32, 512)
        >>> out = decoder_layer(tgt, memory)
    """

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
        super(TransformerDecoderLayer, self).__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = Linear(d_model, dim_feedforward)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model)

        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.norm3 = LayerNorm(d_model)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)
        self.dropout3 = Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super(TransformerDecoderLayer, self).__setstate__(state)

    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the inputs (and mask) through the decoder layer.

        Args:
            tgt: the sequence to the decoder layer (required).
            memory: the sequence from the last layer of the encoder (required).
            tgt_mask: the mask for the tgt sequence (optional).
            memory_mask: the mask for the memory sequence (optional).
            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
            memory_key_padding_mask: the mask for the memory keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        tgt2 = self.norm1(tgt)        
        tgt2 = self.self_attn(tgt2, tgt2, tgt2, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        
        tgt2 = self.norm2(tgt)
        tgt2 = self.multihead_attn(tgt2, memory, memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout2(tgt2)
        
        tgt2 = self.norm3(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
        tgt = tgt + self.dropout3(tgt2)
        
        return tgt

In [None]:
class Conv2dSubsampling(torch.nn.Module):
    """Convolutional 2D subsampling (to 1/4 length).
    :param int idim: input dim
    :param int odim: output dim
    :param flaot dropout_rate: dropout rate
    """

    def __init__(self, idim, odim, dropout_rate=0.5):
        """Construct an Conv2dSubsampling object."""
        super(Conv2dSubsampling, self).__init__()
        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(1, odim, 3, 2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(odim, odim, 3, 2),
            torch.nn.ReLU(),
        )
        self.out = torch.nn.Sequential(
            torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim),
#             PositionalEncoding(odim, dropout_rate),
        )

    def forward(self, x, x_mask):
        """Subsample x.
        :param torch.Tensor x: input tensor
        :param torch.Tensor x_mask: input mask
        :return: subsampled x and mask
        :rtype Tuple[torch.Tensor, torch.Tensor]
        """
        x = x.unsqueeze(1)  # (b, c, t, f)
        x = self.conv(x)
        b, c, t, f = x.size()
        x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
        if x_mask is None:
            return x, None
        return x, x_mask[:, :, :-2:2][:, :, :-2:2]

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


class EncoderModel(nn.Module):

    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, idim, dropout=0.5):
        super(EncoderModel, self).__init__()
        from torch.nn import TransformerEncoder
        self.model_type = 'Transformer'
        self.src_mask = None
        
        self.embedding = Conv2dSubsampling(idim, ninp)
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.layer_norm = LayerNorm(ninp)
        
        self.ninp = ninp
#         self.decoder = nn.Linear(ninp, ntoken)

#         self.init_weights()

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
#         self.decoder.bias.data.zero_()
#         self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src):
        
        # src [b, t, f]        
        
        src, self.src_mask = self.embedding(src, self.src_mask)

        src = src.permute(1, 0, 2) # [t, b, f]
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src=src, 
                                          mask=self.src_mask,
                                         )
#         output = self.decoder(output)
        output = self.layer_norm(output)
    
        return output

class DecoderModel(nn.Module):
    
    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(DecoderModel, self).__init__()
#         from torch.nn import TransformerDecoder, TransformerDecoderLayer
        self.model_type = 'Transformer'
        self.tgt_mask = None
        self.memory_mask = None
        
        self.embedding = nn.Embedding(ntoken, ninp)
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        decoder_layers = TransformerDecoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_decoder = TransformerDecoder(decoder_layers, nlayers)
        self.layer_norm = LayerNorm(ninp)
        
        self.ninp = ninp
        
        self.init_weights()
        
    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
    
    def init_weights(self):
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        
    def forward(self, tgt, tgt_mask, tgt_key_padding_mask, memory):
        
        # tgt [b, t], tgt_key_padding_mask [b, t], memory [t, b, f]
        
        tgt = tgt.permute(1, 0) # [t, b]
        
#         if self.tgt_mask is None or self.tgt_mask.size(0) != len(tgt):
#             device = tgt.device
#             mask = self._generate_square_subsequent_mask(len(tgt)).to(device)
#             self.tgt_mask = mask

        # The reason we increase the embedding values before the addition 
        # is to make the positional encoding relatively smaller. 
        # This means the original meaning 
        # in the embedding vector won’t be lost when we add them together.
        # maybe use learned position embedding will not need to do this?  
        
        tgt = self.embedding(tgt) * math.sqrt(self.ninp)
        tgt = self.pos_encoder(tgt)
        
        tgt_mask = tgt_mask[0] ##為了dataparallel給的batch維把它去掉
        
        output = self.transformer_decoder(tgt, memory, 
                                          tgt_mask=tgt_mask, 
                                          tgt_key_padding_mask=tgt_key_padding_mask, 
                                          memory_mask=self.memory_mask, 
                                          memory_key_padding_mask=None,
                                          )
        
        output = self.layer_norm(output)
        
        return output
        
class Model(nn.Module):
    
    def __init__(self, ntoken, ninp, nhead, nhid, nlayers_enc, nlayers_dec, idim, dropout=0.5):
        super(Model, self).__init__()
        
        self.encoder = EncoderModel(ntoken, ninp, nhead, nhid, nlayers_enc, idim, dropout)
        self.ctc_classifier = nn.Linear(ninp, ntoken)
        
        self.decoder = DecoderModel(ntoken, ninp, nhead, nhid, nlayers_dec, dropout)
        self.classifier = nn.Linear(ninp, ntoken)
        
    def forward(self, src, tgt, tgt_mask, tgt_key_padding_mask):
        
        memory = self.encoder(src)
        
        ctc_output = self.ctc_classifier(memory.permute(1, 0, 2))
        
        decoder_output = self.decoder(tgt, tgt_mask, tgt_key_padding_mask, memory)
        
        output = self.classifier(decoder_output.permute(1, 0, 2))

        # return 一定要batch first 不然dataparallel會concat錯維
        return ctc_output, output, memory.permute(1, 0, 2), decoder_output.permute(1, 0, 2)
#         return memory, ctc_output, decoder_output


class Test_Model(nn.Module):
    
    def __init__(self, ntoken, ninp, nhead, nhid, nlayers_enc, nlayers_dec, idim, dropout=0.5):
        super(Test_Model, self).__init__()
        
        self.encoder = EncoderModel(ntoken, ninp, nhead, nhid, nlayers_enc, idim, dropout)
        self.ctc_classifier = nn.Linear(ninp, ntoken)
        
        self.decoder = DecoderModel(ntoken, ninp, nhead, nhid, nlayers_dec, dropout)
        self.classifier = nn.Linear(ninp, ntoken)
        
        self.encoder.eval()
        self.ctc_classifier.eval()
        self.decoder.eval()
        self.classifier.eval()
        
    def forward(self, src, len_limit):
        
        device = src.device
        memory = self.encoder(src)
        
        nbest = []
        beams = []
        beams.append([[token2id['<sos>']], 0])

        for _ in range(2*len_limit):

            results = []

            for beam in beams:

                input_idxs = beam[0]

                input_seqs = torch.LongTensor([input_idxs])
        #         input_seqs = input_seqs.unsqueeze(0)
                input_seqs = input_seqs.to(device)

                decoder_output = self.decoder(input_seqs, None, None, memory)
                output = self.classifier(decoder_output.permute(1, 0, 2)[:, -1])
                
                output[:, token2id['<blank>']] = float('-inf')
                output[:, token2id['<pad>']] = float('-inf')
                output = output.log_softmax(dim=1)
                
                probs, idxs = output.topk(k=opts.beam_size, dim=1)

                for prob, idx in zip(probs.squeeze(0), idxs.squeeze(0)):

                    generate_idxs = input_idxs + [idx.item()]
                    accumulate_prob = beam[1] + prob.item()

                    results.append([generate_idxs, accumulate_prob])


            results.sort(key=lambda x:x[1])
            results = results[::-1]
            results = results[:opts.beam_size]

            beams = []

            for result in results:
                if result[0][-1] == token2id['<eos>']:
                    nbest.append(result)
                else:
                    beams.append(result)
                    
            if len(beams) == 0:
                break

        return nbest, beams

    
class Test_Model_pal(nn.Module):
    
    def __init__(self, ntoken, ninp, nhead, nhid, nlayers_enc, nlayers_dec, idim, \
                 emb_size, hidden_size, num_layers, tie_weight, dropout=0.5):
        
        super(Test_Model_pal, self).__init__()
        
        self.encoder = EncoderModel(ntoken, ninp, nhead, nhid, nlayers_enc, idim, dropout)
        self.ctc_classifier = nn.Linear(ninp, ntoken)
        
        self.decoder = DecoderModel(ntoken, ninp, nhead, nhid, nlayers_dec, dropout)
        self.classifier = nn.Linear(ninp, ntoken)
        
        self.rnnlm = RNNLM(ntoken, emb_size, hidden_size, num_layers, tie_weight, dropout)
        
        self.encoder.eval()
        self.ctc_classifier.eval()
        self.decoder.eval()
        self.classifier.eval()
        self.rnnlm.eval()
        
    def forward(self, src, len_limit):
        
        device = src.device
        memory = self.encoder(src)
        
        nbest = []
        beams = []
        beams.append([[token2id['<sos>']], 0])

        for _ in range(int(src.size(1) * 0.1)):
            
            results = []
    
            input_inxss = []
            for beam in beams:
                input_inxss.append(beam[0])

            lens = [len(input_idxs) for input_idxs in input_inxss]

            input_seqs = torch.LongTensor(len(input_inxss), max(lens)).fill_(token2id['<pad>']).to(device)
            seqs_len = torch.LongTensor(lens).to(device)

            for i, input_idxs in enumerate(input_inxss):
                input_seqs[i, :len(input_idxs)] = torch.LongTensor(input_idxs)

            input_seqs_mask = input_seqs == token2id['<pad>']

            this_memory = memory.repeat(1, input_seqs.size(0), 1)
            
            input_seqs = input_seqs.to(device)
            input_seqs_mask = input_seqs_mask.to(device)

            decoder_output = self.decoder(input_seqs, None, input_seqs_mask, this_memory)

            output = self.classifier(decoder_output.permute(1, 0, 2)[:, -1, :])
            
            lm_output = self.rnnlm(input_seqs, seqs_len)

            lm_output = lm_output[:, -1, :]
            
            lm_output = lm_output / opts.temperature

            output[:, token2id['<pad>']] = float('-inf')
            output[:, token2id['<blank>']] = float('-inf')
            lm_output[:, token2id['<pad>']] = float('-inf')
            lm_output[:, token2id['<blank>']] = float('-inf')

            output = output.log_softmax(dim=1)
            lm_output = lm_output.log_softmax(dim=1)

            total_output = output + (opts.lm_weight*lm_output)

#             total_output = output

#             total_output = total_output.log()

            probs, idxs = total_output.topk(k=opts.beam_size, dim=1)

            for i, (batch_probs, batch_idxs) in enumerate(zip(probs, idxs)):
                for j, (prob, idx) in enumerate(zip(batch_probs, batch_idxs)):

                    generate_idxs = beams[i][0] + [idx.item()]
                    accumulate_prob = beams[i][1] + prob.item()

                    results.append([generate_idxs, accumulate_prob])

            results.sort(key=lambda x:x[1])
            results = results[::-1]
            results = results[:opts.beam_size]

            beams = []

            for result in results:
                if result[0][-1] == token2id['<eos>']:
                    nbest.append(result)
                else:
                    beams.append(result)
                    
            if len(beams) == 0:
                break

        return nbest, beams

In [None]:
t_model = Test_Model_pal(opts.ntoken, opts.ninp, opts.nhead, opts.nhid, \
                         opts.nlayers_enc, opts.nlayers_dec, opts.feature, \
                        opts.emb_size, opts.hidden_size, opts.num_layers, \
                         opts.tie_weight, opts.dropout_rate)


In [None]:
t_model

In [None]:
e2e_ckpt = torch.load("exp/**myaishell_ce0.7_ctc0.3_2020-09-15 17:15:39/epoch_16.ckpt", map_location='cpu')
e2e_parms = e2e_ckpt['net']

lm_ckpt = torch.load("exp/RNNLM_2layers_650hidden_2020-09-16 08:51:44/epoch_26.ckpt", map_location='cpu')
lm_parms = lm_ckpt['net']


from collections import OrderedDict

    
new_e2e_parms = OrderedDict()
for k, v in e2e_parms.items():
    name = k[7:] # remove `module.`
    new_e2e_parms[name] = v

new_lm_parms = OrderedDict()
for k, v in lm_parms.items():
    name = k[7:] # remove `module.`
    new_lm_parms[name] = v

print(t_model.load_state_dict(new_e2e_parms, strict=False))
print(t_model.rnnlm.load_state_dict(new_lm_parms, strict=False))


In [None]:
class error_stats:
    def __init__(self):
        self.ins_num = 0 
        self.del_num = 0
        self.sub_num = 0
        self.total_cost = 0
        
# ref=['聽', '說', '馬', '上', '就', '要', '放', '假', '了']
# hyp=['你', '聽', '說', '要', '放', '假', '了']

def wer(ref, hyp):
    N = len(ref)
    e = []
    for i in range(len(ref)+1):
        e.append(error_stats())
    cur_e = []
    for i in range(len(ref)+1):
        cur_e.append(error_stats)

    for i in range(len(e)):
        e[i].ins_num = 0
        e[i].sub_num = 0
        e[i].del_num = i
        e[i].total_cost = i

    for hyp_index in range(1, len(hyp)+1):
        cur_e[0] = copy.copy(e[0])

        cur_e[0].ins_num+=1
        cur_e[0].total_cost+=1
        for ref_index in range(1, len(ref)+1):
            ins_err = e[ref_index].total_cost + 1
            #print(cur_e[ref_index-1].total_cost)
            del_err = cur_e[ref_index-1].total_cost + 1
            sub_err = e[ref_index-1].total_cost
            #print(ins_err, del_err, sub_err)
            #print(e[0].total_cost)
            if hyp[hyp_index-1] != ref[ref_index-1]:
                sub_err+=1
            #print(ins_err, del_err, sub_err)
            if sub_err < ins_err and sub_err < del_err:
                cur_e[ref_index] = copy.copy(e[ref_index-1])
                
                if hyp[hyp_index-1] != ref[ref_index-1]:
                    cur_e[ref_index].sub_num+=1
                cur_e[ref_index].total_cost = sub_err
            elif del_err < ins_err:
                cur_e[ref_index] = copy.copy(cur_e[ref_index-1])

                cur_e[ref_index].total_cost = del_err
                cur_e[ref_index].del_num+=1
            else:
                cur_e[ref_index] = copy.copy(e[ref_index])

                cur_e[ref_index].total_cost = ins_err
                cur_e[ref_index].ins_num+=1
        e = cur_e.copy()

    ref_index = len(e)-1
    Ins = e[ref_index].ins_num
    Del = e[ref_index].del_num
    Sub = e[ref_index].sub_num
    Cost = e[ref_index].total_cost
    #print(Ins, Del, Sub, Cost)
    return Ins, Del, Sub, Cost, N

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
USE_CUDA = torch.cuda.is_available()
# USE_CUDA = False

if USE_CUDA:
    t_model.cuda()

In [None]:
pbar = notebook.tqdm(total = len(test_iter))

nbests = []
refs = []


for batch in test_iter:
    
    name, input_features, text, token, input_seqs, input_seqs_mask, input_seqs_pad_mask, target_seqs, lens = batch
    
    t_model.eval()
    
    len_limit = int(input_features.size(1) * 0.1)
    
    if USE_CUDA:
        input_features = input_features.cuda()
    
    
    nbest, beams = t_model(input_features, len_limit)
    
    if len(nbest) == 0:
        nbest = beams
        
    nbest.sort(key=lambda x:x[1])
    nbest = nbest[::-1]
        
    nbests.append(nbest)
    refs.append(token)
    
    pbar.update(1)
    

    

In [None]:
' '.join([id2token[token] for token in nbests[10][0][0]])

In [None]:
print(len(refs), len(nbests))

totalN = 0
totalIns = 0
totalDel = 0
totalSub = 0



for ref, nbest in zip(refs, nbests):
    
    nbest.sort(key=lambda x:x[1])
    nbest = nbest[::-1]
    
    hyp = [id2token[token] for token in nbest[0][0]]
    
    if hyp[0] == '<sos>':
        hyp = hyp[1:]
    if hyp[-1] == '<eos>':
        hyp = hyp[:-1]
        
#     print(' '.join(hyp))

    ref = ref[0]  
    ref = ref[1:-1]
    
    Ins, Del, Sub, Cost, N = wer(ref, hyp)
    totalN += N
    totalIns += Ins
    totalDel += Del
    totalSub += Sub
    
    
print(totalN, totalIns, totalDel, totalSub)

print('wer : {}'.format((totalIns+totalDel+totalSub)/totalN*100))