In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn

import os
import time
import json
import munch

import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns

plt.style.use('seaborn')

plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus']=False

### 0x01超参数设置

In [None]:
cfg = munch.Munch({
    'eng_vocab_path': './dataset/cmneng/eng_vocab_m2.json',
    'cmn_vocab_path': './dataset/cmneng/cmn_vocab_m2_jb.json',
    'data_path': './dataset/cmneng/data_jb.txt',
    
    # model save
    'enc_model_path': './enc',
    'dec_model_path': './dec',
    
    # network config
    'embed_size': 512,
    'hidden_size': 1024,
    'num_layers': 2,
    'dropout': 0.1,
    'batch_size': 128,
    'lr': 0.001,
    'num_epoches': 20,
    
    # transformer
    'transformer_hidden_size': 512,
    'num_heads': 4,
    
    # for plot
    'all_losslog': dict(),
})

cfg.UNK = 0
cfg.PAD = 1
cfg.START = 2
cfg.END = 3

print(cfg)

#### Utils

In [None]:
import re
from matplotlib_inline import backend_inline


def eng_transform(data):
    data = re.sub(r'([\.\?\!\,\"])', r" \1 ", data)  # 在特殊字符前后加空格
    data = re.sub(r'[^a-zA-Z\.\?\!\,\"\']', r" ", data).strip().lower()
    data = re.sub(r' +', r" ", data)  # 将多个空格替换为单个空格
    return data

def show_heatmaps(matries, xlabel, ylabel, titles=None, suptitle=None, figsize=(2.5, 2.5), cmap='Reds'):
    backend_inline.set_matplotlib_formats('svg')

    num_rows, num_cols = matries.shape[0], matries.shape[1]
    fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize,
                            sharex=True, sharey=True, squeeze=False)
    for i, (row_axes, row_matrices) in enumerate(zip(axes, matries)):
        for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):
            pcm = ax.imshow(matrix, cmap=cmap)
            if i == num_rows - 1:
                ax.set_xlabel(xlabel)
            if j == 0:
                ax.set_ylabel(ylabel)
            if titles:
                ax.set_title(titles[j])
    if suptitle:
        fig.suptitle(suptitle)
    fig.colorbar(pcm, ax=axes, shrink=0.6)
    plt.show()

### 0x02准备数据

In [None]:
eng_vocab = dict()
with open(cfg.eng_vocab_path, 'r', encoding='utf-8') as f:
    eng_vocab_dict = json.load(f)
    for k, v in eng_vocab_dict.items():
        eng_vocab[k] = v

cmn_vocab = dict()
with open(cfg.cmn_vocab_path, 'r', encoding='utf-8') as f:
    cmn_vocab_dict = json.load(f)
    for k, v in cmn_vocab_dict.items():
        cmn_vocab[k] = v

cfg.eng_vocab = eng_vocab
cfg.cmn_vocab = cmn_vocab

cfg.eng_vocab_size = len(eng_vocab['token_to_idx'])
cfg.cmn_vocab_size = len(cmn_vocab['token_to_idx'])

print(list(eng_vocab['token_to_idx'].items())[:10])
print(list(cmn_vocab['token_to_idx'].items())[:10])

print(f'eng_vocab_size: {cfg.eng_vocab_size}, cmn_vocab_size: {cfg.cmn_vocab_size}')

In [None]:
org_data = pd.read_csv(cfg.data_path, sep='\t', header=None, quoting=3)
org_data.sample(frac=1).head()

In [None]:
org_data[0] = org_data[0].apply(lambda x: x.split())
org_data[1] = org_data[1].apply(lambda x: x.split())
org_data.sample(frac=1).head()

In [None]:
# 平均长度
eng_len = org_data[0].apply(lambda x: len(x)).mean()
cmn_len = org_data[1].apply(lambda x: len(x)).mean()
print(f'eng_avg_len: {eng_len}, cmnavg_len: {cmn_len}')

cfg.seq_len = int(20)

#### 将数据转换为向量表示

In [None]:
def line_to_idx(line, lang):
    if len(line) > cfg.seq_len-1:
        line = line[:cfg.seq_len-1]  # 截断
    line = line + ['<END>']
    line = line + ['<PAD>'] * (cfg.seq_len - len(line))
    
    eng_unk = eng_vocab['token_to_idx'].get('<UNK>', 0)
    cmn_unk = cmn_vocab['token_to_idx'].get('<UNK>', 0)
    
    if lang == 'eng':
        return [eng_vocab['token_to_idx'].get(t, eng_unk) for t in line]
    if lang == 'cmn':
        return [cmn_vocab['token_to_idx'].get(t, cmn_unk) for t in line]
    
data = org_data.copy()
data[0] = org_data[0].apply(lambda x: line_to_idx(x, 'eng'))
data[1] = org_data[1].apply(lambda x: line_to_idx(x, 'cmn'))
data.sample(frac=1).head()

In [None]:
# random_shuffle 
np.random.seed(443)
cfg.data = data.sample(frac=1)
print(cfg.data.shape)
cfg.data.head()

#### 自定义Dataset和DataLoader

In [None]:
from torch.utils.data import Dataset, DataLoader


class TextDataset(Dataset):
    def __init__(self, mode="train", valid_ratio=0.1, test_ratio=0.1):
        self.df = cfg.data
        self.train_len = int(len(self.df) * (1 - valid_ratio - test_ratio))
        self.valid_len = int(len(self.df) * valid_ratio)
        self.test_len = int(len(self.df) * test_ratio)
        
        if mode == 'train':
            self.eng_data = torch.tensor(self.df[:self.train_len][0].to_list())
            self.cmn_data = torch.tensor(self.df[:self.train_len][1].to_list())
        elif mode == 'valid':
            self.eng_data = torch.tensor(self.df[self.train_len:self.train_len + self.valid_len][0].to_list())
            self.cmn_data = torch.tensor(self.df[self.train_len:self.train_len + self.valid_len][1].to_list())
        elif mode == 'test':
            self.eng_data = torch.tensor(self.df[self.train_len + self.valid_len:][0].to_list())
            self.cmn_data = torch.tensor(self.df[self.train_len + self.valid_len:][1].to_list())
        
        print(f"Finish loading {mode} data ({len(self.eng_data)} samples)")
        
    def __len__(self):
        return len(self.eng_data)
    
    def __getitem__(self, idx):
        return self.eng_data[idx], self.cmn_data[idx]
    
    
def get_dataloader():
    train_dataset = TextDataset(mode='train')
    valid_dataset = TextDataset(mode='valid')
    test_dataset = TextDataset(mode='test')
    
    train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True)
    valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)
    
    return train_dataloader, valid_dataloader, test_dataloader

cfg.train_dataloader, cfg.valid_dataloader, cfg.test_dataloader = get_dataloader()


def eng_cmn_to_text(tensor_tuple):
    return [cfg.eng_vocab['idx_to_token'][t] for t in tensor_tuple[0]], [cfg.cmn_vocab['idx_to_token'][t] for t in tensor_tuple[1]]

print(f'\nSample data:')
print(cfg.train_dataloader.dataset[0])
print(eng_cmn_to_text(cfg.train_dataloader.dataset[0]))

### 0x03定义Seq2Seq-RNN网络模型

#### Seq2SeqEncoder

In [None]:
class Seq2SeqEncoder(nn.Module):
    ''' 编码器 '''
    def __init__(self,vocab_size, embed_size, hidden_size, num_layers, dropout=0.1, withBidirectional=False):
        super(Seq2SeqEncoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size, hidden_size, num_layers, dropout=dropout, bidirectional=withBidirectional)
        self.withBidirectional = withBidirectional
        if withBidirectional is True:
            self.fc1 = nn.Linear(2*hidden_size, hidden_size)
            self.fc2 = nn.Linear(2*hidden_size, hidden_size)
        
    def forward(self, x):
        '''
        x: (batch_size, seq_len)
        '''
        x = self.embedding(x) # (batch_size, seq_len, embed_size)
        x = x.permute(1, 0, 2) # (seq_len, batch_size, embed_size)
        output, state = self.rnn(x)
        
        if self.withBidirectional is True:
            output = self.fc1(output) # only the attention_decoder use the enc_output
            
            num_layers = state.shape[0]//2
                
            state = torch.concat((state[:num_layers], state[num_layers:]), dim=-1)
            state = self.fc2(state)
            state = torch.tanh(state)
            
        # output: (seq_len, batch_size, hidden_size)
        # state: (num_layers, batch_size, hidden_size)
        return output, state
    
    
encoder_tester = Seq2SeqEncoder(vocab_size=1024, embed_size=128, hidden_size=128, num_layers=2, dropout=0.1, withBidirectional=True)
encoder_tester.eval()
x_test = torch.ones((64, 10), dtype=torch.long)
output, state = encoder_tester(x_test)
output.shape, state.shape
# (seq_len, batch_size, hidden_size), (num_layers, batch_size, hidden_size)
# if withBidirectional: (seq_len, batch_size, hidden_size), (2*num_layers, batch_size, hidden_size)

#### Seq2SeqDecoder

In [None]:
class Seq2SeqDecoder(nn.Module):
    ''' 解码器 '''
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers, dropout=0.1):
        super(Seq2SeqDecoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        # 为了进一步包含经过编码的输入序列的信息， 上下文变量context在所有的时间步与解码器的输入dec_input进行拼接（concatenate）。
        self.rnn = nn.GRU(embed_size+hidden_size, hidden_size, num_layers, dropout=dropout)
        self.fc = nn.Linear(hidden_size, vocab_size)
        
    def init_state(self, enc_outputs, enc_valid_len):
        '''
        enc_outputs: ((seq_len, batch_size, hidden_size), (num_layers, batch_size, hidden_size))
        '''
        return enc_outputs[1]
    
    def forward(self, x, state):
        '''
        x: (batch_size, seq_len)
        state: (num_layers, batch_size, hidden_size)
        '''
        x = self.embedding(x).permute(1, 0, 2) # (seq_len, batch_size, embed_size)
        context = state[-1].repeat(x.shape[0], 1, 1) # (seq_len, batch_size, hidden_size)
        x_context = torch.cat([x, context], dim=-1) # (seq_len, batch_size, embed_size+hidden_size)
        output, state = self.rnn(x_context, state)
        # output: (seq_len, batch_size, hidden_size)
        # state: (num_layers, batch_size, hidden_size)
        output = self.fc(output).permute(1, 0, 2)
        # output: (batch_size, seq_len, vocab_size)
        return output, state
    
decoder_tester = Seq2SeqDecoder(vocab_size=1024, embed_size=128, hidden_size=128, num_layers=2, dropout=0.1)
x_test = torch.ones((64, 10), dtype=torch.long)
x_valid_len = torch.sum(x_test!=cfg.PAD, dim=-1)
state = decoder_tester.init_state(encoder_tester(x_test), x_valid_len)
output, state = decoder_tester(x_test, state)
output.shape, state.shape
# (batch_size, seq_len, vocab_size), (num_layers, batch_size, hidden_size)

### 0x04定义Seq2SeqSolver

#### MaskedSoftmaxCELoss

In [None]:
cfg.PAD = 1

class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):
    def get_mask(self, target, valid_len):
        mask = torch.ones_like(target)
        bool_mask = torch.arange((target.shape[1]), dtype=torch.float32, device=target.device)[None, :] < valid_len[:, None]
        mask[~bool_mask] = 0
        return mask
    
    def forward(self, logits, target, valid_len):
        '''
        logits: (batch_size, seq_len, vocab_size)
        target: (batch_size, seq_len)
        valid_len: (batch_size,)
        '''
        mask = self.get_mask(target, valid_len)
        self.reduction = 'none'
        loss_input = logits.permute(0, 2, 1) # (batch_size, vocab_size, seq_len) [CELoss: input(N,C,d1,d2,...), target(N,d1,d2,...)]
        org_loss = super(MaskedSoftmaxCELoss, self).forward(loss_input, target)
        loss = (org_loss * mask).mean(dim=1)
        return loss
        

loss = MaskedSoftmaxCELoss()
l_input = torch.ones((4, 10, 1024))
l_target = torch.ones((4, 10), dtype=torch.long)
l_valid_len = torch.tensor([10, 7, 5, 3])
print(l_valid_len)

mask = loss.get_mask(l_target, l_valid_len)
print(mask)

loss(l_input, l_target, l_valid_len)

#### BLEU
$$
\exp \left(\min \left(0,1-\frac{\operatorname{len}_{\text {label }}}{\operatorname{len}_{\text {pred }}}\right)\right) \prod_{n=1}^{k} p_{n}^{1 / 2^{n}}
$$

In [None]:
import math
import collections

def bleu(orig_seq, pred_seq, k):
    orig_tokens, pred_tokens = orig_seq.split(' '), pred_seq.split(' ')
    while orig_tokens[-1]=='<PAD>' or orig_tokens[-1]=='<END>':
        orig_tokens.pop()
    while pred_tokens[-1]=='<PAD>' or pred_tokens[-1]=='<END>':
        pred_tokens.pop()
    len_orig, len_pred = len(orig_tokens), len(pred_tokens)
    
    score = math.exp(min(0, 1 - len_orig / len_pred)) # 惩罚项
    for n in range(1, k + 1):
        match_cnt, orig_subs = 0, collections.defaultdict(int)
        for i in range(len_orig - n + 1):
            orig_subs[' '.join(orig_tokens[i: i + n])] += 1
        for i in range(len_pred - n + 1):
            pred_sent = ' '.join(pred_tokens[i: i + n])
            if orig_subs[pred_sent] > 0:
                match_cnt += 1
                orig_subs[pred_sent] -= 1
        if len_orig - n + 1 <= 0:
            return 0
        else:
            score *= math.pow(match_cnt / (len_orig - n + 1), math.pow(0.5, n))

    return score

#### Seq2SeqSolver

In [None]:
cfg.PAD = 1
cfg.START = 2
cfg.END = 3

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Seq2SeqSolver(nn.Module):
    '''
    统一的Seq2Seq框架
    '''
    def __init__(self, encoder, decoder, with_attn=False, is_transformer=False, save_path=''):
        super(Seq2SeqSolver, self).__init__()
        self.encoder, self.decoder = encoder.to(device), decoder.to(device)
        self.seq_len = cfg.seq_len
        
        self.loss = MaskedSoftmaxCELoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=cfg.lr)
        
        self.is_transformer = is_transformer
        self.with_attn = with_attn
        self.attn_weights = None
        self.losslog = None
        self.save_path = save_path
        
    def grad_clipping(self, theta):
        '''梯度裁剪'''
        enc_params = [p for p in self.encoder.parameters() if p.requires_grad and p.grad is not None]
        dec_params = [p for p in self.decoder.parameters() if p.requires_grad and p.grad is not None]
        params = enc_params + dec_params     
        norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))
        if norm > theta:
            for param in params:
                param.grad[:] *= theta / norm
        
    def train(self, mini_train=False):
        self.encoder.train()
        self.decoder.train()
        self.losslog = []
        
        num_epoches = cfg.num_epoches if not mini_train else 1
        mini_cnt = 0 
        for epoch in range(num_epoches):
            loss_sum = 0
            num_tokens_sum = 0
            
            for i, (x, y) in enumerate(cfg.train_dataloader):
                x, y = x.to(device), y.to(device)
                x_valid_len = torch.sum(x!=cfg.PAD, dim=-1)
                y_valid_len = torch.sum(y!=cfg.PAD, dim=-1)
                
                self.optimizer.zero_grad()

                ### enc-dec part
                if self.is_transformer is True:
                    state = self.encoder(x, x_valid_len)
                else:
                    state = self.encoder(x)
                init_state = self.decoder.init_state(state, x_valid_len)
                # teacher forcing
                start = torch.tensor([cfg.START]*y.shape[0], device=y.device).reshape(-1,1)
                dec_input = torch.cat([start, y[:,:-1]], dim=1)
                dec_outputs, _ = self.decoder(dec_input, init_state)
                ### enc-dec part
                
                loss = self.loss(dec_outputs, y, y_valid_len)
                loss.sum().backward()
                self.grad_clipping(1)   # 梯度裁剪
                self.optimizer.step()
                
                num_tokens = y_valid_len.sum()
                
                with torch.no_grad():
                    loss_sum += loss.sum().item()
                    num_tokens_sum += num_tokens

                # for test only
                if mini_train is True:
                    if mini_cnt < 10:
                        mini_cnt += 1
                        print(f'epoch {epoch+1}, batch {i+1}, loss {loss_sum/num_tokens_sum:.4f}')
                    else:
                        break

            print(f'epoch: {epoch}, loss: {loss_sum/num_tokens_sum}')
            self.losslog.append(loss_sum/num_tokens_sum)
        self.losslog = np.array([x.cpu().numpy() for x in self.losslog])
        torch.save(self.encoder.state_dict(), cfg.enc_model_path + '-' + self.save_path + '.pth')
        torch.save(self.decoder.state_dict(), cfg.dec_model_path + '-' + self.save_path + '.pth')
        
                    
    def predict(self, en_input, enc_valid_lens):
        self.encoder.eval()
        self.decoder.eval()
        
        outputs, attn_weights = [], []
        if self.is_transformer is True:
            attn_weights = [[], [], []]
            state = self.encoder(en_input, enc_valid_lens)
            attn_weights[0] = self.encoder.attn_weights
            attn_weights[0] = torch.cat(attn_weights[0], dim=0).reshape(cfg.num_layers, cfg.num_heads, -1, cfg.seq_len)
            # attn_weights[0]: (num_layers, num_heads, seq_len, seq_len)
        else:
            state = self.encoder(en_input)
        hidden = self.decoder.init_state(state, enc_valid_lens)
        batch_size = en_input.shape[0]
        dec_input = torch.tensor([cfg.START]*batch_size, device=en_input.device).reshape(-1,1)
        
        pred_len = 0
        for _ in range(cfg.seq_len):
            # dec_input: (batch_size, 1)
            dec_input, hidden = self.decoder(dec_input, hidden)
            # dec_input: (batch_size, 1, vocab_size)
            dec_input = torch.argmax(dec_input, dim=2)
            # dec_input: (batch_size, 1)
            outputs.append(dec_input)
            if self.with_attn == True:
                # decoder.attn_weights: (batch_size, 1, num_k)
                attn_weights.append(self.decoder.attn_weights)
            if self.is_transformer is True:
                attn_weights[1].append(self.decoder.attn_weights[0])
                attn_weights[2].append(self.decoder.attn_weights[1])
            pred_len += 1
            if dec_input == cfg.END:
                break
        
        if self.with_attn == True:
            attn_weights = torch.concat(attn_weights, dim=1)
            # attn_weights: (batch_size, pred_len, num_k)
            # print(f'attn_weights: {attn_weights.shape}')

        if self.is_transformer is True:
            attn_weights[1] = [head[-1].tolist() 
                                for step in attn_weights[1]
                                for block in step
                                for head in block]
            attn_weights[1] = torch.tensor(pd.DataFrame(attn_weights[1]).fillna(0.0).values).reshape(-1, cfg.num_layers, cfg.num_heads, pred_len)
            attn_weights[1] = attn_weights[1].permute(1, 2, 0, 3)
            # attn_weights[1]: (num_layers, num_heads, pred_len, pred_len)
            attn_weights[2] = [head[-1].tolist() 
                                for step in attn_weights[2]
                                for block in step
                                for head in block]
            attn_weights[2] = torch.tensor(pd.DataFrame(attn_weights[2]).fillna(0.0).values).reshape(-1, cfg.num_layers, cfg.num_heads, cfg.seq_len)
            attn_weights[2] = attn_weights[2].permute(1, 2, 0, 3)
            # attn_weights[2]: (num_layers, num_heads, pred_len, num_k)
            # print(f'attn_weights[0]: {attn_weights[0].shape}, attn_weights[1]: {attn_weights[1].shape}, attn_weights[2]: {attn_weights[2].shape}')

        outputs = torch.concat(outputs, dim=1)
        return outputs, attn_weights
        
        
    def test(self, show_case=3):
        self.encoder.load_state_dict(torch.load(cfg.enc_model_path + '-' + self.save_path + '.pth'))
        self.decoder.load_state_dict(torch.load(cfg.dec_model_path + '-' + self.save_path + '.pth'))
        self.encoder.eval()
        self.decoder.eval()
        
        avg_bleu = [0, 0, 0, 0]
        num_case = 0
        
        for i, (x, y) in enumerate(cfg.test_dataloader):
            if show_case != 0 and i == show_case:
                break
            
            x, y = x.to(device), y.to(device)
            x_valid_len = torch.sum(x!=cfg.PAD, dim=-1)
            output, attn_weights = self.predict(x, x_valid_len)
            
            y_valid_len = torch.sum(y!=cfg.PAD, dim=-1)
            x, y, output = x[0], y[0], output[0]
            x, y = x[:x_valid_len], y[:y_valid_len]
            
            pred_len = len(output)

            origin_x = ' '.join([cfg.eng_vocab['idx_to_token'][_] for _ in x])
            origin_y = ''.join([cfg.cmn_vocab['idx_to_token'][_] for _ in y])
            pred = ''.join([cfg.cmn_vocab['idx_to_token'][_] for _ in output])
    
            origin_y = (' '.join(origin_y)).replace('< U N K >', '<UNK>').replace('< E N D >', '<END>')
            pred = ' '.join(pred).replace('< U N K >', '<UNK>').replace('< E N D >', '<END>')
    
            bleu_score = [0, 0, 0, 0]
            for i in range(4):
                bleu_score[i] = bleu(pred,origin_y,k=i+1)
                avg_bleu[i] += bleu_score[i]
            num_case += 1
            
            if num_case <= show_case:
                print(f"origin [eng]: {origin_x}")
                print(f"origin [cmn]: {origin_y}") 
                print(f"predict[cmn]: {pred}")
                print(f"BLEU: (1){bleu_score[0]} (2){bleu_score[1]} (3){bleu_score[2]} (4){bleu_score[3]}\n")
                
                if self.with_attn is True:
                    attn_weights = attn_weights.cpu().detach().numpy()
                    attn_df = pd.DataFrame(attn_weights[0][:, :x_valid_len[0]])
                    ax = sns.heatmap(attn_df, annot=True,  cmap="YlGnBu")
                    plt.show()
                    
                if self.is_transformer is True:
                    attn_weights = [attn.cpu().detach().numpy() for attn in attn_weights]

                    attn_enc = attn_weights[0][:, :, :x_valid_len[0], :x_valid_len[0]]
                    show_heatmaps(attn_enc, xlabel="Key positions", ylabel="Query positions", 
                                titles=['Head %d' % i for i in range(1,5)], suptitle='Transformer Encoder Attention Weights',
                                figsize=(10,5), cmap="YlGnBu")
                    attn_dec_self = attn_weights[1][:, :, :pred_len, :pred_len]
                    show_heatmaps(attn_dec_self, xlabel="Key positions", ylabel="Query positions",
                                titles=['Head %d' % i for i in range(1,5)], suptitle='Transformer Decoder Self Attention Weights',
                                figsize=(10,5), cmap="YlGnBu")
                    attn_dec_inter = attn_weights[2][:, :, :pred_len, :x_valid_len[0]]
                    show_heatmaps(attn_dec_inter, xlabel="Key positions", ylabel="Query positions",
                                titles=['Head %d' % i for i in range(1,5)], suptitle='Transformer Decoder Inter Attention Weights',
                                figsize=(10,5), cmap="YlGnBu")
        
        for i in range(4): 
            avg_bleu[i] /= num_case
        print(f'Test for {num_case} cases, BLEU: (1){avg_bleu[0]} (2){avg_bleu[1]} (3){avg_bleu[2]} (4){avg_bleu[3]}')
        
        return avg_bleu
    

    def online_predict(self, en_input, with_attn=False, is_transformer=False):
        self.encoder.load_state_dict(torch.load(cfg.enc_model_path + '-' + self.save_path + '.pth'))
        self.decoder.load_state_dict(torch.load(cfg.dec_model_path + '-' + self.save_path + '.pth'))
        self.encoder.eval()
        self.decoder.eval()
        
        # process input
        en_input = eng_transform(en_input)

        print(f'Input: {en_input}')

        en_input = line_to_idx(en_input.split(), lang='eng')
        en_input = torch.tensor(en_input, dtype=torch.long).reshape(1, -1).to(device)
        input_len = torch.sum(en_input!=cfg.PAD, dim=-1)

        output, attn_weights = self.predict(en_input, input_len)
        
        output = output[0]
        pred_len = len(output)
        pred = ''.join([cfg.cmn_vocab['idx_to_token'][_] for _ in output])
        pred = ' '.join(pred).replace('< U N K >', '<UNK>').replace('< E N D >', '<END>')
        print(f'Pred: {pred}')

        if with_attn is True:
            attn_weights = attn_weights.cpu().detach().numpy()
            attn_df = pd.DataFrame(attn_weights[0][:, :input_len])
            ax = sns.heatmap(attn_df, annot=True,  cmap="YlGnBu")
            plt.show()

        if is_transformer is True:
            attn_weights = [attn.cpu().detach().numpy() for attn in attn_weights]

            attn_enc = attn_weights[0][:, :, :input_len, :input_len]
            show_heatmaps(attn_enc, xlabel="Key positions", ylabel="Query positions", 
                        titles=['Head %d' % i for i in range(1,5)], suptitle='Transformer Encoder Attention Weights',
                        figsize=(10,5), cmap="YlGnBu")
            attn_dec_self = attn_weights[1][:, :, :pred_len, :pred_len]
            show_heatmaps(attn_dec_self, xlabel="Key positions", ylabel="Query positions",
                        titles=['Head %d' % i for i in range(1,5)], suptitle='Transformer Decoder Self Attention Weights',
                        figsize=(10,5), cmap="YlGnBu")
            attn_dec_inter = attn_weights[2][:, :, :pred_len, :input_len]
            show_heatmaps(attn_dec_inter, xlabel="Key positions", ylabel="Query positions",
                        titles=['Head %d' % i for i in range(1,5)], suptitle='Transformer Decoder Inter Attention Weights',
                        figsize=(10,5), cmap="YlGnBu")                       

#### Training test

In [None]:
# %%time
# baseline_encoder = Seq2SeqEncoder(
#     vocab_size=cfg.eng_vocab_size,
#     embed_size=cfg.embed_size,
#     hidden_size=cfg.hidden_size,
#     num_layers=cfg.num_layers,
#     dropout=cfg.dropout,
#     withBidirectional=False,
# )
# baseline_decoder = Seq2SeqDecoder(
#     vocab_size=cfg.cmn_vocab_size,
#     embed_size=cfg.embed_size,
#     hidden_size=H_size,
#     num_layers=cfg.num_layers,
#     dropout=cfg.dropout,
# )
# solver_baseline = Seq2SeqSolver(baseline_encoder, baseline_decoder, save_path='baseline')
# solver_baseline.train(mini_train=True)

In [None]:
# %%time
# bidRNN_encoder = Seq2SeqEncoder(
#     vocab_size=cfg.eng_vocab_size,
#     embed_size=cfg.embed_size,
#     hidden_size=H_size,
#     num_layers=cfg.num_layers,
#     dropout=cfg.dropout,
#     withBidirectional=True,
# )
# bidRNN_decoder = Seq2SeqDecoder(
#     vocab_size=cfg.cmn_vocab_size,
#     embed_size=cfg.embed_size,
#     hidden_size=cfg.hidden_size,
#     num_layers=cfg.num_layers,
#     dropout=cfg.dropout,
# )
# solver_bidRNN = Seq2SeqSolver(bidRNN_encoder, bidRNN_decoder, save_path='bidRNN')
# solver_bidRNN.train(mini_train=True)

### 0x05加入Attention机制 - AdditiveAttention

#### AttnMaskedSoftmax

In [None]:
def sequence_mask(target, valid_len, value=0):
    '''
    target: (batch_size, seq_len)
    valid_len: (batch_size,)
    '''
    bool_mask = torch.arange((target.shape[1]), dtype=torch.float32, device=target.device)[None, :] < valid_len[:, None]
    target[~bool_mask] = value
    return target

target = torch.randint(1, 10, (4, 10))
valid_len = torch.tensor([10, 7, 5, 3])
mask_target = sequence_mask(target, valid_len)
mask_target

In [None]:

def attn_masked_softmax(X, valid_len):
    """
    对num_k这一维作mask 【num_k在实际运行中等于seq_len】
    X: (batch_size, num_q, num_k)
    valid_len: (batch_size,) or (num_heads, num_q)
    """
    if valid_len is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_len.dim() == 1:
            valid_len = torch.repeat_interleave(valid_len, shape[1])
        else:
            valid_len = valid_len.reshape(-1)
        # 对被掩蔽的元素使用极小负值替换，从而其softmax输出为0
        X = sequence_mask(X.reshape(-1, shape[-1]), valid_len, value=-1e10)
        return nn.functional.softmax(X.reshape(shape), dim=-1)
    
x = torch.randn((2,4,10))
valid_len = torch.tensor([6,3])
attn_weight = attn_masked_softmax(x, valid_len)
print(attn_weight)

#### AdditiveAttention
$$
a(\mathbf{q}, \mathbf{k})=\mathbf{W}_{v} \tanh \left(\mathbf{W}_{q} \mathbf{q}+\mathbf{W}_{k} \mathbf{k}\right) \\
Attention(\mathbf{q}) = \sum_{i=1}^m softmax(a(\mathbf{q}, \mathbf{k_i})) \mathbf{v_i}
$$

In [None]:

class AdditiveAttention(nn.Module):
    ''' 加性注意力 '''
    def __init__(self, k_size, q_size, hidden_size, dropout=0.1):
        super(AdditiveAttention, self).__init__()
        self.w_q = nn.Linear(q_size, hidden_size, bias=False)
        self.w_k = nn.Linear(k_size, hidden_size, bias=False)
        self.w_v = nn.Linear(hidden_size, 1, bias=False)
        self.dropout = nn.Dropout(dropout)
        
        self.attn_weight = None
        
    def forward(self, Q, K, V, valid_len):
        '''
        Q: (batch_size, num_q, q_size)
        K: (batch_size, num_k, k_size)
        V: (batch_size, num_k, hidden_size)
        '''
        Q, K = self.w_q(Q), self.w_k(K)
        # 扩展维度, 并使用广播方式求和
        # Q: (batch_size, num_q, 1, hidden_size)
        # K: (batch_size, 1, num_k, hidden_size)
        features = Q.unsqueeze(2) + K.unsqueeze(1)
        features = torch.tanh(features)
        # features: (batch_size, num_q, num_k, hidden_size)
        scores = self.w_v(features).squeeze(-1)
        # scores: (batch_size, num_q, num_k)
        attn_weight = attn_masked_softmax(scores, valid_len)
        # attn_weight: (batch_size, num_q, num_k)
        output = torch.bmm(self.dropout(attn_weight), V)
        # output: (batch_size, num_q, hidden_size)
        self.attn_weight = attn_weight
        return output
    

test_attn = AdditiveAttention(k_size=128, q_size=128, hidden_size=128)
Q = torch.ones((64, 1, 128))
K = torch.ones((64, 10, 128))
V = torch.ones((64, 10, 128))
valid_len = torch.randint(2,10,(64,))
output = test_attn(Q,K,V,valid_len)
output.shape

#### Seq2SeqAttentionDecoder

In [None]:

class Seq2SeqAttentionDecoder(nn.Module):
    ''' Attention-解码器 '''
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers, dropout=0.1):
        super(Seq2SeqAttentionDecoder, self).__init__()
        # ***
        self.attention = AdditiveAttention(k_size=hidden_size, q_size=hidden_size, hidden_size=hidden_size, dropout=dropout)
        # ***
        self.embedding = nn.Embedding(vocab_size, embed_size)
        # 加入注意力机制的上下文变量attn_context在所有的时间步与解码器的输入dec_input进行拼接（concatenate）。
        self.rnn = nn.GRU(embed_size+hidden_size, hidden_size, num_layers, dropout=dropout)
        self.fc = nn.Linear(hidden_size, vocab_size)
        
        self.attn_weights = None
        
    def init_state(self, enc_outputs, enc_valid_len):
        enc_output, hidden = enc_outputs
        # enc_output: (seq_len, batch_size, hidden_size)
        # hidden: (num_layers, batch_size, hidden_size)
        enc_output = enc_output.permute(1, 0, 2) 
        # enc_output: (batch_size, seq_len, hidden_size)
        return (enc_output, hidden, enc_valid_len)
    
    def forward(self, x, state):
        '''
        x: (batch_size, seq_len)
        state: tuple_like(init_state())
        '''
        enc_output, hidden, enc_valid_len = state
        K, V = enc_output, enc_output
        x = self.embedding(x).permute(1, 0, 2) # (seq_len, batch_size, embed_size)
        seq_len = x.shape[0]
        
        outputs, attn_weights = [], []
        
        for i in range(seq_len):
            Q = hidden[-1]                  # (batch_size, hidden_size)
            Q = torch.unsqueeze(Q, dim=1)  # (batch_size, 1, hidden_size)
            attn_context = self.attention(Q, K, V, enc_valid_len)   # (batch_size, 1, hidden_size)
            x_context = torch.cat((torch.unsqueeze(x[i], dim=1), attn_context), dim=-1).permute(1, 0, 2)
            # x_context: (1, batch_size, embed_size+hidden_size)
            dec_output, hidden = self.rnn(x_context, hidden)
            # dec_output: (1, batch_size, hidden_size)
            # hidden: (num_layers, batch_size, hidden_size)
            outputs.append(dec_output)
            attn_weights.append(self.attention.attn_weight)
            # attn_weights: (batch_size, 1, num_k)
        
        self.attn_weights = torch.concat(attn_weights, dim=1)
        # self.attn_weights: (batch_size, x_seq_len, num_k)
        outputs = torch.cat(outputs, dim=0)
        outputs = self.fc(outputs).permute(1, 0, 2)
        # outputs: (batch_size, seq_len, vocab_size)
        state = (enc_output, hidden, enc_valid_len)
        return outputs, state
        
    
attn_decoder_tester = Seq2SeqAttentionDecoder(vocab_size=1024, embed_size=128, hidden_size=128, num_layers=2, dropout=0.1)
x_test = torch.randint(1, 100, (64, 10))
x_valid_len = torch.randint(2, 10, (64, ))

enc_outputs = encoder_tester(x_test)
state = attn_decoder_tester.init_state(enc_outputs, x_valid_len)

output, state = attn_decoder_tester(x_test, state)
output.shape, len(state), state[0].shape, state[1].shape, state[2].shape
# (batch_size, seq_len, vocab_size), 3, (batch_size, seq_len, hidden_size), (num_layers, batch_size, hidden_size), (batch_size, )

# show attn_weights
attn_weights = attn_decoder_tester.attn_weights.detach().numpy()
print(attn_weights[0].shape)

attn_df = pd.DataFrame(attn_weights[0][:, :x_valid_len[0]])

ax = sns.heatmap(attn_df, annot=True,  cmap="YlGnBu")
plt.show()

#### Training test

In [None]:
# %%time
# attention_encoder = Seq2SeqEncoder(
#     vocab_size=cfg.eng_vocab_size,
#     embed_size=cfg.embed_size,
#     hidden_size=cfg.hidden_size,
#     num_layers=cfg.num_layers,
#     dropout=cfg.dropout,
#     withBidirectional=False,
# )
# attention_decoder = Seq2SeqAttentionDecoder(
#     vocab_size=cfg.cmn_vocab_size,
#     embed_size=cfg.embed_size,
#     hidden_size=cfg.hidden_size,
#     num_layers=cfg.num_layers,
#     dropout=cfg.dropout,
# )
# solver_attention = Seq2SeqSolver(attention_encoder, attention_decoder, with_attn=True, save_path='attention')
# solver_attention.train(mini_train=True)

### 0x06Transformer

#### Scaled Dot-Product Attention
$$
Attention(Q,K,V) = \text{dropout}\bigg(\operatorname{softmax}\left(\frac{\mathbf{Q K}^{\top}}{\sqrt{d}}\right)\bigg) \mathbf{V}
$$

In [None]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, dropout) -> None:
        super(ScaledDotProductAttention, self).__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, Q, K, V, valid_lens=None):
        '''
        Q = (batch_size, num_query, query_size[d])
        K = (batch_size, num_key, key_size[d])
        V = (batch_size, num_key, value_size[v])
        valid_lens = None or (batch_size, ) or (batch_size, num_query)
        '''
        d = Q.shape[-1]
        scores = torch.bmm(Q, K.transpose(-2, -1)) / np.sqrt(d)
        # scores = (batch_size, num_query, num_key)
        self.attn_weights = attn_masked_softmax(scores, valid_lens)
        # self.attn_weights = (batch_size, num_query, num_key)
        out = torch.bmm(self.dropout(self.attn_weights), V)
        # out = (batch_size, num_query, value_size[v])
        return out

queries = torch.normal(0, 1, (2, 1, 2))
test_attn = ScaledDotProductAttention(dropout=0.5)
Q = torch.ones((64, 10, 128))
K = torch.ones((64, 10, 128))
V = torch.ones((64, 10, 128))
valid_len = torch.randint(2,10,(64,))
test_attn.eval()
output = test_attn(Q,K,V,valid_len)
output.shape

#### Multi-Head Attention
$$
Y_i = \text{dropout}\bigg(\text{softmax}\bigg(\frac{(XQ_i)(XK_i)^\top}{\sqrt{d/h}}\bigg)\bigg)(XV_i) \\
Y = [Y_1;\dots;Y_h]W_o
$$

In [None]:

def transpose_QKV(x, num_heads):
    '''
    x = (batch_size, num_query or num_key, hidden_size)
    '''
    x = x.reshape(x.shape[0], x.shape[1], num_heads, -1)
    x = x.permute(0, 2, 1, 3)
    # x = (batch_size, num_heads, num_query or num_key, hidden_size/num_heads)
    out = x.reshape(-1, x.shape[2], x.shape[3])
    # out = (batch_size * num_heads, num_query or num_key, hidden_size/num_heads)
    return out

def transpose_QKV_inv(x, num_heads):
    '''
    x = (batch_size * num_heads, num_query or num_key, hidden_size/num_heads)
    '''
    x = x.reshape(-1, num_heads, x.shape[1], x.shape[2])
    x = x.permute(0, 2, 1, 3)
    # x = (batch_size, num_query or num_key, num_heads, hidden_size/num_heads)
    out = x.reshape(x.shape[0], x.shape[1], -1)
    # out = (batch_size, num_query or num_key, hidden_size)
    return out


class MultiHeadAttention(nn.Module):
    def __init__(self, query_size, key_size, value_size, hidden_size, num_heads, dropout, bias=False) -> None:
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.attention = ScaledDotProductAttention(dropout=dropout)
        # Projects the query, key and value.
        self.W_q = nn.Linear(query_size, hidden_size, bias=bias)
        self.W_k = nn.Linear(key_size, hidden_size, bias=bias)
        self.W_v = nn.Linear(value_size, hidden_size, bias=bias)
        self.W_o = nn.Linear(hidden_size, hidden_size, bias=bias)

    def forward(self, Q, K, V, valid_lens=None):
        '''
        Q = (batch_size, num_query, query_size[d])
        K = (batch_size, num_key, key_size[d])
        V = (batch_size, num_key, value_size[v])
        valid_lens = None or (batch_size, ) or (batch_size, num_query)
        '''
        Q = transpose_QKV(self.W_q(Q), self.num_heads)
        K = transpose_QKV(self.W_k(K), self.num_heads)
        V = transpose_QKV(self.W_v(V), self.num_heads)
        
        if valid_lens is not None:
            if valid_lens.dim() == 1:
                valid_lens = valid_lens.repeat(self.num_heads, 1).reshape(-1)
            else:
                valid_lens = valid_lens.repeat(self.num_heads, 1)

        # Multi-Head Attention
        out = self.attention(Q, K, V, valid_lens)
        # Concat
        out_concat = transpose_QKV_inv(out, self.num_heads)
        # output
        out = self.W_o(out_concat)
        # out = (batch_size, num_query, hidden_size)
        return out

hidden_size, num_heads = 128, 4
test_attn = MultiHeadAttention(query_size=hidden_size, key_size=hidden_size, value_size=hidden_size, hidden_size=hidden_size, num_heads=num_heads, dropout=0.5)
test_attn.eval()
Q = torch.ones((64, 1, hidden_size))
K = torch.ones((64, 10, hidden_size))
V = torch.ones((64, 10, hidden_size))
valid_len = torch.randint(2,10,(64,))
test_attn(Q, K, V, valid_len).shape


#### PositionWiseFFN
$$
\operatorname{FFN}(x)=\max \left(0, x W_{1}+b_{1}\right) W_{2}+b_{2}
$$

In [None]:
class PositionWiseFFN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size) -> None:
        super(PositionWiseFFN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

#### AddNorm

In [None]:
class AddNorm(nn.Module):
    '''
    Residual connection + Layer Normalization
    '''
    def __init__(self, normalized_shape, dropout) -> None:
        super(AddNorm, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalized_shape=normalized_shape)

    def forward(self, x, y):
        '''
        x.shape = y.shape = (batch_size, seq_len, hidden_size)
        '''
        return self.ln(self.dropout(y) + x)

test_addnorm = AddNorm(normalized_shape=128, dropout=0.5)
test_addnorm.eval()
x = torch.ones((64, 1, 128))
y = torch.ones((64, 1, 128))
test_addnorm(x, y).shape

#### PositionalEncoding
$$
\begin{aligned}
P E_{(p o s, 2 i)} &=\sin \left(p o s / 10000^{2 i / d_{\text {model }}}\right) \\
P E_{(p o s, 2 i+1)} &=\cos \left(p o s / 10000^{2 i / d_{\text {model }}}\right)
\end{aligned}
$$

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, dropout=0.1, max_len=100) -> None:
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        assert embed_dim % 2 == 0, "Embedding dimension must be even."
        pe = torch.zeros(1, max_len, embed_dim)
        val1 = [pos / (10000 ** (i / embed_dim)) for pos in range(max_len) for i in range(0, embed_dim, 2)]
        val2 = [pos / (10000 ** ( (i - 1) / embed_dim)) for pos in range(max_len) for i in range(1, embed_dim, 2)]
        pe[0, :, 0::2] = torch.sin(torch.tensor(val1)).reshape(max_len, -1)
        pe[0, :, 1::2] = torch.cos(torch.tensor(val2)).reshape(max_len, -1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        '''
        x = (batch_size, seq_len, embed_dim)
        '''
        x = x + self.pe[:, :x.shape[1], :x.shape[2]]
        return self.dropout(x)

torch.manual_seed(231)
x = torch.randn(1, 2, 6)

test_pe = PositionalEncoding(embed_dim=6, dropout=0.1)
# test_pe.eval() # eval() will disable dropout
output = test_pe(x)
print(output)

expected_pe_output = np.asarray([[[-1.2340,  1.1127,  1.6978, -0.0865, -0.0000,  1.2728],
                                  [ 0.9028, -0.4781,  0.5535,  0.8133,  1.2644,  1.7034]]])

def rel_error(x, y):
    """ returns relative error """
    return np.max(np.abs(x - y) / (np.maximum(1e-8, np.abs(x) + np.abs(y))))

print('pe_output error: ', rel_error(expected_pe_output, output.detach().numpy()))


#### Encoder

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, Q_size, K_size, V_size, hidden_size, norm_shape,
                ffn_input_size, ffn_hidden_size, num_heads, dropout, use_bias=False) -> None:
        super(EncoderBlock, self).__init__()
        self.attention = MultiHeadAttention(query_size=Q_size, key_size=K_size, value_size=V_size, hidden_size=hidden_size, num_heads=num_heads, dropout=dropout, bias=use_bias)
        self.addnorm1 = AddNorm(normalized_shape=norm_shape, dropout=dropout)
        self.ffn = PositionWiseFFN(input_size=ffn_input_size, hidden_size=ffn_hidden_size, output_size=hidden_size)
        self.addnorm2 = AddNorm(normalized_shape=norm_shape, dropout=dropout)

    def forward(self, x, valid_lens):
        '''
        x = (batch_size, seq_len, hidden_size)
        valid_lens = (batch_size, )
        '''
        # self Multi-Head Attention + AddNorm
        out = self.addnorm1(x, self.attention(x, x, x, valid_lens))
        # Position-wise Feed-Forward Network + AddNorm
        out = self.addnorm2(out, self.ffn(out))
        return out

hidden_size, num_heads = 1024, 2
x = torch.ones((64, 10, hidden_size))
test_enc_block = \
    EncoderBlock(Q_size=hidden_size, K_size=hidden_size, V_size=hidden_size, hidden_size=hidden_size, norm_shape=[10, hidden_size],
                 ffn_input_size=hidden_size, ffn_hidden_size=hidden_size*2, num_heads=num_heads, dropout=0.5)
test_enc_block.eval()
valid_len = torch.randint(2,10,(64,))
test_enc_block(x, valid_len).shape
# 输入输出维度不变

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, num_layers, Q_size, K_size, V_size, hidden_size, norm_shape,
                ffn_input_size, ffn_hidden_size, num_heads, dropout, use_bias=False) -> None:
        super(TransformerEncoder, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.pos_encoding = PositionalEncoding(hidden_size, dropout)
        self.encoder_blocks = nn.Sequential()
        self.num_layers = num_layers
        for i in range(num_layers):
            self.encoder_blocks.add_module(
                f'encoder_block_{i}', 
                EncoderBlock(Q_size=Q_size, K_size=K_size, V_size=V_size, hidden_size=hidden_size, 
                            norm_shape=norm_shape, ffn_input_size=ffn_input_size, ffn_hidden_size=ffn_hidden_size,
                            num_heads=num_heads, dropout=dropout, use_bias=use_bias))

    def forward(self, x, valid_lens):
        '''
        x = (batch_size, seq_len)
        valid_lens = (batch_size, )
        '''
        # scale embedding
        x = self.embedding(x) * math.sqrt(self.hidden_size)
        # Positional Encoding
        x = self.pos_encoding(x)
        self.attn_weights = [None] * self.num_layers
        # encoder block
        for i, block in enumerate(self.encoder_blocks):
            x = block(x, valid_lens)
            self.attn_weights[i] = block.attention.attention.attn_weights

        return x

vocab_size, num_layers = 200, 2
hidden_size = 128
x = torch.randint(0, vocab_size, (64, 10))
print(x.shape)
test_enc = TransformerEncoder(vocab_size=vocab_size, num_layers=num_layers, Q_size=hidden_size, K_size=hidden_size, V_size=hidden_size, hidden_size=hidden_size, norm_shape=[10, hidden_size],
                            ffn_input_size=hidden_size, ffn_hidden_size=hidden_size*2, num_heads=8, dropout=0.5)
test_enc.eval()
valid_len = torch.randint(2,10,(64,))
test_enc(x, valid_len).shape

#### Decoder

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, Q_size, K_size, V_size, hidden_size, norm_shape,
                ffn_input_size, ffn_hidden_size, num_heads, dropout, block_idx) -> None:
        super(DecoderBlock, self).__init__()
        self.block_idx = block_idx
        self.attention1 = MultiHeadAttention(query_size=Q_size, key_size=K_size, value_size=V_size, hidden_size=hidden_size, num_heads=num_heads, dropout=dropout)
        self.addnorm1 = AddNorm(normalized_shape=norm_shape, dropout=dropout)
        self.attention2 = MultiHeadAttention(query_size=Q_size, key_size=K_size, value_size=V_size, hidden_size=hidden_size, num_heads=num_heads, dropout=dropout)
        self.addnorm2 = AddNorm(normalized_shape=norm_shape, dropout=dropout)
        self.ffn = PositionWiseFFN(input_size=ffn_input_size, hidden_size=ffn_hidden_size, output_size=hidden_size)
        self.addnorm3 = AddNorm(normalized_shape=norm_shape, dropout=dropout)


    def forward(self, x, state):
        '''
        x = (batch_size, seq_len, hidden_size)
        state = tuple( (batch_size, seq_len, hidden_size), (batch_size, ) )
        '''
        if self.training:
            dec_valid_lens = torch.arange(1, x.shape[1]+1, device=x.device).repeat(x.shape[0], 1)
        else:
            dec_valid_lens = None

        # self Multi-Head Attention + add & norm
        out = self.addnorm1(x, self.attention1(x, x, x, dec_valid_lens))
        # encoder-decoder Multi-Head Attention + add & norm
        enc_out, enc_valid_lens = state[0], state[1]
        out = self.addnorm2(out, self.attention2(out, enc_out, enc_out, enc_valid_lens))
        # Position-wise Feed-Forward Network + add & norm
        out = self.addnorm3(out, self.ffn(out))
        return out, state


hidden_size, num_heads = 128, 2
x = torch.ones((64, 10, hidden_size))
test_enc_block = \
    EncoderBlock(Q_size=hidden_size, K_size=hidden_size, V_size=hidden_size, hidden_size=hidden_size, norm_shape=[10, hidden_size],
                 ffn_input_size=hidden_size, ffn_hidden_size=hidden_size*2, num_heads=num_heads, dropout=0.5)
test_dec_block = \
    DecoderBlock(Q_size=hidden_size, K_size=hidden_size, V_size=hidden_size, hidden_size=hidden_size, norm_shape=[10, hidden_size],
                 ffn_input_size=hidden_size, ffn_hidden_size=hidden_size*2, num_heads=num_heads, dropout=0.5, block_idx=0)
# test_dec_block.eval() # self.training = False

valid_len = torch.randint(2,10,(64,))
state = (test_enc_block(x, valid_len), valid_len)
output = test_dec_block(x, state)
output[0].shape

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, num_layers, Q_size, K_size, V_size, hidden_size, norm_shape,
                ffn_input_size, ffn_hidden_size, num_heads, dropout) -> None:
        super(TransformerDecoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.pos_encoding = PositionalEncoding(hidden_size, dropout)
        self.decoder_blocks = nn.Sequential()
        for i in range(num_layers):
            self.decoder_blocks.add_module(
                f'decoder_block_{i}', 
                DecoderBlock(Q_size=Q_size, K_size=K_size, V_size=V_size, hidden_size=hidden_size, norm_shape=norm_shape,
                            ffn_input_size=ffn_input_size, ffn_hidden_size=ffn_hidden_size, num_heads=num_heads, dropout=dropout, block_idx=i))
        self.fc = nn.Linear(hidden_size, vocab_size)
    
    def init_state(self, enc_out, enc_valid_lens):
        '''
        enc_out = (batch_size, seq_len, hidden_size)
        enc_valid_lens = (batch_size, )
        '''
        self.preX = None
        return enc_out, enc_valid_lens

    def forward(self, x, state):
        '''
        x = (batch_size, seq_len)
        '''
        if not self.training:
            self.preX = x if self.preX is None else torch.cat((self.preX, x), dim=1)
            x = self.preX    
        
        # scale embedding
        x = self.embedding(x) * math.sqrt(self.hidden_size)
        # add position encoding
        x = self.pos_encoding(x)
        self.attn_weights = [[None] * self.num_layers for _ in range(2)]
        # decoder blocks
        for i, dec_block in enumerate(self.decoder_blocks):
            x, state = dec_block(x, state)
            # enc attention weights
            self.attn_weights[0][i] = dec_block.attention1.attention.attn_weights
            # enc-dec attention weights
            self.attn_weights[1][i] = dec_block.attention2.attention.attn_weights
        if not self.training:
            x = x[:, -1:, :]
        
        return self.fc(x), state
        

#### Training test

In [None]:
# %%time

# H_size = cfg.transformer_hidden_size

# transformer_encoder = \
#     TransformerEncoder(vocab_size=cfg.eng_vocab_size, num_layers=cfg.num_layers, 
#                         Q_size=H_size, K_size=H_size, V_size=H_size, hidden_size=H_size, 
#                         norm_shape=[H_size], ffn_input_size=H_size, ffn_hidden_size=H_size*2, 
#                         num_heads=cfg.num_heads, dropout=0.1)
# transformer_decoder = \
#     TransformerDecoder(vocab_size=cfg.cmn_vocab_size, num_layers=cfg.num_layers, 
#                         Q_size=H_size, K_size=H_size, V_size=H_size, hidden_size=H_size, 
#                         norm_shape=[H_size], ffn_input_size=H_size, ffn_hidden_size=H_size*2, 
#                         num_heads=cfg.num_heads, dropout=0.1)

# solver_transformer = Seq2SeqSolver(
#     transformer_encoder, transformer_decoder, 
#     with_attn=True, is_transformer=True, save_path='transformer')
# solver_transformer.train(mini_train=True)

### 0x07训练

In [None]:
%%time
st = time.time()
baseline_encoder = Seq2SeqEncoder(
    vocab_size=cfg.eng_vocab_size,
    embed_size=cfg.embed_size,
    hidden_size=cfg.hidden_size,
    num_layers=cfg.num_layers,
    dropout=cfg.dropout,
    withBidirectional=False,
)
baseline_decoder = Seq2SeqDecoder(
    vocab_size=cfg.cmn_vocab_size,
    embed_size=cfg.embed_size,
    hidden_size=cfg.hidden_size,
    num_layers=cfg.num_layers,
    dropout=cfg.dropout,
)
solver_baseline = Seq2SeqSolver(baseline_encoder, baseline_decoder, save_path='baseline')
solver_baseline.train()
time_baseline = time.time() - st

In [None]:
%%time
st = time.time()
bidRNN_encoder = Seq2SeqEncoder(
    vocab_size=cfg.eng_vocab_size,
    embed_size=cfg.embed_size,
    hidden_size=cfg.hidden_size,
    num_layers=cfg.num_layers,
    dropout=cfg.dropout,
    withBidirectional=True,
)
bidRNN_decoder = Seq2SeqDecoder(
    vocab_size=cfg.cmn_vocab_size,
    embed_size=cfg.embed_size,
    hidden_size=cfg.hidden_size,
    num_layers=cfg.num_layers,
    dropout=cfg.dropout,
)
solver_bidRNN = Seq2SeqSolver(bidRNN_encoder, bidRNN_decoder, save_path='bidRNN')
solver_bidRNN.train()
time_bidRNN = time.time() - st

In [None]:
%%time
st = time.time()
attention_encoder = Seq2SeqEncoder(
    vocab_size=cfg.eng_vocab_size,
    embed_size=cfg.embed_size,
    hidden_size=cfg.hidden_size,
    num_layers=cfg.num_layers,
    dropout=cfg.dropout,
    withBidirectional=False,
)
attention_decoder = Seq2SeqAttentionDecoder(
    vocab_size=cfg.cmn_vocab_size,
    embed_size=cfg.embed_size,
    hidden_size=cfg.hidden_size,
    num_layers=cfg.num_layers,
    dropout=cfg.dropout,
)
solver_attention = Seq2SeqSolver(attention_encoder, attention_decoder, with_attn=True, save_path='attention')
solver_attention.train()
time_attention = time.time() - st

In [None]:
%%time
st = time.time()
bidRNN_attention_encoder = Seq2SeqEncoder(
    vocab_size=cfg.eng_vocab_size,
    embed_size=cfg.embed_size,
    hidden_size=cfg.hidden_size,
    num_layers=cfg.num_layers,
    dropout=cfg.dropout,
    withBidirectional=True,
)
bidRNN_attention_decoder = Seq2SeqAttentionDecoder(
    vocab_size=cfg.cmn_vocab_size,
    embed_size=cfg.embed_size,
    hidden_size=cfg.hidden_size,
    num_layers=cfg.num_layers,
    dropout=cfg.dropout,
)
solver_bidRNN_attention = Seq2SeqSolver(bidRNN_attention_encoder, bidRNN_attention_decoder, with_attn=True, save_path='bidRNN_attention')
solver_bidRNN_attention.train()
time_bidRNN_attention = time.time() - st

In [None]:
%%time
st = time.time()

H_size = cfg.transformer_hidden_size

transformer_encoder = \
    TransformerEncoder(vocab_size=cfg.eng_vocab_size, num_layers=cfg.num_layers, 
                        Q_size=H_size, K_size=H_size, V_size=H_size, hidden_size=H_size, 
                        norm_shape=[H_size], ffn_input_size=H_size, ffn_hidden_size=H_size*2, 
                        num_heads=cfg.num_heads, dropout=0.1)
transformer_decoder = \
    TransformerDecoder(vocab_size=cfg.cmn_vocab_size, num_layers=cfg.num_layers, 
                        Q_size=H_size, K_size=H_size, V_size=H_size, hidden_size=H_size, 
                        norm_shape=[H_size], ffn_input_size=H_size, ffn_hidden_size=H_size*2, 
                        num_heads=cfg.num_heads, dropout=0.1)

solver_transformer = Seq2SeqSolver(
    transformer_encoder, transformer_decoder, 
    is_transformer=True, save_path='transformer')
solver_transformer.train()
time_transformer = time.time() - st

In [None]:
cfg.all_losslog['baseline'] = solver_baseline.losslog
cfg.all_losslog['bidRNN'] = solver_bidRNN.losslog
cfg.all_losslog['attention'] = solver_attention.losslog
cfg.all_losslog['bidRNN_attention'] = solver_bidRNN_attention.losslog
cfg.all_losslog['transformer'] = solver_transformer.losslog

plt.xlabel('epoch')
plt.ylabel('loss')
plt.plot(np.arange(cfg.num_epoches), cfg.all_losslog['baseline'], label='baseline', linewidth =2.0, color='black')
plt.plot(np.arange(cfg.num_epoches), cfg.all_losslog['bidRNN'], label='bidRNN', linewidth =2.0, color='blue')
plt.plot(np.arange(cfg.num_epoches), cfg.all_losslog['attention'], label='attention', linewidth =2.0, color='orange')
plt.plot(np.arange(cfg.num_epoches), cfg.all_losslog['bidRNN_attention'], label='bidRNN_attention', linewidth =2.0, color='red')
plt.plot(np.arange(cfg.num_epoches), cfg.all_losslog['transformer'], label='transformer', linewidth =2.0, color='green')
plt.legend()
plt.show()

### 0x08测试

In [None]:
baseline_bleu = solver_baseline.test(show_case=0)
bidRNN_bleu = solver_bidRNN.test(show_case=0)
attention_bleu = solver_attention.test(show_case=0)
bidRNN_attention_bleu = solver_bidRNN_attention.test(show_case=0)
transformer_bleu = solver_transformer.test(show_case=0)

In [None]:
all_metric = [baseline_bleu+[time_baseline], bidRNN_bleu+[time_bidRNN], 
            attention_bleu+[time_attention], bidRNN_attention_bleu+[time_bidRNN_attention], 
            transformer_bleu+[time_transformer]]
all_metric = np.array(all_metric)
cols = ["1-BLEU", "2-BLEU", "3-BLEU", "4-BLEU", "Training Time (seconds)"]
rows = ['baseline', 'bidRNN', 'attention', 'bidRNN_attention', 'transformer']
bleu_results = pd.DataFrame(data=all_metric, columns=cols, index=rows)
bleu_results.round(2)

In [None]:
test_bleu = solver_bidRNN_attention.test(show_case=3)

In [None]:
test_bleu = solver_transformer.test(show_case=3)

In [None]:
solver_bidRNN_attention.online_predict("this is an interesting test .", with_attn=True)

In [None]:
solver_transformer.online_predict("this is an interesting test .", is_transformer=True)